Source code for intelligence_layer.connectors.argilla.argilla_wrapper_client

import itertools
import logging
import os
from collections.abc import Iterable, Sequence
from typing import (
    Any,
    Optional,
)

import argilla as rg  # type: ignore
from argilla.client.feedback.schemas.types import (  # type: ignore
    AllowedFieldTypes,
    AllowedQuestionTypes,
)

from intelligence_layer.connectors.argilla.argilla_client import (
    ArgillaClient,
    ArgillaEvaluation,
    Record,
    RecordData,
)


[docs] class ArgillaWrapperClient(ArgillaClient): def __init__( self, api_url: Optional[str] = None, api_key: Optional[str] = None, disable_warnings: bool = True, ) -> None: if disable_warnings: import warnings warnings.filterwarnings("ignore", module="argilla.*") # this logger is set on info for some reason logging.getLogger("argilla.client.feedback.dataset.local.mixins").setLevel( logging.WARNING ) rg.init( api_url=api_url if api_url is not None else os.getenv("ARGILLA_API_URL"), api_key=api_key if api_key is not None else os.getenv("ARGILLA_API_KEY"), )
[docs] def create_dataset( self, workspace_id: str, dataset_name: str, fields: Sequence[AllowedFieldTypes], questions: Sequence[AllowedQuestionTypes], ) -> str: """Creates and publishes a new feedback dataset in Argilla. Raises an error if the name exists already. Args: workspace_id: the name of the workspace the feedback dataset should be created in. The user executing this request must have corresponding permissions for this workspace. dataset_name: the name of the feedback-dataset to be created. fields: all fields of this dataset. questions: all questions for this dataset. Returns: The id of the created dataset. """ dataset = rg.FeedbackDataset( fields=fields, questions=questions, allow_extra_metadata=True ) remote_datasets = dataset.push_to_argilla( name=dataset_name, workspace=rg.Workspace.from_name(workspace_id), show_progress=False, ) return str(remote_datasets.id)
[docs] def ensure_dataset_exists( self, workspace_id: str, dataset_name: str, fields: Sequence[AllowedFieldTypes], questions: Sequence[AllowedQuestionTypes], ) -> str: """Retrieves an existing dataset or creates and publishes a new feedback dataset in Argilla. Args: workspace_id: the name of the workspace the feedback dataset should be created in. The user executing this request must have corresponding permissions for this workspace. dataset_name: the name of the feedback-dataset to be created. fields: all fields of this dataset. questions: all questions for this dataset. Returns: The id of the dataset to be retrieved . """ try: return str( rg.FeedbackDataset.from_argilla( name=dataset_name, workspace=rg.Workspace.from_name(workspace_id) ).id ) except ValueError: pass return self.create_dataset(workspace_id, dataset_name, fields, questions)
[docs] def add_record(self, dataset_id: str, record: RecordData) -> None: self.add_records(dataset_id=dataset_id, records=[record])
[docs] def add_records(self, dataset_id: str, records: Sequence[RecordData]) -> None: remote_dataset = self._dataset_from_id(dataset_id=dataset_id) argilla_records = [ rg.FeedbackRecord( fields=record.content, metadata={ **record.metadata, "example_id": record.example_id, }, ) for record in records ] remote_dataset.add_records(argilla_records, show_progress=False)
[docs] def evaluations(self, dataset_id: str) -> Iterable[ArgillaEvaluation]: remote_dataset = self._dataset_from_id(dataset_id=dataset_id) filtered_dataset = remote_dataset.filter_by(response_status="submitted") for record in filtered_dataset.records: submitted_response = next((response for response in record.responses), None) if submitted_response is not None: metadata = record.metadata example_id = metadata.pop("example_id") yield ArgillaEvaluation( example_id=example_id, record_id="ignored", responses={ k: v.value for k, v in submitted_response.values.items() }, metadata=metadata, )
[docs] def split_dataset(self, dataset_id: str, n_splits: int) -> None: """Adds a new metadata property to the dataset and assigns a split to each record. Deletes the property if n_splits is equal to one. Args: dataset_id: the id of the dataset n_splits: the number of splits to create """ remote_dataset = self._dataset_from_id(dataset_id=dataset_id) name = "split" metadata_config = remote_dataset.metadata_property_by_name(name) if n_splits == 1: if metadata_config is None: return remote_dataset.delete_metadata_properties(name) self._delete_metadata_from_records(remote_dataset, name) return if metadata_config is None: config = rg.IntegerMetadataProperty( name=name, visible_for_annotators=True, min=1, max=n_splits ) remote_dataset.add_metadata_property(config) else: metadata_config.max = n_splits remote_dataset.update_metadata_properties(metadata_config) self._update_record_metadata(n_splits, remote_dataset, name)
def _update_record_metadata( self, n_splits: int, remote_dataset: rg.FeedbackDataset, metadata_name: str ) -> None: modified_records = [] for record, split in zip( remote_dataset.records, itertools.cycle(range(1, n_splits + 1)) ): record.metadata[metadata_name] = split modified_records.append(record) remote_dataset.update_records(modified_records, show_progress=False) def _delete_metadata_from_records( self, remote_dataset: rg.FeedbackDataset, metadata_name: str ) -> None: modified_records = [] for record in remote_dataset.records: del record.metadata[metadata_name] modified_records.append(record) remote_dataset.update_records(modified_records, show_progress=False)
[docs] def ensure_workspace_exists(self, workspace_name: str) -> str: """Retrieves the name of an argilla workspace with specified name or creates a new workspace if necessary. Args: workspace_name: the name of the workspace to be retrieved or created. Returns: The name of an argilla workspace with the given `workspace_name`. """ try: workspace = rg.Workspace.from_name(workspace_name) return str(workspace.name) except ValueError: return str(rg.Workspace.create(name=workspace_name).name)
def records(self, dataset_id: str) -> Iterable[Record]: remote_dataset = self._dataset_from_id(dataset_id=dataset_id) return ( Record( id=str(record.id), example_id=record.metadata["example_id"], content=record.fields, metadata=record.metadata, ) for record in remote_dataset.records ) def _create_evaluation(self, record_id: str, data: dict[str, Any]) -> None: api_url = os.environ["ARGILLA_API_URL"] if not api_url.endswith("/"): api_url = api_url + "/" rg.active_client().http_client.post( f"{api_url}api/v1/records/{record_id}/responses", json={ "status": "submitted", "values": { question_name: {"value": response_value} for question_name, response_value in data.items() }, }, ) def _delete_dataset(self, dataset_id: str) -> None: remote_dataset = self._dataset_from_id(dataset_id=dataset_id) remote_dataset.delete() def _delete_workspace(self, workspace_name: str) -> None: workspace = rg.Workspace.from_name(workspace_name) datasets = rg.list_datasets(workspace=workspace.name) for dataset in datasets: dataset.delete() workspace.delete() def _dataset_from_id(self, dataset_id: str) -> rg.FeedbackDataset: return rg.FeedbackDataset.from_argilla(id=dataset_id)