from collections.abc import Iterable
from typing import Optional
from intelligence_layer.connectors import (
SerializableDict,
StudioClient,
)
from intelligence_layer.connectors.studio.studio import (
StudioDataset,
StudioExample,
)
from intelligence_layer.core import Input
from intelligence_layer.evaluation.dataset.dataset_repository import DatasetRepository
from intelligence_layer.evaluation.dataset.domain import (
Dataset,
Example,
ExpectedOutput,
)
from intelligence_layer.evaluation.dataset.in_memory_dataset_repository import (
InMemoryDatasetRepository,
)
[docs]
class StudioDatasetRepository(DatasetRepository):
"""Dataset repository interface with Data Platform."""
def __init__(self, studio_client: StudioClient) -> None:
"""Initializes the StudioDatasetRepository.
Args:
studio_client: Client to interact with the Studio API.
"""
self.studio_client = studio_client
self._in_memory_dataset_repository = InMemoryDatasetRepository()
[docs]
def create_dataset(
self,
examples: Iterable[Example[Input, ExpectedOutput]],
dataset_name: str,
id: str | None = None,
labels: set[str] | None = None,
metadata: SerializableDict | None = None,
) -> Dataset:
"""Creates a dataset from given :class:`Example`s and returns the ID of that dataset.
Args:
examples: An :class:`Iterable` of :class:`Example`s to be saved in the same dataset.
dataset_name: A name for the dataset.
id: ID is not used in the StudioDatasetRepository as it is generated by the Studio.
labels: A list of labels for filtering. Defaults to an empty list. Defaults to None.
metadata: A dict for additional information about the dataset. Defaults to an empty dict. Defaults to None.
Returns:
:class:`Dataset`
"""
if id is not None:
raise NotImplementedError(
"Custom dataset IDs are not supported by the StudioDataRepository"
)
created_dataset = Dataset(
name=dataset_name,
labels=labels or set(),
metadata=metadata or dict(),
)
studio_dataset = self.map_to_studio_dataset(created_dataset)
studio_examples = self.map_to_many_studio_example(examples)
studio_dataset_id = self.studio_client.submit_dataset(
dataset=studio_dataset, examples=studio_examples
)
created_dataset.id = studio_dataset_id
return created_dataset
[docs]
def delete_dataset(self, dataset_id: str) -> None:
"""Deletes a dataset identified by the given dataset ID.
Args:
dataset_id: Dataset ID of the dataset to delete.
"""
raise NotImplementedError()
[docs]
def dataset(self, dataset_id: str) -> Optional[Dataset]:
"""Returns a dataset identified by the given dataset ID.
Args:
dataset_id: Dataset ID of the dataset to delete.
Returns:
:class:`Dataset` if it was not, `None` otherwise.
"""
raise NotImplementedError()
[docs]
def datasets(self) -> Iterable[Dataset]:
"""Returns all :class:`Dataset`s. Sorting is not guaranteed.
Returns:
:class:`Sequence` of :class:`Dataset`s.
"""
raise NotImplementedError()
[docs]
def dataset_ids(self) -> Iterable[str]:
"""Returns all sorted dataset IDs.
Returns:
:class:`Iterable` of dataset IDs.
"""
raise NotImplementedError()
[docs]
def example(
self,
dataset_id: str,
example_id: str,
input_type: type[Input],
expected_output_type: type[ExpectedOutput],
) -> Optional[Example[Input, ExpectedOutput]]:
"""Returns an :class:`Example` for the given dataset ID and example ID.
Args:
dataset_id: Dataset ID of the linked dataset.
example_id: ID of the example to retrieve.
input_type: Input type of the example.
expected_output_type: Expected output type of the example.
Returns:
:class:`Example` if it was found, `None` otherwise.
"""
raise NotImplementedError()
[docs]
def examples(
self,
dataset_id: str,
input_type: type[Input],
expected_output_type: type[ExpectedOutput],
examples_to_skip: Optional[frozenset[str]] = None,
) -> Iterable[Example[Input, ExpectedOutput]]:
"""Returns all :class:`Example`s for the given dataset ID sorted by their ID.
Args:
dataset_id: Dataset ID whose examples should be retrieved.
input_type: Input type of the example.
expected_output_type: Expected output type of the example.
examples_to_skip: Optional list of example IDs. Those examples will be excluded from the output. Defaults to None.
Returns:
:class:`Iterable` of :class`Example`s.
"""
if self._in_memory_dataset_repository.dataset(dataset_id) is None:
examples = self.map_to_many_example(
self.studio_client.get_dataset_examples(
dataset_id,
input_type=input_type,
expected_output_type=expected_output_type,
)
)
self._in_memory_dataset_repository.create_dataset(
examples, id=dataset_id, dataset_name="in_memory_dataset"
)
return self._in_memory_dataset_repository.examples(
dataset_id, input_type, expected_output_type
)
@staticmethod
def map_to_studio_example(
example_to_map: Example[Input, ExpectedOutput],
) -> StudioExample[Input, ExpectedOutput]:
return StudioExample(**example_to_map.model_dump())
@staticmethod
def map_to_many_studio_example(
examples_to_map: Iterable[Example[Input, ExpectedOutput]],
) -> Iterable[StudioExample[Input, ExpectedOutput]]:
return (
StudioDatasetRepository.map_to_studio_example(example)
for example in examples_to_map
)
@staticmethod
def map_to_studio_dataset(dataset_to_map: Dataset) -> StudioDataset:
return StudioDataset(**dataset_to_map.model_dump())
@staticmethod
def map_to_example(
example_to_map: StudioExample[Input, ExpectedOutput],
) -> Example[Input, ExpectedOutput]:
return Example[Input, ExpectedOutput](
input=example_to_map.input,
expected_output=example_to_map.expected_output,
id=example_to_map.id,
metadata=example_to_map.metadata,
)
@staticmethod
def map_to_many_example(
examples_to_map: Iterable[StudioExample[Input, ExpectedOutput]],
) -> Iterable[Example[Input, ExpectedOutput]]:
return (
StudioDatasetRepository.map_to_example(example)
for example in examples_to_map
)