From bfd45bebaa4d6bdfce492a53262d21d28388b5f8 Mon Sep 17 00:00:00 2001 From: ilkankilic Date: Tue, 24 Feb 2026 16:29:32 +0100 Subject: [PATCH 1/3] neurodamus synaptic seeding alignment --- bluecellulab/cell/core.py | 6 +- .../circuit/circuit_access/definition.py | 3 +- .../circuit_access/sonata_circuit_access.py | 14 +- bluecellulab/circuit/gid_resolver.py | 23 ++ bluecellulab/circuit_simulation.py | 165 +++++++------- bluecellulab/connection.py | 2 + bluecellulab/synapse/synapse_factory.py | 7 +- bluecellulab/synapse/synapse_types.py | 32 +-- tests/test_circuit_simulation_mpi.py | 212 ++++++++++-------- 9 files changed, 249 insertions(+), 215 deletions(-) create mode 100644 bluecellulab/circuit/gid_resolver.py diff --git a/bluecellulab/cell/core.py b/bluecellulab/cell/core.py index 40596285..83a8ad7d 100644 --- a/bluecellulab/cell/core.py +++ b/bluecellulab/cell/core.py @@ -394,7 +394,8 @@ def add_replay_synapse(self, connection_modifiers: dict, condition_parameters: Conditions, popids: tuple[int, int], - extracellular_calcium: float | None) -> None: + extracellular_calcium: float | None, + post_gid: int) -> None: """Add synapse based on the syn_description to the cell.""" synapse = SynapseFactory.create_synapse( cell=self, @@ -403,7 +404,8 @@ def add_replay_synapse(self, condition_parameters=condition_parameters, popids=popids, extracellular_calcium=extracellular_calcium, - connection_modifiers=connection_modifiers) + connection_modifiers=connection_modifiers, + post_gid=post_gid) self.synapses[synapse_id] = synapse diff --git a/bluecellulab/circuit/circuit_access/definition.py b/bluecellulab/circuit/circuit_access/definition.py index b37b2640..2fcbb5f1 100644 --- a/bluecellulab/circuit/circuit_access/definition.py +++ b/bluecellulab/circuit/circuit_access/definition.py @@ -88,7 +88,7 @@ def get_cell_properties( raise NotImplementedError def extract_synapses( - self, cell_id: CellId, projections: Optional[list[str] | str] + self, cell_id: CellId, projections: Optional[list[str] | str | bool] ) -> pd.DataFrame: raise NotImplementedError @@ -122,3 +122,4 @@ def morph_filepath(self, cell_id: CellId) -> str: def emodel_path(self, cell_id: CellId) -> str: raise NotImplementedError + diff --git a/bluecellulab/circuit/circuit_access/sonata_circuit_access.py b/bluecellulab/circuit/circuit_access/sonata_circuit_access.py index a76f4a04..a0ffe400 100644 --- a/bluecellulab/circuit/circuit_access/sonata_circuit_access.py +++ b/bluecellulab/circuit/circuit_access/sonata_circuit_access.py @@ -178,12 +178,9 @@ def _select_edge_pop_names(self, projections) -> list[str]: return out def extract_synapses( - self, cell_id: CellId, projections: Optional[list[str] | str] + self, cell_id: CellId, projections: Optional[list[str] | str | bool] ) -> pd.DataFrame: - """Extract the synapses. - - If projections is None, all the synapses are extracted. - """ + """Extract the synapses.""" snap_node_id = CircuitNodeId(cell_id.population_name, cell_id.id) edges = self._circuit.edges @@ -308,3 +305,10 @@ def morph_filepath(self, cell_id: CellId) -> str: def emodel_path(self, cell_id: CellId) -> str: node_population = self._circuit.nodes[cell_id.population_name] return str(node_population.models.get_filepath(cell_id.id)) + + def node_population_sizes(self) -> dict[str, int]: + out: dict[str, int] = {} + for pop_name, node_pop in self._circuit.nodes.items(): + s = getattr(node_pop, "size", None) + out[str(pop_name)] = int(s() if callable(s) else s) + return out diff --git a/bluecellulab/circuit/gid_resolver.py b/bluecellulab/circuit/gid_resolver.py new file mode 100644 index 00000000..1c8a8d41 --- /dev/null +++ b/bluecellulab/circuit/gid_resolver.py @@ -0,0 +1,23 @@ +# Copyright 2025 Open Brain Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""GID resolver for NEURON simulations in BlueCelluLab.""" +from dataclasses import dataclass + + +@dataclass(frozen=True) +class GidNamespace: + pop_offset: dict[str, int] + + def global_gid(self, pop: str, local_id: int) -> int: + return int(self.pop_offset[pop]) + int(local_id) + 1 # 1-based indexing to mirror Neurodamus synapse seeding implementation diff --git a/bluecellulab/circuit_simulation.py b/bluecellulab/circuit_simulation.py index 15bc4924..1c6b664a 100644 --- a/bluecellulab/circuit_simulation.py +++ b/bluecellulab/circuit_simulation.py @@ -22,6 +22,7 @@ import logging import warnings +from bluecellulab.circuit.gid_resolver import GidNamespace from bluecellulab.reports.utils import configure_all_reports import neuron import numpy as np @@ -134,13 +135,12 @@ def __init__( self.spike_threshold = self.circuit_access.config.spike_threshold self.spike_location = self.circuit_access.config.spike_location - self.projections: list[str] | str | None = None + self.projections: list[str] | str | bool = False condition_parameters = self.circuit_access.config.condition_parameters() set_global_condition_parameters(condition_parameters) - self._gid_stride = 1_000_000 # must exceed the max node_id in any population - self._pop_index: dict[str, int] = {"": 0} + self.gid_resolver: Optional[GidNamespace] = None def instantiate_gids( self, @@ -276,18 +276,25 @@ def instantiate_gids( "if you want to specify use add_replay or " "pre_spike_trains") - # legacy for backward compatibility - if add_projections is None: - add_projections = False - if add_projections is True: self.projections = self.circuit_access.config.get_all_projection_names() - elif add_projections is False: - self.projections = None else: self.projections = add_projections self._add_cells(cell_ids) + + need_gids = ( + add_synapses + or (self.pc is not None) + or add_replay + or interconnect_cells + or pre_spike_trains is not None + or self.print_cellstate + ) + + if need_gids: + self.gids = self._build_gid_namespace() + if add_synapses: self._add_synapses( pre_gids=pre_gids, @@ -297,7 +304,6 @@ def instantiate_gids( raise BluecellulabError("add_replay option can not be used if " "add_synapses is False") if self.pc is not None: - self._init_pop_index_mpi() self._register_gids_for_mpi() self.pc.barrier() self.pc.setup_transfer() @@ -473,12 +479,17 @@ def _add_cell_synapses( syn_description["source_popid"], syn_description["target_popid"], ) + + post_gid = self.global_gid(cell_id.population_name, cell_id.id) + self._instantiate_synapse( cell_id=cell_id, syn_id=idx, # type: ignore syn_description=syn_description, add_minis=add_minis, - popids=popids + popids=popids, + post_gid=post_gid + ) logger.info(f"Added {syn_descriptions} synapses for gid {cell_id}") if add_minis: @@ -607,26 +618,26 @@ def _add_connections( syn_description: pd.Series = synapse.syn_description delay_weights = synapse.delay_weights source_population = syn_description["source_population_name"] - pre_gid = CellId(source_population, int(syn_description[SynapseProperty.PRE_GID])) + pre_local_id = CellId(source_population, int(syn_description[SynapseProperty.PRE_GID])) - ov = self._find_matching_override(connections_overrides, pre_gid, post_gid) + ov = self._find_matching_override(connections_overrides, pre_local_id, post_gid) if ov is not None and ov.weight == 0.0: logger.debug( "Skipping connection due to zero weight override: %s -> %s | syn_id=%s", - pre_gid, post_gid, syn_id + pre_local_id, post_gid, syn_id ) continue if self.pc is None: - real_synapse_connection = bool(interconnect_cells) and (pre_gid in self.cells) + real_synapse_connection = bool(interconnect_cells) and (pre_local_id in self.cells) else: real_synapse_connection = bool(interconnect_cells) if real_synapse_connection: if ( user_pre_spike_trains is not None - and pre_gid in user_pre_spike_trains + and pre_local_id in user_pre_spike_trains ): raise BluecellulabError( """Specifying prespike trains of real connections""" @@ -636,18 +647,18 @@ def _add_connections( connection = bluecellulab.Connection( self.cells[post_gid].synapses[syn_id], pre_spiketrain=None, - pre_cell=self.cells[pre_gid], + pre_cell=self.cells[pre_local_id], stim_dt=self.dt, parallel_context=None, spike_threshold=self.spike_threshold, spike_location=self.spike_location, ) else: # MPI cross-rank - pre_g = self.global_gid(pre_gid.population_name, pre_gid.id) + pre_gid = self.global_gid(pre_local_id.population_name, pre_local_id.id) connection = bluecellulab.Connection( self.cells[post_gid].synapses[syn_id], pre_spiketrain=None, - pre_gid=pre_g, + pre_gid=pre_gid, pre_cell=None, stim_dt=self.dt, parallel_context=self.pc, @@ -655,9 +666,9 @@ def _add_connections( spike_location=self.spike_location, ) - logger.debug(f"Added real connection between {pre_gid} and {post_gid}, {syn_id}") + logger.debug(f"Added real connection between {pre_local_id} and {post_gid}, {syn_id}") else: # replay connection - pre_spiketrain = pre_spike_trains.get(pre_gid, None) + pre_spiketrain = pre_spike_trains.get(pre_local_id, None) connection = bluecellulab.Connection( self.cells[post_gid].synapses[syn_id], pre_spiketrain=pre_spiketrain, @@ -668,12 +679,12 @@ def _add_connections( spike_location=self.spike_location ) - logger.debug(f"Added replay connection from {pre_gid} to {post_gid}, {syn_id}") + logger.debug(f"Added replay connection from {pre_local_id} to {post_gid}, {syn_id}") if ov is not None: logger.debug( "Override matched: %s -> %s | syn_id=%s | weight=%s delay=%s", - pre_gid, post_gid, syn_id, ov.weight, ov.delay + pre_local_id, post_gid, syn_id, ov.weight, ov.delay ) syn_delay = getattr(ov, "synapse_delay_override", None) @@ -681,21 +692,21 @@ def _add_connections( connection.set_netcon_delay(float(syn_delay)) logger.debug( "Applied synapse_delay_override %.4g ms to %s -> %s | syn_id=%s", - syn_delay, pre_gid, post_gid, syn_id + syn_delay, pre_local_id, post_gid, syn_id ) if ov.delay is not None: logger.warning( "SONATA override 'delay' (delayed weight activation) is not supported yet; " "applying weight immediately. %s -> %s | syn_id=%s | delay=%s", - pre_gid, post_gid, syn_id, ov.delay + pre_local_id, post_gid, syn_id, ov.delay ) if ov.weight is not None: connection.set_weight_scalar(float(ov.weight)) logger.debug( "Applied weight override factor %.4g to %s -> %s | syn_id=%s | final_weight=%.4g", - ov.weight, pre_gid, post_gid, syn_id, connection.post_netcon_weight + ov.weight, pre_local_id, post_gid, syn_id, connection.post_netcon_weight ) self.cells[post_gid].connections[syn_id] = connection @@ -716,7 +727,7 @@ def _add_cells(self, cell_ids: list[CellId]) -> None: if self.circuit_access.node_properties_available: cell.connect_to_circuit(SonataProxy(cell_id, self.circuit_access)) - def _instantiate_synapse(self, cell_id: CellId, syn_id: SynapseID, syn_description, + def _instantiate_synapse(self, cell_id: CellId, syn_id: SynapseID, syn_description, post_gid: int, add_minis=False, popids=(0, 0)) -> None: """Instantiate one synapse for a given gid, syn_id and syn_description.""" @@ -730,7 +741,8 @@ def _instantiate_synapse(self, cell_id: CellId, syn_id: SynapseID, syn_descripti self.cells[cell_id].add_replay_synapse( syn_id, syn_description, syn_connection_parameters, condition_parameters, - popids=popids, extracellular_calcium=self.circuit_access.config.extracellular_calcium) + popids=popids, extracellular_calcium=self.circuit_access.config.extracellular_calcium, + post_gid=post_gid) if add_minis: mini_frequencies = self.circuit_access.fetch_mini_frequencies(cell_id) logger.debug(f"Adding minis for synapse {syn_id}: syn_description={syn_description}, connection={syn_connection_parameters}, frequency={mini_frequencies}") @@ -1007,77 +1019,50 @@ def create_cell_from_circuit(self, cell_id: CellId) -> bluecellulab.Cell: emodel_properties=cell_kwargs['emodel_properties']) def global_gid(self, pop: str, gid: int) -> int: - """Return a globally unique NEURON GID for a (population, node_id) - pair. - - NEURON's ParallelContext requires presynaptic sources to be identified by a - single integer GID across all ranks. In SONATA circuits, node ids are only - unique *within* a population, so we combine (population_name, node_id) into a - single integer: + if self.gids is None: + raise RuntimeError("GID namespace not initialized yet.") + return self.gids.global_gid(pop, gid) - global_gid = pop_index[population] * STRIDE + node_id - - Notes - ----- - STRIDE must be larger than the maximum node_id in any population to avoid GID - collisions. - - Parameters - ---------- - pop : str - SONATA population name (e.g. "S1nonbarrel_neurons", "POm", ...). - gid : int - Node id within that population. - - Returns - ------- - int - Globally unique integer GID used with ParallelContext. - """ - if pop not in self._pop_index: - raise KeyError( - f"Population '{pop}' missing from pop index. " - f"Known pops: {sorted(self._pop_index.keys())}" - ) - return self._pop_index[pop] * self._gid_stride + int(gid) + def _build_gid_namespace(self) -> GidNamespace: + ca = self.circuit_access # SonataCircuitAccess + sizes = ca.node_population_sizes() # pop -> N nodes - def _init_pop_index_mpi(self) -> None: - """Build a consistent pop->index mapping across ranks.""" - assert self.pc is not None + pops_sorted = sorted(sizes.keys()) + max_raw = {p: max(0, int(n) - 1) for p, n in sizes.items()} # ids 0..N-1 - local_pops = set() - for post_gid in self.cells: - local_pops.add(post_gid.population_name) - synapses = getattr(self.cells[post_gid], "synapses", None) - if synapses: - for syn in synapses.values(): - sd = syn.syn_description - local_pops.add(sd["source_population_name"]) - - gathered = self.pc.py_gather(sorted(local_pops), 0) - - if int(self.pc.id()) == 0: - all_pops = set([""]) - for lst in gathered: - all_pops.update(lst) - pops_sorted = sorted(all_pops) - pop_index = {p: i for i, p in enumerate(pops_sorted)} - else: - pop_index = None + # compute only on rank 0 and broadcast + if self.pc is not None: + if int(self.pc.id()) == 0: + pop_offset = self._compute_offsets_from_max(pops_sorted, max_raw) + else: + pop_offset = None + pop_offset = self.pc.py_broadcast(pop_offset, 0) + return GidNamespace(pop_offset) - pop_index = self.pc.py_broadcast(pop_index, 0) - self._pop_index = pop_index + pop_offset = self._compute_offsets_from_max(pops_sorted, max_raw) + return GidNamespace(pop_offset) def _register_gids_for_mpi(self) -> None: assert self.pc is not None + assert self.gids is not None for cell_id, cell in self.cells.items(): g = self.global_gid(cell_id.population_name, cell_id.id) - self.pc.set_gid2node(g, int(self.pc.id())) - nc = cell.create_netcon_spikedetector( - None, - location=self.spike_location, - threshold=self.spike_threshold - ) + nc = cell.create_netcon_spikedetector(None, location=self.spike_location, threshold=self.spike_threshold) self.pc.cell(g, nc) + + def _compute_offsets_from_max(self, pops_sorted: list[str], max_raw: dict[str, int]) -> dict[str, int]: + pop_offset: dict[str, int] = {} + prev: str | None = None + + for p in pops_sorted: + if prev is None: + pop_offset[p] = 0 + else: + prev_count = int(max_raw[prev]) + 1 + end_prev = pop_offset[prev] + prev_count + pop_offset[p] = ((end_prev + 999) // 1000) * 1000 + prev = p + + return pop_offset diff --git a/bluecellulab/connection.py b/bluecellulab/connection.py index 31f2162a..97d1f36d 100644 --- a/bluecellulab/connection.py +++ b/bluecellulab/connection.py @@ -39,6 +39,7 @@ def __init__( self.delay = post_synapse.syn_description[SynapseProperty.AXONAL_DELAY] self.weight = post_synapse.syn_description[SynapseProperty.G_SYNX] self.pre_cell = pre_cell + self.pre_gid = pre_gid self.pre_spiketrain = pre_spiketrain self.post_synapse = post_synapse self.pc = parallel_context @@ -86,6 +87,7 @@ def __init__( raise ValueError("pre_gid must be provided when using ParallelContext") self.post_netcon = self.pc.gid_connect(int(pre_gid), self.post_synapse.hsynapse) + # NetCon setup self.set_netcon_weight(self.post_netcon_weight) self.set_netcon_delay(self.post_netcon_delay) self.post_netcon.threshold = spike_threshold diff --git a/bluecellulab/synapse/synapse_factory.py b/bluecellulab/synapse/synapse_factory.py index a074ad15..1981878a 100644 --- a/bluecellulab/synapse/synapse_factory.py +++ b/bluecellulab/synapse/synapse_factory.py @@ -49,6 +49,7 @@ def create_synapse( popids: tuple[int, int], extracellular_calcium: float | None, connection_modifiers: dict, + post_gid: int, ) -> Synapse: """Returns a Synapse object.""" syn_type = cls.determine_synapse_type(syn_description) @@ -61,13 +62,13 @@ def create_synapse( else: randomize_gaba_risetime = True synapse = GabaabSynapse(cell.cell_id, syn_hoc_args, syn_id, syn_description, - popids, extracellular_calcium, randomize_gaba_risetime) + popids, post_gid, extracellular_calcium, randomize_gaba_risetime) elif syn_type == SynapseType.AMPANMDA: synapse = AmpanmdaSynapse(cell.cell_id, syn_hoc_args, syn_id, syn_description, - popids, extracellular_calcium) + popids, post_gid, extracellular_calcium) else: synapse = GluSynapse(cell.cell_id, syn_hoc_args, syn_id, syn_description, - popids, extracellular_calcium) + popids, post_gid, extracellular_calcium) synapse = cls.apply_connection_modifiers(connection_modifiers, synapse) diff --git a/bluecellulab/synapse/synapse_types.py b/bluecellulab/synapse/synapse_types.py index a69fc47a..7969162b 100644 --- a/bluecellulab/synapse/synapse_types.py +++ b/bluecellulab/synapse/synapse_types.py @@ -52,6 +52,7 @@ def __init__( syn_id: tuple[str, int], syn_description: pd.Series, popids: tuple[int, int], + post_gid: int, extracellular_calcium: float | None = None): """Constructor. @@ -80,7 +81,8 @@ def __init__( self.source_popid, self.target_popid = popids - self.pre_gid = int(self.syn_description[SynapseProperty.PRE_GID]) + self.pre_local_id = int(self.syn_description[SynapseProperty.PRE_GID]) + self.post_gid = int(post_gid) self.hoc_args = hoc_args self.mech_name: str = "not-yet-defined" @@ -145,7 +147,7 @@ def _set_gabaab_ampanmda_rng(self) -> None: """ rng_settings = RNGSettings.get_instance() if rng_settings.mode == "Random123": - self.randseed1 = self.post_cell_id.id + 250 + self.randseed1 = self.post_gid + 250 self.randseed2 = self.syn_id.sid + 100 self.randseed3 = self.source_popid * 65536 + self.target_popid + \ rng_settings.synapse_seed + 300 @@ -157,12 +159,12 @@ def _set_gabaab_ampanmda_rng(self) -> None: rndd = neuron.h.Random() if rng_settings.mode == "Compatibility": self.randseed1 = self.syn_id.sid * 100000 + 100 - self.randseed2 = self.post_cell_id.id + \ + self.randseed2 = self.post_gid + \ 250 + rng_settings.base_seed elif rng_settings.mode == "UpdatedMCell": self.randseed1 = self.syn_id.sid * 1000 + 100 self.randseed2 = self.source_popid * 16777216 + \ - self.post_cell_id.id + \ + self.post_gid + \ 250 + rng_settings.base_seed + \ rng_settings.synapse_seed else: @@ -207,7 +209,7 @@ def info_dict(self) -> dict[str, Any]: synapse_dict: dict[str, Any] = {} synapse_dict['synapse_id'] = self.syn_id - synapse_dict['pre_cell_id'] = self.pre_gid + synapse_dict['pre_cell_id'] = self.pre_local_id synapse_dict['post_cell_id'] = self.post_cell_id.id synapse_dict['syn_description'] = self.syn_description.to_dict() # if keys are enum make them str @@ -238,8 +240,8 @@ def __del__(self) -> None: class GluSynapse(Synapse): - def __init__(self, gid, hoc_args, syn_id, syn_description, popids, extracellular_calcium): - super().__init__(gid, hoc_args, syn_id, syn_description, popids, extracellular_calcium) + def __init__(self, gid, hoc_args, syn_id, syn_description, popids, post_gid, extracellular_calcium): + super().__init__(gid, hoc_args, syn_id, syn_description, popids, post_gid, extracellular_calcium) self.use_glusynapse_helper() def use_glusynapse_helper(self) -> None: @@ -285,7 +287,7 @@ def use_glusynapse_helper(self) -> None: if self.syn_description[SynapseProperty.NRRP] >= 0: self.hsynapse.Nrrp = self.syn_description[SynapseProperty.NRRP] - self.randseed1 = self.post_cell_id.id + self.randseed1 = self.post_gid self.randseed2 = 100000 + self.syn_id.sid rng_settings = RNGSettings.get_instance() self.randseed3 = rng_settings.synapse_seed + 200 @@ -301,8 +303,8 @@ def info_dict(self): class GabaabSynapse(Synapse): - def __init__(self, gid, hoc_args, syn_id, syn_description, popids, extracellular_calcium, randomize_risetime=True): - super().__init__(gid, hoc_args, syn_id, syn_description, popids, extracellular_calcium) + def __init__(self, gid, hoc_args, syn_id, syn_description, popids, post_gid, extracellular_calcium, randomize_risetime=True): + super().__init__(gid, hoc_args, syn_id, syn_description, popids, post_gid, extracellular_calcium) self.use_gabaab_helper(randomize_risetime) def use_gabaab_helper(self, randomize_gaba_risetime: bool) -> None: @@ -331,19 +333,19 @@ def use_gabaab_helper(self, randomize_gaba_risetime: bool) -> None: if rng_settings.mode == "Compatibility": rng.MCellRan4( self.syn_id.sid * 100000 + 100, - self.post_cell_id.id + 250 + rng_settings.base_seed) + self.post_gid + 250 + rng_settings.base_seed) elif rng_settings.mode == "UpdatedMCell": rng.MCellRan4( self.syn_id.sid * 1000 + 100, self.source_popid * 16777216 + - self.post_cell_id.id + + self.post_gid + 250 + rng_settings.base_seed + rng_settings.synapse_seed) elif rng_settings.mode == "Random123": rng.Random123( - self.post_cell_id.id + + self.post_gid + 250, self.syn_id.sid + 100, @@ -385,8 +387,8 @@ def info_dict(self): class AmpanmdaSynapse(Synapse): - def __init__(self, gid, hoc_args, syn_id, syn_description, popids, extracellular_calcium): - super().__init__(gid, hoc_args, syn_id, syn_description, popids, extracellular_calcium) + def __init__(self, gid, hoc_args, syn_id, syn_description, popids, post_gid, extracellular_calcium): + super().__init__(gid, hoc_args, syn_id, syn_description, popids, post_gid, extracellular_calcium) self.use_ampanmda_helper() def use_ampanmda_helper(self) -> None: diff --git a/tests/test_circuit_simulation_mpi.py b/tests/test_circuit_simulation_mpi.py index acbd64a7..243e67c1 100644 --- a/tests/test_circuit_simulation_mpi.py +++ b/tests/test_circuit_simulation_mpi.py @@ -1,11 +1,11 @@ # Copyright 2025 Open Brain Institute - +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at - +# # http://www.apache.org/licenses/LICENSE-2.0 - +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,6 +19,20 @@ from bluecellulab.circuit.config.sections import ConnectionOverrides +# ----------------------------- +# Fakes / helpers +# ----------------------------- + +class FakeSonataCircuitAccess: + """Minimal SONATA CircuitAccess stub for GID namespace tests.""" + def __init__(self, sizes: dict[str, int]): + self._sizes = dict(sizes) + + def node_population_sizes(self) -> dict[str, int]: + # pop -> N (count) + return dict(self._sizes) + + class FakePC: def __init__(self, rank=0, gather_result=None): self._rank = rank @@ -26,13 +40,16 @@ def __init__(self, rank=0, gather_result=None): self.set_gid2node_calls = [] self.cell_calls = [] self.broadcasted = None + self.gathered = None def id(self): return self._rank def py_gather(self, data, root): self.gathered = (data, root) - return self.gather_result or [data] + # In MPI, rank0 sees list of all ranks' contributions. + # Our FakePC can be configured to return a provided gather_result. + return self.gather_result if self.gather_result is not None else [data] def py_broadcast(self, value, root): self.broadcasted = (value, root) @@ -66,103 +83,35 @@ def create_netcon_spikedetector(self, *_, **__): return "netcon" -def make_sim(pc=None): +def make_sim(*, pc=None, circuit_access=None): + """ + Create a CircuitSimulation instance without running __init__. + We only populate the fields used by the tested methods. + """ sim = CircuitSimulation.__new__(CircuitSimulation) + sim.pc = pc - sim._gid_stride = 1_000 - sim._pop_index = {"": 0} sim.dt = 0.1 sim.spike_threshold = -20.0 sim.spike_location = "soma" - return sim - - -def test_init_pop_index_mpi_collects_all_populations(): - pc = FakePC(rank=0, gather_result=[["PostA", "SourceA"], ["PostB"]]) - sim = make_sim(pc=pc) - - syn_a = DummySynapse(source_pop="SourceA", pre_gid=1) - cell_a = DummyCell(synapses={"s0": syn_a}) - cell_b = DummyCell(synapses={}) - - sim.cells = { - CellId("PostA", 10): cell_a, - CellId("PostB", 11): cell_b, - } - - sim._init_pop_index_mpi() - - assert sim._pop_index[""] == 0 - assert sim._pop_index["PostA"] == 1 - assert sim._pop_index["PostB"] == 2 - assert sim._pop_index["SourceA"] == 3 - - -def test_register_gids_for_mpi_uses_global_mapping(): - pc = FakePC(rank=1) - sim = make_sim(pc=pc) - sim._pop_index = {"": 0, "PopX": 2} - cell_id = CellId("PopX", 7) - sim.cells = {cell_id: DummyCell()} - - sim._register_gids_for_mpi() - - expected_gid = sim.global_gid("PopX", 7) - assert pc.set_gid2node_calls == [(expected_gid, 1)] - assert pc.cell_calls == [(expected_gid, "netcon")] + # New fields introduced by refactor + sim.gids = None + sim.projections = False + # SONATA-only + sim.circuit_format = circuit_simulation.CircuitFormat.SONATA + sim.circuit_access = circuit_access if circuit_access is not None else FakeSonataCircuitAccess({}) -def test_add_connections_mpi_uses_global_pre_gid(monkeypatch): - pc = FakePC(rank=0) - sim = make_sim(pc=pc) - sim._pop_index = {"": 0, "PrePop": 1, "PostPop": 2} - - post_id = CellId("PostPop", 5) - syn = DummySynapse(source_pop="PrePop", pre_gid=3) - post_cell = DummyCell(synapses={"syn1": syn}) - sim.cells = {post_id: post_cell} - - created = [] - - class FakeConnection: - def __init__(self, synapse, pre_spiketrain, pre_gid=None, pre_cell=None, - stim_dt=None, parallel_context=None, spike_threshold=None, - spike_location=None): - self.synapse = synapse - self.pre_spiketrain = pre_spiketrain - self.pre_gid = pre_gid - self.pre_cell = pre_cell - self.parallel_context = parallel_context - self.weight = 1.0 - created.append(self) - - monkeypatch.setattr(circuit_simulation.bluecellulab, "Connection", FakeConnection) - - sim._add_connections(interconnect_cells=True) - - assert "syn1" in post_cell.connections - conn = post_cell.connections["syn1"] - assert conn is created[0] - assert conn.pre_cell is None - assert conn.parallel_context is pc - assert conn.pre_gid == sim.global_gid("PrePop", 3) - - -def test_global_gid_uses_stride_and_pop_index(): - sim = make_sim(pc=None) - sim._gid_stride = 1000 - sim._pop_index = {"": 0, "PopA": 2} - - assert sim.global_gid("PopA", 7) == 2 * 1000 + 7 + # Default empty cells + sim.cells = {} + return sim -def test_global_gid_raises_for_unknown_population(): - sim = make_sim(pc=None) - sim._pop_index = {"": 0} - with pytest.raises(KeyError): - sim.global_gid("UnknownPop", 1) +# ----------------------------- +# Tests: connection overrides +# ----------------------------- def test_add_connections_skips_zero_weight_override(monkeypatch): sim = make_sim(pc=None) @@ -174,8 +123,13 @@ def test_add_connections_skips_zero_weight_override(monkeypatch): overrides = [ ConnectionOverrides( - source="src", target="dst", delay=None, weight=0.0, - spont_minis=None, synapse_configure=None, mod_override=None, + source="src", + target="dst", + delay=None, + weight=0.0, + spont_minis=None, + synapse_configure=None, + mod_override=None, ) ] @@ -189,8 +143,12 @@ def target_contains_cell(self, *_): sim.circuit_access = FakeCircuitAccess(overrides) - # Monkeypatch Connection so it would raise if invoked - monkeypatch.setattr(circuit_simulation.bluecellulab, "Connection", lambda *a, **k: (_ for _ in ()).throw(AssertionError("Connection should not be created"))) + # If Connection is constructed, the test should fail + monkeypatch.setattr( + circuit_simulation.bluecellulab, + "Connection", + lambda *a, **k: (_ for _ in ()).throw(AssertionError("Connection should not be created")), + ) sim._add_connections(interconnect_cells=False) @@ -207,12 +165,24 @@ def test_add_connections_applies_last_matching_override(monkeypatch): overrides = [ ConnectionOverrides( - source="any", target="any", synapse_delay_override=1.5, delay=None, weight=2.0, - spont_minis=None, synapse_configure=None, mod_override=None, + source="any", + target="any", + synapse_delay_override=1.5, + delay=None, + weight=2.0, + spont_minis=None, + synapse_configure=None, + mod_override=None, ), ConnectionOverrides( - source="any", target="any", synapse_delay_override=4.0, delay=None, weight=3.0, - spont_minis=None, synapse_configure=None, mod_override=None, + source="any", + target="any", + synapse_delay_override=4.0, + delay=None, + weight=3.0, + spont_minis=None, + synapse_configure=None, + mod_override=None, ), ] @@ -222,7 +192,7 @@ def __init__(self, ov): self.config = type("cfg", (), {"connection_entries": lambda _self: self._ov})() def target_contains_cell(self, *_): - return True # everything matches for simplicity + return True # Everything matches for simplicity sim.circuit_access = FakeCircuitAccess(overrides) @@ -247,3 +217,47 @@ def set_netcon_delay(self, delay: float): conn = post_cell.connections["synX"] assert conn.post_netcon_delay == pytest.approx(4.0) assert conn.post_netcon_weight == pytest.approx(3.0) + + +# ----------------------------- +# Tests: GID namespace behavior +# ----------------------------- + +def test_gid_namespace_offsets_are_1000_blocked_and_1_based(): + sim = make_sim(pc=None, circuit_access=FakeSonataCircuitAccess({"PopA": 3, "PopB": 2})) + sim.gids = sim._build_gid_namespace() + + # PopA: offset 0, so gid = local_id + 1 + assert sim.global_gid("PopA", 0) == 1 + assert sim.global_gid("PopA", 2) == 3 + + # PopB: should begin at the next 1000-block after PopA is filled + assert sim.global_gid("PopB", 0) == 1001 + assert sim.global_gid("PopB", 1) == 1002 + + +def test_gid_namespace_does_not_depend_on_projections(): + sim = make_sim(pc=None, circuit_access=FakeSonataCircuitAccess({"PopA": 3, "PopB": 2})) + + sim.projections = False + gids1 = sim._build_gid_namespace() + + sim.projections = True + gids2 = sim._build_gid_namespace() + + assert gids1.pop_offset == gids2.pop_offset + + +def test_register_gids_for_mpi_uses_gid_namespace(): + pc = FakePC(rank=1) + sim = make_sim(pc=pc, circuit_access=FakeSonataCircuitAccess({"PopX": 10})) + sim.gids = sim._build_gid_namespace() + + cell_id = CellId("PopX", 7) + sim.cells = {cell_id: DummyCell()} + + sim._register_gids_for_mpi() + + expected_gid = sim.global_gid("PopX", 7) + assert pc.set_gid2node_calls == [(expected_gid, 1)] + assert pc.cell_calls == [(expected_gid, "netcon")] From 164c267219a64a9c65b33287a9e367dda7925d4e Mon Sep 17 00:00:00 2001 From: ilkankilic Date: Fri, 20 Mar 2026 18:10:55 +0100 Subject: [PATCH 2/3] update unit-test --- .../circuit_access/bluepy_circuit_access.py | 3 +++ bluecellulab/circuit_simulation.py | 16 ++-------------- bluecellulab/connection.py | 2 +- bluecellulab/synapse/synapse_types.py | 1 + tests/test_synapse/test_synapse_factory.py | 3 ++- 5 files changed, 9 insertions(+), 16 deletions(-) diff --git a/bluecellulab/circuit/circuit_access/bluepy_circuit_access.py b/bluecellulab/circuit/circuit_access/bluepy_circuit_access.py index 51294846..11d7724f 100644 --- a/bluecellulab/circuit/circuit_access/bluepy_circuit_access.py +++ b/bluecellulab/circuit/circuit_access/bluepy_circuit_access.py @@ -360,6 +360,9 @@ def morph_filepath(self, cell_id: CellId) -> str: def emodel_path(self, cell_id: CellId) -> str: return os.path.join(self._emodels_dir, f"{self._fetch_emodel_name(cell_id)}.hoc") + def node_population_sizes(self): + raise NotImplementedError + @property def _emodels_dir(self) -> str: return self.config.impl.Run['METypePath'] diff --git a/bluecellulab/circuit_simulation.py b/bluecellulab/circuit_simulation.py index 1c6b664a..c61f3116 100644 --- a/bluecellulab/circuit_simulation.py +++ b/bluecellulab/circuit_simulation.py @@ -1024,21 +1024,9 @@ def global_gid(self, pop: str, gid: int) -> int: return self.gids.global_gid(pop, gid) def _build_gid_namespace(self) -> GidNamespace: - ca = self.circuit_access # SonataCircuitAccess - sizes = ca.node_population_sizes() # pop -> N nodes - + sizes = self.circuit_access.node_population_sizes() pops_sorted = sorted(sizes.keys()) - max_raw = {p: max(0, int(n) - 1) for p, n in sizes.items()} # ids 0..N-1 - - # compute only on rank 0 and broadcast - if self.pc is not None: - if int(self.pc.id()) == 0: - pop_offset = self._compute_offsets_from_max(pops_sorted, max_raw) - else: - pop_offset = None - pop_offset = self.pc.py_broadcast(pop_offset, 0) - return GidNamespace(pop_offset) - + max_raw = {p: max(0, int(n) - 1) for p, n in sizes.items()} pop_offset = self._compute_offsets_from_max(pops_sorted, max_raw) return GidNamespace(pop_offset) diff --git a/bluecellulab/connection.py b/bluecellulab/connection.py index 97d1f36d..b960c709 100644 --- a/bluecellulab/connection.py +++ b/bluecellulab/connection.py @@ -118,7 +118,7 @@ def info_dict(self): connection_dict = {} - connection_dict['pre_cell_id'] = self.post_synapse.pre_gid + connection_dict['pre_cell_id'] = self.post_synapse.pre_local_id connection_dict['post_cell_id'] = self.post_synapse.post_cell_id.id connection_dict['post_synapse_id'] = self.post_synapse.syn_id.sid diff --git a/bluecellulab/synapse/synapse_types.py b/bluecellulab/synapse/synapse_types.py index 7969162b..7b4a04d3 100644 --- a/bluecellulab/synapse/synapse_types.py +++ b/bluecellulab/synapse/synapse_types.py @@ -82,6 +82,7 @@ def __init__( self.source_popid, self.target_popid = popids self.pre_local_id = int(self.syn_description[SynapseProperty.PRE_GID]) + self.pre_gid = int(self.syn_description[SynapseProperty.PRE_GID]) self.post_gid = int(post_gid) self.hoc_args = hoc_args diff --git a/tests/test_synapse/test_synapse_factory.py b/tests/test_synapse/test_synapse_factory.py index db43db10..1e8bc775 100644 --- a/tests/test_synapse/test_synapse_factory.py +++ b/tests/test_synapse/test_synapse_factory.py @@ -60,9 +60,10 @@ def test_create_synapse(self): 'SpontMinis': 0.0, 'SynapseConfigure': ['%s.Use = 1 %s.Use_GB = 1 %s.Use_p = 1 %s.gmax0_AMPA = gmax_p_AMPA %s.rho_GB = 1 %s.rho0_GB = 1 %s.gmax_AMPA = %s.gmax_p_AMPA'] } + post_gid = 1 synapse = SynapseFactory.create_synapse( - self.cell, syn_id, self.syn_description, condition_parameters, popids, extracellular_calcium, connection_modifiers + self.cell, syn_id, self.syn_description, condition_parameters, popids, extracellular_calcium, connection_modifiers, post_gid ) assert isinstance(synapse, GluSynapse) assert synapse.weight == connection_modifiers["Weight"] From 436d13177b264ab77444800d2e3b27a4d088aef2 Mon Sep 17 00:00:00 2001 From: ilkankilic Date: Fri, 20 Mar 2026 18:33:24 +0100 Subject: [PATCH 3/3] clean + lint fixes --- bluecellulab/circuit/circuit_access/definition.py | 4 +++- bluecellulab/circuit/circuit_access/sonata_circuit_access.py | 4 ++-- bluecellulab/circuit_simulation.py | 4 +++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/bluecellulab/circuit/circuit_access/definition.py b/bluecellulab/circuit/circuit_access/definition.py index 2fcbb5f1..cf4646be 100644 --- a/bluecellulab/circuit/circuit_access/definition.py +++ b/bluecellulab/circuit/circuit_access/definition.py @@ -88,7 +88,7 @@ def get_cell_properties( raise NotImplementedError def extract_synapses( - self, cell_id: CellId, projections: Optional[list[str] | str | bool] + self, cell_id: CellId, projections: Optional[list[str] | str] ) -> pd.DataFrame: raise NotImplementedError @@ -123,3 +123,5 @@ def morph_filepath(self, cell_id: CellId) -> str: def emodel_path(self, cell_id: CellId) -> str: raise NotImplementedError + def node_population_sizes(self) -> dict[str, int]: + raise NotImplementedError diff --git a/bluecellulab/circuit/circuit_access/sonata_circuit_access.py b/bluecellulab/circuit/circuit_access/sonata_circuit_access.py index a0ffe400..3ca25fc1 100644 --- a/bluecellulab/circuit/circuit_access/sonata_circuit_access.py +++ b/bluecellulab/circuit/circuit_access/sonata_circuit_access.py @@ -178,7 +178,7 @@ def _select_edge_pop_names(self, projections) -> list[str]: return out def extract_synapses( - self, cell_id: CellId, projections: Optional[list[str] | str | bool] + self, cell_id: CellId, projections: Optional[list[str] | str] ) -> pd.DataFrame: """Extract the synapses.""" snap_node_id = CircuitNodeId(cell_id.population_name, cell_id.id) @@ -309,6 +309,6 @@ def emodel_path(self, cell_id: CellId) -> str: def node_population_sizes(self) -> dict[str, int]: out: dict[str, int] = {} for pop_name, node_pop in self._circuit.nodes.items(): - s = getattr(node_pop, "size", None) + s = node_pop.size out[str(pop_name)] = int(s() if callable(s) else s) return out diff --git a/bluecellulab/circuit_simulation.py b/bluecellulab/circuit_simulation.py index c61f3116..d98d8f3b 100644 --- a/bluecellulab/circuit_simulation.py +++ b/bluecellulab/circuit_simulation.py @@ -135,7 +135,7 @@ def __init__( self.spike_threshold = self.circuit_access.config.spike_threshold self.spike_location = self.circuit_access.config.spike_location - self.projections: list[str] | str | bool = False + self.projections: list[str] | str | None = None condition_parameters = self.circuit_access.config.condition_parameters() set_global_condition_parameters(condition_parameters) @@ -278,6 +278,8 @@ def instantiate_gids( if add_projections is True: self.projections = self.circuit_access.config.get_all_projection_names() + elif add_projections is None or add_projections is False: + self.projections = None else: self.projections = add_projections