from collections.abc import Mapping, Sequence
from typing import Generic, Optional
from pydantic import BaseModel
from intelligence_layer.connectors.retrievers.base_retriever import (
ID,
BaseRetriever,
SearchResult,
)
from intelligence_layer.core.chunk import TextChunk
from intelligence_layer.core.detect_language import Language
from intelligence_layer.core.model import ControlModel, LuminousControlModel
from intelligence_layer.core.task import Task
from intelligence_layer.core.text_highlight import ScoredTextHighlight
from intelligence_layer.core.tracer.tracer import TaskSpan
from intelligence_layer.examples.search.expand_chunks import (
ExpandChunks,
ExpandChunksInput,
ExpandChunksOutput,
)
from intelligence_layer.examples.search.search import Search, SearchInput
from .retriever_based_qa import RetrieverBasedQaInput
from .single_chunk_qa import SingleChunkQa, SingleChunkQaInput, SingleChunkQaOutput
class EnrichedChunk(BaseModel, Generic[ID]):
document_id: ID
chunk: TextChunk
indices: tuple[int, int]
class AnswerSource(BaseModel, Generic[ID]):
chunk: EnrichedChunk[ID]
highlights: Sequence[ScoredTextHighlight]
[docs]
class MultipleChunkRetrieverQaOutput(BaseModel, Generic[ID]):
"""Returns the answer of a `QA` task together with the sources which support the answer.
The important thing to note is that the answer is generated based on multiple chunks of text. Furthermore, there are potentially
multiple sources associated with a single search result. This is due to the fact, that the search result may be expanded to include
adjacent chunks of text.
Attributes:
answer: The answer generated by the QA task. May be `None` if no answer was found in the text.
sources: A list of source chunks or passages that support or are relevant to
the provided answer.
search_results: A list of search results from the retriever, providing additional references.
"""
answer: Optional[str]
sources: Sequence[AnswerSource[ID]]
search_results: Sequence[SearchResult[ID]]
SOURCE_PREFIX_CONFIG = {
Language("en"): "Source {i}:\n",
Language("de"): "Quelle {i}:\n",
Language("fr"): "Source {i}:\n",
Language("es"): "Fuente {i}:\n",
Language("it"): "Fonte {i}:\n",
}
[docs]
class MultipleChunkRetrieverQa(
Task[RetrieverBasedQaInput, MultipleChunkRetrieverQaOutput[ID]], Generic[ID]
):
"""Answer a question based on documents found by a retriever.
`MultipleChunkRetrieverBasedQa` is a task that answers a question based on a set of documents.
It relies on some retriever of type `BaseRetriever` that has the ability to access texts.
In contrast to the regular `RetrieverBasedQa`, this tasks injects multiple chunks into one
`SingleChunkQa` task run.
We recommend using this task instead of `RetrieverBasedQa`.
Note:
`model` provided should be a control-type model.
Args:
retriever: Used to access and return a set of texts.
insert_chunk_number: number of top chunks to inject into :class:`SingleChunkQa`-task.
model: The model used throughout the task for model related API calls.
expand_chunks: The task used to fetch adjacent chunks to the search results. These
"expanded" chunks will be injected into the prompt.
single_chunk_qa: The task used to generate an answer for a single chunk (retrieved through
the retriever). Defaults to :class:`SingleChunkQa`.
source_prefix_config: A mapping that describes the source section string for different languages.
Defaults to the equivalent of "Source {language}".
"""
def __init__(
self,
retriever: BaseRetriever[ID],
insert_chunk_number: int = 5,
model: ControlModel | None = None,
expand_chunks: Task[ExpandChunksInput[ID], ExpandChunksOutput] | None = None,
single_chunk_qa: Task[SingleChunkQaInput, SingleChunkQaOutput] | None = None,
source_prefix_config: Mapping[Language, str] = SOURCE_PREFIX_CONFIG,
):
super().__init__()
self._search = Search(retriever)
self._insert_chunk_number = insert_chunk_number
self._model = model or LuminousControlModel("luminous-supreme-control")
self._expand_chunks = expand_chunks or ExpandChunks(retriever, self._model)
self._single_chunk_qa = single_chunk_qa or SingleChunkQa(self._model)
if any("{i}" not in value for value in source_prefix_config.values()):
raise ValueError("All values in `source_prefix_config` must contain '{i}'.")
self._source_prefix_config = source_prefix_config
[docs]
def do_run(
self, input: RetrieverBasedQaInput, task_span: TaskSpan
) -> MultipleChunkRetrieverQaOutput[ID]:
search_output = self._search.run(
SearchInput(query=input.question), task_span
).results
sorted_search_results = sorted(
search_output, key=lambda output: output.score, reverse=True
)
if not sorted_search_results:
return MultipleChunkRetrieverQaOutput(
answer=None,
sources=[],
search_results=[],
)
chunks_to_insert = self._expand_search_result_chunks(
sorted_search_results, task_span
)
source_prefix = input.language.language_config(self._source_prefix_config)
chunk_for_prompt, chunk_start_indices = self._combine_input_texts(
[c.chunk for c in chunks_to_insert], source_prefix
)
single_chunk_qa_input = SingleChunkQaInput(
chunk=chunk_for_prompt,
question=input.question,
language=input.language,
generate_highlights=input.generate_highlights,
)
single_chunk_qa_output = self._single_chunk_qa.run(
single_chunk_qa_input, task_span
)
highlights_per_chunk = self._get_highlights_per_chunk(
chunk_start_indices, single_chunk_qa_output.highlights
)
return MultipleChunkRetrieverQaOutput(
answer=single_chunk_qa_output.answer,
sources=[
AnswerSource(
chunk=enriched_chunk,
highlights=highlights,
)
for enriched_chunk, highlights in zip(
chunks_to_insert, highlights_per_chunk, strict=True
)
],
search_results=sorted_search_results,
)
@staticmethod
def _combine_input_texts(
chunks: Sequence[str], source_appendix: str
) -> tuple[TextChunk, Sequence[int]]:
start_indices: list[int] = []
combined_text = ""
for i, chunk in enumerate(chunks):
combined_text += source_appendix.format(i=i + 1)
start_indices.append(len(combined_text))
combined_text += chunk + "\n\n"
return (TextChunk(combined_text.strip()), start_indices)
@staticmethod
def _get_highlights_per_chunk(
chunk_start_indices: Sequence[int], highlights: Sequence[ScoredTextHighlight]
) -> Sequence[Sequence[ScoredTextHighlight]]:
overlapping_ranges = []
for i in range(len(chunk_start_indices)):
current_start = chunk_start_indices[i]
next_start = (
chunk_start_indices[i + 1]
if i + 1 < len(chunk_start_indices)
else float("inf")
)
current_overlaps = []
for highlight in highlights:
if highlight.start < next_start and highlight.end > current_start:
highlights_with_indices_fixed = ScoredTextHighlight(
start=max(0, highlight.start - current_start),
end=(
highlight.end - current_start
if isinstance(next_start, float)
else min(
next_start - current_start,
highlight.end - current_start,
)
),
score=highlight.score,
)
current_overlaps.append(highlights_with_indices_fixed)
overlapping_ranges.append(current_overlaps)
return overlapping_ranges
def _expand_search_result_chunks(
self, search_results: Sequence[SearchResult[ID]], task_span: TaskSpan
) -> Sequence[EnrichedChunk[ID]]:
chunks_to_insert: list[EnrichedChunk[ID]] = []
for result in search_results:
input = ExpandChunksInput(
document_id=result.id, chunks_found=[result.document_chunk]
)
expand_chunks_output = self._expand_chunks.run(input, task_span)
# This loop causes the output to potentially contain more sources than search_results.
for chunk in expand_chunks_output.chunks:
if len(chunks_to_insert) >= self._insert_chunk_number:
break
enriched_chunk = EnrichedChunk(
document_id=result.id,
chunk=chunk.chunk,
indices=(chunk.start_index, chunk.end_index),
)
if enriched_chunk in chunks_to_insert:
continue
chunks_to_insert.append(enriched_chunk)
return chunks_to_insert