Source code for intelligence_layer.connectors.argilla.argilla_wrapper_client

import itertools
import logging
import os
from collections.abc import Iterable, Sequence
from typing import (
    Any,
    Optional,
)

import argilla as rg  # type: ignore

from intelligence_layer.connectors.argilla.argilla_client import (
    ArgillaClient,
    ArgillaEvaluation,
    Record,
    RecordData,
)


[docs] class ArgillaWrapperClient(ArgillaClient): def __init__( self, api_url: Optional[str] = None, api_key: Optional[str] = None, disable_warnings: bool = True, ) -> None: if disable_warnings: import warnings warnings.filterwarnings("ignore", module="argilla.*") # this logger is set on info for some reason logging.getLogger("argilla.client.feedback.dataset.local.mixins").setLevel( logging.WARNING ) self.client = rg.Argilla( api_url=api_url if api_url is not None else os.getenv("ARGILLA_API_URL"), api_key=api_key if api_key is not None else os.getenv("ARGILLA_API_KEY"), )
[docs] def create_dataset( self, workspace_name: str, dataset_name: str, fields: Sequence[rg.TextField], questions: Sequence[rg.QuestionType], ) -> str: """Creates and publishes a new feedback dataset in Argilla. Raises an error if the name exists already. Args: workspace_name: the name of the workspace the feedback dataset should be created in. The user executing this request must have corresponding permissions for this workspace. dataset_name: the name of the feedback-dataset to be created. fields: all fields of this dataset. questions: all questions for this dataset. Returns: The id of the created dataset. """ settings = rg.Settings( fields=list(fields), questions=list(questions), allow_extra_metadata=True, ) workspace = self.ensure_workspace_exists(workspace_name) dataset = rg.Dataset( name=dataset_name, settings=settings, workspace=workspace, client=self.client, ) dataset.create() return str(dataset.id)
[docs] def ensure_dataset_exists( self, workspace_name: str, dataset_name: str, fields: Sequence[rg.TextField], questions: Sequence[rg.QuestionType], ) -> str: """Retrieves an existing dataset or creates and publishes a new feedback dataset in Argilla. Args: workspace_name: the name of the workspace the feedback dataset should be created in. The user executing this request must have corresponding permissions for this workspace. dataset_name: the name of the feedback-dataset to be created. fields: all fields of this dataset. questions: all questions for this dataset. Returns: The id of the dataset to be retrieved . """ dataset = self.client.datasets(name=dataset_name, workspace=workspace_name) return ( str(dataset.id) if dataset else self.create_dataset(workspace_name, dataset_name, fields, questions) )
[docs] def add_record(self, dataset_id: str, record: RecordData) -> None: self.add_records(dataset_id=dataset_id, records=[record])
[docs] def add_records(self, dataset_id: str, records: Sequence[RecordData]) -> None: remote_dataset = self._dataset_from_id(dataset_id=dataset_id) if remote_dataset is None: raise ValueError argilla_records = [ rg.Record( fields=dict(record.content), metadata={ **record.metadata, "example_id": record.example_id, }, ) for record in records ] remote_dataset.records.log(records=argilla_records)
[docs] def evaluations(self, dataset_id: str) -> Iterable[ArgillaEvaluation]: remote_dataset = self._dataset_from_id(dataset_id=dataset_id) status_filter = rg.Filter([("response.status", "==", "submitted")]) query = rg.Query(filter=status_filter) for record in remote_dataset.records(query=query): metadata = record.metadata example_id = metadata.pop("example_id") yield ArgillaEvaluation( example_id=example_id, record_id="ignored", responses={ response.question_name: response.value for response in record.responses if response is not None }, metadata=metadata, )
[docs] def split_dataset(self, dataset_id: str, n_splits: int) -> None: """Adds a new metadata property to the dataset and assigns a split to each record. Deletes the property if n_splits is equal to one. Args: dataset_id: the id of the dataset n_splits: the number of splits to create """ remote_dataset = self._dataset_from_id(dataset_id=dataset_id) name = "split" metadata_config = remote_dataset.settings.metadata[name] if n_splits == 1: if metadata_config is None: return metadata_config.delete() self._delete_metadata_from_records(remote_dataset, name) return if metadata_config: metadata_config.delete() self._delete_metadata_from_records(remote_dataset, name) self._update_record_metadata(n_splits, remote_dataset, name)
def _update_record_metadata( self, n_splits: int, remote_dataset: rg.Dataset, metadata_name: str ) -> None: config = rg.IntegerMetadataProperty( name=metadata_name, visible_for_annotators=True, min=1, max=n_splits ) remote_dataset.settings.metadata = [config] remote_dataset.update() updated_records = [] for record, split in zip( remote_dataset.records, itertools.cycle(range(1, n_splits + 1)) ): record.metadata[metadata_name] = split updated_records.append(record) remote_dataset.records.log(updated_records) def _delete_metadata_from_records( self, remote_dataset: rg.Dataset, metadata_name: str ) -> None: modified_records = [] for record in remote_dataset.records: del record.metadata[metadata_name] modified_records.append(record) remote_dataset.records.log(modified_records)
[docs] def ensure_workspace_exists(self, workspace_name: str) -> str: """Retrieves the name of an argilla workspace with specified name or creates a new workspace if necessary. Args: workspace_name: the name of the workspace to be retrieved or created. Returns: The name of an argilla workspace with the given `workspace_name`. """ workspace = self.client.workspaces(name=workspace_name) if workspace: return workspace_name workspace = rg.Workspace(name=workspace_name, client=self.client) workspace.create() if not workspace: raise ValueError( f"Workspace with name {workspace_name} could not be created." ) return str(workspace.name)
def records(self, dataset_id: str) -> Iterable[Record]: remote_dataset = self._dataset_from_id(dataset_id=dataset_id) return ( Record( id=str(record.id), example_id=record.metadata["example_id"], content=record.fields, metadata=record.metadata, ) for record in remote_dataset.records ) def _create_evaluation( self, dataset_id: str, record_id: str, data: dict[str, Any] ) -> None: dataset = self._dataset_from_id(dataset_id=dataset_id) if dataset is None: raise ValueError(f"Dataset with id {dataset_id} does not exist.") records = dataset.records user_id = self.client.me.id if user_id is None: raise ValueError("user_id is not a UUID") # argilla currently does not allow to retrieve a record by id # This could be optimized (in a scenario for creating multiple evaluations) by passing a dict of record_ids to the function # and update all the records for the given record id list. for record in records: if record.id == record_id: for question_name, response_value in data.items(): response = rg.Response( question_name=question_name, value=response_value, status="submitted", user_id=user_id, ) record.responses.add(response) dataset.records.log([record]) return def _delete_dataset(self, dataset_id: str) -> None: remote_dataset = self._dataset_from_id(dataset_id=dataset_id) remote_dataset.delete() def _delete_workspace(self, workspace_name: str) -> None: workspace = self.client.workspaces(name=workspace_name) if workspace is None: raise ValueError("workspace does not exist.") for dataset in workspace.datasets: dataset.delete() workspace.delete() def _dataset_from_id(self, dataset_id: str) -> rg.Dataset: dataset = self.client.datasets(id=dataset_id) if not dataset: raise ValueError("Dataset is not existent") # Fetch Settings from Dataset in order to pass questions, necessary since Argilla V2 dataset.settings.get() return dataset