import itertools
from collections.abc import Iterable, Mapping, Sequence
from typing import cast
from aleph_alpha_client import (
Prompt,
PromptGranularity,
Text,
TextPromptItemExplanation,
TextScore,
)
from pydantic import BaseModel
from intelligence_layer.core.model import AlephAlphaModel, ExplainInput, ExplainOutput
from intelligence_layer.core.prompt_template import (
Cursor,
PromptRange,
RichPrompt,
TextCursor,
)
from intelligence_layer.core.task import Task
from intelligence_layer.core.tracer.tracer import TaskSpan
[docs]
class TextHighlightInput(BaseModel):
"""The input for a text highlighting task.
Attributes:
rich_prompt: From client's PromptTemplate. Includes both the actual 'Prompt' as well as text range information.
Supports liquid-template-language-style {% promptrange range_name %}/{% endpromptrange %} for range.
target: The target that should be explained. Expected to follow the prompt.
focus_ranges: The ranges contained in `rich_prompt` the returned highlights stem from. That means that each returned
highlight overlaps with at least one character with one of the ranges listed here.
If this set is empty highlights of the entire prompt are returned.
"""
rich_prompt: RichPrompt
target: str
focus_ranges: frozenset[str] = frozenset()
[docs]
class ScoredTextHighlight(BaseModel):
"""A substring of the input prompt scored for relevance with regard to the output.
Attributes:
start: The start index of the highlight.
end: The end index of the highlight.
score: The score of the highlight. Normalized to be between zero and one, with higher being more important.
"""
start: int
end: int
score: float
[docs]
class TextHighlightOutput(BaseModel):
"""The output of a text highlighting task.
Attributes:
highlights: A sequence of 'ScoredTextHighlight's.
"""
highlights: Sequence[ScoredTextHighlight]
class TextPromptRange(PromptRange):
start: TextCursor
end: TextCursor
[docs]
class TextHighlight(Task[TextHighlightInput, TextHighlightOutput]):
r"""Generates text highlights given a prompt and completion.
For a given prompt and target (completion), extracts the parts of the prompt responsible for generation.
The prompt can only contain text. A range can be provided via use of the liquid language (see the example).
In this case, the highlights will only refer to text within this range.
Args:
model: The model used throughout the task for model related API calls.
granularity: At which granularity should the target be explained in terms of the prompt.
threshold: After normalization, everything highlight below this value will be dropped.
clamp: Control whether highlights should be clamped to a focus range if they intersect it.
Example:
>>> import os
>>> from intelligence_layer.core import (
... InMemoryTracer,
... PromptTemplate,
... TextHighlight,
... TextHighlightInput,
... AlephAlphaModel
... )
>>> model = AlephAlphaModel(name="luminous-base")
>>> text_highlight = TextHighlight(model=model)
>>> prompt_template_str = (
... "{% promptrange r1 %}Question: What is 2 + 2?{% endpromptrange %}\nAnswer:"
... )
>>> template = PromptTemplate(prompt_template_str)
>>> rich_prompt = template.to_rich_prompt()
>>> completion = " 4."
>>> model = "luminous-base"
>>> input = TextHighlightInput(
... rich_prompt=rich_prompt, target=completion, focus_ranges=frozenset({"r1"})
... )
>>> output = text_highlight.run(input, InMemoryTracer())
"""
def __init__(
self,
model: AlephAlphaModel,
granularity: PromptGranularity | None = None,
threshold: float = 0.1,
clamp: bool = False,
) -> None:
super().__init__()
self._threshold = threshold
self._model = model
self._granularity = granularity
self._clamp_to_focus = clamp
[docs]
def do_run(
self, input: TextHighlightInput, task_span: TaskSpan
) -> TextHighlightOutput:
self._raise_on_invalid_focus_range(input)
self._raise_on_incompatible_prompt(input)
explanation = self._explain(
prompt=input.rich_prompt,
target=input.target,
task_span=task_span,
)
focus_ranges = self._filter_and_flatten_prompt_ranges(
input.focus_ranges, input.rich_prompt.ranges
)
text_prompt_item_explanations_and_indices = (
self._extract_text_prompt_item_explanations_and_item_index(explanation)
)
highlights = self._to_highlights(
focus_ranges,
text_prompt_item_explanations_and_indices,
task_span,
)
return TextHighlightOutput(highlights=highlights)
def _raise_on_incompatible_prompt(self, input: TextHighlightInput) -> None:
"""Raises an error if the prompt contains anything other than text.
Currently, the text highlight task does not properly deal with multimodal prompts.
This is a result of returning indices instead of text.
Therefore, we disable running text highlighting on prompts with more than one index
for the moment. This also means we only deal with text items.
Args:
input: The input for a text highlighting task.
"""
n_items = len(input.rich_prompt.items)
# the last item is always the question
if n_items > 2:
raise ValueError(
f"Text highlighting currently only works correctly with a single Text item. Found {n_items-1}."
)
if any(not isinstance(item, Text) for item in input.rich_prompt.items):
raise ValueError("Text highlighting only supports text prompts.")
def _raise_on_invalid_focus_range(self, input: TextHighlightInput) -> None:
unknown_focus_ranges = input.focus_ranges - set(input.rich_prompt.ranges.keys())
if unknown_focus_ranges:
raise ValueError(f"Unknown focus ranges: {', '.join(unknown_focus_ranges)}")
def _explain(
self, prompt: Prompt, target: str, task_span: TaskSpan
) -> ExplainOutput:
input = ExplainInput(
prompt=prompt,
target=target,
prompt_granularity=self._granularity,
)
output = self._model.explain(input, task_span)
return output
def _filter_and_flatten_prompt_ranges(
self,
focus_ranges: frozenset[str],
input_ranges: Mapping[str, Sequence[PromptRange]],
) -> Sequence[PromptRange]:
relevant_ranges = (
range for name, range in input_ranges.items() if name in focus_ranges
)
return list(itertools.chain.from_iterable(relevant_ranges))
def _extract_text_prompt_item_explanations_and_item_index(
self,
explain_output: ExplainOutput,
) -> Iterable[tuple[TextPromptItemExplanation, int]]:
return (
(explanation, index)
# we explain the complete target at once, therefore we have 1 explanation
for index, explanation in enumerate(explain_output.explanations[0].items)
if isinstance(explanation, TextPromptItemExplanation)
)
def _to_highlights(
self,
focus_ranges: Sequence[PromptRange],
text_prompt_item_explanations_and_indices: Iterable[
tuple[TextPromptItemExplanation, int]
],
task_span: TaskSpan,
) -> Sequence[ScoredTextHighlight]:
relevant_text_scores: list[TextScore] = []
for (
text_prompt_item_explanation,
explanation_idx,
) in text_prompt_item_explanations_and_indices:
for text_score in text_prompt_item_explanation.scores:
assert isinstance(text_score, TextScore) # for typing
if self._is_relevant_explanation(
explanation_idx, text_score, focus_ranges
):
relevant_text_scores.append(text_score)
task_span.log(
"Raw explanation scores",
[
{
"start": text_score.start,
"end": text_score.start + text_score.length,
"score": text_score.score,
}
for text_score in relevant_text_scores
],
)
if self._clamp_to_focus:
relevant_text_scores = self._clamp_ranges_to_focus(
focus_ranges, relevant_text_scores
)
text_highlights = [
ScoredTextHighlight(
start=text_score.start,
end=text_score.start + text_score.length,
score=text_score.score,
)
for text_score in relevant_text_scores
]
return self._normalize_and_filter(text_highlights)
def _clamp_ranges_to_focus(
self,
prompt_ranges: Sequence[PromptRange],
relevant_highlights: list[TextScore],
) -> list[TextScore]:
text_prompt_ranges = [
cast(TextPromptRange, prompt_range) for prompt_range in prompt_ranges
]
if not self._should_clamp(text_prompt_ranges):
return relevant_highlights
new_relevant_text_scores: list[TextScore] = []
for highlight in relevant_highlights:
def _get_overlap(
range: TextPromptRange, highlight: TextScore = highlight
) -> int:
return min(
highlight.start + highlight.length, range.end.position
) - max(highlight.start, range.start.position)
most_overlapping_range = sorted(
text_prompt_ranges,
key=_get_overlap,
)[-1]
new_relevant_text_scores.append(
self._clamp_to_range(highlight, most_overlapping_range)
)
return new_relevant_text_scores
def _should_clamp(self, prompt_ranges: Sequence[TextPromptRange]) -> bool:
def are_overlapping(p1: TextPromptRange, p2: TextPromptRange) -> bool:
if p1.start.position > p2.start.position:
p1, p2 = p2, p1
return (
p1.start.position <= p2.end.position
and p1.end.position >= p2.start.position
)
if len(prompt_ranges) == 0:
return False
# this check is relatively expensive if no ranges are overlapping
has_overlapping_ranges = any(
are_overlapping(p1, p2)
for p1, p2 in itertools.permutations(prompt_ranges, 2)
)
if has_overlapping_ranges:
print(
"TextHighlighting with clamping is on, but focus ranges are overlapping. Disabling clamping."
)
return not has_overlapping_ranges
def _clamp_to_range(
self, highlight: TextScore, focus_range: TextPromptRange
) -> TextScore:
new_start = highlight.start
new_length = highlight.length
new_score = highlight.score
cut_characters = 0
# clamp start
if highlight.start < focus_range.start.position:
cut_characters = focus_range.start.position - highlight.start
new_start = focus_range.start.position
new_length -= cut_characters
# clamp end
if highlight.start + highlight.length > focus_range.end.position:
n_cut_at_end = (
highlight.start + highlight.length
) - focus_range.end.position
cut_characters += n_cut_at_end
new_length -= n_cut_at_end
if cut_characters:
new_score = highlight.score * (new_length / highlight.length)
return TextScore(new_start, new_length, new_score)
def _normalize_and_filter(
self, text_highlights: Sequence[ScoredTextHighlight]
) -> Sequence[ScoredTextHighlight]:
max_score = max(highlight.score for highlight in text_highlights)
divider = max(
1, max_score
) # We only normalize if the max score is above a threshold to avoid noisy attribution in case where
for highlight in text_highlights:
highlight.score = max(highlight.score / divider, 0)
return [
highlight
for highlight in text_highlights
if highlight.score >= self._threshold
]
def _is_relevant_explanation(
self,
explanation_idx: int,
text_score: TextScore,
prompt_ranges: Iterable[PromptRange],
) -> bool:
return (
any(
self._prompt_range_overlaps_with_text_score(
prompt_range, text_score, explanation_idx
)
for prompt_range in prompt_ranges
)
or not prompt_ranges
)
@classmethod
def _prompt_range_overlaps_with_text_score(
cls,
prompt_range: PromptRange,
text_score: TextScore,
explanation_item_idx: int,
) -> bool:
return (
cls._is_within_prompt_range(
prompt_range,
explanation_item_idx,
text_score.start,
)
or cls._is_within_prompt_range(
prompt_range,
explanation_item_idx,
text_score.start + text_score.length - 1,
)
or cls._is_within_text_score(
text_score, explanation_item_idx, prompt_range.start
)
)
@staticmethod
def _is_within_text_score(
text_score: TextScore,
text_score_item: int,
prompt_range_cursor: Cursor,
) -> bool:
if text_score_item != prompt_range_cursor.item:
return False
assert isinstance(prompt_range_cursor, TextCursor)
return (
text_score.start
<= prompt_range_cursor.position
<= text_score.start + text_score.length - 1
)
@staticmethod
def _is_within_prompt_range(
prompt_range: PromptRange,
item_check: int,
pos_check: int,
) -> bool:
if item_check < prompt_range.start.item or item_check > prompt_range.end.item:
return False
if item_check == prompt_range.start.item:
# must be a text cursor, because has same index as TextScore
assert isinstance(prompt_range.start, TextCursor)
if pos_check < prompt_range.start.position:
return False
if item_check == prompt_range.end.item:
assert isinstance(prompt_range.end, TextCursor) # see above
if pos_check > prompt_range.end.position:
return False
return True