Source code for coord2region.coord2region

"""Coordinate-to-region mapping utilities.

This module provides classes and helper functions for converting between
MNI coordinates, voxel indices, and anatomical region labels. It enables
lookups and transformations across multiple brain atlases.
"""

import logging
import pickle
import warnings
from typing import Any

import mne
import numpy as np
import pandas as pd
from scipy.spatial import cKDTree

from .fetching import AtlasFetcher

logging.basicConfig(level=logging.INFO)
[docs] logger = logging.getLogger(__name__)
def _mni_to_tal(coords: list[float] | np.ndarray) -> np.ndarray: """Convert MNI coordinates to Talairach space. Parameters ---------- coords : array-like, shape (..., 3) Coordinates defined in MNI space. The input is reshaped to ``(-1, 3)`` for transformation and the original shape is restored in the output. Returns ------- np.ndarray Coordinates in Talairach space with the same shape as ``coords``. Examples -------- >>> _mni_to_tal([0, 0, 0]) array([0., 0., 0.]) """ coords = np.asarray(coords, dtype=float) orig_shape = coords.shape coords = coords.reshape(-1, 3) A = np.array( [[0.99, 0, 0, 0], [0, 0.9688, 0.0460, 0], [0, -0.0485, 0.9189, 0], [0, 0, 0, 1]] ) B = np.array( [[0.99, 0, 0, 0], [0, 0.9688, 0.0420, 0], [0, -0.0485, 0.8390, 0], [0, 0, 0, 1]] ) out = np.empty_like(coords) for i, c in enumerate(coords): vec = np.append(c, 1) out[i] = (A @ vec)[:3] if c[2] >= 0 else (B @ vec)[:3] return out.reshape(orig_shape) def _tal_to_mni(coords: list[float] | np.ndarray) -> np.ndarray: """Convert Talairach coordinates to MNI space. Parameters ---------- coords : array-like, shape (..., 3) Coordinates defined in Talairach space. The input is reshaped to ``(-1, 3)`` for transformation and the original shape is restored in the output. Returns ------- np.ndarray Coordinates in MNI space with the same shape as ``coords``. Examples -------- >>> _tal_to_mni([0, 0, 0]) array([0., 0., 0.]) """ coords = np.asarray(coords, dtype=float) orig_shape = coords.shape coords = coords.reshape(-1, 3) A = np.array( [[0.99, 0, 0, 0], [0, 0.9688, 0.0460, 0], [0, -0.0485, 0.9189, 0], [0, 0, 0, 1]] ) B = np.array( [[0.99, 0, 0, 0], [0, 0.9688, 0.0420, 0], [0, -0.0485, 0.8390, 0], [0, 0, 0, 1]] ) A_inv = np.linalg.inv(A) B_inv = np.linalg.inv(B) out = np.empty_like(coords) for i, c in enumerate(coords): vec = np.append(c, 1) out[i] = (A_inv @ vec)[:3] if c[2] >= 0 else (B_inv @ vec)[:3] return out.reshape(orig_shape) _TRANSFORMS = { ("mni", "tal"): _mni_to_tal, ("tal", "mni"): _tal_to_mni, } def _get_numeric_hemi(hemi: str | int) -> int: """Convert hemisphere string to numeric code (0 or 1).""" if isinstance(hemi, int): return hemi if hemi is None: return None if isinstance(hemi, str): if hemi.lower() in ("l", "lh", "left"): return 0 if hemi.lower() in ("r", "rh", "right"): return 1 raise ValueError("Invalid hemisphere value. Use 'L', 'R', 'LH', 'RH', 0, or 1.")
[docs] class AtlasMapper: """Stores a single atlas and provides coordinate conversions. The atlas may be volumetric (a 3D numpy array with an associated 4x4 affine) or surface-based (a vertices array). In either case, the mapper supports conversions between coordinates, voxel indices, and region labels. Parameters ---------- name : str Identifier for the atlas (e.g., "aal" or "brodmann"). vol : np.ndarray A 3D numpy array representing the volumetric atlas. hdr : np.ndarray A 4x4 affine transform mapping voxel indices to MNI/world coordinates. labels : dict or list or None, optional Region labels. If a dict, keys should be strings for numeric indices and values are region names. If a list/array, it should match ``indexes``. indexes : list or np.ndarray or None, optional Region indices corresponding to ``labels``. Not needed if ``labels`` is a dict. regions : dict or None, optional For surface atlases, mapping of region names to vertex indices. system : str, optional The anatomical coordinate space (e.g., "mni" or "tal"). Attributes ---------- name : str Atlas identifier. vol : np.ndarray Volumetric atlas array. hdr : np.ndarray Affine transform mapping voxel indices to MNI/world coordinates. labels : dict or list or None Region labels. indexes : list or np.ndarray or None Region indices corresponding to labels. regions : dict or None Mapping of region names to vertex indices for surface atlases. system : str Anatomical coordinate space. shape : tuple Shape of the volumetric atlas. """ def __init__( self, name: str, vol: np.ndarray, hdr: np.ndarray, labels: dict[str, str] | list[str] | np.ndarray | None = None, indexes: list[int] | np.ndarray | None = None, subject: str | None = "fsaverage", regions: dict[str, np.ndarray] | None = None, subjects_dir: str | None = None, system: str = "mni", ) -> None:
[docs] self.name = name
[docs] self.labels = labels
[docs] self.indexes = indexes
# Ensure region->vertex mapping uses integer vertex indices if regions is not None: self.regions = { key: np.asarray(vals, dtype=int).ravel() for key, vals in regions.items() } else: self.regions = None
[docs] self.vertex_to_region = None
[docs] self.system = system
# Basic shape checks if isinstance(vol, np.ndarray): self.vol = np.asarray(vol) # volumetric atlas if hdr is not None and self.vol.ndim == 3: self.hdr = np.asarray(hdr) if self.hdr.shape != (4, 4): raise ValueError("`hdr` must be a 4x4 transform matrix.") self.shape = self.vol.shape self.atlas_type = "volume" # coordinate atlas (list of region centroids) elif self.vol.ndim == 2 and self.vol.shape[1] == 3: self.hdr = None self.atlas_type = "coords" if self.indexes is None: self.indexes = np.arange(self.vol.shape[0]) else: raise ValueError("Unsupported array format for `vol`.") elif isinstance(vol, list): arr = np.asarray(vol) if arr.ndim == 2 and arr.shape[1] == 3: # coordinate atlas provided as list self.vol = arr.astype(float) self.hdr = None self.atlas_type = "coords" if self.indexes is None: self.indexes = np.arange(self.vol.shape[0]) else: # For surface atlases, `vol` is a list of vertex arrays per hemisphere self.vol = [np.asarray(v, dtype=int) for v in vol] self.hdr = None self.atlas_type = "surface" self.subject = subject self.subjects_dir = subjects_dir self.vertex_to_region = { int(v): k for k, verts in (regions or {}).items() for v in np.asarray(verts).ravel() } # If labels is a dict, prepare an inverse mapping: # region_name -> region_index if isinstance(self.labels, dict): self._label2index = {v: k for k, v in self.labels.items()} else: self._label2index = None # Cache for region centroids (used by nearest-region queries) self._centroids_cache: dict[int, np.ndarray] | None = None # Cached KD-tree for voxel center lookup (volume atlases) self._voxel_kdtree: cKDTree | None = None self._voxel_indices: np.ndarray | None = None # ------------------------------------------------------------------------- # Internal lookups (private) # ------------------------------------------------------------------------- def _lookup_region_name(self, value: int | str) -> str: """ Return the region name corresponding to the given region index (int/str). Returns "Unknown" if not found. """ if not isinstance(value, int | str): raise ValueError("value must be int or str") if self.atlas_type == "surface" and self.vertex_to_region is not None: try: return self.vertex_to_region.get(int(value), "Unknown") except ValueError: return "Unknown" value_str = str(value) if isinstance(self.labels, dict): return self.labels.get(value_str, "Unknown") if self.indexes is not None and self.labels is not None: try: if isinstance(self.indexes, list): pos = self.indexes.index(int(value)) else: pos = int(np.where(self.indexes == int(value))[0][0]) return self.labels[pos] except (ValueError, IndexError): return "Unknown" elif self.labels is not None: try: return self.labels[int(value)] except (ValueError, IndexError): return "Unknown" return "Unknown" def _lookup_region_index(self, label: str) -> int | str: """ Return the numeric region index corresponding to the given region name. Returns "Unknown" if not found. """ if not isinstance(label, str): raise ValueError("label must be a string") if self.atlas_type == "surface" and self.regions is not None: return np.asarray(self.regions.get(label, [])) if self._label2index is not None: return self._label2index.get(label, "Unknown") if self.indexes is not None and self.labels is not None: try: if isinstance(self.labels, list): pos = self.labels.index(label) else: pos = int(np.where(np.array(self.labels) == label)[0][0]) # Return the corresponding numeric index from self.indexes if isinstance(self.indexes, list): return self.indexes[pos] else: return int(self.indexes[pos]) except (ValueError, IndexError): return "Unknown" elif self.labels is not None: # If self.labels is just a list of strings try: return int(np.where(np.array(self.labels) == label)[0][0]) except (ValueError, IndexError): return "Unknown" return "Unknown" # ------------------------------------------------------------------------- # Region name / index # -------------------------------------------------------------------------
[docs] def region_name_from_index(self, region_idx: int | str) -> str: """Return region name from numeric region index.""" return self._lookup_region_name(region_idx)
[docs] def region_index_from_name(self, region_name: str) -> int | str | np.ndarray: """Return region index from region name.""" return self._lookup_region_index(region_name)
[docs] def list_all_regions(self) -> list[str]: """Return a list of all unique region names in this atlas.""" if self.regions is not None: return list(self.regions.keys()) if self.labels is None: return [] regions = self.labels.values() if isinstance(self.labels, dict) else self.labels return list(dict.fromkeys(regions))
[docs] def infer_hemisphere(self, region: int | str) -> str | None: """ Return the hemisphere ('L' or 'R') inferred from ``region``. Returns None if not found or not applicable. """ # Convert numeric region to string name, if needed: region_name = ( region if isinstance(region, str) else self._lookup_region_name(region) ) if isinstance(region, str): # If a string is actually an index, resolve it to the label first. resolved_name = self._lookup_region_name(region) if resolved_name != "Unknown": region_name = resolved_name if region_name in (None, "Unknown"): return None # Ensure the region actually belongs to the current atlas. if isinstance(region_name, str): idx = self._lookup_region_index(region_name) missing = isinstance(idx, str) and idx == "Unknown" if isinstance(idx, np.ndarray) and idx.size == 0: missing = True if missing: warnings.warn( f"Region '{region_name}' is not part of the '{self.name}' atlas.", UserWarning, stacklevel=2, ) return None if self.name.lower() == "schaefer": parts = region_name.split("_", 1) lower = parts[-1].lower() return ( "L" if lower.startswith("lh") else "R" if lower.startswith("rh") else None ) lower = region_name.lower() return ( "L" if lower.endswith(("_lh", "-lh")) else "R" if lower.endswith(("_rh", "-rh")) else None )
# ------------------------------------------------------------------------- # Coordinate system conversions # -------------------------------------------------------------------------
[docs] def convert_system( self, coord: list[float] | np.ndarray, source_system: str, target_system: str, ) -> np.ndarray: """Convert coordinates between anatomical systems.""" source = source_system.lower() target = target_system.lower() if source == target: return np.asarray(coord, dtype=float) try: func = _TRANSFORMS[(source, target)] except KeyError: raise ValueError( f"Unsupported system conversion: {source_system} -> {target_system}" ) from None return func(coord)
# ------------------------------------------------------------------------- # MNI <--> voxel conversions # ------------------------------------------------------------------------- def _build_voxel_kdtree(self) -> None: """Build a KD-tree of voxel centers for nearest-neighbor queries. Lazy initialization of the KD-tree for efficient nearest voxel lookups. The tree is built once on first use and cached for subsequent queries. """ if self.atlas_type != "volume" or self._voxel_kdtree is not None: return grid = np.indices(self.vol.shape).reshape(3, -1).T mni_coords = grid @ self.hdr[:3, :3].T + self.hdr[:3, 3] self._voxel_indices = grid.astype(int) self._voxel_kdtree = cKDTree(mni_coords)
[docs] def mni_to_voxel(self, mni_coord: list[float] | np.ndarray) -> tuple[int, int, int]: """Convert an MNI coordinate to the nearest voxel indices. The coordinate is transformed using the atlas affine. If it does not exactly match a voxel center, the voxel whose MNI coordinates are closest in Euclidean distance is returned. """ if not isinstance(mni_coord, list | np.ndarray): raise ValueError("`mni_coord` must be a list or numpy array.") pos_arr = np.asarray(mni_coord) if pos_arr.shape != (3,): raise ValueError("`mni_coord` must be a 3-element (x,y,z).") # MNI coordinates are 3D (x, y, z). For affine transforms we use # homogeneous coordinates (x, y, z, 1) homogeneous = np.append(pos_arr, 1) voxel = np.linalg.inv(self.hdr) @ homogeneous # self.hdr is a 4×4 affine matrix mapping voxel indices to MNI # coordinates. Its inverse maps MNI back to voxel space. The @ # applies the matrix multiplication. rounded = np.round(voxel[:3]).astype(int) # Check if this voxel maps back exactly to the MNI coordinate back = (self.hdr @ np.append(rounded, 1))[:3] if np.allclose(back, pos_arr, atol=1e-6): return tuple(rounded) # Otherwise search for the voxel with minimal distance in MNI space self._build_voxel_kdtree() if self._voxel_kdtree is None or self._voxel_indices is None: raise RuntimeError( f"Failed to construct voxel KD-tree for atlas '{self.name}'. " "This may indicate memory issues or invalid volume data." ) _, idx = self._voxel_kdtree.query(pos_arr) nearest = self._voxel_indices[idx] return tuple(int(v) for v in nearest)
[docs] def mni_to_vertex( self, mni_coord: list[float] | np.ndarray, hemi: list[int] | int | None = None, ) -> np.ndarray | int: """Convert an MNI coordinate to the nearest vertex index. Parameters ---------- mni_coord : list | ndarray The target MNI coordinate ``[x, y, z]``. hemi : int | list[int] | None Hemisphere(s) to restrict the search to. ``0`` for left, ``1`` for right. If ``None`` (default) both hemispheres are searched. Returns ------- int | ndarray Index/indices of the matching vertex. If no vertex matches exactly, the closest vertex is returned. """ mni_coord = np.asarray(mni_coord) # Determine which hemispheres to search if hemi is None: hemis = [0, 1] elif isinstance(hemi, list | tuple | np.ndarray): hemis = [_get_numeric_hemi(h) for h in hemi] else: hemis = [_get_numeric_hemi(hemi)] all_vertices: list[np.ndarray] = [] all_coords: list[np.ndarray] = [] for h in hemis: verts = np.asarray(self.vol[h]) if verts.size == 0: continue coords = mne.vertex_to_mni(verts, h, self.subject, self.subjects_dir) all_vertices.append(verts) all_coords.append(coords) if not all_vertices: return np.array([]) vertices = np.concatenate(all_vertices) coords = np.vstack(all_coords) dists = np.linalg.norm(coords - mni_coord, axis=1) exact = np.where(dists == 0)[0] if exact.size: matches = vertices[exact] return matches if matches.size > 1 else int(matches[0]) closest_vertex = vertices[int(np.argmin(dists))] return int(closest_vertex)
[docs] def convert_to_source( self, target: list[float] | np.ndarray, hemi: list[int] | int | None = None, source_system: str = "mni", ) -> np.ndarray: """Convert a coordinate to the atlas source space. Parameters ---------- target : list | ndarray The coordinate to convert. hemi : int | list[int] | None Hemisphere(s) to search when using surface atlases. ``0`` for left and ``1`` for right. If ``None`` (default) both hemispheres are searched. source_system : str, optional Coordinate system of ``target``. Defaults to ``"mni"``. """ if source_system.lower() != self.system.lower(): target = self.convert_system(target, source_system, self.system) if self.atlas_type == "volume": return self.mni_to_voxel(target) if self.atlas_type == "surface": return self.mni_to_vertex(target, hemi) if self.atlas_type == "coords": arr = np.asarray(self.vol, dtype=float) tgt = np.asarray(target, dtype=float).reshape(1, 3) mask = np.all(np.isclose(arr, tgt), axis=1) if not mask.any(): return np.array([], dtype=int) inds = np.where(mask)[0] if self.indexes is not None: return np.array([self.indexes[i] for i in inds]) return inds
[docs] def voxel_to_mni(self, voxel_ijk: list[int] | np.ndarray) -> np.ndarray: """ Convert voxel indices (i, j, k) to MNI/world coordinates. Returns an array of shape (3,). """ if not isinstance(voxel_ijk, list | np.ndarray): raise ValueError("`voxel_ijk` must be list or numpy array.") src_arr = np.atleast_2d(voxel_ijk) ones = np.ones((src_arr.shape[0], 1)) homogeneous = np.hstack([src_arr, ones]) transformed = homogeneous @ self.hdr.T coords = transformed[:, :3] / transformed[:, 3, np.newaxis] if src_arr.shape[0] == 1: return coords[0] return coords
[docs] def vertex_to_mni( self, vertices: list[int] | np.ndarray, hemi: list[int] | int ) -> np.ndarray: """ Convert vertices to MNI coordinates. Returns an array of shape (3,). """ # use mne.vertex_to_mni coords = mne.vertex_to_mni(vertices, hemi, self.subject, self.subjects_dir) return coords
def _vertices_to_mni(self, vertices: np.ndarray) -> np.ndarray: """Convert vertices from both hemispheres to MNI coordinates.""" vertices = np.atleast_1d(vertices).astype(int) if vertices.size == 0: return np.empty((0, 3)) lh_vertices, rh_vertices = self.vol lh_mask = np.isin(vertices, lh_vertices) coords = [] if lh_mask.any(): coords.append( mne.vertex_to_mni(vertices[lh_mask], 0, self.subject, self.subjects_dir) ) if (~lh_mask).any(): coords.append( mne.vertex_to_mni( vertices[~lh_mask], 1, self.subject, self.subjects_dir ) ) return np.vstack(coords) if coords else np.empty((0, 3))
[docs] def convert_to_mni( self, source: list[int] | np.ndarray, hemi: list[int] | int | None = None, ) -> np.ndarray: """Convert source space coordinates to MNI.""" if self.atlas_type == "volume": return self.voxel_to_mni(source) if self.atlas_type == "surface": if hemi is None: raise ValueError("hemi must be provided for surface atlases") return self.vertex_to_mni(source, hemi) if self.atlas_type == "coords": return np.asarray(source, dtype=float)
# ------------------------------------------------------------------------- # MNI <--> region # -------------------------------------------------------------------------
[docs] def mni_to_region_index( self, mni_coord: list[float] | np.ndarray, max_distance: float | None = None, hemi: list[int] | int | None = None, return_distance: bool = False, ) -> int | str | tuple[int | str, float]: """Return the region index for a given MNI coordinate. Parameters ---------- mni_coord : list | ndarray Target MNI coordinate. max_distance : float | None If provided, fall back to the nearest region and apply this distance threshold. Distances greater than ``max_distance`` return ``"Unknown"``. hemi : int | list[int] | None Hemisphere restriction for surface atlases. return_distance : bool Whether to also return the distance to the reported region. """ coord = np.asarray(mni_coord, dtype=float) result: int | str | np.ndarray dist = 0.0 if self.atlas_type == "volume": ind = np.asarray(self.convert_to_source(coord)) if ind.size == 0 or np.any((ind < 0) | (ind >= np.array(self.shape))): result, dist = self._nearest_region_index(coord, hemi) else: result = int(self.vol[tuple(ind)]) if result == 0: result, dist = self._nearest_region_index(coord, hemi) elif self.atlas_type == "surface": if hemi is not None: verts = np.atleast_1d(self.convert_to_source(coord, hemi)) hemis = ( [_get_numeric_hemi(h) for h in hemi] if isinstance(hemi, list | tuple | np.ndarray) else [_get_numeric_hemi(hemi)] ) else: verts = np.atleast_1d(self.convert_to_source(coord)) hemis = [0, 1] exact_matches: list[int] = [] for v in verts: v_int = int(v) hemi_v = next( (h for h in hemis if v_int in np.asarray(self.vol[h])), None ) if hemi_v is not None: v_mni = mne.vertex_to_mni( [v_int], hemi_v, self.subject, self.subjects_dir )[0] if np.allclose(v_mni, coord): exact_matches.append(v_int) elif self.vertex_to_region and v_int in self.vertex_to_region: exact_matches.append(v_int) if exact_matches: result = ( np.array(exact_matches) if len(exact_matches) > 1 else int(exact_matches[0]) ) else: result, dist = self._nearest_region_index(coord, hemi) elif self.atlas_type == "coords": exact = np.atleast_1d(self.convert_to_source(coord)) if exact.size > 0: result = exact if exact.size > 1 else int(exact[0]) else: result, dist = self._nearest_region_index(coord, hemi) else: result, dist = self._nearest_region_index(coord, hemi) if max_distance is not None and dist > max_distance: result = "Unknown" return (result, dist) if return_distance else result
[docs] def mni_to_region_name( self, mni_coord: list[float] | np.ndarray, max_distance: float | None = None, hemi: list[int] | int | None = None, return_distance: bool = False, ) -> str | tuple[str, float]: """Return the region name for a given MNI coordinate.""" idx, dist = self.mni_to_region_index( mni_coord, max_distance=max_distance, hemi=hemi, return_distance=True, ) if isinstance(idx, np.ndarray): names = {self._lookup_region_name(int(i)) for i in idx} name = names.pop() if len(names) == 1 else "Unknown" else: name = "Unknown" if idx == "Unknown" else self._lookup_region_name(idx) return (name, dist) if return_distance else name
# ------------------------------------------------------------------ # Nearest region helpers # ------------------------------------------------------------------ def _compute_centroids(self) -> None: """Compute and cache centroids for all regions (volume atlases).""" if self.atlas_type != "volume" or self._centroids_cache is not None: return centroids = {} for idx in np.unique(self.vol): if idx == 0: continue coords = self.region_index_to_mni(int(idx)) # Ensure 2D shape even for singleton regions (1x3). Without this, # mean(axis=0) on a 1D array can yield a scalar and later stacking # of centroids would fail (shape mismatch), as seen in iEEG case. coords = np.atleast_2d(coords) if coords.size == 0: continue centroids[int(idx)] = coords.mean(axis=0) self._centroids_cache = centroids def _nearest_region_index( self, mni_coord: list[float] | np.ndarray, hemi: list[int] | int | None = None, ) -> tuple[int | str, float]: """Return (nearest region index, distance) to ``mni_coord``.""" coord = np.asarray(mni_coord, dtype=float) if self.atlas_type == "volume": self._compute_centroids() if not self._centroids_cache: return "Unknown", float("inf") ids = np.array(list(self._centroids_cache.keys())) cents = np.vstack(list(self._centroids_cache.values())) dists = np.linalg.norm(cents - coord, axis=1) min_idx = np.argmin(dists) return int(ids[min_idx]), float(dists[min_idx]) if self.atlas_type == "surface": if hemi is None: hemis = [0, 1] elif isinstance(hemi, list | tuple | np.ndarray): hemis = [_get_numeric_hemi(h) for h in hemi] else: hemis = [_get_numeric_hemi(hemi)] all_vertices: list[np.ndarray] = [] all_coords: list[np.ndarray] = [] for h in hemis: verts = np.asarray(self.vol[h]) if verts.size == 0: continue coords = mne.vertex_to_mni(verts, h, self.subject, self.subjects_dir) all_vertices.append(verts) all_coords.append(coords) if not all_vertices: return "Unknown", float("inf") vertices = np.concatenate(all_vertices) coords = np.vstack(all_coords) dists = np.linalg.norm(coords - coord, axis=1) min_idx = int(np.argmin(dists)) return int(vertices[min_idx]), float(dists[min_idx]) if self.atlas_type == "coords": coords = np.asarray(self.vol, dtype=float) dists = np.linalg.norm(coords - coord, axis=1) min_idx = int(np.argmin(dists)) idx = self.indexes[min_idx] if self.indexes is not None else min_idx return int(idx), float(dists[min_idx]) return "Unknown", float("inf") # ------------------------------------------------------------------------- # region index/name <--> all voxel coords # -------------------------------------------------------------------------
[docs] def region_index_to_mni( self, region_idx: int | str | list[int] | np.ndarray, hemi: int | None = None, ) -> np.ndarray: """ Return MNI coordinates for voxels or vertices in ``region_idx``. Returns an Nx3 array or an empty array if none found. """ # Make sure region_idx is an integer: if self.atlas_type == "volume": try: idx_val = int(region_idx) except (ValueError, TypeError): return np.empty((0, 3)) coords = np.argwhere(self.vol == idx_val) if coords.size == 0: return np.empty((0, 3)) return self.convert_to_mni(coords, hemi) elif self.atlas_type == "surface": try: verts = np.atleast_1d(region_idx).astype(int) except (ValueError, TypeError): return np.empty((0, 3)) return self._vertices_to_mni(verts) elif self.atlas_type == "coords": try: idx_val = int(region_idx) except (ValueError, TypeError): return np.empty((0, 3)) if self.indexes is not None: try: pos = list(self.indexes).index(idx_val) except ValueError: return np.empty((0, 3)) else: pos = idx_val if pos < 0 or pos >= len(self.vol): return np.empty((0, 3)) return np.atleast_2d(self.vol[pos])
[docs] def region_name_to_mni(self, region_name: str) -> np.ndarray: """Return MNI coordinates for voxels matching ``region_name``. Returns an Nx3 array or an empty array if no matches are found. """ region_idx = self.region_index_from_name(region_name) if isinstance(region_idx, str) and region_idx == "Unknown": return np.empty((0, 3)) if isinstance(region_idx, np.ndarray) and region_idx.size == 0: return np.empty((0, 3)) return self.region_index_to_mni( region_idx, _get_numeric_hemi(self.infer_hemisphere(region_name)) )
[docs] def region_centroid(self, region: int | str) -> np.ndarray: """Return the centroid MNI coordinate for a region or vertex index.""" if isinstance(region, str): coords = self.region_name_to_mni(region) else: coords = self.region_index_to_mni(region) # Some regions can contain exactly one voxel/vertex; keep (1, 3) # to make mean/distances robust and consistent with batch cases. coords = np.atleast_2d(coords) if coords.size == 0: return np.empty((0,)) return coords.mean(axis=0)
[docs] def distance_to_region_centroid( self, mni_coord: list[float] | np.ndarray, region: int | str ) -> float: """Return Euclidean distance from ``mni_coord`` to a region centroid.""" centroid = self.region_centroid(region) if centroid.size == 0: return float("inf") coord = np.asarray(mni_coord, dtype=float) return float(np.linalg.norm(coord - centroid))
[docs] def distance_to_region_boundary( self, mni_coord: list[float] | np.ndarray, region: int | str ) -> float: """Return distance from ``mni_coord`` to the nearest point in ``region``.""" if isinstance(region, str): coords = self.region_name_to_mni(region) else: coords = self.region_index_to_mni(region) # Guard against 1D arrays so pairwise distances work reliably. coords = np.atleast_2d(coords) if coords.size == 0: return float("inf") coord = np.asarray(mni_coord, dtype=float) dists = np.linalg.norm(coords - coord, axis=1) return float(dists.min())
[docs] def membership_scores( self, mni_coord: list[float] | np.ndarray, method: str = "centroid", ) -> dict[int | str, float]: """Return normalized membership probabilities for all regions.""" coord = np.asarray(mni_coord, dtype=float) # Determine region identifiers if self.atlas_type == "volume": region_ids = ( [int(i) for i in (self.indexes or [])] if self.indexes is not None else [int(i) for i in np.unique(self.vol) if int(i) != 0] ) elif self.atlas_type == "coords": if self.indexes is not None: region_ids = [int(i) for i in self.indexes] else: region_ids = list(range(len(self.vol))) elif self.atlas_type == "surface" and self.regions is not None: region_ids = list(self.regions.keys()) else: return {} dists = [] for rid in region_ids: if method == "boundary": d = self.distance_to_region_boundary(coord, rid) else: d = self.distance_to_region_centroid(coord, rid) dists.append(d) dists_arr = np.array(dists, dtype=float) scores = np.exp(-dists_arr) total = float(scores.sum()) if total > 0: scores /= total # Map region identifiers to names if possible if isinstance(self.labels, dict): names = [self.labels.get(str(r), str(r)) for r in region_ids] elif isinstance(self.labels, list | np.ndarray): names = list(self.labels) else: names = region_ids return dict(zip(names, scores, strict=False))
# ------------------------------------------------------------------ # Persistence # ------------------------------------------------------------------ _SERIAL_VERSION = 1 def _get_state(self) -> dict[str, Any]: """Return minimal state necessary to recreate this mapper.""" state: dict[str, Any] = { "name": self.name, "vol": self.vol, "hdr": self.hdr, "labels": self.labels, "indexes": self.indexes, "regions": self.regions, "system": self.system, } if hasattr(self, "subject"): state["subject"] = self.subject if hasattr(self, "subjects_dir"): state["subjects_dir"] = self.subjects_dir return state
[docs] def save(self, filename: str) -> None: """Serialize this ``AtlasMapper`` to ``filename`` using pickle.""" data = { "metadata": { "class": self.__class__.__name__, "version": self._SERIAL_VERSION, }, "state": self._get_state(), } with open(filename, "wb") as f: pickle.dump(data, f)
@classmethod
[docs] def load(cls, filename: str) -> "AtlasMapper": """Load an ``AtlasMapper`` from ``filename``. Parameters ---------- filename : str Path to the serialized mapper. Returns ------- AtlasMapper A reconstructed mapper instance. Raises ------ ValueError If the file metadata is incompatible. """ with open(filename, "rb") as f: data = pickle.load(f) meta = data.get("metadata", {}) if meta.get("class") != cls.__name__: raise ValueError("File does not contain AtlasMapper data") if meta.get("version") != cls._SERIAL_VERSION: raise ValueError("Incompatible AtlasMapper version") state = data.get("state", {}) return cls(**state)
[docs] class BatchAtlasMapper: """Provide batch (vectorized) conversions for a single atlas mapper. Parameters ---------- mapper : AtlasMapper The atlas mapper to wrap for vectorized operations. Attributes ---------- mapper : AtlasMapper Wrapped atlas mapper used for transformations. Examples -------- >>> mapper = AtlasMapper(...) >>> batch = BatchAtlasMapper(mapper) >>> regions = batch.batch_mni_to_region_name([[0, 0, 0], [10, -20, 30]]) """ def __init__(self, mapper: AtlasMapper) -> None: if not isinstance(mapper, AtlasMapper): raise ValueError("mapper must be an instance of AtlasMapper")
[docs] self.mapper = mapper
# ---- region name <-> index (batch) ---------------------------------------
[docs] def batch_region_name_from_index(self, values: list[int | str]) -> list[str]: """Return the region name for each index in ``values``.""" return [self.mapper.region_name_from_index(val) for val in values]
[docs] def batch_region_index_from_name(self, labels: list[str]) -> list[int | str]: """Return the region index for each name in ``labels``.""" return [self.mapper.region_index_from_name(label) for label in labels]
# ---- MNI <-> voxel (batch) -----------------------------------------------
[docs] def batch_mni_to_voxel( self, positions: list[list[float]] | np.ndarray ) -> list[tuple]: """Convert MNI coordinates to voxel indices (i, j, k).""" positions_arr = np.atleast_2d(positions) return [self.mapper.mni_to_voxel(pos) for pos in positions_arr]
[docs] def batch_voxel_to_mni(self, sources: list[list[int]] | np.ndarray) -> np.ndarray: """ Convert a batch of voxel indices (i, j, k) to MNI coordinates. Returns an Nx3 array. """ sources_arr = np.atleast_2d(sources) return np.array([self.mapper.voxel_to_mni(s) for s in sources_arr])
# ---- MNI -> region (batch) -----------------------------------------------
[docs] def batch_mni_to_region_index( self, positions: list[list[float]] | np.ndarray, max_distance: float | None = None, hemi: list[int] | int | None = None, ) -> list[int | str]: """Return region index for each coordinate, using nearest lookup if needed.""" positions_arr = np.atleast_2d(positions) return [ self.mapper.mni_to_region_index(pos, max_distance=max_distance, hemi=hemi) for pos in positions_arr ]
[docs] def batch_mni_to_region_name( self, positions: list[list[float]] | np.ndarray, max_distance: float | None = None, hemi: list[int] | int | None = None, ) -> list[str]: """Return region name for each coordinate, using nearest lookup if needed.""" positions_arr = np.atleast_2d(positions) return [ self.mapper.mni_to_region_name(pos, max_distance=max_distance, hemi=hemi) for pos in positions_arr ]
# ---- region index/name -> MNI coords (batch) -----------------------------
[docs] def batch_region_index_to_mni(self, indices: list[int | str]) -> list[np.ndarray]: """Return MNI coordinates (Nx3) for each region index.""" return [self.mapper.region_index_to_mni(idx) for idx in indices]
[docs] def batch_region_name_to_mni(self, regions: list[str]) -> list[np.ndarray]: """Return MNI coordinates (Nx3) for each region name.""" return [self.mapper.region_name_to_mni(r) for r in regions]
[docs] class MultiAtlasMapper: """Manage multiple atlases and provide batch queries across them. Parameters ---------- data_dir : str Directory for atlas data. atlases : dict Dictionary mapping atlas names to keyword arguments passed to :class:`AtlasFetcher` in order to retrieve each atlas. Attributes ---------- mappers : dict Mapping of atlas names to :class:`BatchAtlasMapper` instances. """ def __init__(self, data_dir: str, atlases: dict[str, dict[str, Any]]) -> None:
[docs] self.mappers = {}
atlas_fetcher = AtlasFetcher(data_dir=data_dir) for name, kwargs in atlases.items(): atlas_data = atlas_fetcher.fetch_atlas(name, **kwargs) vol = atlas_data["vol"] hdr = atlas_data["hdr"] labels = atlas_data.get("labels") indexes = atlas_data.get("indexes") subject = kwargs.get("subject", "fsaverage") subjects_dir = kwargs.get("subjects_dir") # Handle coordinate atlases represented as DataFrames or lists if isinstance(vol, pd.DataFrame): df = vol if {"x", "y", "z"}.issubset(df.columns): vol = df[["x", "y", "z"]].to_numpy() else: vol = df.iloc[:, :3].to_numpy() if labels is None: for col in ["label", "labels", "name", "region", "roi"]: if col in df.columns: labels = df[col].astype(str).tolist() break if indexes is None: indexes = df.index.to_list() else: arr = np.asarray(vol) if hdr is None and arr.ndim == 2 and arr.shape[1] == 3: vol = arr if indexes is None: indexes = np.arange(vol.shape[0]) single_mapper = AtlasMapper( name=name, vol=vol, hdr=hdr, labels=labels, indexes=indexes, regions=atlas_data.get("regions"), subject=subject, subjects_dir=subjects_dir, system="mni", # or read from atlas_data if you store that ) batch_mapper = BatchAtlasMapper(single_mapper) self.mappers[name] = batch_mapper
[docs] def batch_mni_to_region_names( self, coords: list[list[float]] | np.ndarray, max_distance: float | None = None, hemi: list[int] | int | None = None, ) -> dict[str, list[str]]: """ Convert a batch of MNI coordinates to region names for all atlases. Returns a dict {atlas_name: [region_name, region_name, ...], ...}. """ results = {} for atlas_name, mapper in self.mappers.items(): results[atlas_name] = mapper.batch_mni_to_region_name( coords, max_distance=max_distance, hemi=hemi ) return results
[docs] def batch_region_name_to_mni( self, region_names: list[str] ) -> dict[str, list[np.ndarray]]: """ Convert a list of region names to MNI coordinates for all atlases. Returns a dict {atlas_name: [np.array_of_coords_per_region, ...], ...}. """ results = {} for atlas_name, mapper in self.mappers.items(): results[atlas_name] = mapper.batch_region_name_to_mni(region_names) return results
# ------------------------------------------------------------------ # Persistence # ------------------------------------------------------------------ _SERIAL_VERSION = 1
[docs] def save(self, filename: str) -> None: """Serialize all contained mappers to ``filename`` using pickle.""" mapper_states = { name: mapper.mapper._get_state() for name, mapper in self.mappers.items() } data = { "metadata": { "class": self.__class__.__name__, "version": self._SERIAL_VERSION, }, "state": {"mappers": mapper_states}, } with open(filename, "wb") as f: pickle.dump(data, f)
@classmethod
[docs] def load(cls, filename: str) -> "MultiAtlasMapper": """Load a ``MultiAtlasMapper`` instance from ``filename``.""" with open(filename, "rb") as f: data = pickle.load(f) meta = data.get("metadata", {}) if meta.get("class") != cls.__name__: raise ValueError("File does not contain MultiAtlasMapper data") if meta.get("version") != cls._SERIAL_VERSION: raise ValueError("Incompatible MultiAtlasMapper version") mapper_states = data.get("state", {}).get("mappers", {}) obj = cls.__new__(cls) obj.mappers = {} for name, mstate in mapper_states.items(): atlas = AtlasMapper(**mstate) obj.mappers[name] = BatchAtlasMapper(atlas) return obj