Source code for coord2region.ai_reports

"""High-level helpers for generating structured AI reports and image specs.

This module centralises the logic that was previously spread across gallery
examples, making it easier to build real applications around the AI features.
"""

from __future__ import annotations

import json
import re
import time
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any

from .ai_model_interface import AIModelInterface, build_generation_summary

[docs] DEFAULT_SYSTEM_MESSAGE = ( "You are a careful neuroscience assistant. You convert MNI brain coordinates " "into clear, evidence-grounded explanations using ONLY the data provided in the " "prompt (atlas labels, neighbors, and the study list).\n" 'If something is missing or conflicting, explicitly report "insufficient_evidence" ' "for that part--do not invent facts or citations." )
[docs] DEFAULT_NEGATIVE_PROMPT = ( "cartoon, abstract art, texture noise, extra brains, low resolution, blurry, " "distorted anatomy, overexposed, text blocks covering the brain, bright " "multicolor palettes, dramatic lighting, artistic shadows" )
@dataclass
[docs] class ReasonedReportContext: """Structured payload describing the coordinate, atlas, and studies."""
[docs] coordinate_mni: Sequence[float]
[docs] hemisphere: str | None = None
[docs] boundary_proximity_mm: float | None = None
[docs] atlas: dict[str, Any] = field(default_factory=dict)
[docs] atlas_notes: list[str] = field(default_factory=list)
[docs] studies: list[dict[str, Any]] = field(default_factory=list)
[docs] allowed_domains: Sequence[str] | None = None
[docs] format_instructions: list[str] = field(default_factory=list)
@dataclass
[docs] class ReasonedReport: """Parsed result returned by :func:`run_reasoned_report`."""
[docs] narrative: str
[docs] json_text: str | None = None
[docs] json_data: dict[str, Any] | None = None
[docs] json_error: str | None = None
[docs] def infer_hemisphere(coord: Sequence[float]) -> str: """Infer the hemisphere from an MNI coordinate. Parameters ---------- coord : sequence of float MNI coordinate to evaluate. Returns ------- str ``"left"``, ``"right"``, ``"midline"``, or ``"unknown"`` depending on the x coordinate. """ if not coord: return "unknown" x = float(coord[0]) if x > 3: return "right" if x < -3: return "left" if -3 <= x <= 3: return "midline" return "unknown"
def _context_to_payload(context: ReasonedReportContext) -> dict[str, Any]: """Convert :class:`ReasonedReportContext` to a JSON-serialisable payload. Parameters ---------- context : ReasonedReportContext Structured context describing the coordinate, atlas, and studies. Returns ------- dict Flattened payload ready for inclusion in a user prompt. """ hemisphere = context.hemisphere or infer_hemisphere(context.coordinate_mni) payload: dict[str, Any] = { "coordinate_mni": list(context.coordinate_mni), "hemisphere": hemisphere, } if context.boundary_proximity_mm is not None: payload["boundary_proximity_mm"] = context.boundary_proximity_mm if context.atlas: payload["atlas"] = context.atlas if context.atlas_notes: payload["atlas_notes"] = list(context.atlas_notes) if context.studies: payload["studies"] = list(context.studies) if context.allowed_domains is not None: payload["allowed_domains"] = list(context.allowed_domains) if context.format_instructions: payload["format_instructions"] = list(context.format_instructions) return payload
[docs] def build_reasoned_report_messages( context: ReasonedReportContext, *, system_message: str = DEFAULT_SYSTEM_MESSAGE, max_words: int = 180, ) -> list[dict[str, str]]: """Construct chat messages for the reasoned report prompt. Parameters ---------- context : ReasonedReportContext Context describing the coordinate, atlas, and studies. system_message : str, optional System prompt guiding the AI assistant. max_words : int, optional Maximum narrative word count requested from the assistant. Returns ------- list of dict Stream-ready chat messages for the AI request. """ payload = _context_to_payload(context) user_prompt = ( "Coordinate context for the Coord2Region reasoned report:\n" f"{json.dumps(payload, indent=2, sort_keys=True)}\n\n" f"Return the narrative (<= {max_words} words) followed immediately by the " "STRICT JSON object in ```json fences. Do not include extra commentary or " "surrounding text." ) return [ {"role": "system", "content": system_message}, {"role": "user", "content": user_prompt}, ]
[docs] def parse_reasoned_report_output(output: str) -> ReasonedReport: """Split the reasoned report narrative and JSON payload. Parameters ---------- output : str Raw text returned by the AI assistant. Returns ------- ReasonedReport Parsed narrative along with optional JSON payload information. """ narrative = output.strip() json_text: str | None = None json_data: dict[str, Any] | None = None json_error: str | None = None fenced_pattern = re.compile(r"```json\s*(\{.*?\})\s*```", re.DOTALL | re.IGNORECASE) fenced_match = fenced_pattern.search(output) if fenced_match: json_text = fenced_match.group(1).strip() narrative = output[: fenced_match.start()].strip() try: json_data = json.loads(json_text) except json.JSONDecodeError as exc: json_error = f"{exc.__class__.__name__}: {exc}" else: first_brace = output.find("{") if first_brace != -1: tail = output[first_brace:] last_brace = tail.rfind("}") if last_brace != -1: candidate = tail[: last_brace + 1].strip() if candidate: json_text = candidate narrative = output[:first_brace].strip() try: json_data = json.loads(candidate) except json.JSONDecodeError as exc: json_error = f"{exc.__class__.__name__}: {exc}" if json_text is None: json_error = "JSON block not found." return ReasonedReport( narrative=narrative, json_text=json_text, json_data=json_data, json_error=json_error, )
[docs] def run_reasoned_report( ai: AIModelInterface, model: str, context: ReasonedReportContext, *, max_tokens: int = 512, retries: int = 3, system_message: str = DEFAULT_SYSTEM_MESSAGE, ) -> tuple[ReasonedReport, dict[str, Any]]: """Generate and parse a reasoned report, returning metadata alongside. Parameters ---------- ai : AIModelInterface AI interface used to generate text. model : str Model name to use for the completion. context : ReasonedReportContext Structured context describing the coordinate, atlas, and studies. max_tokens : int, optional Maximum number of tokens requested from the model. retries : int, optional Number of retry attempts for the text generation call. system_message : str, optional System message that guides the assistant's tone and scope. Returns ------- tuple Tuple containing the parsed :class:`ReasonedReport` and metadata. """ messages = build_reasoned_report_messages(context, system_message=system_message) start = time.perf_counter() completion = ai.generate_text( model=model, prompt=messages, max_tokens=max_tokens, retries=retries, ) duration = time.perf_counter() - start provider = ai.provider_name(model) summary = build_generation_summary(model, completion, provider) report = parse_reasoned_report_output(completion) metadata = { "model": model, "provider": provider, "duration_s": duration, "summary": summary, "raw_text": completion, "messages": messages, } return report, metadata
[docs] def build_region_image_request( coord: Sequence[float], context: ReasonedReportContext, *, sphere_radius_mm: float = 6, negative_prompt: str = DEFAULT_NEGATIVE_PROMPT, ) -> dict[str, Any]: """Return prompt components for :meth:`AIModelInterface.generate_image`. Parameters ---------- coord : sequence of float Target MNI coordinate for the image. context : ReasonedReportContext Context used to describe the coordinate and atlas. sphere_radius_mm : float, optional Radius of the spherical highlight in millimetres. negative_prompt : str, optional Negative prompt guiding the image generation. Returns ------- dict Dictionary containing both structured specification and text prompts. """ atlas = context.atlas or {} hemisphere = context.hemisphere or atlas.get("hemisphere") if not hemisphere: hemisphere = infer_hemisphere(coord) primary_label = atlas.get("primary_label") or atlas.get("label") or "Unknown Region" coord_list = [float(value) for value in coord] coord_int = [int(round(value)) for value in coord_list] coord_str = f"[{coord_int[0]},{coord_int[1]},{coord_int[2]}]" spec = { "figure_goal": ( "Show the anatomical location of the given MNI coordinate on an " "MNI152-like template with overlays." ), "coordinate_mni": coord_int, "atlas": { "name": atlas.get("name", "Unknown Atlas"), "version": atlas.get("version", "unknown"), "primary_label": primary_label, "hemisphere": hemisphere, }, "views": [ {"plane": "axial", "z_mm": coord_int[2]}, {"plane": "coronal", "y_mm": coord_int[1]}, {"plane": "sagittal", "x_mm": coord_int[0]}, ], "overlays": [ {"type": "crosshair", "thickness": "thin"}, {"type": "sphere", "radius_mm": sphere_radius_mm, "opacity": 0.65}, ], "annotations": [ {"text": f"{primary_label} ({hemisphere})", "anchor": "near_coordinate"}, {"text": f"MNI {coord_str}", "anchor": "lower_left"}, ], "style": { "figure_type": "clean medical figure", "background": "T1-weighted MNI152 appearance", "palette": "grayscale anatomy with a single red overlay", "resolution": "1024x1024", "layout": "3-view grid (axial, coronal, sagittal)", "watermark": "Illustrative", }, "constraints": [ "No diagnostic claims.", "No artistic textures.", "Crisp labels and legible fonts.", "Avoid cartoonish elements.", ], } positive_prompt = ( "Publication-quality structural brain figure on an MNI152 template, " "3-view grid (axial, coronal, sagittal), grayscale T1 anatomy, " f"bold red spherical highlight (radius {sphere_radius_mm} mm) centered at " f"MNI {coord_str}, thin white crosshair at the coordinate, " f'labels: "{primary_label} ({hemisphere})" near the highlight ' f'and "MNI {coord_str}" bottom-left, ' "clean medical styling, high contrast, crisp lines, 1024x1024," "subtle 'Illustrative' watermark in a corner." ) return { "spec": spec, "positive_prompt": positive_prompt, "negative_prompt": negative_prompt, }
__all__ = [ "ReasonedReportContext", "ReasonedReport", "DEFAULT_SYSTEM_MESSAGE", "DEFAULT_NEGATIVE_PROMPT", "infer_hemisphere", "build_reasoned_report_messages", "parse_reasoned_report_output", "run_reasoned_report", "build_region_image_request", ]