Source code for intelligence_layer.evaluation.run.in_memory_run_repository

import warnings
from collections import defaultdict
from collections.abc import Iterable, Sequence
from typing import Optional, cast

from intelligence_layer.core import InMemoryTracer, Output, PydanticSerializable
from intelligence_layer.core.tracer.tracer import Tracer
from intelligence_layer.evaluation.run.domain import ExampleOutput, RunOverview
from intelligence_layer.evaluation.run.run_repository import RecoveryData, RunRepository


[docs] class InMemoryRunRepository(RunRepository): def __init__(self) -> None: super().__init__() self._example_outputs: dict[str, list[ExampleOutput[PydanticSerializable]]] = ( defaultdict(list) ) self._example_traces: dict[str, Tracer] = dict() self._run_overviews: dict[str, RunOverview] = dict() self._recovery_data: dict[str, RecoveryData] = dict()
[docs] def store_run_overview(self, overview: RunOverview) -> None: self._run_overviews[overview.id] = overview if overview.id not in self._example_outputs: self._example_outputs[overview.id] = []
def _create_temporary_run_data(self, tmp_hash: str, run_id: str) -> None: self._recovery_data[tmp_hash] = RecoveryData(run_id=run_id) def _delete_temporary_run_data(self, tmp_hash: str) -> None: del self._recovery_data[tmp_hash] def _temp_store_finished_example(self, tmp_hash: str, example_id: str) -> None: self._recovery_data[tmp_hash].finished_examples.append(example_id) def finished_examples(self, tmp_hash: str) -> Optional[RecoveryData]: if tmp_hash in self._recovery_data: return self._recovery_data[tmp_hash] else: return None
[docs] def run_overview(self, run_id: str) -> Optional[RunOverview]: return self._run_overviews.get(run_id, None)
[docs] def run_overview_ids(self) -> Sequence[str]: return sorted(self._run_overviews.keys())
[docs] def store_example_output(self, example_output: ExampleOutput[Output]) -> None: self._example_outputs[example_output.run_id].append( cast(ExampleOutput[PydanticSerializable], example_output) )
[docs] def example_output( self, run_id: str, example_id: str, output_type: type[Output] ) -> Optional[ExampleOutput[Output]]: if run_id not in self._example_outputs: warnings.warn( f'Repository does not contain a run with id: "{run_id}"', UserWarning ) return None for example_output in self._example_outputs[run_id]: if example_output.example_id == example_id: return cast(ExampleOutput[Output], example_output) return None
[docs] def example_outputs( self, run_id: str, output_type: type[Output] ) -> Iterable[ExampleOutput[Output]]: if run_id not in self._example_outputs and run_id not in self._run_overviews: warnings.warn( f'Repository does not contain a run with id: "{run_id}"', UserWarning ) return [] return ( cast(ExampleOutput[Output], example_output) for example_output in sorted( self._example_outputs[run_id], key=lambda example_output: example_output.example_id, ) )
[docs] def example_output_ids(self, run_id: str) -> Sequence[str]: return sorted( [ example_output.example_id for example_output in self._example_outputs[run_id] ] )
[docs] def example_tracer(self, run_id: str, example_id: str) -> Optional[Tracer]: return self._example_traces.get(f"{run_id}/{example_id}")
[docs] def create_tracer_for_example(self, run_id: str, example_id: str) -> Tracer: tracer = InMemoryTracer() self._example_traces[f"{run_id}/{example_id}"] = tracer return tracer