Source code for intelligence_layer.evaluation.aggregation.aggregator

from abc import ABC, abstractmethod
from import Callable, Iterable, Iterator, Mapping
from functools import cached_property
from typing import (
from uuid import uuid4

from intelligence_layer.connectors.base.json_serializable import (
from intelligence_layer.core import utc_now
from intelligence_layer.evaluation.aggregation.aggregation_repository import (
from intelligence_layer.evaluation.aggregation.domain import (
from intelligence_layer.evaluation.evaluation.domain import (
from intelligence_layer.evaluation.evaluation.evaluation_repository import (

T = TypeVar("T")

class CountingFilterIterable(Iterable[T]):
    def __init__(
        self, wrapped_iterable: Iterable[T], filter: Callable[[T], bool]
    ) -> None:
        self._wrapped_iterator = iter(wrapped_iterable)
        self._filter = filter
        self._included_count = 0
        self._excluded_count = 0

    def __next__(self) -> T:
        e = next(self._wrapped_iterator)
        while not self._filter(e):
            self._excluded_count += 1
            e = next(self._wrapped_iterator)
        self._included_count += 1
        return e

    def __iter__(self) -> Iterator[T]:
        return self

    def included_count(self) -> int:
        return self._included_count

    def excluded_count(self) -> int:
        return self._excluded_count

[docs] class AggregationLogic(ABC, Generic[Evaluation, AggregatedEvaluation]):
[docs] @abstractmethod def aggregate(self, evaluations: Iterable[Evaluation]) -> AggregatedEvaluation: """`Evaluator`-specific method for aggregating individual `Evaluations` into report-like `Aggregated Evaluation`. This method is responsible for taking the results of an evaluation run and aggregating all the results. It should create an `AggregatedEvaluation` class and return it at the end. Args: evaluations: The results from running `eval_and_aggregate_runs` with a :class:`Task`. Returns: The aggregated results of an evaluation run with a :class:`Dataset`. """ ...
[docs] class Aggregator(Generic[Evaluation, AggregatedEvaluation]): """Aggregator that can handle automatic aggregation of evaluation scenarios. This aggregator should be used for automatic eval. A user still has to implement :class: `AggregationLogic`. Arguments: evaluation_repository: The repository that will be used to store evaluation results. aggregation_repository: The repository that will be used to store aggregation results. description: Human-readable description for the evaluator. aggregation_logic: The logic to aggregate the evaluations. Generics: Evaluation: Interface of the metrics that come from the evaluated :class:`Task`. AggregatedEvaluation: The aggregated results of an evaluation run with a :class:`Dataset`. """ def __init__( self, evaluation_repository: EvaluationRepository, aggregation_repository: AggregationRepository, description: str, aggregation_logic: AggregationLogic[Evaluation, AggregatedEvaluation], ) -> None: self._evaluation_repository = evaluation_repository self._aggregation_repository = aggregation_repository self._aggregation_logic = aggregation_logic self.description = description @cached_property def _get_types(self) -> Mapping[str, type]: """Type magic function that gets the actual types of the generic parameters. Traverses the inheritance history of `AggregationLogic`-subclass to find an actual type every time a TypeVar is replaced. Returns: Name of generic parameter to the type found. """ def is_eligible_subclass(parent: type) -> bool: return hasattr(parent, "__orig_bases__") and issubclass( parent, AggregationLogic ) def update_types() -> None: num_types_set = 0 for current_index, current_type in enumerate(current_types): if type(current_type) is not TypeVar: type_var_count = num_types_set - 1 final_index = -1 for element_index, element in enumerate(type_list): final_index = element_index if type(element) is TypeVar: type_var_count += 1 if type_var_count == current_index: break assert type_var_count == current_index type_list[final_index] = current_type num_types_set += 1 # mypy does not know __orig_bases__ base_types = AggregationLogic.__orig_bases__[1] # type: ignore type_list: list[type | TypeVar] = list(get_args(base_types)) possible_parent_classes = [ p for p in reversed(type(self._aggregation_logic).__mro__) if is_eligible_subclass(p) ] for parent in possible_parent_classes: # mypy does not know __orig_bases__ for base in parent.__orig_bases__: # type: ignore origin = get_origin(base) if origin is None or not issubclass(origin, AggregationLogic): continue current_types = list(get_args(base)) update_types() return { name: param_type for name, param_type in zip( (a.__name__ for a in get_args(base_types)), type_list, strict=False ) if type(param_type) is not TypeVar }
[docs] def evaluation_type(self) -> type[Evaluation]: """Returns the type of the evaluation result of an example. This can be used to retrieve properly typed evaluations of an evaluation run from a :class:`EvaluationRepository` Returns: Returns the type of the evaluation result of an example. """ try: evaluation_type = self._get_types["Evaluation"] except KeyError: raise TypeError( f"Alternatively overwrite evaluation_type() in {type(self)}" ) from None return cast(type[Evaluation], evaluation_type)
[docs] def aggregated_evaluation_type(self) -> type[AggregatedEvaluation]: """Returns the type of the aggregated result of a run. Returns: Returns the type of the aggreagtion result. """ try: aggregated_evaluation_type = self._get_types["AggregatedEvaluation"] except KeyError: raise TypeError( f"Alternatively overwrite aggregated_evaluation_type() in {type(self)}" ) from None return cast(type[AggregatedEvaluation], aggregated_evaluation_type)
[docs] @final def aggregate_evaluation( self, *eval_ids: str, description: Optional[str] = None, labels: set[str] | None = None, metadata: SerializableDict | None = None, ) -> AggregationOverview[AggregatedEvaluation]: """Aggregates all evaluations into an overview that includes high-level statistics. Aggregates :class:`Evaluation`s according to the implementation of :func:`AggregationLogic.aggregate`. Args: *eval_ids: An overview of the evaluation to be aggregated. Does not include actual evaluations as these will be retrieved from the repository. description: Optional description of the aggregation. Defaults to None. labels: A list of labels for filtering. Defaults to an empty list. metadata: A dict for additional information about the aggregation overview. Defaults to an empty dict. Returns: An overview of the aggregated evaluation. """ if metadata is None: metadata = dict() if labels is None: labels = set() def load_eval_overview(evaluation_id: str) -> EvaluationOverview: evaluation_overview = self._evaluation_repository.evaluation_overview( evaluation_id ) if not evaluation_overview: raise ValueError( f"No PartialEvaluationOverview found for eval-id: {evaluation_id}" ) return evaluation_overview evaluation_overviews = frozenset( load_eval_overview(evaluation_id) for evaluation_id in set(eval_ids) ) nested_evaluations = [ self._evaluation_repository.example_evaluations(, self.evaluation_type() ) for overview in evaluation_overviews ] example_evaluations = [ evaluation for sublist in nested_evaluations for evaluation in sublist ] successful_evaluations = CountingFilterIterable( (example_eval.result for example_eval in example_evaluations), lambda evaluation: not isinstance(evaluation, FailedExampleEvaluation), ) start = utc_now() statistics = self._aggregation_logic.aggregate( cast(Iterable[Evaluation], successful_evaluations) ) full_description = ( self.description + " : " + description if description else self.description ) aggregation_overview = AggregationOverview( evaluation_overviews=frozenset(evaluation_overviews), id=str(uuid4()), start=start, end=utc_now(), successful_evaluation_count=successful_evaluations.included_count(), crashed_during_evaluation_count=successful_evaluations.excluded_count(), description=full_description, statistics=statistics, labels=labels, metadata=metadata, ) self._aggregation_repository.store_aggregation_overview(aggregation_overview) return aggregation_overview