From e41ec7eab8ff7351fdc1c25655d4e5da8e361aee Mon Sep 17 00:00:00 2001 From: ilkankilic Date: Fri, 27 Feb 2026 11:42:35 +0100 Subject: [PATCH 01/16] refactor reporting --- bluecellulab/cell/core.py | 59 ++-- bluecellulab/circuit_simulation.py | 4 +- bluecellulab/reports/manager.py | 46 ++- bluecellulab/reports/typing.py | 62 ++++ bluecellulab/reports/utils.py | 302 ++++++++++++----- bluecellulab/reports/writers/compartment.py | 76 ++--- bluecellulab/type_aliases.py | 7 +- .../simulation_config_compartment_set.json | 9 + tests/test_reports/test_compartment_writer.py | 305 ++++++++++-------- 9 files changed, 564 insertions(+), 306 deletions(-) create mode 100644 bluecellulab/reports/typing.py diff --git a/bluecellulab/cell/core.py b/bluecellulab/cell/core.py index 40596285..ef0d13d6 100644 --- a/bluecellulab/cell/core.py +++ b/bluecellulab/cell/core.py @@ -19,9 +19,10 @@ from pathlib import Path import queue -from typing import List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple from typing_extensions import deprecated +from bluecellulab.reports.typing import ReportSite import neuron import numpy as np import pandas as pd @@ -1012,46 +1013,62 @@ def resolve_segments_from_config(self, report_cfg) -> List[Tuple[NeuronSection, targets.append((sec, sec_name, seg.x)) return targets - def configure_recording(self, recording_sites, variable_name, report_name): - """Configure recording of a variable on a single cell. - - This function sets up the recording of the specified variable (e.g., membrane voltage) - in the target cell, for each resolved segment. + def configure_recording(self, + recording_sites: Iterable[tuple[NeuronSection | None, str, float]], + variable_name: str, + report_name: str + ) -> list[tuple[ReportSite, str]]: + """ + Attach NEURON recordings for a variable at the given sites and return the + recording names created. Parameters ---------- - cell : Any - The cell object on which to configure recordings. - - recording_sites : list of tuples - List of tuples (section, section_name, segment) where: - - section is the section object in the cell. - - section_name is the name of the section. - - segment is the Neuron segment index (0-1). - + recording_sites : iterable + (section, section_name, segx) tuples describing recording locations. variable_name : str - The name of the variable to record (e.g., "v" for membrane voltage). - + Variable to record (e.g. "v", "ina", "kca.gkca"). report_name : str - The name of the report (used in logging). + Report identifier (for logging). + + Returns + ------- + list[tuple[ReportSite, str]] + Pairs of (site, rec_name) for sites that were successfully configured. """ node_id = self.cell_id.id + added: list[str] = [] for sec, sec_name, seg in recording_sites: try: - self.add_variable_recording(variable=variable_name, section=sec, segx=seg) + if sec is None: + self.add_variable_recording(variable=variable_name, section=None, segx=float(seg)) + sec_obj = self.soma + rec_name = section_to_variable_recording_str(sec_obj, float(seg), variable_name) + else: + rec_name = section_to_variable_recording_str(sec, float(seg), variable_name) + if rec_name not in self.recordings: + self.add_variable_recording(variable=variable_name, section=sec, segx=float(seg)) + added.append(rec_name) + logger.info( f"Recording '{variable_name}' at {sec_name}({seg}) on GID {node_id} for report '{report_name}'" ) + except AttributeError: logger.warning( f"Recording for variable '{variable_name}' is not implemented in Cell." ) - return + continue + except Exception as e: logger.warning( - f"Failed to record '{variable_name}' at {sec_name}({seg}) on GID {node_id} for report '{report_name}': {e}" + f"Failed to record '{variable_name}' at {sec_name}({seg}) on GID {node_id} " + f"for report '{report_name}': {e}" ) + continue + + return added def add_currents_recordings( self, diff --git a/bluecellulab/circuit_simulation.py b/bluecellulab/circuit_simulation.py index 15bc4924..3d8b040d 100644 --- a/bluecellulab/circuit_simulation.py +++ b/bluecellulab/circuit_simulation.py @@ -22,7 +22,7 @@ import logging import warnings -from bluecellulab.reports.utils import configure_all_reports +from bluecellulab.reports.utils import prepare_recordings_for_reports import neuron import numpy as np import pandas as pd @@ -334,7 +334,7 @@ def instantiate_gids( add_linear_stimuli=add_linear_stimuli ) - configure_all_reports( + self.recording_index, self.sites_index = prepare_recordings_for_reports( cells=self.cells, simulation_config=self.circuit_access.config ) diff --git a/bluecellulab/reports/manager.py b/bluecellulab/reports/manager.py index cab76d03..4e2e4535 100644 --- a/bluecellulab/reports/manager.py +++ b/bluecellulab/reports/manager.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Dict +from typing import Any, Optional, Dict + +from bluecellulab.circuit.node_id import CellId from bluecellulab.reports.writers import get_writer from bluecellulab.reports.utils import SUPPORTED_REPORT_TYPES, extract_spikes_from_cells # helper you already have / write @@ -30,31 +32,23 @@ def __init__(self, config, sim_dt: float): def write_all( self, - cells_or_traces: Dict, - spikes_by_pop: Optional[Dict[str, Dict[int, list]]] = None, + cells: Dict[CellId, Any], + spikes_by_pop: Optional[Dict[str, Dict[int, list[float]]]] = None, ): - """Write all configured reports (compartment and spike) in SONATA - format. - - Parameters - ---------- - cells_or_traces : dict - A dictionary mapping (population, gid) to either: - - Cell objects with recorded data (used in single-process simulations), or - - Precomputed trace dictionaries, e.g., {"voltage": ndarray}, typically gathered across ranks in parallel runs. - - spikes_by_pop : dict, optional - A precomputed dictionary of spike times by population. - If not provided, spike times are extracted from `cells_or_traces`. - - Notes - ----- - In parallel simulations, you must gather all traces and spikes to rank 0 and pass them here. """ - self._write_voltage_reports(cells_or_traces) - self._write_spike_report(spikes_by_pop or extract_spikes_from_cells(cells_or_traces, location=self.cfg.spike_location, threshold=self.cfg.spike_threshold)) - - def _write_voltage_reports(self, cells_or_traces): + Write all configured reports (compartment and spike) in SONATA format. + `cells` entries must expose `report_sites` and `get_recording(rec_name)` + for compartment reports. If `spikes_by_pop` is None, entries must also + provide `get_recorded_spikes(location=..., threshold=...)`. + """ + self._write_compartment_reports(cells) + self._write_spike_report( + spikes_by_pop or extract_spikes_from_cells( + cells, location=self.cfg.spike_location, threshold=self.cfg.spike_threshold + ) + ) + + def _write_compartment_reports(self, cells): for name, rcfg in self.cfg.get_report_entries().items(): if rcfg.get("type") not in SUPPORTED_REPORT_TYPES: continue @@ -83,9 +77,9 @@ def _write_voltage_reports(self, cells_or_traces): out_path = self.cfg.report_file_path(rcfg, name) writer = get_writer("compartment")(rcfg, out_path, self.dt) - writer.write(cells_or_traces, self.cfg.tstart, self.cfg.tstop) + writer.write(cells, self.cfg.tstart, self.cfg.tstop) - def _write_spike_report(self, spikes_by_pop): + def _write_spike_report(self, spikes_by_pop: Dict[str, Dict[int, list[float]]]): out_path = self.cfg.spikes_file_path writer = get_writer("spikes")({}, out_path, self.dt) writer.write(spikes_by_pop) diff --git a/bluecellulab/reports/typing.py b/bluecellulab/reports/typing.py new file mode 100644 index 00000000..51f1bfb2 --- /dev/null +++ b/bluecellulab/reports/typing.py @@ -0,0 +1,62 @@ +# Copyright 2026 Open Brain Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any, Protocol, TypeAlias + +# Keep section as Any here: NeuronSection is currently a runtime alias to +# NEURON's hoc object, and using `NeuronSection | None` in a TypeAlias is +# evaluated at import time. +ReportSite: TypeAlias = tuple[Any, str, float] + + +class ReportSiteResolvable(Protocol): + """Object able to resolve recording locations for a SONATA report. + + Implemented by instantiated Cell objects during simulation setup. + """ + + def resolve_segments_from_config( + self, + report_cfg: dict + ) -> list[ReportSite]: ... + + def resolve_segments_from_compartment_set( + self, + node_id: int, + compartment_nodes: list + ) -> list[ReportSite]: ... + + +class ReportConfigurableCell(ReportSiteResolvable, Protocol): + """Cell-like object that can configure recordings from resolved sites.""" + + def configure_recording( + self, + recording_sites: Iterable[ReportSite], + variable_name: str, + report_name: str, + ) -> list[str]: ... + + +class SpikeExtractableCell(Protocol): + """Cell-like object that can return recorded spike times.""" + + def get_recorded_spikes( + self, + location: str = "soma", + threshold: float = -20.0, + ) -> Any: ... diff --git a/bluecellulab/reports/utils.py b/bluecellulab/reports/utils.py index a3b7e1d4..197ba345 100644 --- a/bluecellulab/reports/utils.py +++ b/bluecellulab/reports/utils.py @@ -12,11 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. """Report class of bluecellulab.""" +from __future__ import annotations from collections import defaultdict +from dataclasses import dataclass import logging -from typing import Dict, Any, List +from typing import Dict, Any, List, Mapping, Optional, Tuple +from bluecellulab.circuit.node_id import CellId +import numpy as np + +from bluecellulab.cell.section_tools import section_to_variable_recording_str +from bluecellulab.type_aliases import NeuronSection, SiteEntry from bluecellulab.tools import ( resolve_source_nodes, ) @@ -25,27 +32,15 @@ SUPPORTED_REPORT_TYPES = {"compartment", "compartment_set"} +def prepare_recordings_for_reports(cells, simulation_config): + recording_index: dict[CellId, list[str]] = defaultdict(list) # (pop,gid) -> [rec_name,...] ordered + sites_index: dict[CellId, list[SiteEntry]] = defaultdict(list) -def configure_all_reports(cells, simulation_config): - """Configure recordings for all reports defined in the simulation - configuration. - - This iterates through all report entries, resolves source nodes or compartments, - and configures the corresponding recordings on each cell. - - Parameters - ---------- - cells : dict - Mapping from (population, gid) → Cell object. - - simulation_config : Any - Simulation configuration object providing report entries, - node sets, and compartment sets. - """ report_entries = simulation_config.get_report_entries() for report_name, report_cfg in report_entries.items(): report_type = report_cfg.get("type", "compartment") + if report_type == "compartment_set": source_sets = simulation_config.get_compartment_sets() source_name = report_cfg.get("compartment_set") @@ -76,104 +71,88 @@ def configure_all_reports(cells, simulation_config): continue population = source["population"] - node_ids, compartment_nodes = resolve_source_nodes( - source, report_type, cells, population - ) + node_ids, compartment_nodes = resolve_source_nodes(source, report_type, cells, population) + recording_sites_per_cell = build_recording_sites( cells, node_ids, population, report_type, report_cfg, compartment_nodes ) - variable_name = report_cfg.get("variable_name", "v") - for node_id, recording_sites in recording_sites_per_cell.items(): - cell = cells.get((population, node_id)) - if not cell or recording_sites is None: + variable = report_cfg.get("variable_name", "v") + + for node_id, sites in recording_sites_per_cell.items(): + cell_id = CellId(population, node_id) + cell = cells.get(cell_id) + if cell is None or not sites: continue - cell.configure_recording(recording_sites, variable_name, report_name) + rec_names = cell.configure_recording(sites, variable, report_name) + for (sec, sec_name, segx), rec_name in zip(sites, rec_names, strict=True): + recording_index[cell_id].append(rec_name) -def build_recording_sites( - cells_or_traces, node_ids, population, report_type, report_cfg, compartment_nodes -): - """Build per-cell recording sites based on source type and report - configuration. + sites_index[cell_id].append({ + "report": report_name, + "rec_name": rec_name, + "section": sec_name, + "segx": float(segx), + }) - This function resolves the segments (section, name, seg.x) where variables - should be recorded for each cell, based on either a node set (standard - compartment reports) or a compartment set (predefined segment list). + return dict(recording_index), dict(sites_index) + +def build_recording_sites( + cells: Dict[CellId, Any], + node_ids: list[int], + population: str, + report_type: str, + report_cfg: dict, + compartment_nodes: list | None, +) -> Dict[int, List[Tuple[Any, str, float]]]: + """ + Resolve recording sites for instantiated cells in one population. Parameters ---------- - cells_or_traces : dict - Either a mapping from (population, node_id) to Cell objects (live sim), - or from gid_key strings to trace dicts (gathered traces on rank 0). - - node_ids : list of int - List of node IDs for which recordings should be configured. - + cells : dict[CellId, Any] + Mapping from CellId to cell-like objects. + node_ids : list[int] + Node IDs to resolve within `population`. population : str - Name of the population to which the cells belong. - + Population name used to build CellId(population, node_id). report_type : str - The report type, either 'compartment_set' or 'compartment'. - + "compartment" or "compartment_set". report_cfg : dict - Configuration dictionary specifying report parameters - - compartment_nodes : list or None - Optional list of [node_id, section_name, seg_x] defining segment locations - for each cell (used if report_type == 'compartment_set'). + Report configuration. + compartment_nodes : list | None + Compartment-set entries used when `report_type == "compartment_set"`. Returns ------- - dict - Mapping from node ID to list of recording site tuples: - (section_object, section_name, seg_x). + dict[int, list[tuple[Any, str, float]]] + Mapping `{node_id: [(section_obj, section_name, segx), ...]}`. """ - targets_per_cell = {} + targets_per_cell: Dict[int, List[Tuple[Any, str, float]]] = {} for node_id in node_ids: - # Handle both (pop, id) and "pop_id" keys - key = (population, node_id) - cell_or_trace = cells_or_traces.get(key) or cells_or_traces.get(f"{population}_{node_id}") - if not cell_or_trace: + cell = cells.get(CellId(population, node_id)) + if cell is None: continue - if isinstance(cell_or_trace, dict): # Trace dict, not Cell - if report_type == "compartment_set": - # Find all entries matching node_id - targets = [ - (None, section_name, segx) - for nid, section_name, segx in compartment_nodes - if nid == node_id - ] - elif report_type == "compartment": - section_name = report_cfg.get("sections", "soma") - segx = 0.5 if report_cfg.get("compartments", "center") == "center" else 0.0 - targets = [(None, f"{section_name}[0]", segx)] - else: - raise NotImplementedError( - f"Unsupported report type '{report_type}' in trace-based output." - ) + if report_type == "compartment_set": + if compartment_nodes is None: + continue + targets = cell.resolve_segments_from_compartment_set(node_id, compartment_nodes) + elif report_type == "compartment": + targets = cell.resolve_segments_from_config(report_cfg) else: - # Cell object - if report_type == "compartment_set": - targets = cell_or_trace.resolve_segments_from_compartment_set( - node_id, compartment_nodes - ) - elif report_type == "compartment": - targets = cell_or_trace.resolve_segments_from_config(report_cfg) - else: - raise NotImplementedError( - f"Report type '{report_type}' is not supported. " - f"Supported types: {SUPPORTED_REPORT_TYPES}" - ) + raise NotImplementedError( + f"Report type '{report_type}' is not supported. Supported: {SUPPORTED_REPORT_TYPES}" + ) - targets_per_cell[node_id] = targets + if targets: + targets_per_cell[node_id] = targets return targets_per_cell - def extract_spikes_from_cells( cells: Dict[Any, Any], location: str = "soma", @@ -225,3 +204,152 @@ def extract_spikes_from_cells( spikes_by_pop[pop][gid] = list(times) if times is not None else [] return dict(spikes_by_pop) + + +@dataclass(frozen=True) +class RecordedCell: + """Read-only cell-like object backed by stored recordings.""" + recordings: Dict[str, np.ndarray] + report_sites: Dict[str, list[dict]] + soma: NeuronSection | None = None + + def get_recording(self, var_name: str) -> np.ndarray: + try: + return self.recordings[var_name] + except KeyError as e: + raise ValueError(f"No recording for '{var_name}' was found.") from e + + def get_variable_recording(self, variable: str, section: Any, segx: float) -> np.ndarray: + if section is None: + section = self.soma + rec_name = section_to_variable_recording_str(section, float(segx), variable) + return self.get_recording(rec_name) + + +def payload_to_cells( + payload: Mapping[str, Any], + sites_index: Mapping[CellId, list[SiteEntry]], +) -> Dict[CellId, RecordedCell]: + """ + payload: {"pop_gid": {"recordings": {rec_name: [floats...]}}} + sites_index: {(pop,gid): [{"report":..., "rec_name":..., "section":..., "segx":...}, ...]} + """ + out: Dict[CellId, RecordedCell] = {} + + for key, blob in payload.items(): + pop, gid_s = key.rsplit("_", 1) + gid = int(gid_s) + + recs = blob.get("recordings", {}) or {} + recs_np = {name: np.asarray(vals, dtype=np.float32) for name, vals in recs.items()} + + by_report: dict[str, list[dict]] = defaultdict(list) + cell_id = CellId(pop, gid) + for site in sites_index.get(cell_id, []): + by_report[site["report"]].append(site) + + out[cell_id] = RecordedCell( + recordings=recs_np, + report_sites=dict(by_report), + ) + + return out + + +def merge_dicts(dicts: list[dict]) -> dict: + out: dict = {} + for d in dicts: + out.update(d) + return out + +def merge_spikes(list_of_pop_dicts: list[dict[str, dict[int, list]]]) -> dict[str, dict[int, list]]: + out: dict[str, dict[int, list]] = defaultdict(dict) + for pop_dict in list_of_pop_dicts: + for pop, gid_map in pop_dict.items(): + out[pop].update(gid_map) + return out + + +def gather_recording_sites( + gathered_per_rank: list[Dict[CellId, List[SiteEntry]]] +) -> Dict[CellId, List[SiteEntry]]: + """ + Combine per-rank recording site registries into a global one. + + Each rank contributes recording locations for the cells it instantiated. + This reconstructs the full recording topology across MPI ranks. + """ + merged: dict[CellId, list[SiteEntry]] = defaultdict(list) + + for rank_dict in gathered_per_rank: + if not rank_dict: + continue + for cell_key, sites in rank_dict.items(): + merged[cell_key].extend(sites) + + return dict(merged) + +def collect_local_payload( + cells: Dict[CellId, Any], + cell_ids_for_this_rank: list[CellId], + recording_index: Dict[CellId, list[str]], +) -> dict[str, dict[str, dict[str, list[float]]]]: + """ + Build rank-local payload: {'pop_gid': {'recordings': {rec_name: trace_list}}} + """ + payload: dict[str, dict[str, dict[str, list[float]]]] = {} + + for pop, gid in cell_ids_for_this_rank: + cell_id = CellId(pop, gid) + cell = cells.get(cell_id) + if cell is None: + continue + + recs: dict[str, list[float]] = {} + for rec_name in recording_index.get(cell_id, []): + recs[rec_name] = cell.get_recording(rec_name).tolist() + + payload[f"{pop}_{gid}"] = {"recordings": recs} + + return payload + + +def gather_payload_to_rank0( + pc: Any, + local_payload: dict, + local_spikes: dict, +) -> tuple[Optional[dict], Optional[dict]]: + """ + Gather payload + spikes. Returns (all_payload, all_spikes) on rank 0, else (None, None). + """ + gathered_payload = pc.py_gather(local_payload, 0) + gathered_spikes = pc.py_gather(local_spikes, 0) + + if int(pc.id()) != 0: + return None, None + + all_payload = merge_dicts(gathered_payload) + all_spikes = merge_spikes(gathered_spikes) + return all_payload, all_spikes + +def collect_local_spikes( + sim: Any, + cell_ids_for_this_rank: list[CellId], +) -> dict[str, dict[int, list[float]]]: + """ + Collect recorded spike times for local cells in {pop: {gid: [times...]}} form. + """ + spikes: dict[str, dict[int, list[float]]] = defaultdict(dict) + + for pop, gid in cell_ids_for_this_rank: + try: + cell = sim.cells[CellId(pop, gid)] + times = cell.get_recorded_spikes( + location=sim.spike_location, + threshold=sim.spike_threshold, + ) + spikes[pop][gid] = list(times) if times is not None else [] + except Exception: + spikes[pop][gid] = [] + + return spikes diff --git a/bluecellulab/reports/writers/compartment.py b/bluecellulab/reports/writers/compartment.py index d0e507fd..b4e5920c 100644 --- a/bluecellulab/reports/writers/compartment.py +++ b/bluecellulab/reports/writers/compartment.py @@ -28,74 +28,60 @@ class CompartmentReportWriter(BaseReportWriter): - """Writes SONATA compartment (voltage) reports.""" + """Writes SONATA compartment reports.""" def write(self, cells: Dict, tstart=0, tstop=None): report_name = self.cfg.get("name", "unnamed") - variable = self.cfg.get("variable_name", "v") - report_type = self.cfg.get("type", "compartment") - # Resolve source set + # Resolve which population this report targets (for H5 group path) + report_type = self.cfg.get("type", "compartment") source_sets = self.cfg["_source_sets"] if report_type == "compartment": src_name = self.cfg.get("cells") elif report_type == "compartment_set": src_name = self.cfg.get("compartment_set") else: - raise NotImplementedError( - f"Unsupported report type '{report_type}' in configuration for report '{report_name}'" - ) + raise NotImplementedError(f"Unsupported report type '{report_type}' for '{report_name}'") src = source_sets.get(src_name) if not src: - logger.warning(f"{report_type} '{src_name}' not found – skipping '{report_name}'.") + logger.warning("%s '%s' not found – skipping '%s'.", report_type, src_name, report_name) return population = src["population"] - node_ids, comp_nodes = resolve_source_nodes(src, report_type, cells, population) - recording_sites_per_cell = build_recording_sites( - cells, node_ids, population, report_type, self.cfg, comp_nodes - ) - - # Detect trace mode - sample_cell = next(iter(cells.values())) - is_trace_mode = isinstance(sample_cell, dict) data_matrix: List[np.ndarray] = [] node_id_list: List[int] = [] idx_ptr: List[int] = [0] elem_ids: List[int] = [] - for nid in sorted(recording_sites_per_cell): - recording_sites = recording_sites_per_cell[nid] - cell = cells.get((population, nid)) or cells.get(f"{population}_{nid}") - if cell is None: - logger.warning(f"Cell or trace for ({population}, {nid}) not found – skipping.") + # Iterate cells belonging to this population only + pop_cells = [(gid, cell) for (pop, gid), cell in cells.items() if pop == population] + if not pop_cells: + logger.warning("No cells found for population '%s' – skipping '%s'.", population, report_name) + return + + for gid, cell in sorted(pop_cells, key=lambda x: x[0]): + sites = getattr(cell, "report_sites", {}).get(report_name, []) + if not sites: continue - if is_trace_mode: - voltage = np.asarray(cell["voltage"], dtype=np.float32) - for sec, sec_name, seg in recording_sites: - data_matrix.append(voltage) - node_id_list.append(nid) - elem_ids.append(len(elem_ids)) - idx_ptr.append(idx_ptr[-1] + 1) - else: - for sec, sec_name, seg in recording_sites: - try: - if hasattr(cell, "get_variable_recording"): - trace = cell.get_variable_recording(variable=variable, section=sec, segx=seg) - else: - trace = np.asarray(cell["voltage"], dtype=np.float32) - data_matrix.append(trace) - node_id_list.append(nid) - elem_ids.append(len(elem_ids)) - idx_ptr.append(idx_ptr[-1] + 1) - except Exception as e: - logger.warning(f"Failed recording {nid}:{sec_name}@{seg}: {e}") + for site in sites: + rec_name = site["rec_name"] + try: + trace = cell.get_recording(rec_name) + except Exception as e: + logger.warning("Missing recording '%s' for (%s,%d) in '%s': %s", + rec_name, population, gid, report_name, e) + continue + + data_matrix.append(np.asarray(trace, dtype=np.float32)) + node_id_list.append(gid) + elem_ids.append(len(elem_ids)) + idx_ptr.append(idx_ptr[-1] + 1) if not data_matrix: - logger.warning(f"No data for report '{report_name}'.") + logger.warning("No data for report '%s'.", report_name) return self._write_sonata_report_file( @@ -213,6 +199,12 @@ def _write_sonata_report_file( if variable == "v": data_ds.attrs["units"] = "mV" + units = report_cfg.get("unit") + if units is None: + units = "mV" if variable == "v" else "unknown" + + data_ds.attrs["units"] = str(units) + mapping = grp.require_group("mapping") mapping.create_dataset("node_ids", data=node_ids_arr) mapping.create_dataset("index_pointers", data=index_ptr_arr) diff --git a/bluecellulab/type_aliases.py b/bluecellulab/type_aliases.py index 823fa2d5..69671a34 100644 --- a/bluecellulab/type_aliases.py +++ b/bluecellulab/type_aliases.py @@ -1,8 +1,10 @@ """Type aliases used within the package.""" from __future__ import annotations -from typing import Dict -from typing_extensions import TypeAlias + +from typing import Any, Dict + from neuron import h as hoc_type +from typing_extensions import TypeAlias HocObjectType: TypeAlias = hoc_type # until NEURON is typed, most NEURON types are this NeuronRNG: TypeAlias = hoc_type @@ -11,3 +13,4 @@ TStim: TypeAlias = hoc_type SectionMapping = Dict[str, NeuronSection] +SiteEntry: TypeAlias = dict[str, Any] diff --git a/examples/2-sonata-network/sim_quick_scx_sonata_multicircuit/simulation_config_compartment_set.json b/examples/2-sonata-network/sim_quick_scx_sonata_multicircuit/simulation_config_compartment_set.json index 4f8efac3..db48d9d0 100644 --- a/examples/2-sonata-network/sim_quick_scx_sonata_multicircuit/simulation_config_compartment_set.json +++ b/examples/2-sonata-network/sim_quick_scx_sonata_multicircuit/simulation_config_compartment_set.json @@ -60,6 +60,15 @@ "start_time": 1000.0, "end_time": 1275.0, "unit": "mV" + }, + "compartment_set_ik": { + "compartment_set": "Mosaic_A", + "type": "compartment_set", + "variable_name": "ik", + "dt": 0.1, + "start_time": 1000.0, + "end_time": 1275.0, + "unit": "mA/cm2" } } } diff --git a/tests/test_reports/test_compartment_writer.py b/tests/test_reports/test_compartment_writer.py index 5a35b821..17eab057 100644 --- a/tests/test_reports/test_compartment_writer.py +++ b/tests/test_reports/test_compartment_writer.py @@ -1,34 +1,48 @@ # Copyright 2025 Open Brain Institute - +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at - +# # http://www.apache.org/licenses/LICENSE-2.0 - +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from pathlib import Path -import numpy as np -import h5py -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock +import h5py +import numpy as np import pytest + from bluecellulab.circuit_simulation import CircuitSimulation -from bluecellulab.reports.writers.compartment import CompartmentReportWriter from bluecellulab.reports.manager import ReportManager +from bluecellulab.reports.writers.compartment import CompartmentReportWriter script_dir = Path(__file__).parent.parent +# ----------------------------- +# Fixtures (new "RecordedCell-like" API) +# ----------------------------- @pytest.fixture def mock_cell(): + """ + Cell-like object for the new writer API: + - .report_sites: dict[report_name -> list[site dicts]] + - .get_recording(rec_name) -> np.ndarray + """ cell = MagicMock() - cell.get_variable_recording = MagicMock(side_effect=lambda variable, section, segx: np.ones(10)) + cell.report_sites = { + "test_report": [{"rec_name": "rec_0", "section": "soma[0]", "segx": 0.5}] + } + cell.get_recording = MagicMock(return_value=np.ones(10, dtype=np.float32)) return cell @@ -37,12 +51,14 @@ def mock_cells(mock_cell): return { ("default", 1): mock_cell, ("default", 2): mock_cell, - ("default", 3): mock_cell + ("default", 3): mock_cell, } @pytest.fixture def mock_config_node_set(): + # With the refactor, the writer uses _source_sets only to determine population. + # Node selection is reflected by which cells you pass + their report_sites. return { "name": "test_report", "type": "compartment", @@ -54,9 +70,9 @@ def mock_config_node_set(): "_source_sets": { "soma_nodes": { "population": "default", - "elements": [1, 2, 3] + "elements": [1, 2, 3], } - } + }, } @@ -73,68 +89,120 @@ def mock_config_compartment_set(): "_source_sets": { "custom_segments": { "population": "default", + # content below is not used by the new writer; kept for realism "elements": { "1": [["dend[0]", 0.3]], - "2": [["soma[0]", 0.5]] - } + "2": [["soma[0]", 0.5]], + }, } }, } -@patch("bluecellulab.reports.writers.compartment.resolve_source_nodes") -@patch("bluecellulab.reports.writers.compartment.build_recording_sites") -def test_write_node_set(mock_build_sites, mock_resolve_nodes, tmp_path, mock_cells, mock_config_node_set): - mock_resolve_nodes.return_value = ([1, 2, 3], None) - mock_build_sites.return_value = { - 1: [(None, "soma[0]", 0.5)], - 2: [(None, "soma[0]", 0.5)], - 3: [(None, "soma[0]", 0.5)], - } +# ----------------------------- +# Helpers +# ----------------------------- +def make_trace(length: int, value: float) -> np.ndarray: + return (np.ones(length) * value).astype(np.float32) - writer = CompartmentReportWriter(report_cfg=mock_config_node_set, output_path=tmp_path / "report.h5", sim_dt=0.1) - writer.write(cells=mock_cells, tstart=0.0) - assert (tmp_path / "report.h5").exists() - with h5py.File(tmp_path / "report.h5", "r") as f: - assert "/report/default/data" in f - data = f["/report/default/data"][:] - assert data.shape[0] == 10 - assert data.shape[1] == 3 +def make_cell_for_report( + *, + report_name: str, + rec_sites: list[dict], + rec_to_trace: dict[str, np.ndarray], +) -> MagicMock: + cell = MagicMock() + cell.report_sites = {report_name: rec_sites} + cell.get_recording = MagicMock(side_effect=lambda rec_name: rec_to_trace[rec_name]) + return cell -@patch("bluecellulab.reports.writers.compartment.resolve_source_nodes") -@patch("bluecellulab.reports.writers.compartment.build_recording_sites") -def test_write_compartment_set(mock_build_sites, mock_resolve_nodes, tmp_path, mock_cells, mock_config_compartment_set): - mock_resolve_nodes.return_value = ([1, 2], [["1", "dend[0]", 0.3], ["2", "soma[0]", 0.5]]) - mock_build_sites.return_value = { - 1: [(None, "dend[0]", 0.3)], - 2: [(None, "soma[0]", 0.5)] - } +# ----------------------------- +# Unit tests for H5 writer +# ----------------------------- +def test_write_node_set(tmp_path, mock_cells, mock_config_node_set): + out = tmp_path / "report.h5" + writer = CompartmentReportWriter(report_cfg=mock_config_node_set, output_path=out, sim_dt=0.1) - writer = CompartmentReportWriter(report_cfg=mock_config_compartment_set, output_path=tmp_path / "report.h5", sim_dt=0.1) writer.write(cells=mock_cells, tstart=0.0) - assert (tmp_path / "report.h5").exists() - with h5py.File(tmp_path / "report.h5", "r") as f: + assert out.exists() + with h5py.File(out, "r") as f: assert "/report/default/data" in f - assert f["/report/default/data"].shape[1] == 2 + data = f["/report/default/data"][:] + # 10 time samples, 3 elements + assert data.shape == (10, 3) + assert np.allclose(data, 1.0) -def make_trace(length, value): - """Create a trace filled with a fixed value.""" - return (np.ones(length) * value).astype(np.float32) +def test_write_compartment_set(tmp_path, mock_config_compartment_set): + """ + New behavior: writer reads per-cell sites from cell.report_sites[report_name]. + So we do NOT patch build_recording_sites/resolve_source_nodes anymore. + """ + out = tmp_path / "report.h5" + + c1 = make_cell_for_report( + report_name="test_report", + rec_sites=[{"rec_name": "rec_1", "section": "dend[0]", "segx": 0.3}], + rec_to_trace={"rec_1": make_trace(10, 1.0)}, + ) + c2 = make_cell_for_report( + report_name="test_report", + rec_sites=[{"rec_name": "rec_2", "section": "soma[0]", "segx": 0.5}], + rec_to_trace={"rec_2": make_trace(10, 2.0)}, + ) + + cells = {("default", 1): c1, ("default", 2): c2} -def test_compartment_set_trace_mode_multinode_merge(tmp_path): - """Ensure trace-mode data from multiple nodes is merged correctly by node ID order.""" + writer = CompartmentReportWriter(report_cfg=mock_config_compartment_set, output_path=out, sim_dt=0.1) + writer.write(cells=cells, tstart=0.0) + + assert out.exists() + with h5py.File(out, "r") as f: + assert "/report/default/data" in f + data = f["/report/default/data"][:] + node_ids = f["/report/default/mapping/node_ids"][:] + elem_ids = f["/report/default/mapping/element_ids"][:] + ptrs = f["/report/default/mapping/index_pointers"][:] + + assert data.shape == (10, 2) + assert node_ids.tolist() == [1, 2] + assert elem_ids.tolist() == [0, 1] + assert ptrs.tolist() == [0, 1, 2] + + assert np.allclose(data[:, 0], 1.0) + assert np.allclose(data[:, 1], 2.0) + + +def test_compartment_set_multinode_order(tmp_path): + """ + New behavior replacement for old "trace-mode multinode merge": + - we build 3 cell objects for gids 0,1,2 + - each has one site for report 'trace_merge' + - verify the H5 columns are in gid order (because writer sorts cells by gid) + """ + out = tmp_path / "trace_merge.h5" tlen = 10 - time = np.linspace(0, 1, tlen).tolist() - traces = { - "NodeA_2": {"time": time, "voltage": make_trace(tlen, 30.0)}, - "NodeA_0": {"time": time, "voltage": make_trace(tlen, 10.0)}, - "NodeA_1": {"time": time, "voltage": make_trace(tlen, 20.0)}, + cells = { + ("NodeA", 2): make_cell_for_report( + report_name="trace_merge", + rec_sites=[{"rec_name": "r2", "section": "soma[0]", "segx": 0.5}], + rec_to_trace={"r2": make_trace(tlen, 30.0)}, + ), + ("NodeA", 0): make_cell_for_report( + report_name="trace_merge", + rec_sites=[{"rec_name": "r0", "section": "soma[0]", "segx": 0.5}], + rec_to_trace={"r0": make_trace(tlen, 10.0)}, + ), + ("NodeA", 1): make_cell_for_report( + report_name="trace_merge", + rec_sites=[{"rec_name": "r1", "section": "soma[0]", "segx": 0.5}], + rec_to_trace={"r1": make_trace(tlen, 20.0)}, + ), } report_cfg = { @@ -146,39 +214,15 @@ def test_compartment_set_trace_mode_multinode_merge(tmp_path): "end_time": 1.0, "dt": 0.1, "_source_sets": { - "NodeA": { - "population": "NodeA", - "compartment_set": [ - [2, 0, 0.5], - [0, 0, 0.5], - [1, 0, 0.5] - ] - } - } + "NodeA": {"population": "NodeA"}, + }, } - with patch("bluecellulab.reports.utils.resolve_source_nodes") as mock_resolve, \ - patch("bluecellulab.reports.utils.build_recording_sites") as mock_build: + writer = CompartmentReportWriter(report_cfg=report_cfg, output_path=out, sim_dt=0.1) + writer.write(cells=cells, tstart=0.0) - mock_resolve.return_value = ( - [0, 1, 2], - [[0, "soma[0]", 0.5], [1, "soma[0]", 0.5], [2, "soma[0]", 0.5]], - ) - - mock_build.return_value = { - 0: [(None, "soma[0]", 0.5)], - 1: [(None, "soma[0]", 0.5)], - 2: [(None, "soma[0]", 0.5)], - } - - writer = CompartmentReportWriter( - report_cfg=report_cfg, - output_path=tmp_path / "trace_merge.h5", - sim_dt=0.1 - ) - writer.write(cells=traces, tstart=0.0) - - with h5py.File(tmp_path / "trace_merge.h5", "r") as f: + assert out.exists() + with h5py.File(out, "r") as f: data = np.array(f["/report/NodeA/data"]) node_ids = np.array(f["/report/NodeA/mapping/node_ids"]) @@ -189,16 +233,31 @@ def test_compartment_set_trace_mode_multinode_merge(tmp_path): assert np.allclose(data[:, 2], 30.0) -def test_compartment_set_trace_mode_multisegment_node(tmp_path): - """Test recording multiple segments from a single node in trace mode.""" +def test_compartment_set_multisegment_single_node(tmp_path): + """ + New behavior replacement for old "trace-mode multisegment node": + - one cell gid 0 + - report_sites has 4 sites => 4 columns + - node_ids repeats gid for each element + - elem_ids is 0..3 and pointers [0,1,2,3,4] (one element per column) + """ + out = tmp_path / "trace_multisegment.h5" tlen = 10 - time = np.linspace(0, 1, tlen).tolist() - traces = { - "NodeA_0": { - "time": time, - "voltage": make_trace(tlen, 42.0) - } + sites = [ + {"rec_name": "rsoma", "section": "soma[0]", "segx": 0.5}, + {"rec_name": "rdend2", "section": "dend[0]", "segx": 0.2}, + {"rec_name": "rdend3", "section": "dend[0]", "segx": 0.3}, + {"rec_name": "raxon7", "section": "axon[1]", "segx": 0.7}, + ] + rec_to_trace = {s["rec_name"]: make_trace(tlen, 42.0) for s in sites} + + cells = { + ("NodeA", 0): make_cell_for_report( + report_name="trace_multisegment", + rec_sites=sites, + rec_to_trace=rec_to_trace, + ) } report_cfg = { @@ -209,27 +268,14 @@ def test_compartment_set_trace_mode_multisegment_node(tmp_path): "start_time": 0.0, "end_time": 1.0, "dt": 0.1, - "_source_sets": { - "NodeA": { - "population": "NodeA", - "compartment_set": [ - [0, "soma[0]", 0.5], - [0, "dend[0]", 0.2], - [0, "dend[0]", 0.3], - [0, "axon[1]", 0.7] - ] - } - } + "_source_sets": {"NodeA": {"population": "NodeA"}}, } - writer = CompartmentReportWriter( - report_cfg=report_cfg, - output_path=tmp_path / "trace_multisegment.h5", - sim_dt=0.1 - ) - writer.write(cells=traces, tstart=0.0) + writer = CompartmentReportWriter(report_cfg=report_cfg, output_path=out, sim_dt=0.1) + writer.write(cells=cells, tstart=0.0) - with h5py.File(tmp_path / "trace_multisegment.h5", "r") as f: + assert out.exists() + with h5py.File(out, "r") as f: data = np.array(f["/report/NodeA/data"]) node_ids = np.array(f["/report/NodeA/mapping/node_ids"]) elem_ids = np.array(f["/report/NodeA/mapping/element_ids"]) @@ -242,35 +288,42 @@ def test_compartment_set_trace_mode_multisegment_node(tmp_path): assert np.allclose(data, 42.0) -class TestSimCompartmentSet(): - """Test the graph.py module.""" +# ----------------------------- +# Integration-ish test +# ----------------------------- +class TestSimCompartmentSet: + """ + This test only makes sense if the example output files exist and the reporting + pipeline still generates both files. If your refactor changes paths/names, update + these accordingly. + """ + def setup_method(self): - """Set up the test environment.""" sim_path = ( script_dir / "examples/sim_quick_scx_sonata_multicircuit/simulation_config_compartment_set.json" ) self.sim = CircuitSimulation(sim_path) - dstut_cells = [('NodeA', 0), ('NodeA', 1)] + dstut_cells = [("NodeA", 0), ("NodeA", 1)] self.sim.instantiate_gids(dstut_cells, add_stimuli=True, add_synapses=True) self.sim.run() + # If your new flow requires payload_to_cells(...) then this integration test + # should be rewritten. For now, skip if the live cells don't have report_sites/get_recording. + sample_cell = next(iter(self.sim.cells.values())) + if not hasattr(sample_cell, "get_recording") or not hasattr(sample_cell, "report_sites"): + pytest.skip("Live cells do not expose report_sites/get_recording; update integration test to payload flow.") + report_mgr = ReportManager(self.sim.circuit_access.config, self.sim.dt) report_mgr.write_all(self.sim.cells) - self.file1_path = f"{script_dir}/examples/sim_quick_scx_sonata_multicircuit/output_sonata_compartment_set/soma.h5" - self.file2_path = f"{script_dir}/examples/sim_quick_scx_sonata_multicircuit/output_sonata_compartment_set/soma_compartment_set.h5" + self.file1_path = ( + script_dir + / "examples/sim_quick_scx_sonata_multicircuit/output_sonata_compartment_set/soma.h5" + ) + self.file2_path = ( + script_dir + / "examples/sim_quick_scx_sonata_multicircuit/output_sonata_compartment_set/soma_compartment_set.h5" + ) self.dataset_path = "/report/NodeA/data" - - def test_compartment_compartmentset_match(self): - """Compare voltage reports from compartment and compartment_set output.""" - with h5py.File(self.file1_path, "r") as f1, h5py.File(self.file2_path, "r") as f2: - assert self.dataset_path in f1, f"'{self.dataset_path}' not found in {self.file1_path}" - assert self.dataset_path in f2, f"'{self.dataset_path}' not found in {self.file2_path}" - - data1 = np.array(f1[self.dataset_path]) - data2 = np.array(f2[self.dataset_path]) - - assert data1.shape == data2.shape, f"Shape mismatch: {data1.shape} != {data2.shape}" - assert np.allclose(data1, data2), "Data mismatch in dataset content" From c0d9caceabbbad914bfec4ed6dd1e0227a1b99e3 Mon Sep 17 00:00:00 2001 From: ilkankilic Date: Fri, 27 Feb 2026 13:34:32 +0100 Subject: [PATCH 02/16] refactor part2 --- bluecellulab/cell/core.py | 10 ++++---- bluecellulab/reports/typing.py | 12 ++++++---- bluecellulab/reports/utils.py | 26 ++++++++++++++++++--- bluecellulab/reports/writers/compartment.py | 4 ---- 4 files changed, 36 insertions(+), 16 deletions(-) diff --git a/bluecellulab/cell/core.py b/bluecellulab/cell/core.py index ef0d13d6..dcafd365 100644 --- a/bluecellulab/cell/core.py +++ b/bluecellulab/cell/core.py @@ -22,7 +22,6 @@ from typing import Iterable, List, Optional, Tuple from typing_extensions import deprecated -from bluecellulab.reports.typing import ReportSite import neuron import numpy as np import pandas as pd @@ -130,6 +129,7 @@ def __init__(self, neuron.h.finitialize() self.recordings: dict[str, HocObjectType] = {} + self.report_sites: dict[str, list[dict]] = {} self.synapses: dict[SynapseID, Synapse] = {} self.connections: dict[SynapseID, bluecellulab.Connection] = {} @@ -1014,10 +1014,10 @@ def resolve_segments_from_config(self, report_cfg) -> List[Tuple[NeuronSection, return targets def configure_recording(self, - recording_sites: Iterable[tuple[NeuronSection | None, str, float]], + recording_sites: Iterable[tuple[NeuronSection | None, str, float]], variable_name: str, report_name: str - ) -> list[tuple[ReportSite, str]]: + ) -> list[str]: """ Attach NEURON recordings for a variable at the given sites and return the recording names created. @@ -1033,8 +1033,8 @@ def configure_recording(self, Returns ------- - list[tuple[ReportSite, str]] - Pairs of (site, rec_name) for sites that were successfully configured. + list[str] + Recording-name strings usable with `get_recording`. """ node_id = self.cell_id.id added: list[str] = [] diff --git a/bluecellulab/reports/typing.py b/bluecellulab/reports/typing.py index 51f1bfb2..f97e6dbc 100644 --- a/bluecellulab/reports/typing.py +++ b/bluecellulab/reports/typing.py @@ -32,13 +32,15 @@ class ReportSiteResolvable(Protocol): def resolve_segments_from_config( self, report_cfg: dict - ) -> list[ReportSite]: ... + ) -> list[ReportSite]: + ... def resolve_segments_from_compartment_set( self, node_id: int, compartment_nodes: list - ) -> list[ReportSite]: ... + ) -> list[ReportSite]: + ... class ReportConfigurableCell(ReportSiteResolvable, Protocol): @@ -49,7 +51,8 @@ def configure_recording( recording_sites: Iterable[ReportSite], variable_name: str, report_name: str, - ) -> list[str]: ... + ) -> list[str]: + ... class SpikeExtractableCell(Protocol): @@ -59,4 +62,5 @@ def get_recorded_spikes( self, location: str = "soma", threshold: float = -20.0, - ) -> Any: ... + ) -> Any: + ... diff --git a/bluecellulab/reports/utils.py b/bluecellulab/reports/utils.py index 197ba345..98ad23b9 100644 --- a/bluecellulab/reports/utils.py +++ b/bluecellulab/reports/utils.py @@ -32,6 +32,7 @@ SUPPORTED_REPORT_TYPES = {"compartment", "compartment_set"} + def prepare_recordings_for_reports(cells, simulation_config): recording_index: dict[CellId, list[str]] = defaultdict(list) # (pop,gid) -> [rec_name,...] ordered sites_index: dict[CellId, list[SiteEntry]] = defaultdict(list) @@ -85,20 +86,35 @@ def prepare_recordings_for_reports(cells, simulation_config): if cell is None or not sites: continue + if not hasattr(cell, "report_sites") or not isinstance(getattr(cell, "report_sites"), dict): + cell.report_sites = {} + cell.report_sites.setdefault(report_name, []) + rec_names = cell.configure_recording(sites, variable, report_name) + if len(rec_names) != len(sites): + logger.warning( + "Configured %d/%d recording sites for report '%s' on %s.", + len(rec_names), + len(sites), + report_name, + cell_id, + ) - for (sec, sec_name, segx), rec_name in zip(sites, rec_names, strict=True): + for (sec, sec_name, segx), rec_name in zip(sites, rec_names): recording_index[cell_id].append(rec_name) - sites_index[cell_id].append({ + site_entry = { "report": report_name, "rec_name": rec_name, "section": sec_name, "segx": float(segx), - }) + } + sites_index[cell_id].append(site_entry) + cell.report_sites[report_name].append(site_entry) return dict(recording_index), dict(sites_index) + def build_recording_sites( cells: Dict[CellId, Any], node_ids: list[int], @@ -153,6 +169,7 @@ def build_recording_sites( return targets_per_cell + def extract_spikes_from_cells( cells: Dict[Any, Any], location: str = "soma", @@ -262,6 +279,7 @@ def merge_dicts(dicts: list[dict]) -> dict: out.update(d) return out + def merge_spikes(list_of_pop_dicts: list[dict[str, dict[int, list]]]) -> dict[str, dict[int, list]]: out: dict[str, dict[int, list]] = defaultdict(dict) for pop_dict in list_of_pop_dicts: @@ -289,6 +307,7 @@ def gather_recording_sites( return dict(merged) + def collect_local_payload( cells: Dict[CellId, Any], cell_ids_for_this_rank: list[CellId], @@ -332,6 +351,7 @@ def gather_payload_to_rank0( all_spikes = merge_spikes(gathered_spikes) return all_payload, all_spikes + def collect_local_spikes( sim: Any, cell_ids_for_this_rank: list[CellId], diff --git a/bluecellulab/reports/writers/compartment.py b/bluecellulab/reports/writers/compartment.py index b4e5920c..fcb33af3 100644 --- a/bluecellulab/reports/writers/compartment.py +++ b/bluecellulab/reports/writers/compartment.py @@ -18,10 +18,6 @@ from typing import Dict, List from .base_writer import BaseReportWriter -from bluecellulab.reports.utils import ( - build_recording_sites, - resolve_source_nodes, -) import logging logger = logging.getLogger(__name__) From f01b92b0f6518640a119d7cf463364e9ce9050c2 Mon Sep 17 00:00:00 2001 From: ilkankilic Date: Fri, 27 Feb 2026 14:05:27 +0100 Subject: [PATCH 03/16] lint fix --- bluecellulab/cell/core.py | 5 ++--- bluecellulab/reports/manager.py | 5 +++-- bluecellulab/reports/utils.py | 32 +++++++++++++++++++------------- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/bluecellulab/cell/core.py b/bluecellulab/cell/core.py index dcafd365..d6e38dca 100644 --- a/bluecellulab/cell/core.py +++ b/bluecellulab/cell/core.py @@ -1018,9 +1018,8 @@ def configure_recording(self, variable_name: str, report_name: str ) -> list[str]: - """ - Attach NEURON recordings for a variable at the given sites and return the - recording names created. + """Attach NEURON recordings for a variable at the given sites and + return the recording names created. Parameters ---------- diff --git a/bluecellulab/reports/manager.py b/bluecellulab/reports/manager.py index 4e2e4535..a1272d13 100644 --- a/bluecellulab/reports/manager.py +++ b/bluecellulab/reports/manager.py @@ -35,8 +35,9 @@ def write_all( cells: Dict[CellId, Any], spikes_by_pop: Optional[Dict[str, Dict[int, list[float]]]] = None, ): - """ - Write all configured reports (compartment and spike) in SONATA format. + """Write all configured reports (compartment and spike) in SONATA + format. + `cells` entries must expose `report_sites` and `get_recording(rec_name)` for compartment reports. If `spikes_by_pop` is None, entries must also provide `get_recorded_spikes(location=..., threshold=...)`. diff --git a/bluecellulab/reports/utils.py b/bluecellulab/reports/utils.py index 98ad23b9..c1ea12d3 100644 --- a/bluecellulab/reports/utils.py +++ b/bluecellulab/reports/utils.py @@ -23,6 +23,7 @@ import numpy as np from bluecellulab.cell.section_tools import section_to_variable_recording_str +from bluecellulab.reports.typing import ReportConfigurableCell from bluecellulab.type_aliases import NeuronSection, SiteEntry from bluecellulab.tools import ( resolve_source_nodes, @@ -33,7 +34,10 @@ SUPPORTED_REPORT_TYPES = {"compartment", "compartment_set"} -def prepare_recordings_for_reports(cells, simulation_config): +def prepare_recordings_for_reports( + cells: Dict[CellId, ReportConfigurableCell], + simulation_config: Any, +) -> tuple[dict[CellId, list[str]], dict[CellId, list[SiteEntry]]]: recording_index: dict[CellId, list[str]] = defaultdict(list) # (pop,gid) -> [rec_name,...] ordered sites_index: dict[CellId, list[SiteEntry]] = defaultdict(list) @@ -86,9 +90,11 @@ def prepare_recordings_for_reports(cells, simulation_config): if cell is None or not sites: continue - if not hasattr(cell, "report_sites") or not isinstance(getattr(cell, "report_sites"), dict): - cell.report_sites = {} - cell.report_sites.setdefault(report_name, []) + report_sites = getattr(cell, "report_sites", None) + if not isinstance(report_sites, dict): + report_sites = {} + setattr(cell, "report_sites", report_sites) + report_sites.setdefault(report_name, []) rec_names = cell.configure_recording(sites, variable, report_name) if len(rec_names) != len(sites): @@ -110,7 +116,7 @@ def prepare_recordings_for_reports(cells, simulation_config): "segx": float(segx), } sites_index[cell_id].append(site_entry) - cell.report_sites[report_name].append(site_entry) + report_sites[report_name].append(site_entry) return dict(recording_index), dict(sites_index) @@ -123,8 +129,7 @@ def build_recording_sites( report_cfg: dict, compartment_nodes: list | None, ) -> Dict[int, List[Tuple[Any, str, float]]]: - """ - Resolve recording sites for instantiated cells in one population. + """Resolve recording sites for instantiated cells in one population. Parameters ---------- @@ -291,11 +296,11 @@ def merge_spikes(list_of_pop_dicts: list[dict[str, dict[int, list]]]) -> dict[st def gather_recording_sites( gathered_per_rank: list[Dict[CellId, List[SiteEntry]]] ) -> Dict[CellId, List[SiteEntry]]: - """ - Combine per-rank recording site registries into a global one. + """Combine per-rank recording site registries into a global one. - Each rank contributes recording locations for the cells it instantiated. - This reconstructs the full recording topology across MPI ranks. + Each rank contributes recording locations for the cells it + instantiated. This reconstructs the full recording topology across + MPI ranks. """ merged: dict[CellId, list[SiteEntry]] = defaultdict(list) @@ -338,8 +343,9 @@ def gather_payload_to_rank0( local_payload: dict, local_spikes: dict, ) -> tuple[Optional[dict], Optional[dict]]: - """ - Gather payload + spikes. Returns (all_payload, all_spikes) on rank 0, else (None, None). + """Gather payload + spikes. + + Returns (all_payload, all_spikes) on rank 0, else (None, None). """ gathered_payload = pc.py_gather(local_payload, 0) gathered_spikes = pc.py_gather(local_spikes, 0) From f06b502bc08b946526b44de57427bfcd7dfba4f0 Mon Sep 17 00:00:00 2001 From: ilkankilic Date: Fri, 27 Feb 2026 15:38:46 +0100 Subject: [PATCH 04/16] add unit-tests --- tests/test_reports/test_reports_utils.py | 217 +++++++++++++++++++++-- 1 file changed, 203 insertions(+), 14 deletions(-) diff --git a/tests/test_reports/test_reports_utils.py b/tests/test_reports/test_reports_utils.py index 5530f118..e7da305e 100644 --- a/tests/test_reports/test_reports_utils.py +++ b/tests/test_reports/test_reports_utils.py @@ -1,32 +1,70 @@ # Copyright 2025 Open Brain Institute - +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at - +# # http://www.apache.org/licenses/LICENSE-2.0 - +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import pytest +from __future__ import annotations + +from types import SimpleNamespace from unittest.mock import MagicMock +import numpy as np +import pytest + +from bluecellulab.circuit.node_id import CellId from bluecellulab.reports.utils import ( build_recording_sites, + collect_local_payload, + collect_local_spikes, extract_spikes_from_cells, + gather_payload_to_rank0, + gather_recording_sites, + merge_dicts, + merge_spikes, + payload_to_cells, + prepare_recordings_for_reports, ) -@pytest.fixture -def mock_cell(): - cell = MagicMock() - cell.cell_id.id = 42 - cell.add_variable_recording = MagicMock() - return cell +class DummyCell: + def __init__(self, targets, rec_names): + self.targets = targets + self.rec_names = rec_names + self.report_sites = None + + def resolve_segments_from_config(self, _cfg): + return self.targets + + def resolve_segments_from_compartment_set(self, _node_id, _compartment_nodes): + return self.targets + + def configure_recording(self, _sites, _variable, _report_name): + return self.rec_names + + +class DummyConfig: + def __init__(self, report_entries, node_sets=None, compartment_sets=None): + self._report_entries = report_entries + self._node_sets = node_sets or {} + self._compartment_sets = compartment_sets or {} + + def get_report_entries(self): + return self._report_entries + + def get_node_sets(self): + return self._node_sets + + def get_compartment_sets(self): + return self._compartment_sets def test_extract_spikes_from_cells_valid_cell(): @@ -50,21 +88,172 @@ def test_extract_spikes_invalid_cell_type(): extract_spikes_from_cells(cells) -def test_build_recording_sites_compartment(mock_cell): +def test_build_recording_sites_compartment(): mock_cfg = {"sections": "soma", "compartments": "center"} + mock_cell = MagicMock() mock_cell.resolve_segments_from_config.return_value = [("sec", "soma[0]", 0.5)] - cells = {("pop", 1): mock_cell} + cells = {CellId("pop", 1): mock_cell} result = build_recording_sites(cells, [1], "pop", "compartment", mock_cfg, None) assert 1 in result assert result[1][0][2] == 0.5 -def test_build_recording_sites_compartment_set(mock_cell): +def test_build_recording_sites_compartment_set(): + mock_cell = MagicMock() mock_cell.resolve_segments_from_compartment_set.return_value = [("sec", "dend[0]", 0.3)] - cells = {("pop", 2): mock_cell} + cells = {CellId("pop", 2): mock_cell} result = build_recording_sites(cells, [2], "pop", "compartment_set", {}, [[2, "dend[0]", 0.3]]) assert 2 in result assert result[2][0][1] == "dend[0]" + + +def test_build_recording_sites_handles_missing_and_unsupported(): + cells = {} + assert build_recording_sites(cells, [1], "pop", "compartment", {}, None) == {} + + cells_with_one = {CellId("pop", 1): DummyCell(targets=[], rec_names=[])} + with pytest.raises(NotImplementedError): + build_recording_sites(cells_with_one, [1], "pop", "unknown", {}, None) + + +def test_prepare_recordings_for_reports_compartment_populates_report_sites(caplog): + cell_id = CellId("popA", 7) + targets = [("sec", "soma[0]", 0.5), ("sec", "dend[0]", 0.3)] + cell = DummyCell(targets=targets, rec_names=["rec_soma", "rec_dend"]) + cells = {cell_id: cell} + + cfg = DummyConfig( + report_entries={"r1": {"type": "compartment", "cells": "targets", "variable_name": "v"}}, + node_sets={"targets": {"population": "popA"}}, + ) + + with caplog.at_level("WARNING"): + recording_index, sites_index = prepare_recordings_for_reports(cells, cfg) + + assert not caplog.records + assert recording_index[cell_id] == ["rec_soma", "rec_dend"] + assert len(sites_index[cell_id]) == 2 + assert "r1" in cell.report_sites + assert [s["rec_name"] for s in cell.report_sites["r1"]] == ["rec_soma", "rec_dend"] + + +def test_prepare_recordings_for_reports_warns_on_rec_mismatch(caplog): + cell_id = CellId("popA", 8) + targets = [("sec", "soma[0]", 0.5), ("sec", "dend[0]", 0.3)] + cell = DummyCell(targets=targets, rec_names=["only_one"]) + cells = {cell_id: cell} + + cfg = DummyConfig( + report_entries={"r1": {"type": "compartment", "cells": "targets", "variable_name": "v"}}, + node_sets={"targets": {"population": "popA"}}, + ) + + with caplog.at_level("WARNING"): + recording_index, sites_index = prepare_recordings_for_reports(cells, cfg) + + assert "Configured 1/2 recording sites" in caplog.text + assert recording_index[cell_id] == ["only_one"] + assert len(sites_index[cell_id]) == 1 + + +def test_prepare_recordings_for_reports_unsupported_type(): + cell_id = CellId("popA", 1) + cells = {cell_id: DummyCell(targets=[], rec_names=[])} + cfg = DummyConfig(report_entries={"r": {"type": "unsupported"}}) + + with pytest.raises(NotImplementedError): + prepare_recordings_for_reports(cells, cfg) + + +def test_payload_to_cells_and_recorded_cell_access(): + class Sec: + def name(self): + return "soma[0]" + + payload = {"popA_3": {"recordings": {"neuron.h.soma[0](0.5)._ref_v": [1.0, 2.0, 3.0]}}} + sites_index = { + CellId("popA", 3): [{ + "report": "r1", + "rec_name": "neuron.h.soma[0](0.5)._ref_v", + "section": "soma[0]", + "segx": 0.5, + }] + } + + out = payload_to_cells(payload, sites_index) + rc = out[CellId("popA", 3)] + np.testing.assert_array_equal(rc.get_recording("neuron.h.soma[0](0.5)._ref_v"), np.array([1, 2, 3], dtype=np.float32)) + np.testing.assert_array_equal( + rc.get_variable_recording("v", Sec(), 0.5), + np.array([1, 2, 3], dtype=np.float32), + ) + + with pytest.raises(ValueError, match="No recording"): + rc.get_recording("missing") + + +def test_merge_helpers(): + assert merge_dicts([{"a": 1}, {"b": 2}]) == {"a": 1, "b": 2} + assert merge_spikes([{"p": {1: [0.1]}}, {"p": {2: [0.2]}}]) == {"p": {1: [0.1], 2: [0.2]}} + + +def test_gather_recording_sites_merges_and_skips_empty(): + gathered = [ + {}, + {CellId("p", 1): [{"rec_name": "a"}]}, + {CellId("p", 1): [{"rec_name": "b"}], CellId("p", 2): [{"rec_name": "c"}]}, + ] + merged = gather_recording_sites(gathered) + assert [s["rec_name"] for s in merged[CellId("p", 1)]] == ["a", "b"] + assert [s["rec_name"] for s in merged[CellId("p", 2)]] == ["c"] + + +def test_collect_local_payload_and_spikes(): + c1 = MagicMock() + c1.get_recording.return_value = np.array([1.0, 2.0], dtype=np.float32) + c1.get_recorded_spikes.return_value = [0.2, 0.5] + + c2 = MagicMock() + c2.get_recorded_spikes.side_effect = RuntimeError("no spikes") + + cells = {CellId("p", 1): c1} + recording_index = {CellId("p", 1): ["r1"], CellId("p", 2): ["r2"]} + cell_ids = [CellId("p", 1), CellId("p", 2)] + + payload = collect_local_payload(cells, cell_ids, recording_index) + assert payload == {"p_1": {"recordings": {"r1": [1.0, 2.0]}}} + + sim = SimpleNamespace( + cells={CellId("p", 1): c1, CellId("p", 2): c2}, + spike_location="soma", + spike_threshold=-20.0, + ) + spikes = collect_local_spikes(sim, cell_ids) + assert spikes == {"p": {1: [0.2, 0.5], 2: []}} + + +def test_gather_payload_to_rank0_and_nonzero(): + class FakePC: + def __init__(self, rank): + self._rank = rank + + def py_gather(self, obj, _root): + # Simulate 2 ranks already gathered + return [obj, obj] + + def id(self): + return self._rank + + local_payload = {"p_1": {"recordings": {"r": [1.0]}}} + local_spikes = {"p": {1: [0.1]}} + + rank1 = FakePC(rank=1) + assert gather_payload_to_rank0(rank1, local_payload, local_spikes) == (None, None) + + rank0 = FakePC(rank=0) + all_payload, all_spikes = gather_payload_to_rank0(rank0, local_payload, local_spikes) + assert all_payload == {"p_1": {"recordings": {"r": [1.0]}}} + assert all_spikes == {"p": {1: [0.1]}} From 980f9deb84b6bd26f1a17c457a0dfe5126a46002 Mon Sep 17 00:00:00 2001 From: ilkankilic Date: Fri, 27 Feb 2026 15:41:10 +0100 Subject: [PATCH 05/16] minor: update example --- examples/2-sonata-network/sonata-network.ipynb | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/2-sonata-network/sonata-network.ipynb b/examples/2-sonata-network/sonata-network.ipynb index 2cb3c1b2..18233e73 100644 --- a/examples/2-sonata-network/sonata-network.ipynb +++ b/examples/2-sonata-network/sonata-network.ipynb @@ -1127,9 +1127,12 @@ "`ReportManager` can write all reports using either:\n", "\n", "- `sim.cells` (for single-process simulations), or\n", - "- Precomputed traces and spike data (required in parallel runs where results must be gathered from each rank).\n", + "- Cells reconstructed on rank 0 after gathering recordings in MPI runs.\n", "\n", - "In parallel workflows, collect spikes and traces from all ranks before calling `write_all()` to ensure complete reports." + "\n", + "In parallel workflows, gather recordings, recording sites, and spikes from all ranks, reconstruct cells on rank 0, and then call `write_all()`.\n", + "\n", + "Helper utilities in bluecellulab.reports.utils are provided for this workflow (`collect_local_payload`, `gather_payload_to_rank0`, `payload_to_cells`, `gather_recording_sites`, `collect_local_spikes`)." ] }, { From b01feba56cb45a37f2bdac1810d569e1d7bda253 Mon Sep 17 00:00:00 2001 From: ilkankilic Date: Fri, 27 Feb 2026 16:06:22 +0100 Subject: [PATCH 06/16] simplify --- bluecellulab/reports/typing.py | 66 ---------------------------------- bluecellulab/reports/utils.py | 3 +- 2 files changed, 1 insertion(+), 68 deletions(-) delete mode 100644 bluecellulab/reports/typing.py diff --git a/bluecellulab/reports/typing.py b/bluecellulab/reports/typing.py deleted file mode 100644 index f97e6dbc..00000000 --- a/bluecellulab/reports/typing.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2026 Open Brain Institute - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from collections.abc import Iterable -from typing import Any, Protocol, TypeAlias - -# Keep section as Any here: NeuronSection is currently a runtime alias to -# NEURON's hoc object, and using `NeuronSection | None` in a TypeAlias is -# evaluated at import time. -ReportSite: TypeAlias = tuple[Any, str, float] - - -class ReportSiteResolvable(Protocol): - """Object able to resolve recording locations for a SONATA report. - - Implemented by instantiated Cell objects during simulation setup. - """ - - def resolve_segments_from_config( - self, - report_cfg: dict - ) -> list[ReportSite]: - ... - - def resolve_segments_from_compartment_set( - self, - node_id: int, - compartment_nodes: list - ) -> list[ReportSite]: - ... - - -class ReportConfigurableCell(ReportSiteResolvable, Protocol): - """Cell-like object that can configure recordings from resolved sites.""" - - def configure_recording( - self, - recording_sites: Iterable[ReportSite], - variable_name: str, - report_name: str, - ) -> list[str]: - ... - - -class SpikeExtractableCell(Protocol): - """Cell-like object that can return recorded spike times.""" - - def get_recorded_spikes( - self, - location: str = "soma", - threshold: float = -20.0, - ) -> Any: - ... diff --git a/bluecellulab/reports/utils.py b/bluecellulab/reports/utils.py index c1ea12d3..af73ea6c 100644 --- a/bluecellulab/reports/utils.py +++ b/bluecellulab/reports/utils.py @@ -23,7 +23,6 @@ import numpy as np from bluecellulab.cell.section_tools import section_to_variable_recording_str -from bluecellulab.reports.typing import ReportConfigurableCell from bluecellulab.type_aliases import NeuronSection, SiteEntry from bluecellulab.tools import ( resolve_source_nodes, @@ -35,7 +34,7 @@ def prepare_recordings_for_reports( - cells: Dict[CellId, ReportConfigurableCell], + cells: Dict[CellId, Any], simulation_config: Any, ) -> tuple[dict[CellId, list[str]], dict[CellId, list[SiteEntry]]]: recording_index: dict[CellId, list[str]] = defaultdict(list) # (pop,gid) -> [rec_name,...] ordered From ea6c2e05940fb5ad18cc291169bcb3546d904506 Mon Sep 17 00:00:00 2001 From: ilkankilic Date: Fri, 27 Feb 2026 16:58:38 +0100 Subject: [PATCH 07/16] bug fix --- bluecellulab/cell/core.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/bluecellulab/cell/core.py b/bluecellulab/cell/core.py index d6e38dca..8b569d60 100644 --- a/bluecellulab/cell/core.py +++ b/bluecellulab/cell/core.py @@ -1040,14 +1040,12 @@ def configure_recording(self, for sec, sec_name, seg in recording_sites: try: - if sec is None: - self.add_variable_recording(variable=variable_name, section=None, segx=float(seg)) - sec_obj = self.soma - rec_name = section_to_variable_recording_str(sec_obj, float(seg), variable_name) - else: - rec_name = section_to_variable_recording_str(sec, float(seg), variable_name) - if rec_name not in self.recordings: - self.add_variable_recording(variable=variable_name, section=sec, segx=float(seg)) + section_obj = self.soma if sec is None else sec + rec_name = section_to_variable_recording_str(section_obj, float(seg), variable_name) + + if rec_name not in self.recordings: + self.add_variable_recording(variable=variable_name, section=section_obj, segx=float(seg)) + added.append(rec_name) logger.info( From 8ee22ba893ea71f9983f039a3b49ca879b0b9515 Mon Sep 17 00:00:00 2001 From: ilkankilic Date: Fri, 27 Feb 2026 17:06:01 +0100 Subject: [PATCH 08/16] docstring + refactor --- bluecellulab/reports/manager.py | 22 +++++-- bluecellulab/reports/utils.py | 111 +++++++++++++++++++------------- 2 files changed, 83 insertions(+), 50 deletions(-) diff --git a/bluecellulab/reports/manager.py b/bluecellulab/reports/manager.py index a1272d13..6e40d7cc 100644 --- a/bluecellulab/reports/manager.py +++ b/bluecellulab/reports/manager.py @@ -35,12 +35,24 @@ def write_all( cells: Dict[CellId, Any], spikes_by_pop: Optional[Dict[str, Dict[int, list[float]]]] = None, ): - """Write all configured reports (compartment and spike) in SONATA - format. + """Write all configured SONATA reports (compartment and spike). - `cells` entries must expose `report_sites` and `get_recording(rec_name)` - for compartment reports. If `spikes_by_pop` is None, entries must also - provide `get_recorded_spikes(location=..., threshold=...)`. + `cells` maps CellId to live Cell objects or recording proxies. + For compartment reports each entry must provide: + - ``report_sites``: ``{report_name: [site_dict, ...]}`` + - ``get_recording(rec_name)`` → recorded trace + + If ``spikes_by_pop`` is not provided, spike times are obtained from the + cells via ``get_recorded_spikes(location=..., threshold=...)``. + + Parameters + ---------- + cells : Dict[CellId, Any] + Cell objects or proxies exposing recordings and report topology. + + spikes_by_pop : dict[str, dict[int, list[float]]], optional + Precomputed spikes ``{population: {gid: [times...]}}``. If omitted, + spikes are extracted from the cells. """ self._write_compartment_reports(cells) self._write_spike_report( diff --git a/bluecellulab/reports/utils.py b/bluecellulab/reports/utils.py index af73ea6c..cfccffd8 100644 --- a/bluecellulab/reports/utils.py +++ b/bluecellulab/reports/utils.py @@ -33,89 +33,110 @@ SUPPORTED_REPORT_TYPES = {"compartment", "compartment_set"} +def _get_source_for_report(simulation_config: Any, report_name: str, report_cfg: dict) -> tuple[str, dict]: + report_type = report_cfg.get("type", "compartment") + + if report_type == "compartment_set": + source_sets = simulation_config.get_compartment_sets() + source_name = report_cfg.get("compartment_set") + key = "compartment_set" + elif report_type == "compartment": + source_sets = simulation_config.get_node_sets() + source_name = report_cfg.get("cells") + key = "cells" + else: + raise NotImplementedError( + f"Report type '{report_type}' is not supported. Supported types: {SUPPORTED_REPORT_TYPES}" + ) + + if not source_name: + logger.warning("Report '%s' missing '%s' for type '%s'.", report_name, key, report_type) + raise KeyError("missing_source_name") + + source = source_sets.get(source_name) + if not source: + logger.warning("%s '%s' not found for report '%s', skipping.", report_type, source_name, report_name) + raise KeyError("missing_source") + + return report_type, source + + +def _ensure_report_sites(cell: Any) -> dict[str, list[SiteEntry]]: + report_sites = getattr(cell, "report_sites", None) + if not isinstance(report_sites, dict): + report_sites = {} + setattr(cell, "report_sites", report_sites) + return report_sites + + def prepare_recordings_for_reports( cells: Dict[CellId, Any], simulation_config: Any, ) -> tuple[dict[CellId, list[str]], dict[CellId, list[SiteEntry]]]: - recording_index: dict[CellId, list[str]] = defaultdict(list) # (pop,gid) -> [rec_name,...] ordered - sites_index: dict[CellId, list[SiteEntry]] = defaultdict(list) + """Configure report recordings on instantiated cells and build recording + indices. - report_entries = simulation_config.get_report_entries() + Parameters + ---------- + cells + Mapping of CellId -> live Cell objects. + simulation_config + Simulation config providing report entries and node/compartment sets. - for report_name, report_cfg in report_entries.items(): - report_type = report_cfg.get("type", "compartment") + Returns + ------- + (recording_index, sites_index) + recording_index maps CellId -> ordered list of recording names (rec_name). + sites_index maps CellId -> list of site entries (report, rec_name, section, segx). - if report_type == "compartment_set": - source_sets = simulation_config.get_compartment_sets() - source_name = report_cfg.get("compartment_set") - if not source_name: - logger.warning( - f"Report '{report_name}' does not specify a node set in 'compartment_set' for {report_type}." - ) - continue - elif report_type == "compartment": - source_sets = simulation_config.get_node_sets() - source_name = report_cfg.get("cells") - if not source_name: - logger.warning( - f"Report '{report_name}' does not specify a node set in 'cells' for {report_type}." - ) - continue - else: - raise NotImplementedError( - f"Report type '{report_type}' is not supported. " - f"Supported types: {SUPPORTED_REPORT_TYPES}" - ) + Notes + ----- + Populates `cell.report_sites[report_name]` with the configured site entries. + """ + recording_index: dict[CellId, list[str]] = defaultdict(list) + sites_index: dict[CellId, list[SiteEntry]] = defaultdict(list) - source = source_sets.get(source_name) - if not source: - logger.warning( - f"{report_type} '{source_name}' not found for report '{report_name}', skipping recording." - ) + for report_name, report_cfg in simulation_config.get_report_entries().items(): + try: + report_type, source = _get_source_for_report(simulation_config, report_name, report_cfg) + except KeyError: continue population = source["population"] node_ids, compartment_nodes = resolve_source_nodes(source, report_type, cells, population) - recording_sites_per_cell = build_recording_sites( + sites_per_cell = build_recording_sites( cells, node_ids, population, report_type, report_cfg, compartment_nodes ) - variable = report_cfg.get("variable_name", "v") - for node_id, sites in recording_sites_per_cell.items(): + for node_id, sites in sites_per_cell.items(): cell_id = CellId(population, node_id) cell = cells.get(cell_id) if cell is None or not sites: continue - report_sites = getattr(cell, "report_sites", None) - if not isinstance(report_sites, dict): - report_sites = {} - setattr(cell, "report_sites", report_sites) + report_sites = _ensure_report_sites(cell) report_sites.setdefault(report_name, []) rec_names = cell.configure_recording(sites, variable, report_name) if len(rec_names) != len(sites): logger.warning( "Configured %d/%d recording sites for report '%s' on %s.", - len(rec_names), - len(sites), - report_name, - cell_id, + len(rec_names), len(sites), report_name, cell_id, ) for (sec, sec_name, segx), rec_name in zip(sites, rec_names): recording_index[cell_id].append(rec_name) - site_entry = { + entry: SiteEntry = { "report": report_name, "rec_name": rec_name, "section": sec_name, "segx": float(segx), } - sites_index[cell_id].append(site_entry) - report_sites[report_name].append(site_entry) + sites_index[cell_id].append(entry) + report_sites[report_name].append(entry) return dict(recording_index), dict(sites_index) From 4f8bf199efcd9760fbc929c024e5adc1b342ab1f Mon Sep 17 00:00:00 2001 From: ilkankilic Date: Mon, 2 Mar 2026 20:04:37 +0100 Subject: [PATCH 09/16] fix --- bluecellulab/cell/core.py | 29 ++++++++------ .../circuit_access/sonata_circuit_access.py | 4 +- bluecellulab/circuit_simulation.py | 1 - bluecellulab/reports/utils.py | 38 ++++++++++--------- bluecellulab/type_aliases.py | 13 ++++++- tests/test_cell/test_core.py | 5 ++- 6 files changed, 54 insertions(+), 36 deletions(-) diff --git a/bluecellulab/cell/core.py b/bluecellulab/cell/core.py index 8b569d60..c177547f 100644 --- a/bluecellulab/cell/core.py +++ b/bluecellulab/cell/core.py @@ -46,7 +46,7 @@ from bluecellulab.stimulus.circuit_stimulus_definitions import SynapseReplay from bluecellulab.synapse import SynapseFactory, Synapse from bluecellulab.synapse.synapse_types import SynapseID -from bluecellulab.type_aliases import HocObjectType, NeuronSection, SectionMapping +from bluecellulab.type_aliases import HocObjectType, NeuronSection, ReportSite, SectionMapping from bluecellulab.cell.section_tools import currents_vars, section_to_variable_recording_str logger = logging.getLogger(__name__) @@ -1017,7 +1017,7 @@ def configure_recording(self, recording_sites: Iterable[tuple[NeuronSection | None, str, float]], variable_name: str, report_name: str - ) -> list[str]: + ) -> list[tuple[ReportSite, str]]: """Attach NEURON recordings for a variable at the given sites and return the recording names created. @@ -1032,21 +1032,27 @@ def configure_recording(self, Returns ------- - list[str] - Recording-name strings usable with `get_recording`. + list[tuple[ReportSite, str]] + (site, rec_name) pairs for successfully configured recordings. """ node_id = self.cell_id.id - added: list[str] = [] + configured: list[tuple[ReportSite, str]] = [] + + for site in recording_sites: + sec, sec_name, seg = site - for sec, sec_name, seg in recording_sites: try: section_obj = self.soma if sec is None else sec rec_name = section_to_variable_recording_str(section_obj, float(seg), variable_name) if rec_name not in self.recordings: - self.add_variable_recording(variable=variable_name, section=section_obj, segx=float(seg)) + self.add_variable_recording( + variable=variable_name, + section=None if sec is None else sec, + segx=float(seg), + ) - added.append(rec_name) + configured.append((site, rec_name)) logger.info( f"Recording '{variable_name}' at {sec_name}({seg}) on GID {node_id} for report '{report_name}'" @@ -1054,18 +1060,17 @@ def configure_recording(self, except AttributeError: logger.warning( - f"Recording for variable '{variable_name}' is not implemented in Cell." + "Recording '%s' not available at %s(%s) on GID %s for report '%s'", + variable_name, sec_name, seg, node_id, report_name, ) - continue except Exception as e: logger.warning( f"Failed to record '{variable_name}' at {sec_name}({seg}) on GID {node_id} " f"for report '{report_name}': {e}" ) - continue - return added + return configured def add_currents_recordings( self, diff --git a/bluecellulab/circuit/circuit_access/sonata_circuit_access.py b/bluecellulab/circuit/circuit_access/sonata_circuit_access.py index a76f4a04..d7abbd5f 100644 --- a/bluecellulab/circuit/circuit_access/sonata_circuit_access.py +++ b/bluecellulab/circuit/circuit_access/sonata_circuit_access.py @@ -29,7 +29,7 @@ from bluecellulab.circuit import CellId, SynapseProperty from bluecellulab.circuit.config import SimulationConfig from bluecellulab.circuit.synapse_properties import SynapseProperties -from bluecellulab.circuit.config import SimulationConfig, SonataSimulationConfig +from bluecellulab.circuit.config import SonataSimulationConfig from bluecellulab.circuit.synapse_properties import ( properties_from_snap, properties_to_snap, @@ -301,7 +301,7 @@ def morph_filepath(self, cell_id: CellId) -> str: node_population = self._circuit.nodes[cell_id.population_name] try: # if asc defined in alternate morphology return str(node_population.morph.get_filepath(cell_id.id, extension="asc")) - except BluepySnapError as e: + except BluepySnapError: logger.debug(f"No asc morphology found for {cell_id}, trying swc.") return str(node_population.morph.get_filepath(cell_id.id)) diff --git a/bluecellulab/circuit_simulation.py b/bluecellulab/circuit_simulation.py index 3d8b040d..0ddcf82c 100644 --- a/bluecellulab/circuit_simulation.py +++ b/bluecellulab/circuit_simulation.py @@ -28,7 +28,6 @@ import pandas as pd from pydantic.types import NonNegativeInt from typing_extensions import deprecated -from typing import Optional import bluecellulab from bluecellulab.cell import CellDict diff --git a/bluecellulab/reports/utils.py b/bluecellulab/reports/utils.py index cfccffd8..4da5a944 100644 --- a/bluecellulab/reports/utils.py +++ b/bluecellulab/reports/utils.py @@ -18,6 +18,7 @@ from dataclasses import dataclass import logging from typing import Dict, Any, List, Mapping, Optional, Tuple +from neuron import h from bluecellulab.circuit.node_id import CellId import numpy as np @@ -32,6 +33,8 @@ SUPPORTED_REPORT_TYPES = {"compartment", "compartment_set"} +NeuronSection = type(h.Section()) # or your existing alias + def _get_source_for_report(simulation_config: Any, report_name: str, report_cfg: dict) -> tuple[str, dict]: report_type = report_cfg.get("type", "compartment") @@ -61,14 +64,6 @@ def _get_source_for_report(simulation_config: Any, report_name: str, report_cfg: return report_type, source -def _ensure_report_sites(cell: Any) -> dict[str, list[SiteEntry]]: - report_sites = getattr(cell, "report_sites", None) - if not isinstance(report_sites, dict): - report_sites = {} - setattr(cell, "report_sites", report_sites) - return report_sites - - def prepare_recordings_for_reports( cells: Dict[CellId, Any], simulation_config: Any, @@ -116,17 +111,20 @@ def prepare_recordings_for_reports( if cell is None or not sites: continue - report_sites = _ensure_report_sites(cell) - report_sites.setdefault(report_name, []) + cell.report_sites.setdefault(report_name, []) - rec_names = cell.configure_recording(sites, variable, report_name) - if len(rec_names) != len(sites): + configured = cell.configure_recording(sites, variable, report_name) + + if len(configured) != len(sites): logger.warning( "Configured %d/%d recording sites for report '%s' on %s.", - len(rec_names), len(sites), report_name, cell_id, + len(configured), + len(sites), + report_name, + cell_id, ) - for (sec, sec_name, segx), rec_name in zip(sites, rec_names): + for (sec, sec_name, segx), rec_name in configured: recording_index[cell_id].append(rec_name) entry: SiteEntry = { @@ -136,7 +134,7 @@ def prepare_recordings_for_reports( "segx": float(segx), } sites_index[cell_id].append(entry) - report_sites[report_name].append(entry) + cell.report_sites[report_name].append(entry) return dict(recording_index), dict(sites_index) @@ -353,7 +351,8 @@ def collect_local_payload( for rec_name in recording_index.get(cell_id, []): recs[rec_name] = cell.get_recording(rec_name).tolist() - payload[f"{pop}_{gid}"] = {"recordings": recs} + key = f"{pop}_{gid}" + payload[key] = {"recordings": recs} return payload @@ -395,7 +394,12 @@ def collect_local_spikes( threshold=sim.spike_threshold, ) spikes[pop][gid] = list(times) if times is not None else [] - except Exception: + except Exception as e: + logger.debug( + "Failed to collect spikes for (%s, %d): %s", + pop, gid, e, + exc_info=True, + ) spikes[pop][gid] = [] return spikes diff --git a/bluecellulab/type_aliases.py b/bluecellulab/type_aliases.py index 69671a34..4879712b 100644 --- a/bluecellulab/type_aliases.py +++ b/bluecellulab/type_aliases.py @@ -1,7 +1,7 @@ """Type aliases used within the package.""" from __future__ import annotations -from typing import Any, Dict +from typing import Dict, NamedTuple, Optional, TypedDict from neuron import h as hoc_type from typing_extensions import TypeAlias @@ -13,4 +13,13 @@ TStim: TypeAlias = hoc_type SectionMapping = Dict[str, NeuronSection] -SiteEntry: TypeAlias = dict[str, Any] +class SiteEntry(TypedDict): + report: str + rec_name: str + section: str + segx: float + +class ReportSite(NamedTuple): + section: Optional[NeuronSection] + section_name: str + segx: float \ No newline at end of file diff --git a/tests/test_cell/test_core.py b/tests/test_cell/test_core.py index 2cd1c5ef..13f913e9 100644 --- a/tests/test_cell/test_core.py +++ b/tests/test_cell/test_core.py @@ -639,13 +639,14 @@ def test_resolve_segments_compartment_set_by_id(self): assert seg_1 == 0.25 def test_configure_recording_success(self): - sites = [(None, "soma[0]", 0.5), (None, "dend[0]", 0.3)] + dend = MagicMock(name="dend_section") + sites = [(None, "soma[0]", 0.5), (dend, "dend[0]", 0.3)] self.cell.add_variable_recording = MagicMock() self.cell.configure_recording(sites, "v", "test_report") self.cell.add_variable_recording.assert_any_call(variable="v", section=None, segx=0.5) - self.cell.add_variable_recording.assert_any_call(variable="v", section=None, segx=0.3) + self.cell.add_variable_recording.assert_any_call(variable="v", section=dend, segx=0.3) # Optional: check number of total calls assert self.cell.add_variable_recording.call_count == 2 From d5f37c92504a4a8f85c8076af7bc05aeba55586c Mon Sep 17 00:00:00 2001 From: ilkankilic Date: Mon, 2 Mar 2026 20:24:17 +0100 Subject: [PATCH 10/16] fix --- bluecellulab/cell/core.py | 5 +++-- bluecellulab/reports/utils.py | 7 ++----- bluecellulab/type_aliases.py | 5 ++++- tests/test_reports/test_reports_utils.py | 6 +++--- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/bluecellulab/cell/core.py b/bluecellulab/cell/core.py index c177547f..b11389aa 100644 --- a/bluecellulab/cell/core.py +++ b/bluecellulab/cell/core.py @@ -1040,6 +1040,7 @@ def configure_recording(self, for site in recording_sites: sec, sec_name, seg = site + report_site = ReportSite(sec, sec_name, float(seg)) try: section_obj = self.soma if sec is None else sec @@ -1052,7 +1053,7 @@ def configure_recording(self, segx=float(seg), ) - configured.append((site, rec_name)) + configured.append((report_site, rec_name)) logger.info( f"Recording '{variable_name}' at {sec_name}({seg}) on GID {node_id} for report '{report_name}'" @@ -1060,7 +1061,7 @@ def configure_recording(self, except AttributeError: logger.warning( - "Recording '%s' not available at %s(%s) on GID %s for report '%s'", + "Recording for variable '%s' is not implemented at %s(%s) on GID %s for report '%s'", variable_name, sec_name, seg, node_id, report_name, ) diff --git a/bluecellulab/reports/utils.py b/bluecellulab/reports/utils.py index 4da5a944..15af8a9c 100644 --- a/bluecellulab/reports/utils.py +++ b/bluecellulab/reports/utils.py @@ -18,7 +18,6 @@ from dataclasses import dataclass import logging from typing import Dict, Any, List, Mapping, Optional, Tuple -from neuron import h from bluecellulab.circuit.node_id import CellId import numpy as np @@ -33,8 +32,6 @@ SUPPORTED_REPORT_TYPES = {"compartment", "compartment_set"} -NeuronSection = type(h.Section()) # or your existing alias - def _get_source_for_report(simulation_config: Any, report_name: str, report_cfg: dict) -> tuple[str, dict]: report_type = report_cfg.get("type", "compartment") @@ -250,7 +247,7 @@ def extract_spikes_from_cells( class RecordedCell: """Read-only cell-like object backed by stored recordings.""" recordings: Dict[str, np.ndarray] - report_sites: Dict[str, list[dict]] + report_sites: Dict[str, list[SiteEntry]] soma: NeuronSection | None = None def get_recording(self, var_name: str) -> np.ndarray: @@ -283,7 +280,7 @@ def payload_to_cells( recs = blob.get("recordings", {}) or {} recs_np = {name: np.asarray(vals, dtype=np.float32) for name, vals in recs.items()} - by_report: dict[str, list[dict]] = defaultdict(list) + by_report: dict[str, list[SiteEntry]] = defaultdict(list) cell_id = CellId(pop, gid) for site in sites_index.get(cell_id, []): by_report[site["report"]].append(site) diff --git a/bluecellulab/type_aliases.py b/bluecellulab/type_aliases.py index 4879712b..6434c0b0 100644 --- a/bluecellulab/type_aliases.py +++ b/bluecellulab/type_aliases.py @@ -13,13 +13,16 @@ TStim: TypeAlias = hoc_type SectionMapping = Dict[str, NeuronSection] + + class SiteEntry(TypedDict): report: str rec_name: str section: str segx: float + class ReportSite(NamedTuple): section: Optional[NeuronSection] section_name: str - segx: float \ No newline at end of file + segx: float diff --git a/tests/test_reports/test_reports_utils.py b/tests/test_reports/test_reports_utils.py index e7da305e..ccef350e 100644 --- a/tests/test_reports/test_reports_utils.py +++ b/tests/test_reports/test_reports_utils.py @@ -39,7 +39,7 @@ class DummyCell: def __init__(self, targets, rec_names): self.targets = targets self.rec_names = rec_names - self.report_sites = None + self.report_sites: dict[str, list[dict]] = {} def resolve_segments_from_config(self, _cfg): return self.targets @@ -47,8 +47,8 @@ def resolve_segments_from_config(self, _cfg): def resolve_segments_from_compartment_set(self, _node_id, _compartment_nodes): return self.targets - def configure_recording(self, _sites, _variable, _report_name): - return self.rec_names + def configure_recording(self, sites, _variable, _report_name): + return list(zip(sites, self.rec_names)) class DummyConfig: From ec4bf0f4bf431db23033e7697ffcf06d1eb14c74 Mon Sep 17 00:00:00 2001 From: ilkankilic Date: Tue, 10 Mar 2026 11:32:53 +0100 Subject: [PATCH 11/16] Include segment area and time vector in report payload --- bluecellulab/reports/utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/bluecellulab/reports/utils.py b/bluecellulab/reports/utils.py index 15af8a9c..29f8c6c9 100644 --- a/bluecellulab/reports/utils.py +++ b/bluecellulab/reports/utils.py @@ -21,6 +21,7 @@ from bluecellulab.circuit.node_id import CellId import numpy as np +import neuron from bluecellulab.cell.section_tools import section_to_variable_recording_str from bluecellulab.type_aliases import NeuronSection, SiteEntry @@ -129,6 +130,7 @@ def prepare_recordings_for_reports( "rec_name": rec_name, "section": sec_name, "segx": float(segx), + "area_um2": float(neuron.h.area(segx, sec=sec)), } sites_index[cell_id].append(entry) cell.report_sites[report_name].append(entry) @@ -348,6 +350,11 @@ def collect_local_payload( for rec_name in recording_index.get(cell_id, []): recs[rec_name] = cell.get_recording(rec_name).tolist() + try: + recs["neuron.h._ref_t"] = cell.get_recording("neuron.h._ref_t").tolist() + except Exception: + pass + key = f"{pop}_{gid}" payload[key] = {"recordings": recs} From ef74168cec65f2ff1f2b3fe06c363105ecca3dbd Mon Sep 17 00:00:00 2001 From: ilkankilic Date: Tue, 10 Mar 2026 11:36:33 +0100 Subject: [PATCH 12/16] lint fix --- bluecellulab/type_aliases.py | 1 + examples/7-Extra-Simulation/Multiple-protocols.ipynb | 1 - .../copy-hoc-morphs.ipynb | 3 ++- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/bluecellulab/type_aliases.py b/bluecellulab/type_aliases.py index 6434c0b0..29197616 100644 --- a/bluecellulab/type_aliases.py +++ b/bluecellulab/type_aliases.py @@ -20,6 +20,7 @@ class SiteEntry(TypedDict): rec_name: str section: str segx: float + area_um2: float class ReportSite(NamedTuple): diff --git a/examples/7-Extra-Simulation/Multiple-protocols.ipynb b/examples/7-Extra-Simulation/Multiple-protocols.ipynb index 4b41c231..d8196f0f 100644 --- a/examples/7-Extra-Simulation/Multiple-protocols.ipynb +++ b/examples/7-Extra-Simulation/Multiple-protocols.ipynb @@ -55,7 +55,6 @@ "\n", "from pathlib import Path\n", "\n", - "from matplotlib import pyplot as plt\n", "\n", "import neuron\n", "from bluecellulab import Cell\n", diff --git a/tests/examples/circuit_hipp_mooc_most_central_10_SP_PC/copy-hoc-morphs.ipynb b/tests/examples/circuit_hipp_mooc_most_central_10_SP_PC/copy-hoc-morphs.ipynb index cac8a3ed..2807d072 100644 --- a/tests/examples/circuit_hipp_mooc_most_central_10_SP_PC/copy-hoc-morphs.ipynb +++ b/tests/examples/circuit_hipp_mooc_most_central_10_SP_PC/copy-hoc-morphs.ipynb @@ -6,7 +6,8 @@ "metadata": {}, "outputs": [], "source": [ - "import os, shutil\n", + "import os\n", + "import shutil\n", "import pandas as pd" ] }, From 8b0f1f5fc37c1879f0c96fd792d2117d6c4f8eee Mon Sep 17 00:00:00 2001 From: ilkankilic Date: Tue, 10 Mar 2026 12:07:10 +0100 Subject: [PATCH 13/16] tests fix --- bluecellulab/reports/utils.py | 8 +++++++- bluecellulab/type_aliases.py | 2 +- .../output_sonata_hypamp/soma.h5 | Bin 13528 -> 13528 bytes .../output_sonata_noinput/soma.h5 | Bin 13528 -> 13528 bytes .../output_sonata_hypamp/soma.h5 | Bin 18328 -> 18328 bytes .../output_sonata_noinput/soma.h5 | Bin 18328 -> 18328 bytes .../output_sonata_ornstein/soma.h5 | Bin 18328 -> 18328 bytes .../output_sonata_shotnoise/soma.h5 | Bin 18328 -> 18328 bytes tests/test_reports/test_reports_utils.py | 9 ++++++++- 9 files changed, 16 insertions(+), 3 deletions(-) diff --git a/bluecellulab/reports/utils.py b/bluecellulab/reports/utils.py index 29f8c6c9..adddc93f 100644 --- a/bluecellulab/reports/utils.py +++ b/bluecellulab/reports/utils.py @@ -125,12 +125,18 @@ def prepare_recordings_for_reports( for (sec, sec_name, segx), rec_name in configured: recording_index[cell_id].append(rec_name) + area_um2 = None + try: + area_um2 = float(neuron.h.area(segx, sec=sec)) + except Exception: + pass + entry: SiteEntry = { "report": report_name, "rec_name": rec_name, "section": sec_name, "segx": float(segx), - "area_um2": float(neuron.h.area(segx, sec=sec)), + "area_um2": area_um2, } sites_index[cell_id].append(entry) cell.report_sites[report_name].append(entry) diff --git a/bluecellulab/type_aliases.py b/bluecellulab/type_aliases.py index 29197616..f719b78f 100644 --- a/bluecellulab/type_aliases.py +++ b/bluecellulab/type_aliases.py @@ -20,7 +20,7 @@ class SiteEntry(TypedDict): rec_name: str section: str segx: float - area_um2: float + area_um2: float | None class ReportSite(NamedTuple): diff --git a/tests/examples/sim_quick_scx_sonata/output_sonata_hypamp/soma.h5 b/tests/examples/sim_quick_scx_sonata/output_sonata_hypamp/soma.h5 index 8f37a650e497d9d3ce742c3c2a021d0f81ce4b8a..320ed29c85e09adc1a4f5dfe480b69f7776f8882 100644 GIT binary patch delta 58 zcmcbSc_VWJ3lBRJ0|*#Q+-SPlf#)DQM;MsTz&!b*usq|6&4MEH`6nwdOKg6iyhebL Jd2+9bHvr7l590s; delta 53 zcmcbSc_VWJ3lBRZ0|*#Q+-SPlf#)FmWJwWzrUU$&4MnE&Pdp$$S%Fz%^8@8I0*p+P HdriCnu?i4L diff --git a/tests/examples/sim_quick_scx_sonata/output_sonata_noinput/soma.h5 b/tests/examples/sim_quick_scx_sonata/output_sonata_noinput/soma.h5 index 3946c44903ea1e56ca7aef81210e02b90e599b84..bc6dd61af719a15f618f68b4569d2ef0bc486558 100644 GIT binary patch delta 58 zcmcbSc_VWJ3lBRJ0|*#Q+-SPlf#)DQM;MsTz&!b*usq|6&4MEH`6nwdOKg6iyhebL Jd2+9bHvr7l590s; delta 53 zcmcbSc_VWJ3lBRZ0|*#Q+-SPlf#)FmWJwWzrUU$&4MnE&Pdp$$S%Fz%^8@8I0*p+P HdriCnu?i4L diff --git a/tests/examples/sim_quick_scx_sonata_multicircuit/output_sonata_hypamp/soma.h5 b/tests/examples/sim_quick_scx_sonata_multicircuit/output_sonata_hypamp/soma.h5 index faf57c79b1186cc214f4b7b1a8e231a6e54bad73..9b680adfd6631a8ee8470501b26beca8502f0d57 100644 GIT binary patch delta 60 zcmbQy&p4x>aRUnvI}-y47);z~y4iuJQHmoB%x7SpyirY_amD6~YV-LgD=aRUnvJ0k-K7);z~y4iuJQEKu{HGZZ8{F^_jP3NC@Kz_0Uv&7~FW@`i( KnI`+XdjkN=oDnnt diff --git a/tests/examples/sim_quick_scx_sonata_multicircuit/output_sonata_noinput/soma.h5 b/tests/examples/sim_quick_scx_sonata_multicircuit/output_sonata_noinput/soma.h5 index a6dc7faed126f6ca044e9cc22ac50941bcd69482..618ba083df4d74748d8fe08c929a0b6535821c94 100644 GIT binary patch delta 60 zcmbQy&p4x>aRUnvI}-y47);z~y4iuJQHmoB%x7SpyirY_amD6~YV-LgD=aRUnvJ0k-K7);z~y4iuJQEKu{HGZZ8{F^_jP3NC@Kz_0Uv&7~FW@`i( KnI`+XdjkN=oDnnt diff --git a/tests/examples/sim_quick_scx_sonata_multicircuit/output_sonata_ornstein/soma.h5 b/tests/examples/sim_quick_scx_sonata_multicircuit/output_sonata_ornstein/soma.h5 index 1a9d53ae648ac71e920a8abf1e37139861f23cdc..58160e47f3df5cb0b3752a31a647b4e15c80bee8 100644 GIT binary patch delta 60 zcmbQy&p4x>aRUnvI}-y47);z~y4iuJQHmoB%x7SpyirY_amD6~YV-LgD=aRUnvJ0k-K7);z~y4iuJQEKu{HGZZ8{F^_jP3NC@Kz_0Uv&7~FW@`i( KnI`+XdjkN=oDnnt diff --git a/tests/examples/sim_quick_scx_sonata_multicircuit/output_sonata_shotnoise/soma.h5 b/tests/examples/sim_quick_scx_sonata_multicircuit/output_sonata_shotnoise/soma.h5 index 7f008df690af6a163a71ceaa8b502d1939865f2b..2d3f9d476e4e26d93096456ec3218654af60f4c8 100644 GIT binary patch delta 60 zcmbQy&p4x>aRUnvI}-y47);z~y4iuJQHmoB%x7SpyirY_amD6~YV-LgD=aRUnvJ0k-K7);z~y4iuJQEKu{HGZZ8{F^_jP3NC@Kz_0Uv&7~FW@`i( KnI`+XdjkN=oDnnt diff --git a/tests/test_reports/test_reports_utils.py b/tests/test_reports/test_reports_utils.py index ccef350e..2cdb1b4f 100644 --- a/tests/test_reports/test_reports_utils.py +++ b/tests/test_reports/test_reports_utils.py @@ -224,7 +224,14 @@ def test_collect_local_payload_and_spikes(): cell_ids = [CellId("p", 1), CellId("p", 2)] payload = collect_local_payload(cells, cell_ids, recording_index) - assert payload == {"p_1": {"recordings": {"r1": [1.0, 2.0]}}} + assert payload == { + "p_1": { + "recordings": { + "r1": [1.0, 2.0], + "neuron.h._ref_t": [1.0, 2.0], + } + } + } sim = SimpleNamespace( cells={CellId("p", 1): c1, CellId("p", 2): c2}, From faa31d14f4b75ccbe81c6fbb3ab2357d821f5f2e Mon Sep 17 00:00:00 2001 From: ilkankilic Date: Wed, 11 Mar 2026 11:59:17 +0100 Subject: [PATCH 14/16] fix redundant code --- bluecellulab/reports/writers/compartment.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/bluecellulab/reports/writers/compartment.py b/bluecellulab/reports/writers/compartment.py index fcb33af3..16bb9ad0 100644 --- a/bluecellulab/reports/writers/compartment.py +++ b/bluecellulab/reports/writers/compartment.py @@ -192,8 +192,6 @@ def _write_sonata_report_file( data_ds = grp.create_dataset("data", data=data_array.astype(np.float32)) variable = report_cfg.get("variable_name", "v") - if variable == "v": - data_ds.attrs["units"] = "mV" units = report_cfg.get("unit") if units is None: From 9dba71364d611005454603e68b750a588b660422 Mon Sep 17 00:00:00 2001 From: ilkankilic Date: Wed, 11 Mar 2026 14:13:37 +0100 Subject: [PATCH 15/16] cleanup + add more unit-tests --- tests/test_reports/test_compartment_writer.py | 57 ++++--------------- tests/test_reports/test_reports_utils.py | 41 +++++++++++++ 2 files changed, 53 insertions(+), 45 deletions(-) diff --git a/tests/test_reports/test_compartment_writer.py b/tests/test_reports/test_compartment_writer.py index 17eab057..509f4388 100644 --- a/tests/test_reports/test_compartment_writer.py +++ b/tests/test_reports/test_compartment_writer.py @@ -28,16 +28,8 @@ script_dir = Path(__file__).parent.parent -# ----------------------------- -# Fixtures (new "RecordedCell-like" API) -# ----------------------------- @pytest.fixture def mock_cell(): - """ - Cell-like object for the new writer API: - - .report_sites: dict[report_name -> list[site dicts]] - - .get_recording(rec_name) -> np.ndarray - """ cell = MagicMock() cell.report_sites = { "test_report": [{"rec_name": "rec_0", "section": "soma[0]", "segx": 0.5}] @@ -57,8 +49,6 @@ def mock_cells(mock_cell): @pytest.fixture def mock_config_node_set(): - # With the refactor, the writer uses _source_sets only to determine population. - # Node selection is reflected by which cells you pass + their report_sites. return { "name": "test_report", "type": "compartment", @@ -89,7 +79,6 @@ def mock_config_compartment_set(): "_source_sets": { "custom_segments": { "population": "default", - # content below is not used by the new writer; kept for realism "elements": { "1": [["dend[0]", 0.3]], "2": [["soma[0]", 0.5]], @@ -99,9 +88,6 @@ def mock_config_compartment_set(): } -# ----------------------------- -# Helpers -# ----------------------------- def make_trace(length: int, value: float) -> np.ndarray: return (np.ones(length) * value).astype(np.float32) @@ -118,9 +104,6 @@ def make_cell_for_report( return cell -# ----------------------------- -# Unit tests for H5 writer -# ----------------------------- def test_write_node_set(tmp_path, mock_cells, mock_config_node_set): out = tmp_path / "report.h5" writer = CompartmentReportWriter(report_cfg=mock_config_node_set, output_path=out, sim_dt=0.1) @@ -138,10 +121,6 @@ def test_write_node_set(tmp_path, mock_cells, mock_config_node_set): def test_write_compartment_set(tmp_path, mock_config_compartment_set): - """ - New behavior: writer reads per-cell sites from cell.report_sites[report_name]. - So we do NOT patch build_recording_sites/resolve_source_nodes anymore. - """ out = tmp_path / "report.h5" c1 = make_cell_for_report( @@ -178,12 +157,6 @@ def test_write_compartment_set(tmp_path, mock_config_compartment_set): def test_compartment_set_multinode_order(tmp_path): - """ - New behavior replacement for old "trace-mode multinode merge": - - we build 3 cell objects for gids 0,1,2 - - each has one site for report 'trace_merge' - - verify the H5 columns are in gid order (because writer sorts cells by gid) - """ out = tmp_path / "trace_merge.h5" tlen = 10 @@ -234,13 +207,6 @@ def test_compartment_set_multinode_order(tmp_path): def test_compartment_set_multisegment_single_node(tmp_path): - """ - New behavior replacement for old "trace-mode multisegment node": - - one cell gid 0 - - report_sites has 4 sites => 4 columns - - node_ids repeats gid for each element - - elem_ids is 0..3 and pointers [0,1,2,3,4] (one element per column) - """ out = tmp_path / "trace_multisegment.h5" tlen = 10 @@ -288,16 +254,7 @@ def test_compartment_set_multisegment_single_node(tmp_path): assert np.allclose(data, 42.0) -# ----------------------------- -# Integration-ish test -# ----------------------------- class TestSimCompartmentSet: - """ - This test only makes sense if the example output files exist and the reporting - pipeline still generates both files. If your refactor changes paths/names, update - these accordingly. - """ - def setup_method(self): sim_path = ( script_dir @@ -309,8 +266,6 @@ def setup_method(self): self.sim.instantiate_gids(dstut_cells, add_stimuli=True, add_synapses=True) self.sim.run() - # If your new flow requires payload_to_cells(...) then this integration test - # should be rewritten. For now, skip if the live cells don't have report_sites/get_recording. sample_cell = next(iter(self.sim.cells.values())) if not hasattr(sample_cell, "get_recording") or not hasattr(sample_cell, "report_sites"): pytest.skip("Live cells do not expose report_sites/get_recording; update integration test to payload flow.") @@ -327,3 +282,15 @@ def setup_method(self): / "examples/sim_quick_scx_sonata_multicircuit/output_sonata_compartment_set/soma_compartment_set.h5" ) self.dataset_path = "/report/NodeA/data" + + def test_compartment_compartmentset_match(self): + """Compare voltage reports from compartment and compartment_set output.""" + with h5py.File(self.file1_path, "r") as f1, h5py.File(self.file2_path, "r") as f2: + assert self.dataset_path in f1, f"'{self.dataset_path}' not found in {self.file1_path}" + assert self.dataset_path in f2, f"'{self.dataset_path}' not found in {self.file2_path}" + + data1 = np.array(f1[self.dataset_path]) + data2 = np.array(f2[self.dataset_path]) + + assert data1.shape == data2.shape, f"Shape mismatch: {data1.shape} != {data2.shape}" + assert np.allclose(data1, data2), "Data mismatch in dataset content" diff --git a/tests/test_reports/test_reports_utils.py b/tests/test_reports/test_reports_utils.py index 2cdb1b4f..c49fdd2c 100644 --- a/tests/test_reports/test_reports_utils.py +++ b/tests/test_reports/test_reports_utils.py @@ -159,6 +159,47 @@ def test_prepare_recordings_for_reports_warns_on_rec_mismatch(caplog): assert len(sites_index[cell_id]) == 1 +def test_prepare_recordings_for_reports_populates_area_um2(monkeypatch): + cell_id = CellId("popA", 9) + targets = [("sec", "soma[0]", 0.5)] + cell = DummyCell(targets=targets, rec_names=["rec_soma"]) + cells = {cell_id: cell} + + cfg = DummyConfig( + report_entries={"r1": {"type": "compartment", "cells": "targets", "variable_name": "v"}}, + node_sets={"targets": {"population": "popA"}}, + ) + + mock_neuron = SimpleNamespace(h=SimpleNamespace(area=lambda _segx, sec: 12.34)) + monkeypatch.setattr("bluecellulab.reports.utils.neuron", mock_neuron) + + _, sites_index = prepare_recordings_for_reports(cells, cfg) + + assert sites_index[cell_id][0]["area_um2"] == 12.34 + + +def test_prepare_recordings_for_reports_area_failure_sets_none(monkeypatch): + cell_id = CellId("popA", 10) + targets = [("sec", "soma[0]", 0.5)] + cell = DummyCell(targets=targets, rec_names=["rec_soma"]) + cells = {cell_id: cell} + + cfg = DummyConfig( + report_entries={"r1": {"type": "compartment", "cells": "targets", "variable_name": "v"}}, + node_sets={"targets": {"population": "popA"}}, + ) + + def _raise_area(_segx, sec): + raise RuntimeError("area unavailable") + + mock_neuron = SimpleNamespace(h=SimpleNamespace(area=_raise_area)) + monkeypatch.setattr("bluecellulab.reports.utils.neuron", mock_neuron) + + _, sites_index = prepare_recordings_for_reports(cells, cfg) + + assert sites_index[cell_id][0]["area_um2"] is None + + def test_prepare_recordings_for_reports_unsupported_type(): cell_id = CellId("popA", 1) cells = {cell_id: DummyCell(targets=[], rec_names=[])} From 37482b78db1eaecc20b96ffbc86be0c37277f058 Mon Sep 17 00:00:00 2001 From: ilkankilic Date: Wed, 11 Mar 2026 15:31:47 +0100 Subject: [PATCH 16/16] address reviews --- bluecellulab/reports/utils.py | 52 ++++++++++++------------ tests/test_reports/test_reports_utils.py | 11 ++++- 2 files changed, 36 insertions(+), 27 deletions(-) diff --git a/bluecellulab/reports/utils.py b/bluecellulab/reports/utils.py index adddc93f..25df7641 100644 --- a/bluecellulab/reports/utils.py +++ b/bluecellulab/reports/utils.py @@ -159,31 +159,39 @@ def build_recording_sites( cells : dict[CellId, Any] Mapping from CellId to cell-like objects. node_ids : list[int] - Node IDs to resolve within `population`. + List of node IDs for which recordings should be configured. population : str - Population name used to build CellId(population, node_id). + Name of the population to which the cells belong. report_type : str - "compartment" or "compartment_set". + The report type, either 'compartment_set' or 'compartment'. report_cfg : dict Report configuration. - compartment_nodes : list | None - Compartment-set entries used when `report_type == "compartment_set"`. + compartment_nodes : list or None + Optional list of [node_id, section_name, seg_x] defining segment locations + for each cell (used if report_type == 'compartment_set'). Returns ------- - dict[int, list[tuple[Any, str, float]]] - Mapping `{node_id: [(section_obj, section_name, segx), ...]}`. + dict + Mapping from node ID to list of recording site tuples: + (section_object, section_name, seg_x). """ targets_per_cell: Dict[int, List[Tuple[Any, str, float]]] = {} + if report_type == "compartment_set" and compartment_nodes is None: + logger.warning( + "Report type 'compartment_set' requires compartment nodes, but none were found " + "for population '%s'. No recording sites will be resolved.", + population, + ) + return {} + for node_id in node_ids: cell = cells.get(CellId(population, node_id)) if cell is None: continue if report_type == "compartment_set": - if compartment_nodes is None: - continue targets = cell.resolve_segments_from_compartment_set(node_id, compartment_nodes) elif report_type == "compartment": targets = cell.resolve_segments_from_config(report_cfg) @@ -272,24 +280,22 @@ def get_variable_recording(self, variable: str, section: Any, segx: float) -> np def payload_to_cells( - payload: Mapping[str, Any], + payload: Mapping[CellId, Any], sites_index: Mapping[CellId, list[SiteEntry]], ) -> Dict[CellId, RecordedCell]: """ - payload: {"pop_gid": {"recordings": {rec_name: [floats...]}}} - sites_index: {(pop,gid): [{"report":..., "rec_name":..., "section":..., "segx":...}, ...]} + payload: {CellId(...): {"recordings": {rec_name: [floats...]}}} + sites_index: {CellId(...): [{"report":..., "rec_name":..., "section":..., "segx":...}, ...]} """ out: Dict[CellId, RecordedCell] = {} - for key, blob in payload.items(): - pop, gid_s = key.rsplit("_", 1) - gid = int(gid_s) + for cell_id, blob in payload.items(): recs = blob.get("recordings", {}) or {} recs_np = {name: np.asarray(vals, dtype=np.float32) for name, vals in recs.items()} by_report: dict[str, list[SiteEntry]] = defaultdict(list) - cell_id = CellId(pop, gid) + for site in sites_index.get(cell_id, []): by_report[site["report"]].append(site) @@ -333,18 +339,18 @@ def gather_recording_sites( for cell_key, sites in rank_dict.items(): merged[cell_key].extend(sites) - return dict(merged) + return merged def collect_local_payload( cells: Dict[CellId, Any], cell_ids_for_this_rank: list[CellId], recording_index: Dict[CellId, list[str]], -) -> dict[str, dict[str, dict[str, list[float]]]]: +) -> dict[CellId, dict[str, Any]]: """ Build rank-local payload: {'pop_gid': {'recordings': {rec_name: trace_list}}} """ - payload: dict[str, dict[str, dict[str, list[float]]]] = {} + payload: dict[CellId, dict[str, dict[str, list[float]]]] = {} for pop, gid in cell_ids_for_this_rank: cell_id = CellId(pop, gid) @@ -356,13 +362,9 @@ def collect_local_payload( for rec_name in recording_index.get(cell_id, []): recs[rec_name] = cell.get_recording(rec_name).tolist() - try: - recs["neuron.h._ref_t"] = cell.get_recording("neuron.h._ref_t").tolist() - except Exception: - pass + recs["neuron.h._ref_t"] = cell.get_time().tolist() - key = f"{pop}_{gid}" - payload[key] = {"recordings": recs} + payload[cell_id] = {"recordings": recs} return payload diff --git a/tests/test_reports/test_reports_utils.py b/tests/test_reports/test_reports_utils.py index c49fdd2c..88ed17c6 100644 --- a/tests/test_reports/test_reports_utils.py +++ b/tests/test_reports/test_reports_utils.py @@ -214,7 +214,13 @@ class Sec: def name(self): return "soma[0]" - payload = {"popA_3": {"recordings": {"neuron.h.soma[0](0.5)._ref_v": [1.0, 2.0, 3.0]}}} + payload = { + CellId("popA", 3): { + "recordings": { + "neuron.h.soma[0](0.5)._ref_v": [1.0, 2.0, 3.0] + } + } + } sites_index = { CellId("popA", 3): [{ "report": "r1", @@ -255,6 +261,7 @@ def test_gather_recording_sites_merges_and_skips_empty(): def test_collect_local_payload_and_spikes(): c1 = MagicMock() c1.get_recording.return_value = np.array([1.0, 2.0], dtype=np.float32) + c1.get_time.return_value = np.array([1.0, 2.0], dtype=np.float32) c1.get_recorded_spikes.return_value = [0.2, 0.5] c2 = MagicMock() @@ -266,7 +273,7 @@ def test_collect_local_payload_and_spikes(): payload = collect_local_payload(cells, cell_ids, recording_index) assert payload == { - "p_1": { + CellId("p", 1): { "recordings": { "r1": [1.0, 2.0], "neuron.h._ref_t": [1.0, 2.0],