import warnings
from collections import defaultdict
from collections.abc import Iterable, Sequence
from typing import Optional, cast
from intelligence_layer.core import InMemoryTracer, Output, PydanticSerializable
from intelligence_layer.core.tracer.tracer import Tracer
from intelligence_layer.evaluation.run.domain import ExampleOutput, RunOverview
from intelligence_layer.evaluation.run.run_repository import RecoveryData, RunRepository
[docs]
class InMemoryRunRepository(RunRepository):
def __init__(self) -> None:
super().__init__()
self._example_outputs: dict[str, list[ExampleOutput[PydanticSerializable]]] = (
defaultdict(list)
)
self._example_traces: dict[str, Tracer] = dict()
self._run_overviews: dict[str, RunOverview] = dict()
self._recovery_data: dict[str, RecoveryData] = dict()
[docs]
def store_run_overview(self, overview: RunOverview) -> None:
self._run_overviews[overview.id] = overview
if overview.id not in self._example_outputs:
self._example_outputs[overview.id] = []
def _create_temporary_run_data(self, tmp_hash: str, run_id: str) -> None:
self._recovery_data[tmp_hash] = RecoveryData(run_id=run_id)
def _delete_temporary_run_data(self, tmp_hash: str) -> None:
del self._recovery_data[tmp_hash]
def _temp_store_finished_example(self, tmp_hash: str, example_id: str) -> None:
self._recovery_data[tmp_hash].finished_examples.append(example_id)
def finished_examples(self, tmp_hash: str) -> Optional[RecoveryData]:
if tmp_hash in self._recovery_data:
return self._recovery_data[tmp_hash]
else:
return None
[docs]
def run_overview(self, run_id: str) -> Optional[RunOverview]:
return self._run_overviews.get(run_id, None)
[docs]
def run_overview_ids(self) -> Sequence[str]:
return sorted(self._run_overviews.keys())
[docs]
def store_example_output(self, example_output: ExampleOutput[Output]) -> None:
self._example_outputs[example_output.run_id].append(
cast(ExampleOutput[PydanticSerializable], example_output)
)
[docs]
def example_output(
self, run_id: str, example_id: str, output_type: type[Output]
) -> Optional[ExampleOutput[Output]]:
if run_id not in self._example_outputs:
warnings.warn(
f'Repository does not contain a run with id: "{run_id}"', UserWarning
)
return None
for example_output in self._example_outputs[run_id]:
if example_output.example_id == example_id:
return cast(ExampleOutput[Output], example_output)
return None
[docs]
def example_outputs(
self, run_id: str, output_type: type[Output]
) -> Iterable[ExampleOutput[Output]]:
if run_id not in self._example_outputs and run_id not in self._run_overviews:
warnings.warn(
f'Repository does not contain a run with id: "{run_id}"', UserWarning
)
return []
return (
cast(ExampleOutput[Output], example_output)
for example_output in sorted(
self._example_outputs[run_id],
key=lambda example_output: example_output.example_id,
)
)
[docs]
def example_output_ids(self, run_id: str) -> Sequence[str]:
return sorted(
[
example_output.example_id
for example_output in self._example_outputs[run_id]
]
)
[docs]
def example_tracer(self, run_id: str, example_id: str) -> Optional[Tracer]:
return self._example_traces.get(f"{run_id}/{example_id}")
[docs]
def create_tracer_for_example(self, run_id: str, example_id: str) -> Tracer:
tracer = InMemoryTracer()
self._example_traces[f"{run_id}/{example_id}"] = tracer
return tracer