"""LLM utilities for prompt construction and summary generation.
The summary helpers keep an in-memory LRU cache keyed by ``(model, prompt)``.
The cache currently uses a fixed size controlled by :data:`SUMMARY_CACHE_SIZE`.
"""
from collections import OrderedDict
from collections.abc import Iterator
from typing import Any
from .ai_model_interface import AIModelInterface
from .utils.image_utils import add_watermark, generate_mni152_image
SUMMARY_CACHE_SIZE = 128
# ---------------------------------------------------------------------------
# Exposed prompt templates
# ---------------------------------------------------------------------------
# Templates for the introductory portion of LLM prompts. Users can inspect and
# customize these as needed before passing them to :func:`generate_llm_prompt`.
[docs]
LLM_PROMPT_TEMPLATES: dict[str, str] = {
"summary": (
"You are an advanced AI with expertise in neuroanatomy and cognitive "
"neuroscience. The user is interested in understanding the significance "
"of MNI coordinate {coord}.\n\n"
"Below is a list of neuroimaging studies that report activation at this "
"coordinate. Your task is to integrate and synthesize the knowledge from "
"these studies, focusing on:\n"
"1) The anatomical structure(s) most commonly associated with this coordinate\n"
"2) The typical functional roles or processes linked to activation in this "
"region\n"
"3) The main tasks or experimental conditions in which it was reported\n"
"4) Patterns, contradictions, or debates in the findings\n\n"
"Do NOT simply list each study separately. Provide an integrated, cohesive "
"summary.\n"
),
"region_name": (
"You are a neuroanatomy expert. The user wants to identify the probable "
"anatomical labels for MNI coordinate {coord}. The following studies "
"reported activation around this location. Incorporate anatomical "
"knowledge and any direct references to brain regions from these studies. "
"If multiple labels are possible, mention all and provide rationale and "
"confidence levels.\n\n"
),
"function": (
"You are a cognitive neuroscience expert. The user wants a deep "
"functional profile of the brain region(s) around MNI coordinate {coord}. "
"The studies below report activation at or near this coordinate. "
"Synthesize a clear description of:\n"
"1) Core functions or cognitive processes\n"
"2) Typical experimental paradigms or tasks\n"
"3) Known functional networks or connectivity\n"
"4) Divergent or debated viewpoints in the literature\n\n"
),
"default": (
"Please analyze the following neuroimaging studies reporting activation at "
"MNI coordinate {coord} and provide a concise yet thorough discussion of "
"its anatomical location and functional significance.\n\n"
),
}
# Templates for image prompt generation. Each template can be formatted with
# ``coordinate``, ``first_paragraph``, ``atlas_context``,
# and ``study_context`` variables.
[docs]
IMAGE_PROMPT_TEMPLATES: dict[str, str] = {
"anatomical": (
"Create a scientific brain visualization showing exactly three orthogonal MRI"
"slices arranged horizontally: coronal (left), sagittal (middle),"
"and axial (right) views. "
"Use grayscale T1-weighted MRI brain anatomy on a black background. "
"Place bright yellow or white crosshairs (+) at MNI coordinate {coordinate},"
"with the crosshairs extending across each slice to mark the exact location. "
"Label each view with the coordinate values shown. "
"Add L/R orientation markers. The style should match standard neuroimaging "
"software output like FSLeyes or Nilearn, with no artistic interpretation. "
"Ensure the crosshairs intersect precisely at the specified coordinate point.\n"
"Coordinate location: x={x_coord}, y={y_coord}, z={z_coord}\n"
"{atlas_context}"
),
"functional": (
"Produce a Nilearn-style activation map with sagittal, coronal, and axial "
"panels centred on coordinate {coordinate}.\n"
"Functional interpretation: {first_paragraph}\n"
"{atlas_context}{study_context}"
"Overlay activation intensities as a heat map on the MNI152 template, include "
"legend ticks, slice coordinates, and crosshairs precisely at the specified "
"location."
),
"schematic": (
"Draw a network schematic anchored on MNI coordinate {coordinate}. Include an "
"inset miniature of the Nilearn-style orthogonal slices marking the focus.\n"
"Conceptual summary: {first_paragraph}\n"
"{atlas_context}{study_context}"
"Label interacting regions, indicate connectivity directions when supported, "
"and keep the overall style technical and publication-ready."
),
"artistic": (
"Create a stylised yet anatomically faithful visualization spotlighting "
"coordinate {coordinate}. Retain Nilearn-like slice framing so the activation "
"can be compared to reference material.\n"
"Narrative focus: {first_paragraph}\n"
"{atlas_context}{study_context}"
"Blend scientific structure with thoughtful lighting or texture while keeping "
"the coordinate marker and orthogonal slices clear."
),
"default": (
"Render a Nilearn-style comparative figure centred on coordinate "
"{coordinate}. Provide orthogonal MNI152 slices with crosshairs, legend, and "
"activation emphasis.\n"
"Primary description: {first_paragraph}\n"
"{atlas_context}{study_context}"
"Ensure the output resembles neuroimaging data ready for side-by-side "
"comparison with a deterministic Nilearn export."
),
}
[docs]
def generate_llm_prompt(
studies: list[dict[str, Any]],
coordinate: list[float] | tuple[float, float, float],
prompt_type: str = "summary",
prompt_template: str | None = None,
) -> str:
"""Generate a detailed prompt for language models based on studies.
Parameters
----------
studies : list of dict
Study metadata dictionaries that describe the activation evidence.
coordinate : sequence of float
MNI coordinate used for formatting the prompt header.
prompt_type : str, optional
Key that selects a built-in template from :data:`LLM_PROMPT_TEMPLATES`.
prompt_template : str, optional
Custom template string requiring ``coord`` and ``studies`` placeholders.
Returns
-------
str
Fully formatted prompt ready for submission to a language model.
"""
# Format coordinate string safely.
try:
coord_str = (
f"[{float(coordinate[0]):.2f}, "
f"{float(coordinate[1]):.2f}, "
f"{float(coordinate[2]):.2f}]"
)
except Exception:
coord_str = str(coordinate)
if not studies:
return (
"No neuroimaging studies were found reporting activation at "
f"MNI coordinate {coord_str}."
)
# Build the studies section efficiently.
study_lines: list[str] = []
for i, study in enumerate(studies, start=1):
study_lines.append(f"\n--- STUDY {i} ---\n")
study_lines.append(f"ID: {study.get('id', 'Unknown ID')}\n")
study_lines.append(f"Title: {study.get('title', 'No title available')}\n")
abstract_text = study.get("abstract", "No abstract available")
study_lines.append(f"Abstract: {abstract_text}\n")
studies_section = "".join(study_lines)
if prompt_type == "custom" and not prompt_template:
raise ValueError("prompt_template must be provided when prompt_type='custom'")
# If a custom template is provided, use it.
if prompt_template:
return prompt_template.format(coord=coord_str, studies=studies_section)
# Build the prompt header using the templates dictionary.
template = LLM_PROMPT_TEMPLATES.get(prompt_type, LLM_PROMPT_TEMPLATES["default"])
prompt_intro = template.format(coord=coord_str)
prompt_body = (
"STUDIES REPORTING ACTIVATION AT MNI COORDINATE "
+ coord_str
+ ":\n"
+ studies_section
)
prompt_outro = (
"\nUsing ALL of the information above, produce a single cohesive "
"synthesis. Avoid bullet-by-bullet summaries of each study. Instead, "
"integrate the findings across them to describe the region's "
"location, function, and context."
)
return prompt_intro + prompt_body + prompt_outro
[docs]
def generate_region_image_prompt(
coordinate: list[float] | tuple[float, float, float],
region_info: dict[str, Any],
image_type: str = "anatomical",
include_atlas_labels: bool = True,
prompt_template: str | None = None,
) -> str:
"""Generate a prompt for creating images of brain regions.
Parameters
----------
coordinate : sequence of float
Target MNI coordinate to highlight in the visualization.
region_info : dict
Metadata describing the region, such as ``summary`` and atlas labels.
image_type : str, optional
Template key selecting the style of the image prompt.
include_atlas_labels : bool, optional
Whether atlas label descriptions should be inserted into the prompt.
prompt_template : str, optional
Custom template string overriding the built-in prompt dictionary.
Returns
-------
str
Fully formatted prompt with atlas and study context injected.
"""
# Safely get the summary and a short first paragraph.
summary = region_info.get("summary", "No summary available.")
first_paragraph = summary.split("\n\n", 1)[0]
# Format the coordinate for inclusion in the prompt.
try:
x_val = float(coordinate[0])
y_val = float(coordinate[1])
z_val = float(coordinate[2])
coord_str = f"[{x_val:.2f}, {y_val:.2f}, {z_val:.2f}]"
x_coord = f"{x_val:.0f}"
y_coord = f"{y_val:.0f}"
z_coord = f"{z_val:.0f}"
except Exception:
# Fallback to the raw coordinate representation.
coord_str = str(coordinate)
x_coord = y_coord = z_coord = "0"
# Build atlas context if requested and available.
atlas_context = ""
atlas_labels = region_info.get("atlas_labels") or {}
if include_atlas_labels and isinstance(atlas_labels, dict) and atlas_labels:
atlas_parts = [
f"{atlas_name}: {label}" for atlas_name, label in atlas_labels.items()
]
atlas_context = (
"According to brain atlases, this region corresponds to: "
+ ", ".join(atlas_parts)
+ ". "
)
# Build study context - not used for anatomical images but
# needed for template compatibility
study_context = ""
studies = region_info.get("studies") or []
if studies and image_type in ["functional", "schematic", "artistic", "default"]:
study_lines = []
for i, study in enumerate(studies[:3], 1): # Limit to first 3 studies
title = study.get("title", "").strip()
if title:
study_lines.append(f"Study {i}: {title[:80]}...")
if study_lines:
study_context = "Related research: " + "; ".join(study_lines) + ". "
# If a custom template is provided, use it directly.
if prompt_template:
return prompt_template.format(
coordinate=coord_str,
x_coord=x_coord,
y_coord=y_coord,
z_coord=z_coord,
first_paragraph=first_paragraph,
atlas_context=atlas_context,
study_context=study_context,
)
# Retrieve prompt template by image type or fall back to default.
template = IMAGE_PROMPT_TEMPLATES.get(image_type, IMAGE_PROMPT_TEMPLATES["default"])
return template.format(
coordinate=coord_str,
x_coord=x_coord,
y_coord=y_coord,
z_coord=z_coord,
first_paragraph=first_paragraph,
atlas_context=atlas_context,
study_context=study_context,
)
# ---------------------------------------------------------------------------
# Image generation
# ---------------------------------------------------------------------------
[docs]
def generate_region_image(
ai: "AIModelInterface",
coordinate: list[float] | tuple[float, float, float],
region_info: dict[str, Any],
image_type: str = "anatomical",
model: str = "stabilityai/stable-diffusion-2",
include_atlas_labels: bool = True,
prompt_template: str | None = None,
retries: int = 3,
watermark: bool = True,
**kwargs: Any,
) -> bytes:
"""Generate an image for a brain region using an AI model.
Parameters
----------
ai : AIModelInterface
Interface used to generate images.
coordinate : sequence of float
MNI coordinate for the target region.
region_info : dict
Dictionary containing region summary and atlas labels.
image_type : str, optional
Type of image to generate. Defaults to ``"anatomical"``.
model : str, optional
Name of the AI model to use. Defaults to
``"stabilityai/stable-diffusion-2"``.
include_atlas_labels : bool, optional
Whether to include atlas label context in the prompt. Defaults to ``True``.
prompt_template : str, optional
Custom template overriding default prompts.
retries : int, optional
Number of times to retry generation on failure. Defaults to ``3``.
watermark : bool, optional
When ``True`` (default), a semi-transparent watermark is applied to the
resulting image.
**kwargs : Any
Additional keyword arguments passed to the underlying AI provider.
Returns
-------
bytes
PNG image bytes, optionally watermarked.
"""
prompt = generate_region_image_prompt(
coordinate,
region_info,
image_type=image_type,
include_atlas_labels=include_atlas_labels,
prompt_template=prompt_template,
)
img_bytes = ai.generate_image(model=model, prompt=prompt, retries=retries, **kwargs)
if watermark:
img_bytes = add_watermark(img_bytes)
return img_bytes
# ---------------------------------------------------------------------------
# Summary generation
# ---------------------------------------------------------------------------
[docs]
def generate_summary(
ai: "AIModelInterface",
studies: list[dict[str, Any]],
coordinate: list[float] | tuple[float, float, float],
prompt_type: str = "summary",
model: str = "gemini-2.0-flash",
atlas_labels: dict[str, str] | None = None,
custom_prompt: str | None = None,
max_tokens: int = 1000,
) -> str:
"""Generate a text summary for a coordinate based on studies.
Parameters
----------
ai : AIModelInterface
AI backend used to create the summary.
studies : list of dict
Studies reporting activation at the target coordinate.
coordinate : sequence of float
MNI coordinate around which the summary should focus.
prompt_type : str, optional
Key into :data:`LLM_PROMPT_TEMPLATES`. Use ``"custom"`` with
``custom_prompt`` to provide a bespoke template.
model : str, optional
Name of the text generation model. Defaults to ``"gemini-2.0-flash"``.
atlas_labels : dict, optional
Atlas-derived labels to prepend to the prompt for extra context.
custom_prompt : str, optional
Template string formatted with ``coord`` and ``studies`` placeholders.
max_tokens : int, optional
Maximum number of tokens requested from the language model.
Returns
-------
str
Textual summary returned by the AI model.
"""
# Build base prompt with study information
prompt = generate_llm_prompt(
studies,
coordinate,
prompt_type=prompt_type,
prompt_template=custom_prompt if prompt_type == "custom" else None,
)
# Insert atlas label information when provided
if atlas_labels:
parts = prompt.split("STUDIES REPORTING ACTIVATION AT MNI COORDINATE")
atlas_info = "\nATLAS LABELS FOR THIS COORDINATE:\n"
for atlas_name, label in atlas_labels.items():
atlas_info += f"- {atlas_name}: {label}\n"
if len(parts) >= 2:
intro = parts[0]
rest = "STUDIES REPORTING ACTIVATION AT MNI COORDINATE" + parts[1]
prompt = intro + atlas_info + "\n" + rest
else:
prompt = atlas_info + prompt
# Generate and return the summary using the AI interface with caching
key = (model, prompt)
cache: OrderedDict[tuple[str, str], str] = generate_summary._cache
if SUMMARY_CACHE_SIZE > 0:
cached = cache.get(key)
if cached is not None:
cache.move_to_end(key)
return cached
result = ai.generate_text(model=model, prompt=prompt, max_tokens=max_tokens)
if SUMMARY_CACHE_SIZE > 0:
cache[key] = result
cache.move_to_end(key)
while len(cache) > SUMMARY_CACHE_SIZE:
cache.popitem(last=False)
return result
[docs]
def generate_batch_summaries(
ai: "AIModelInterface",
coord_studies_pairs: list[
tuple[list[float] | tuple[float, float, float], list[dict[str, Any]]]
],
prompt_type: str = "summary",
model: str = "gemini-2.0-flash",
custom_prompt: str | None = None,
max_tokens: int = 1000,
) -> list[str]:
"""Generate summaries for multiple coordinates.
Parameters
----------
ai : AIModelInterface
AI backend used to create the summaries.
coord_studies_pairs : list of tuple
Coordinate-study pairs to summarise.
prompt_type : str, optional
Template key used for each summary prompt.
model : str, optional
Model used for text generation. Defaults to ``"gemini-2.0-flash"``.
custom_prompt : str, optional
Template string overriding the built-in prompt for every coordinate.
max_tokens : int, optional
Maximum tokens requested from each AI call.
Returns
-------
list of str
Generated summaries for the provided coordinate pairs.
"""
if not coord_studies_pairs:
return []
if not ai.supports_batching(model):
return [
generate_summary(
ai,
studies,
coord,
prompt_type=prompt_type,
model=model,
custom_prompt=custom_prompt,
max_tokens=max_tokens,
)
for coord, studies in coord_studies_pairs
]
delimiter = "\n@@@\n"
prompts: list[str] = []
for coord, studies in coord_studies_pairs:
prompts.append(
generate_llm_prompt(
studies,
coord,
prompt_type=prompt_type,
prompt_template=custom_prompt if prompt_type == "custom" else None,
)
)
combined_prompt = (
"Provide separate summaries for each coordinate below. "
f"Separate each summary with the delimiter '{delimiter.strip()}'.\n\n"
+ delimiter.join(prompts)
)
key = (model, combined_prompt)
cache: OrderedDict[tuple[str, str], list[str]] = generate_batch_summaries._cache
if SUMMARY_CACHE_SIZE > 0:
cached = cache.get(key)
if cached is not None:
cache.move_to_end(key)
return cached
response = ai.generate_text(
model=model, prompt=combined_prompt, max_tokens=max_tokens
)
results = [part.strip() for part in response.split(delimiter) if part.strip()]
if SUMMARY_CACHE_SIZE > 0:
cache[key] = results
cache.move_to_end(key)
while len(cache) > SUMMARY_CACHE_SIZE:
cache.popitem(last=False)
return results
[docs]
async def generate_summary_async(
ai: "AIModelInterface",
studies: list[dict[str, Any]],
coordinate: list[float] | tuple[float, float, float],
prompt_type: str = "summary",
model: str = "gemini-2.0-flash",
atlas_labels: dict[str, str] | None = None,
custom_prompt: str | None = None,
max_tokens: int = 1000,
) -> str:
"""Asynchronously generate a text summary for a coordinate.
Parameters
----------
ai : AIModelInterface
AI backend used to create the summary asynchronously.
studies : list of dict
Studies reporting activation at the target coordinate.
coordinate : sequence of float
MNI coordinate for the summary.
prompt_type : str, optional
Prompt template key defaulting to ``"summary"``.
model : str, optional
Model name, defaulting to ``"gemini-2.0-flash"``.
atlas_labels : dict, optional
Atlas-derived labels to include in the prompt.
custom_prompt : str, optional
User-supplied template applied via ``str.format``.
max_tokens : int, optional
Maximum number of tokens requested for the summary.
Returns
-------
str
Generated summary text.
"""
prompt = generate_llm_prompt(
studies,
coordinate,
prompt_type=prompt_type,
prompt_template=custom_prompt if prompt_type == "custom" else None,
)
if atlas_labels:
parts = prompt.split("STUDIES REPORTING ACTIVATION AT MNI COORDINATE")
atlas_info = "\nATLAS LABELS FOR THIS COORDINATE:\n"
for atlas_name, label in atlas_labels.items():
atlas_info += f"- {atlas_name}: {label}\n"
if len(parts) >= 2:
intro = parts[0]
rest = "STUDIES REPORTING ACTIVATION AT MNI COORDINATE" + parts[1]
prompt = intro + atlas_info + "\n" + rest
else:
prompt = atlas_info + prompt
key = (model, prompt)
cache: OrderedDict[tuple[str, str], str] = generate_summary_async._cache
if SUMMARY_CACHE_SIZE > 0:
cached = cache.get(key)
if cached is not None:
cache.move_to_end(key)
return cached
result = await ai.generate_text_async(
model=model, prompt=prompt, max_tokens=max_tokens
)
if SUMMARY_CACHE_SIZE > 0:
cache[key] = result
cache.move_to_end(key)
while len(cache) > SUMMARY_CACHE_SIZE:
cache.popitem(last=False)
return result
[docs]
def stream_summary(
ai: "AIModelInterface",
studies: list[dict[str, Any]],
coordinate: list[float] | tuple[float, float, float],
prompt_type: str = "summary",
model: str = "gemini-2.0-flash",
atlas_labels: dict[str, str] | None = None,
custom_prompt: str | None = None,
max_tokens: int = 1000,
) -> Iterator[str]:
"""Stream a text summary for a coordinate in chunks.
Parameters
----------
ai : AIModelInterface
Streaming AI backend used to generate the summary.
studies : list of dict
Studies reporting activation at the target coordinate.
coordinate : sequence of float
MNI coordinate for the summary.
prompt_type : str, optional
Prompt template key defaulting to ``"summary"``.
model : str, optional
Model name, defaulting to ``"gemini-2.0-flash"``.
atlas_labels : dict, optional
Atlas-derived labels to include in the prompt.
custom_prompt : str, optional
User-supplied template applied via ``str.format``.
max_tokens : int, optional
Maximum number of tokens requested for the summary.
Returns
-------
iterator of str
Chunks of text yielded by the streaming AI backend.
"""
prompt = generate_llm_prompt(
studies,
coordinate,
prompt_type=prompt_type,
prompt_template=custom_prompt if prompt_type == "custom" else None,
)
if atlas_labels:
parts = prompt.split("STUDIES REPORTING ACTIVATION AT MNI COORDINATE")
atlas_info = "\nATLAS LABELS FOR THIS COORDINATE:\n"
for atlas_name, label in atlas_labels.items():
atlas_info += f"- {atlas_name}: {label}\n"
if len(parts) >= 2:
intro = parts[0]
rest = "STUDIES REPORTING ACTIVATION AT MNI COORDINATE" + parts[1]
prompt = intro + atlas_info + "\n" + rest
else:
prompt = atlas_info + prompt
key = (model, prompt)
cache: OrderedDict[tuple[str, str], list[str]] = stream_summary._cache
if SUMMARY_CACHE_SIZE > 0:
cached_chunks = cache.get(key)
if cached_chunks is not None:
cache.move_to_end(key)
for chunk in cached_chunks:
yield chunk
return
chunks: list[str] = []
try:
for chunk in ai.stream_generate_text(
model=model, prompt=prompt, max_tokens=max_tokens
):
chunks.append(chunk)
yield chunk
finally:
if SUMMARY_CACHE_SIZE > 0 and chunks:
cache[key] = chunks
cache.move_to_end(key)
while len(cache) > SUMMARY_CACHE_SIZE:
cache.popitem(last=False)
generate_summary._cache = OrderedDict()
generate_batch_summaries._cache = OrderedDict()
generate_summary_async._cache = OrderedDict()
stream_summary._cache = OrderedDict()
__all__ = [
"LLM_PROMPT_TEMPLATES",
"IMAGE_PROMPT_TEMPLATES",
"generate_llm_prompt",
"generate_region_image_prompt",
"generate_region_image",
"generate_mni152_image",
"generate_summary",
"generate_batch_summaries",
"generate_summary_async",
"stream_summary",
]