Source code for coord2region.utils.utils

"""Utility functions for handling atlas files and labels.

This module provides functions to parse label files, load atlas volumes,
and manage surface-based atlases using FreeSurfer annotations.
"""

import os
from pathlib import Path

import numpy as np


[docs] def fetch_labels(labels): """Parse a labels input. A list is returned as-is. If ``labels`` is a string, it is treated as the path to an XML file and parsed for entries within a ``<data><label><name>...</name></label></data>`` structure. Parameters ---------- labels : list of str or str A list of label names or a path to an XML file containing labels. Returns ------- list of str Parsed label names. Raises ------ ValueError If the XML file is invalid, cannot be parsed, contains no labels, or if ``labels`` is neither a list nor a string. Examples -------- >>> fetch_labels(['A', 'B']) ['A', 'B'] >>> fetch_labels('atlas.xml') # doctest: +SKIP ['Region1', 'Region2'] """ if isinstance(labels, str): import xml.etree.ElementTree as ET try: tree = ET.parse(labels) root = tree.getroot() data = root.find("data") if data is None: raise ValueError("Invalid XML file: missing 'data' element.") label_list = [] for label in data.findall("label"): name_elem = label.find("name") if name_elem is not None: label_list.append(name_elem.text) if not label_list: raise ValueError("No labels found in the XML file.") return label_list except Exception as e: raise ValueError(f"Error processing XML file {labels}: {e}") from e elif isinstance(labels, list): return labels else: raise ValueError(f"Invalid labels type: {type(labels)}")
[docs] def pack_vol_output(file): """Load an atlas file and return volume data and header. Parameters ---------- file : str or Nifti1Image Path to a NIfTI/NPZ file or a loaded :class:`~nibabel.nifti1.Nifti1Image`. Returns ------- dict Dictionary with ``'vol'`` and ``'hdr'`` entries. Raises ------ ValueError If the file format or object type is not supported. Examples -------- >>> pack_vol_output('atlas.nii.gz') # doctest: +SKIP {'vol': array(...), 'hdr': array(...)} """ if isinstance(file, str): path = os.path.abspath(file) _, ext = os.path.splitext(file) ext = ext.lower() if ext in [".nii", ".gz", ".nii.gz"]: import nibabel as nib img = nib.load(file) vol_data = img.get_fdata(dtype=np.float32) hdr_matrix = img.affine return { "vol": vol_data, "hdr": hdr_matrix, } elif ext == ".npz": arch = np.load(path, allow_pickle=True) vol_data = arch["vol"] hdr_matrix = arch["hdr"] return { "vol": vol_data, "hdr": hdr_matrix, } else: raise ValueError(f"Unrecognized file format '{ext}' for path: {path}") else: from nibabel.nifti1 import Nifti1Image if isinstance(file, Nifti1Image): vol_data = file.get_fdata(dtype=np.float32) hdr_matrix = file.affine return { "vol": vol_data, "hdr": hdr_matrix, } else: raise ValueError("Unsupported type for pack_vol_output")
[docs] def pack_surf_output( atlas_name, fetcher, subject: str = "fsaverage", subjects_dir: str | os.PathLike | None = None, **kwargs, ): """Load a surface-based atlas using FreeSurfer annotations. Parameters ---------- atlas_name : str Name of the atlas (e.g., ``'aparc'``). fetcher : callable or None Function used to download the atlas if necessary. subject : str, optional Subject identifier, by default ``'fsaverage'``. subjects_dir : path-like or None, optional FreeSurfer subjects directory, by default ``None``. **kwargs Additional keyword arguments passed to the fetcher. Returns ------- dict Dictionary with ``'vol'``, ``'hdr'``, ``'labels'``, ``'indexes'`` and ``'regions'`` keys where ``'regions'`` maps region names to vertex indices. Raises ------ ValueError If the atlas cannot be located or fetched. Examples -------- >>> pack_surf_output('aparc', None) # doctest: +SKIP {'vol': [...], 'hdr': None, 'labels': array([...]), 'indexes': array([...]), 'regions': {'Region': array([...])}} """ # Determine subjects_dir: use provided or from MNE config import mne subjects_dir_path = Path(subjects_dir) if subjects_dir is not None else None subjects_dir_arg = str(subjects_dir_path) if subjects_dir_path is not None else None def _download_fsaverage_annotations() -> None: """Ensure fsaverage annotations are available locally.""" if str(subject).lower() != "fsaverage": return resolved_subjects_dir = mne.utils.get_subjects_dir( subjects_dir_arg, raise_error=False ) resolved_subjects_path = ( Path(resolved_subjects_dir) if resolved_subjects_dir is not None else None ) if resolved_subjects_path is not None: annot_dir = resolved_subjects_path / subject / "label" if annot_dir.exists(): return fetch_kwargs = {} if subjects_dir_arg is not None: fetch_kwargs["subjects_dir"] = subjects_dir_arg try: mne.datasets.fetch_fsaverage(verbose=False, **fetch_kwargs) except TypeError: # pragma: no cover - older MNE versions if subjects_dir_arg is not None: mne.datasets.fetch_fsaverage(subjects_dir_arg) else: mne.datasets.fetch_fsaverage() def _read_labels_from_annot(): return mne.read_labels_from_annot( subject, atlas_name, subjects_dir=subjects_dir_arg, **kwargs, ) if fetcher is None: _download_fsaverage_annotations() try: labels = _read_labels_from_annot() except (OSError, FileNotFoundError): _download_fsaverage_annotations() labels = _read_labels_from_annot() else: try: labels = fetcher(subject=subject, subjects_dir=subjects_dir_arg, **kwargs) except Exception: _download_fsaverage_annotations() try: fetcher(subject=None, subjects_dir=subjects_dir_arg, **kwargs) except Exception: pass labels = _read_labels_from_annot() src = mne.setup_source_space( subject, spacing="oct6", subjects_dir=subjects_dir_arg, add_dist=False, ) lh_vert = src[0]["vertno"] # Left hemisphere vertices rh_vert = src[1]["vertno"] # Right hemisphere vertices # Map label names to indices in the vertex arrays. from collections import defaultdict region_vertices_lh = defaultdict(list) region_vertices_rh = defaultdict(list) labmap_lh = {} labmap_rh = {} for label in labels: if label.hemi == "lh": match = np.nonzero(np.isin(lh_vert, label.vertices))[0] verts = lh_vert[match] region_vertices_lh[label.name].extend(verts.tolist()) for idx in match: labmap_lh[idx] = label.name elif label.hemi == "rh": match = np.nonzero(np.isin(rh_vert, label.vertices))[0] verts = rh_vert[match] region_vertices_rh[label.name].extend(verts.tolist()) for idx in match: labmap_rh[idx] = label.name region_vertices_lh = { k: np.array(v, dtype=int) for k, v in region_vertices_lh.items() } region_vertices_rh = { k: np.array(v, dtype=int) for k, v in region_vertices_rh.items() } region_vertices = {**region_vertices_lh, **region_vertices_rh} indexes_lh = np.sort(np.array(list(labmap_lh.keys()))) labels_lh = np.array([labmap_lh[i] for i in indexes_lh]) vmap_lh = lh_vert[indexes_lh] indexes_rh = np.sort(np.array(list(labmap_rh.keys()))) labels_rh = np.array([labmap_rh[i] for i in indexes_rh]) vmap_rh = rh_vert[indexes_rh] labels_combined = np.concatenate([labels_lh, labels_rh]) indexes_combined = np.concatenate([vmap_lh, vmap_rh]) return { "vol": [lh_vert, rh_vert], "hdr": None, "labels": labels_combined, "indexes": indexes_combined, "regions": region_vertices, }