Source code for coord2region.pipeline

"""High-level analysis pipeline for Coord2Region.

This module exposes a single convenience function :func:`run_pipeline` which
coordinates the existing building blocks in the package to provide an
end-to-end workflow. Users can submit coordinates or region names and request
different types of outputs such as atlas labels, textual summaries, generated
images and the raw study metadata.

The implementation builds directly on the lower-level modules in the package.
Atlas lookups are performed via :mod:`coord2region.coord2region`, studies are
retrieved using :mod:`coord2region.coord2study`, and text or image generation is
handled through :mod:`coord2region.llm`.

The function also supports exporting the produced results to a variety of
formats.
"""

from __future__ import annotations

import asyncio
import json
import logging
import os
import pickle
from collections.abc import Callable, Sequence
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, cast

from .ai_model_interface import AIModelInterface
from .coord2region import MultiAtlasMapper
from .coord2study import get_studies_for_coordinate, prepare_datasets
from .fetching import AtlasFetcher  # noqa: F401 - used by tests via patching
from .llm import (
    generate_mni152_image,
    generate_region_image,
    generate_summary,
    generate_summary_async,
)
from .utils import resolve_working_directory
from .utils.file_handler import save_as_csv, save_as_pdf, save_batch_folder


@dataclass
[docs] class PipelineResult: """Structured container returned by :func:`run_pipeline`. Parameters ---------- coordinate : Optional[List[float]] Coordinate associated with this result (if available). mni_coordinates : Optional[List[float]] Representative MNI coordinate resolved from region name inputs when requested via ``outputs``. region_labels : Dict[str, str] Atlas region labels keyed by atlas name. summaries : Dict[str, str] Mapping of language-model identifiers to their generated summaries. summary : Optional[str] Primary summary (first entry in :attr:`summaries`) kept for backward compatibility. studies : List[Dict[str, Any]] Raw study metadata dictionaries. image : Optional[str] Primary image path (first generated), kept for backward compatibility. images : Dict[str, str] Mapping of image backend names to generated image paths. warnings : List[str] Non-fatal issues encountered while processing the input item. """
[docs] coordinate: list[float] | None = None
[docs] mni_coordinates: list[float] | None = None
[docs] region_labels: dict[str, str] = field(default_factory=dict)
[docs] summaries: dict[str, str] = field(default_factory=dict)
[docs] summary: str | None = None
[docs] studies: list[dict[str, Any]] = field(default_factory=list)
[docs] image: str | None = None
[docs] images: dict[str, str] = field(default_factory=dict)
[docs] warnings: list[str] = field(default_factory=list)
def _normalize_model_list(value: Any) -> list[str]: """Coerce a config value into a list of unique model identifiers.""" if value is None: return [] if isinstance(value, str): candidates = [value] elif isinstance(value, Sequence): candidates = list(value) else: candidates = [value] normalized: list[str] = [] for item in candidates: if item is None: continue name = str(item).strip() if name and name not in normalized: normalized.append(name) return normalized def _get_summary_models(config: dict[str, Any], default_model: str) -> list[str]: """Return the ordered list of summary models honoring config defaults.""" raw = config.get("summary_models") models = _normalize_model_list(raw) if not models and default_model: models = [default_model] return models def _export_results(results: list[PipelineResult], fmt: str, path: str) -> None: """Export pipeline results to the requested format.""" dict_results = [asdict(r) for r in results] if fmt in {"json", "pickle"}: os.makedirs(os.path.dirname(os.path.abspath(path)) or ".", exist_ok=True) if fmt == "json": with open(path, "w", encoding="utf8") as f: json.dump(dict_results, f, indent=2) return if fmt == "pickle": with open(path, "wb") as f: pickle.dump(dict_results, f) return if fmt == "csv": save_as_csv(results, path) return if fmt == "pdf": save_as_pdf(results, path) return if fmt == "directory": save_batch_folder(results, path) return raise ValueError(f"Unknown export format: {fmt}")
[docs] def run_pipeline( inputs: Sequence[Any], input_type: str, outputs: Sequence[str], output_format: str | None = None, output_name: str | None = None, image_backend: str = "ai", *, config: dict[str, Any] | None = None, async_mode: bool = False, progress_callback: Callable[[int, int, PipelineResult], None] | None = None, ) -> list[PipelineResult]: """Run the Coord2Region analysis pipeline. Parameters ---------- inputs : sequence Iterable containing the inputs. The interpretation depends on ``input_type``. input_type : {"coords", "region_names"} Specifies how to treat ``inputs``. outputs : sequence of {"region_labels", "summaries", "images", "raw_studies", "mni_coordinates"} Requested pieces of information for each input item. The ``"mni_coordinates"`` option is only supported when ``input_type == "region_names"``. output_format : {"json", "pickle", "csv", "pdf", "directory"}, optional When provided, results are exported to the specified format. output_name : str, optional File or directory name to use when exporting results. The name is created inside the working directory's ``results`` subfolder. Required when ``output_format`` is specified. image_backend : {"ai", "nilearn", "both"}, optional Backend used to generate images when ``"images"`` is requested. prompt_template : str, optional Template to use for AI image generation prompts. One of: " 'anatomical', 'functional', 'schematic', 'artistic', or 'custom'. async_mode : bool, optional When ``True``, processing occurs concurrently using asyncio and summaries are generated with :func:`generate_summary_async`. progress_callback : callable, optional Function invoked after each input is processed. Receives the number of completed items, the total count and the :class:`PipelineResult` for the processed item. When ``None``, progress is logged via ``logging``. Returns ------- list of :class:`PipelineResult` One result object per item in ``inputs``. """ input_type = input_type.lower() if input_type not in {"coords", "region_names"}: raise ValueError("input_type must be 'coords' or 'region_names'") outputs = [o.lower() for o in outputs] base_outputs = {"region_labels", "summaries", "images", "raw_studies"} valid_outputs = set(base_outputs) if input_type == "region_names": valid_outputs.add("mni_coordinates") invalid_outputs = sorted(set(outputs) - valid_outputs) if invalid_outputs: raise ValueError( "outputs must be a subset of " f"{sorted(valid_outputs)} for input_type='{input_type}'" ) if output_format and output_name is None: raise ValueError("output_name must be provided when output_format is set") image_backend = image_backend.lower() if image_backend not in {"ai", "nilearn", "both"}: raise ValueError("image_backend must be 'ai', 'nilearn' or 'both'") if async_mode: return asyncio.run( _run_pipeline_async( inputs, input_type, outputs, output_format, output_name, image_backend=image_backend, config=config, progress_callback=progress_callback, ) ) kwargs = config or {} study_search_radius = float(kwargs.get("study_search_radius", 0)) region_radius_value = kwargs.get("region_search_radius") region_search_radius = ( float(region_radius_value) if region_radius_value is not None else None ) # unified sources control both dataset preparation and study search sources = kwargs.get("sources") summary_models = _get_summary_models(kwargs, default_model="gpt-4o-mini") prompt_type = kwargs.get("prompt_type") or "summary" custom_prompt = kwargs.get("custom_prompt") summary_max_tokens = kwargs.get("summary_max_tokens", 1000) working_dir = resolve_working_directory(kwargs.get("working_directory")) working_dir.mkdir(parents=True, exist_ok=True) cache_dir = working_dir / "cached_data" image_dir = working_dir / "generated_images" results_dir = working_dir / "results" for p in (cache_dir, image_dir, results_dir): p.mkdir(parents=True, exist_ok=True) export_path: Path | None = None if output_format: output_label = cast(str, output_name) name_path = Path(output_label) if not output_label or name_path.name != output_label: message = ( "output_name must be a single file or directory name without path" " separators" ) raise ValueError(message) export_path = results_dir / output_label email = kwargs.get("email_for_abstracts") use_cached_dataset = kwargs.get("use_cached_dataset", True) atlas_names = kwargs.get("atlas_names", ["harvard-oxford", "juelich", "aal"]) provider_configs = kwargs.get("providers") gemini_api_key = kwargs.get("gemini_api_key") openrouter_api_key = kwargs.get("openrouter_api_key") openai_api_key = kwargs.get("openai_api_key") openai_project = kwargs.get("openai_project") anthropic_api_key = kwargs.get("anthropic_api_key") huggingface_api_key = kwargs.get("huggingface_api_key") image_model = kwargs.get("image_model", "stabilityai/stable-diffusion-2") image_prompt_type = kwargs.get("image_prompt_type") or "anatomical" image_custom_prompt = kwargs.get("image_custom_prompt") dataset = ( prepare_datasets(str(working_dir), sources=sources) if use_cached_dataset else None ) ai = None if provider_configs: ai = AIModelInterface() for name, cfg in provider_configs.items(): ai.register_provider(name, **cfg) elif any( [ gemini_api_key, openrouter_api_key, openai_api_key, anthropic_api_key, huggingface_api_key, ] ): ai = AIModelInterface( gemini_api_key=gemini_api_key, openrouter_api_key=openrouter_api_key, openai_api_key=openai_api_key, openai_project=openai_project, anthropic_api_key=anthropic_api_key, huggingface_api_key=huggingface_api_key, ) atlas_configs = kwargs.get("atlas_configs") or {} atlas_dict = {name: dict(atlas_configs.get(name, {})) for name in atlas_names or []} if not atlas_dict: raise ValueError( "At least one atlas name must be provided to run the pipeline." ) try: multi_atlas: MultiAtlasMapper = MultiAtlasMapper(str(working_dir), atlas_dict) except Exception as exc: # pragma: no cover - defensive guard raise RuntimeError("Failed to initialize atlas mappers") from exc def _from_region_name(name: str) -> list[float] | None: coords_dict = multi_atlas.batch_region_name_to_mni([name]) for atlas_coords in coords_dict.values(): if atlas_coords: coord = atlas_coords[0] if coord is not None: try: return coord.tolist() # type: ignore[attr-defined] except Exception: return list(coord) # type: ignore[arg-type] return None results: list[PipelineResult] = [] for item in inputs: region_name_input: str | None = None if input_type == "coords": coord = list(item) if item is not None else None elif input_type == "region_names": region_name_input = str(item) coord = _from_region_name(region_name_input) else: # only "coords" or "region_names" supported coord = None res = PipelineResult(coordinate=coord) if coord is not None and "mni_coordinates" in outputs: res.mni_coordinates = list(coord) # No special case for input_type "studies" (unsupported) if coord is None: if region_name_input is not None: message = ( f"Region '{region_name_input}' could not be resolved to " "coordinates with the configured atlases." ) logging.warning(message) res.warnings.append(message) results.append(res) if progress_callback: progress_callback(len(results), len(inputs), res) else: logging.info("Processed %d/%d inputs", len(results), len(inputs)) continue if "region_labels" in outputs: try: batch = multi_atlas.batch_mni_to_region_names( [coord], max_distance=region_search_radius ) # Extract first match per atlas res.region_labels = { atlas: (names[0] if names else "Unknown") for atlas, names in batch.items() } except Exception: res.region_labels = {} if ("raw_studies" in outputs or "summaries" in outputs) and dataset is not None: try: if isinstance(dataset, dict): res.studies = get_studies_for_coordinate( dataset, coord, radius=study_search_radius, email=email, sources=sources, ) else: # When using a single deduplicated Dataset (not a mapping), # do not pass 'sources' to avoid filtering out the combined set. res.studies = get_studies_for_coordinate( dataset, coord, radius=study_search_radius, email=email ) except Exception: res.studies = [] if "summaries" in outputs and ai and summary_models: for model_idx, model_name in enumerate(summary_models): summary_text = generate_summary( ai, res.studies, coord, prompt_type=prompt_type, model=model_name, atlas_labels=res.region_labels or None, custom_prompt=custom_prompt if prompt_type == "custom" else None, max_tokens=summary_max_tokens, ) res.summaries[model_name] = summary_text if model_idx == 0: res.summary = summary_text if "images" in outputs: img_dir = image_dir os.makedirs(img_dir, exist_ok=True) if image_backend in {"ai", "both"} and ai: region_info = { "summary": res.summary or "", "atlas_labels": res.region_labels, } try: img_bytes = generate_region_image( ai, coord, region_info, image_type=image_prompt_type, model=image_model, watermark=True, prompt_template=( image_custom_prompt if image_prompt_type == "custom" else None ), ) img_path = img_dir / f"image_{len(list(img_dir.iterdir())) + 1}.png" with open(img_path, "wb") as f: f.write(img_bytes) res.image = res.image or str(img_path) res.images["ai"] = str(img_path) except Exception: pass if image_backend in {"nilearn", "both"}: try: img_bytes = generate_mni152_image(coord) img_path = img_dir / f"image_{len(list(img_dir.iterdir())) + 1}.png" with open(img_path, "wb") as f: f.write(img_bytes) if res.image is None: res.image = str(img_path) res.images["nilearn"] = str(img_path) except Exception: pass results.append(res) if progress_callback: progress_callback(len(results), len(inputs), res) else: logging.info("Processed %d/%d inputs", len(results), len(inputs)) if output_format and export_path is not None: _export_results(results, output_format.lower(), str(export_path)) return results
async def _run_pipeline_async( inputs: Sequence[Any], input_type: str, outputs: Sequence[str], output_format: str | None, output_name: str | None, image_backend: str, *, config: dict[str, Any] | None, progress_callback: Callable[[int, int, PipelineResult], None] | None, ) -> list[PipelineResult]: """Asynchronous implementation backing :func:`run_pipeline`.""" kwargs = config or {} study_search_radius = float(kwargs.get("study_search_radius", 0)) region_radius_value = kwargs.get("region_search_radius") region_search_radius = ( float(region_radius_value) if region_radius_value is not None else None ) # unified sources control both dataset preparation and study search sources = kwargs.get("sources") summary_models = _get_summary_models(kwargs, default_model="gpt-4o-mini") prompt_type = kwargs.get("prompt_type") or "summary" custom_prompt = kwargs.get("custom_prompt") summary_max_tokens = kwargs.get("summary_max_tokens", 1000) working_dir = resolve_working_directory(kwargs.get("working_directory")) working_dir.mkdir(parents=True, exist_ok=True) cache_dir = working_dir / "cached_data" image_dir = working_dir / "generated_images" results_dir = working_dir / "results" for p in (cache_dir, image_dir, results_dir): p.mkdir(parents=True, exist_ok=True) export_path: Path | None = None if output_format: output_label = cast(str, output_name) name_path = Path(output_label) if not output_label or name_path.name != output_label: message = ( "output_name must be a single file or directory name without path" " separators" ) raise ValueError(message) export_path = results_dir / output_label email = kwargs.get("email_for_abstracts") use_cached_dataset = kwargs.get("use_cached_dataset", True) atlas_names = kwargs.get("atlas_names", ["harvard-oxford", "juelich", "aal"]) provider_configs = kwargs.get("providers") gemini_api_key = kwargs.get("gemini_api_key") openrouter_api_key = kwargs.get("openrouter_api_key") openai_api_key = kwargs.get("openai_api_key") openai_project = kwargs.get("openai_project") anthropic_api_key = kwargs.get("anthropic_api_key") huggingface_api_key = kwargs.get("huggingface_api_key") image_model = kwargs.get("image_model", "stabilityai/stable-diffusion-2") image_prompt_type = kwargs.get("image_prompt_type") or "anatomical" image_custom_prompt = kwargs.get("image_custom_prompt") dataset = ( await asyncio.to_thread(prepare_datasets, str(working_dir), sources) if use_cached_dataset else None ) ai = None if provider_configs: ai = AIModelInterface() for name, cfg in provider_configs.items(): ai.register_provider(name, **cfg) elif any( [ gemini_api_key, openrouter_api_key, openai_api_key, anthropic_api_key, huggingface_api_key, ] ): ai = AIModelInterface( gemini_api_key=gemini_api_key, openrouter_api_key=openrouter_api_key, openai_api_key=openai_api_key, openai_project=openai_project, anthropic_api_key=anthropic_api_key, huggingface_api_key=huggingface_api_key, ) atlas_configs = kwargs.get("atlas_configs") or {} atlas_dict = {name: dict(atlas_configs.get(name, {})) for name in atlas_names or []} if not atlas_dict: raise ValueError( "At least one atlas name must be provided to run the pipeline." ) try: multi_atlas: MultiAtlasMapper = MultiAtlasMapper(str(working_dir), atlas_dict) except Exception as exc: # pragma: no cover - defensive guard raise RuntimeError("Failed to initialize atlas mappers") from exc def _from_region_name(name: str) -> list[float] | None: coords_dict = multi_atlas.batch_region_name_to_mni([name]) for atlas_coords in coords_dict.values(): if atlas_coords: coord = atlas_coords[0] if coord is not None: try: return coord.tolist() # type: ignore[attr-defined] except Exception: return list(coord) # type: ignore[arg-type] return None total = len(inputs) results: list[PipelineResult | None] = [None] * total async def _process(idx: int, item: Any) -> tuple[int, PipelineResult]: region_name_input: str | None = None if input_type == "coords": coord = list(item) if item is not None else None elif input_type == "region_names": region_name_input = str(item) coord = await asyncio.to_thread(_from_region_name, region_name_input) else: # only "coords" or "region_names" supported coord = None res = PipelineResult(coordinate=coord) if coord is not None and "mni_coordinates" in outputs: res.mni_coordinates = list(coord) if coord is None: if region_name_input is not None: message = ( f"Region '{region_name_input}' could not be resolved to " "coordinates with the configured atlases." ) logging.warning(message) res.warnings.append(message) return idx, res if "region_labels" in outputs: try: batch = await asyncio.to_thread( multi_atlas.batch_mni_to_region_names, [coord], max_distance=region_search_radius, ) res.region_labels = { atlas: (names[0] if names else "Unknown") for atlas, names in batch.items() } except Exception: res.region_labels = {} if ("raw_studies" in outputs or "summaries" in outputs) and dataset is not None: try: if isinstance(dataset, dict): res.studies = await asyncio.to_thread( lambda: get_studies_for_coordinate( dataset, coord, radius=study_search_radius, email=email, sources=sources, ) ) else: res.studies = await asyncio.to_thread( lambda: get_studies_for_coordinate( dataset, coord, radius=study_search_radius, email=email ) ) except Exception: res.studies = [] if "summaries" in outputs and ai and summary_models: for model_idx, model_name in enumerate(summary_models): summary_text = await generate_summary_async( ai, res.studies, coord, prompt_type=prompt_type, model=model_name, atlas_labels=res.region_labels or None, custom_prompt=custom_prompt if prompt_type == "custom" else None, max_tokens=summary_max_tokens, ) res.summaries[model_name] = summary_text if model_idx == 0: res.summary = summary_text if "images" in outputs: img_dir = image_dir os.makedirs(img_dir, exist_ok=True) if image_backend in {"ai", "both"} and ai: region_info = { "summary": res.summary or "", "atlas_labels": res.region_labels, } def _save_ai_image() -> str: img_bytes = generate_region_image( ai, coord, region_info, image_type=image_prompt_type, model=image_model, watermark=True, prompt_template=( image_custom_prompt if image_prompt_type == "custom" else None ), ) img_path = img_dir / f"image_{len(list(img_dir.iterdir())) + 1}.png" with open(img_path, "wb") as f: f.write(img_bytes) return str(img_path) try: path = await asyncio.to_thread(_save_ai_image) res.image = res.image or path res.images["ai"] = path except Exception: pass if image_backend in {"nilearn", "both"}: def _save_nilearn_image() -> str: img_bytes = generate_mni152_image(coord) img_path = img_dir / f"image_{len(list(img_dir.iterdir())) + 1}.png" with open(img_path, "wb") as f: f.write(img_bytes) return str(img_path) try: path = await asyncio.to_thread(_save_nilearn_image) if res.image is None: res.image = path res.images["nilearn"] = path except Exception: pass return idx, res tasks = [asyncio.create_task(_process(i, item)) for i, item in enumerate(inputs)] completed = 0 for fut in asyncio.as_completed(tasks): idx, res = await fut results[idx] = res completed += 1 if progress_callback: progress_callback(completed, total, res) else: logging.info("Processed %d/%d inputs", completed, total) final_results = [r for r in results if r is not None] if output_format and export_path is not None: await asyncio.to_thread( _export_results, final_results, output_format.lower(), str(export_path), ) return final_results