"""Utilities for managing atlas files.
This module handles downloading, caching, and loading atlas files used by the
mapping utilities. It provides helpers for retrieving label information and
packing volumetric atlas outputs.
"""
import csv
import json
import logging
import os
import pickle
import shutil
from collections.abc import Sequence
from dataclasses import asdict, is_dataclass
from pathlib import Path
from typing import Any
import mne
from fpdf import FPDF
from .paths import ensure_mne_data_directory, resolve_working_directory
from .utils import fetch_labels, pack_vol_output
[docs]
logger = logging.getLogger(__name__)
[docs]
class AtlasFileHandler:
"""Handle file operations for atlas fetching.
Parameters
----------
data_dir : str or None, optional
Base directory for downloaded atlas files. Defaults to
``~/coord2region``. Relative paths are interpreted relative to the
user's home directory.
subjects_dir : str or None, optional
FreeSurfer ``SUBJECTS_DIR``. If ``None``, the value is inferred from
:func:`mne.get_config`.
Attributes
----------
data_dir : str
Base directory where atlas files and other outputs are stored.
cached_data_dir : str
Directory for cached datasets.
generated_images_dir : str
Directory for generated images.
results_dir : str
Directory for exported results.
subjects_dir : str or None
Path to the FreeSurfer subjects directory.
nilearn_data : str
Directory for caching Nilearn datasets.
mne_data_dir : str
Directory registered with MNE for dataset downloads.
Examples
--------
>>> handler = AtlasFileHandler() # doctest: +SKIP
>>> handler.data_dir # doctest: +SKIP
'/home/user/coord2region'
"""
def __init__(self, data_dir: str | None = None, subjects_dir: str | None = None):
"""Initialize the file handler.
data_dir : str or None, optional
Base directory for storing downloaded atlas files. Defaults to
``~/coord2region``. Relative paths are interpreted relative to the
user's home directory.
subjects_dir : str or None, optional
Path to the FreeSurfer ``SUBJECTS_DIR``. If ``None``, the value is
looked up via :func:`mne.get_config`.
Raises
------
ValueError
If the data directory cannot be created or is not writable.
Examples
--------
>>> AtlasFileHandler() # doctest: +SKIP
"""
base_dir = resolve_working_directory(data_dir)
[docs]
self.data_dir = str(base_dir)
try:
os.makedirs(self.data_dir, exist_ok=True)
except Exception as e:
raise ValueError(
f"Could not create data directory {self.data_dir}: {e}"
) from e
if not os.access(self.data_dir, os.W_OK):
raise ValueError(f"Data directory {self.data_dir} is not writable")
[docs]
self.cached_data_dir = os.path.join(self.data_dir, "cached_data")
[docs]
self.generated_images_dir = os.path.join(self.data_dir, "generated_images")
[docs]
self.results_dir = os.path.join(self.data_dir, "results")
[docs]
self.nilearn_data = os.path.join(self.data_dir, "nilearn_data")
[docs]
self.mne_data_dir = str(ensure_mne_data_directory(base_dir))
subject_path: Path | None
if subjects_dir is not None:
subject_path = Path(subjects_dir).expanduser()
if not subject_path.is_absolute():
subject_path = (base_dir / subject_path).resolve()
else:
subject_path = None
try:
config_subjects_dir = mne.get_config("SUBJECTS_DIR", None)
except Exception: # pragma: no cover - defensive
config_subjects_dir = None
if config_subjects_dir:
candidate = Path(config_subjects_dir).expanduser()
if not candidate.is_absolute():
candidate = (base_dir / candidate).resolve()
subject_path = candidate
if subject_path is None:
env_subjects_dir = os.environ.get("SUBJECTS_DIR")
if env_subjects_dir:
candidate = Path(env_subjects_dir).expanduser()
if not candidate.is_absolute():
candidate = (base_dir / candidate).resolve()
subject_path = candidate
if subject_path is None:
try:
sample_root = Path(mne.datasets.sample.data_path(download=False))
except Exception: # pragma: no cover - depends on mne internals
logger.debug(
"Unable to locate MNE sample dataset for default subjects_dir:",
exc_info=True,
)
else:
default_root = sample_root.expanduser()
default_path = default_root / "subjects"
subject_path = default_path.resolve()
try:
mne.utils.set_config(
"SUBJECTS_DIR", str(subject_path), set_env=True
)
except Exception: # pragma: no cover - defensive
logger.debug(
"Failed to set MNE SUBJECTS_DIR configuration",
exc_info=True,
)
[docs]
self.subjects_dir = str(subject_path) if subject_path is not None else None
for path in (
self.cached_data_dir,
self.generated_images_dir,
self.results_dir,
self.nilearn_data,
self.mne_data_dir,
self.subjects_dir,
):
if path is not None:
os.makedirs(path, exist_ok=True)
[docs]
def save(self, obj, filename: str):
"""Save an object to the data directory using pickle.
Parameters
----------
obj : Any
The object to serialize.
filename : str
Name of the file to save the object to.
Raises
------
ValueError
If the data directory is not writable.
Exception
If there is an error during saving.
Examples
--------
>>> handler = AtlasFileHandler()
>>> handler.save({'a': 1}, 'example.pkl') # doctest: +SKIP
"""
filepath = os.path.join(self.data_dir, filename)
try:
with open(filepath, "wb") as f:
pickle.dump(obj, f)
logger.info(f"Object saved to {filepath}")
except Exception as e:
logger.exception(f"Error saving object to {filepath}: {e}")
raise
[docs]
def load(self, filename: str):
"""Load an object from the data directory.
Parameters
----------
filename : str
Name of the file to load the object from.
Returns
-------
object or None
The loaded object, or ``None`` if the file does not exist.
Raises
------
Exception
If there is an error during loading.
Examples
--------
>>> handler = AtlasFileHandler()
>>> handler.load('missing.pkl') # doctest: +SKIP
None
"""
filepath = os.path.join(self.data_dir, filename)
if os.path.exists(filepath):
try:
with open(filepath, "rb") as f:
obj = pickle.load(f)
logger.info(f"Object loaded from {filepath}")
return obj
except Exception as e:
logger.exception(f"Error loading object from {filepath}: {e}")
raise
else:
return None
[docs]
def fetch_from_local(self, atlas_file: str, atlas_dir: str, labels: str | list):
"""Load an atlas from a local file.
Parameters
----------
atlas_file : str
The name of the atlas file.
atlas_dir : str
Directory where the atlas file is located.
labels : str or list
Labels file or a list of label names.
Returns
-------
dict
Dictionary containing the atlas data.
Raises
------
FileNotFoundError
If the atlas or labels file is not found.
Exception
If there is an error during loading.
Examples
--------
>>> handler = AtlasFileHandler()
>>> handler.fetch_from_local('atlas.nii.gz', '.', ['A', 'B']) # doctest: +SKIP
{'vol': array(...), 'hdr': array(...), 'labels': ['A', 'B']}
"""
logger.info(f"Loading local atlas file: {atlas_file}")
found_path = next(
(
os.path.join(root, atlas_file)
for root, _, files in os.walk(atlas_dir)
if atlas_file in files
),
None,
)
if found_path is None:
raise FileNotFoundError(
f"Atlas file {atlas_file} not found in {atlas_dir} or its "
"subdirectories"
)
logger.info(f"Atlas file found at {found_path}")
output = pack_vol_output(found_path)
if isinstance(labels, str):
found_path = next(
(
os.path.join(root, labels)
for root, _, files in os.walk(atlas_dir)
if labels in files
),
None,
)
if found_path is None:
raise FileNotFoundError(
f"Labels file {labels} not found in {atlas_dir} or its "
"subdirectories"
)
logger.info(f"Labels file found at {found_path}")
output["labels"] = fetch_labels(found_path)
elif isinstance(labels, list):
output["labels"] = fetch_labels(labels)
return output
[docs]
def fetch_from_url(self, atlas_url: str, **kwargs):
"""Download an atlas from a URL.
Parameters
----------
atlas_url : str
The URL of the atlas file.
**kwargs
Additional arguments for the download.
Returns
-------
str
Local path to the downloaded (and possibly decompressed) file.
Raises
------
RuntimeError
If the download fails.
ValueError
If the data directory is not writable.
Exception
If there is an error during downloading.
Examples
--------
>>> handler = AtlasFileHandler()
>>> handler.fetch_from_url('http://example.com/atlas.nii.gz') # doctest: +SKIP
'/path/to/atlas.nii.gz'
"""
import warnings
warnings.warn(
"The file name is expected to be in the URL", UserWarning, stacklevel=2
)
import gzip
import shutil
import tarfile
import urllib.parse
import zipfile
import requests
parsed = urllib.parse.urlparse(atlas_url)
file_name = os.path.basename(parsed.path)
local_path = os.path.join(self.data_dir, file_name)
if not os.path.exists(local_path):
logger.info(f"Downloading atlas from {atlas_url}...")
try:
with requests.get(atlas_url, stream=True, timeout=30, verify=True) as r:
r.raise_for_status()
with open(local_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
logger.info(f"Atlas downloaded to {local_path}")
except Exception as e:
if os.path.exists(local_path):
os.remove(local_path)
logger.exception(f"Failed to download from {atlas_url}")
raise RuntimeError(f"Failed to download from {atlas_url}") from e
else:
logger.info(f"Atlas already exists: {local_path}. Skipping download.")
# Check if the downloaded file is compressed and decompress if necessary.
decompressed_path = local_path
if zipfile.is_zipfile(local_path):
logger.info(f"Extracting zip file {local_path}")
extract_dir = os.path.join(self.data_dir, file_name.rstrip(".zip"))
with zipfile.ZipFile(local_path, "r") as zip_ref:
zip_ref.extractall(extract_dir)
decompressed_path = extract_dir
elif tarfile.is_tarfile(local_path):
logger.info(f"Extracting tar archive {local_path}")
# Remove possible extensions to form the extract directory name
base_name = file_name
for ext in [".tar.gz", ".tgz", ".tar"]:
if base_name.endswith(ext):
base_name = base_name[: -len(ext)]
break
extract_dir = os.path.join(self.data_dir, base_name)
with tarfile.open(local_path, "r:*") as tar_ref:
try:
tar_ref.extractall(extract_dir, filter="data")
except TypeError: # Python < 3.12 or old 3.11/3.10
tar_ref.extractall(extract_dir)
decompressed_path = extract_dir
elif local_path.endswith(".gz") and not local_path.endswith(".tar.gz"):
logger.info(f"Decompressing gzip file {local_path}")
decompressed_file = local_path[:-3]
with gzip.open(local_path, "rb") as f_in:
with open(decompressed_file, "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
decompressed_path = decompressed_file
return decompressed_path
def _results_to_dicts(results: Sequence[Any]) -> list[dict]:
"""Convert dataclass or mapping results to plain dictionaries."""
dicts: list[dict] = []
for res in results:
if is_dataclass(res):
dicts.append(asdict(res))
elif isinstance(res, dict):
dicts.append(res)
else:
dicts.append(dict(res)) # type: ignore[arg-type]
return dicts
[docs]
def save_as_pdf(results: Sequence[Any], path: str) -> None:
"""Save pipeline results to a PDF file or directory."""
dict_results = _results_to_dicts(results)
if len(dict_results) > 1 or os.path.isdir(path):
os.makedirs(path, exist_ok=True)
for idx, res in enumerate(dict_results, start=1):
pdf = FPDF()
pdf.add_page()
pdf.set_font("Arial", size=12)
coord = res.get("coordinate")
if coord is not None:
pdf.multi_cell(0, 10, f"Coordinate: {coord}")
summary = res.get("summary")
if summary:
pdf.multi_cell(0, 10, summary)
img = res.get("image")
if img:
try: # pragma: no cover - depends on PIL
pdf.image(img, w=100)
except Exception:
pass
fname = (
os.path.join(path, f"result_{idx}.pdf")
if os.path.isdir(path) or len(dict_results) > 1
else path
)
pdf.output(fname)
[docs]
def save_as_csv(results: Sequence[Any], path: str) -> None:
"""Save pipeline results to a CSV file."""
dict_results = _results_to_dicts(results)
fieldnames = [
"coordinate",
"region_labels",
"summary",
"studies",
"image",
"images",
]
os.makedirs(os.path.dirname(os.path.abspath(path)) or ".", exist_ok=True)
with open(path, "w", newline="", encoding="utf8") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for row in dict_results:
flat = {
k: json.dumps(v) if isinstance(v, list | dict) else v
for k, v in row.items()
}
writer.writerow({k: flat.get(k) for k in fieldnames})
[docs]
def save_batch_folder(results: Sequence[Any], path: str) -> None:
"""Save results as a directory with individual JSON files and images."""
dict_results = _results_to_dicts(results)
os.makedirs(path, exist_ok=True)
for idx, res in enumerate(dict_results, start=1):
out_dir = os.path.join(path, f"result_{idx}")
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, "result.json"), "w", encoding="utf8") as f:
json.dump(res, f, indent=2)
img = res.get("image")
if img and os.path.exists(img):
try:
shutil.copy(img, os.path.join(out_dir, os.path.basename(img)))
except Exception:
pass
for extra in res.get("images", {}).values():
if extra and os.path.exists(extra):
try:
shutil.copy(extra, os.path.join(out_dir, os.path.basename(extra)))
except Exception:
pass