Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions bluecellulab/cell/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,8 @@ def add_replay_synapse(self,
connection_modifiers: dict,
condition_parameters: Conditions,
popids: tuple[int, int],
extracellular_calcium: float | None) -> None:
extracellular_calcium: float | None,
post_gid: int) -> None:
"""Add synapse based on the syn_description to the cell."""
synapse = SynapseFactory.create_synapse(
cell=self,
Expand All @@ -403,7 +404,8 @@ def add_replay_synapse(self,
condition_parameters=condition_parameters,
popids=popids,
extracellular_calcium=extracellular_calcium,
connection_modifiers=connection_modifiers)
connection_modifiers=connection_modifiers,
post_gid=post_gid)

self.synapses[synapse_id] = synapse

Expand Down
3 changes: 3 additions & 0 deletions bluecellulab/circuit/circuit_access/bluepy_circuit_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@ def morph_filepath(self, cell_id: CellId) -> str:
def emodel_path(self, cell_id: CellId) -> str:
return os.path.join(self._emodels_dir, f"{self._fetch_emodel_name(cell_id)}.hoc")

def node_population_sizes(self):
raise NotImplementedError

@property
def _emodels_dir(self) -> str:
return self.config.impl.Run['METypePath']
3 changes: 3 additions & 0 deletions bluecellulab/circuit/circuit_access/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,6 @@ def morph_filepath(self, cell_id: CellId) -> str:

def emodel_path(self, cell_id: CellId) -> str:
raise NotImplementedError

def node_population_sizes(self) -> dict[str, int]:
raise NotImplementedError
12 changes: 8 additions & 4 deletions bluecellulab/circuit/circuit_access/sonata_circuit_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,7 @@ def _select_edge_pop_names(self, projections) -> list[str]:
def extract_synapses(
self, cell_id: CellId, projections: Optional[list[str] | str]
) -> pd.DataFrame:
"""Extract the synapses.

If projections is None, all the synapses are extracted.
"""
"""Extract the synapses."""
snap_node_id = CircuitNodeId(cell_id.population_name, cell_id.id)
edges = self._circuit.edges

Expand Down Expand Up @@ -308,3 +305,10 @@ def morph_filepath(self, cell_id: CellId) -> str:
def emodel_path(self, cell_id: CellId) -> str:
node_population = self._circuit.nodes[cell_id.population_name]
return str(node_population.models.get_filepath(cell_id.id))

def node_population_sizes(self) -> dict[str, int]:
out: dict[str, int] = {}
for pop_name, node_pop in self._circuit.nodes.items():
s = node_pop.size
out[str(pop_name)] = int(s() if callable(s) else s)
return out
23 changes: 23 additions & 0 deletions bluecellulab/circuit/gid_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2025 Open Brain Institute

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GID resolver for NEURON simulations in BlueCelluLab."""
from dataclasses import dataclass


@dataclass(frozen=True)
class GidNamespace:
pop_offset: dict[str, int]

def global_gid(self, pop: str, local_id: int) -> int:
return int(self.pop_offset[pop]) + int(local_id) + 1 # 1-based indexing to mirror Neurodamus synapse seeding implementation
157 changes: 66 additions & 91 deletions bluecellulab/circuit_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import logging
import warnings

from bluecellulab.circuit.gid_resolver import GidNamespace
from bluecellulab.reports.utils import configure_all_reports
import neuron
import numpy as np
Expand Down Expand Up @@ -139,8 +140,7 @@ def __init__(
condition_parameters = self.circuit_access.config.condition_parameters()
set_global_condition_parameters(condition_parameters)

self._gid_stride = 1_000_000 # must exceed the max node_id in any population
self._pop_index: dict[str, int] = {"": 0}
self.gid_resolver: Optional[GidNamespace] = None

def instantiate_gids(
self,
Expand Down Expand Up @@ -276,18 +276,27 @@ def instantiate_gids(
"if you want to specify use add_replay or "
"pre_spike_trains")

# legacy for backward compatibility
if add_projections is None:
add_projections = False

if add_projections is True:
self.projections = self.circuit_access.config.get_all_projection_names()
elif add_projections is False:
elif add_projections is None or add_projections is False:
self.projections = None
else:
self.projections = add_projections

self._add_cells(cell_ids)

need_gids = (
add_synapses
or (self.pc is not None)
or add_replay
or interconnect_cells
or pre_spike_trains is not None
or self.print_cellstate
)

if need_gids:
self.gids = self._build_gid_namespace()

if add_synapses:
self._add_synapses(
pre_gids=pre_gids,
Expand All @@ -297,7 +306,6 @@ def instantiate_gids(
raise BluecellulabError("add_replay option can not be used if "
"add_synapses is False")
if self.pc is not None:
self._init_pop_index_mpi()
self._register_gids_for_mpi()
self.pc.barrier()
self.pc.setup_transfer()
Expand Down Expand Up @@ -473,12 +481,17 @@ def _add_cell_synapses(
syn_description["source_popid"],
syn_description["target_popid"],
)

post_gid = self.global_gid(cell_id.population_name, cell_id.id)

self._instantiate_synapse(
cell_id=cell_id,
syn_id=idx, # type: ignore
syn_description=syn_description,
add_minis=add_minis,
popids=popids
popids=popids,
post_gid=post_gid

)
logger.info(f"Added {syn_descriptions} synapses for gid {cell_id}")
if add_minis:
Expand Down Expand Up @@ -607,26 +620,26 @@ def _add_connections(
syn_description: pd.Series = synapse.syn_description
delay_weights = synapse.delay_weights
source_population = syn_description["source_population_name"]
pre_gid = CellId(source_population, int(syn_description[SynapseProperty.PRE_GID]))
pre_local_id = CellId(source_population, int(syn_description[SynapseProperty.PRE_GID]))

ov = self._find_matching_override(connections_overrides, pre_gid, post_gid)
ov = self._find_matching_override(connections_overrides, pre_local_id, post_gid)

if ov is not None and ov.weight == 0.0:
logger.debug(
"Skipping connection due to zero weight override: %s -> %s | syn_id=%s",
pre_gid, post_gid, syn_id
pre_local_id, post_gid, syn_id
)
continue

if self.pc is None:
real_synapse_connection = bool(interconnect_cells) and (pre_gid in self.cells)
real_synapse_connection = bool(interconnect_cells) and (pre_local_id in self.cells)
else:
real_synapse_connection = bool(interconnect_cells)

if real_synapse_connection:
if (
user_pre_spike_trains is not None
and pre_gid in user_pre_spike_trains
and pre_local_id in user_pre_spike_trains
):
raise BluecellulabError(
"""Specifying prespike trains of real connections"""
Expand All @@ -636,28 +649,28 @@ def _add_connections(
connection = bluecellulab.Connection(
self.cells[post_gid].synapses[syn_id],
pre_spiketrain=None,
pre_cell=self.cells[pre_gid],
pre_cell=self.cells[pre_local_id],
stim_dt=self.dt,
parallel_context=None,
spike_threshold=self.spike_threshold,
spike_location=self.spike_location,
)
else: # MPI cross-rank
pre_g = self.global_gid(pre_gid.population_name, pre_gid.id)
pre_gid = self.global_gid(pre_local_id.population_name, pre_local_id.id)
connection = bluecellulab.Connection(
self.cells[post_gid].synapses[syn_id],
pre_spiketrain=None,
pre_gid=pre_g,
pre_gid=pre_gid,
pre_cell=None,
stim_dt=self.dt,
parallel_context=self.pc,
spike_threshold=self.spike_threshold,
spike_location=self.spike_location,
)

logger.debug(f"Added real connection between {pre_gid} and {post_gid}, {syn_id}")
logger.debug(f"Added real connection between {pre_local_id} and {post_gid}, {syn_id}")
else: # replay connection
pre_spiketrain = pre_spike_trains.get(pre_gid, None)
pre_spiketrain = pre_spike_trains.get(pre_local_id, None)
connection = bluecellulab.Connection(
self.cells[post_gid].synapses[syn_id],
pre_spiketrain=pre_spiketrain,
Expand All @@ -668,34 +681,34 @@ def _add_connections(
spike_location=self.spike_location
)

logger.debug(f"Added replay connection from {pre_gid} to {post_gid}, {syn_id}")
logger.debug(f"Added replay connection from {pre_local_id} to {post_gid}, {syn_id}")

if ov is not None:
logger.debug(
"Override matched: %s -> %s | syn_id=%s | weight=%s delay=%s",
pre_gid, post_gid, syn_id, ov.weight, ov.delay
pre_local_id, post_gid, syn_id, ov.weight, ov.delay
)

syn_delay = getattr(ov, "synapse_delay_override", None)
if syn_delay is not None:
connection.set_netcon_delay(float(syn_delay))
logger.debug(
"Applied synapse_delay_override %.4g ms to %s -> %s | syn_id=%s",
syn_delay, pre_gid, post_gid, syn_id
syn_delay, pre_local_id, post_gid, syn_id
)

if ov.delay is not None:
logger.warning(
"SONATA override 'delay' (delayed weight activation) is not supported yet; "
"applying weight immediately. %s -> %s | syn_id=%s | delay=%s",
pre_gid, post_gid, syn_id, ov.delay
pre_local_id, post_gid, syn_id, ov.delay
)

if ov.weight is not None:
connection.set_weight_scalar(float(ov.weight))
logger.debug(
"Applied weight override factor %.4g to %s -> %s | syn_id=%s | final_weight=%.4g",
ov.weight, pre_gid, post_gid, syn_id, connection.post_netcon_weight
ov.weight, pre_local_id, post_gid, syn_id, connection.post_netcon_weight
)

self.cells[post_gid].connections[syn_id] = connection
Expand All @@ -716,7 +729,7 @@ def _add_cells(self, cell_ids: list[CellId]) -> None:
if self.circuit_access.node_properties_available:
cell.connect_to_circuit(SonataProxy(cell_id, self.circuit_access))

def _instantiate_synapse(self, cell_id: CellId, syn_id: SynapseID, syn_description,
def _instantiate_synapse(self, cell_id: CellId, syn_id: SynapseID, syn_description, post_gid: int,
add_minis=False, popids=(0, 0)) -> None:
"""Instantiate one synapse for a given gid, syn_id and
syn_description."""
Expand All @@ -730,7 +743,8 @@ def _instantiate_synapse(self, cell_id: CellId, syn_id: SynapseID, syn_descripti

self.cells[cell_id].add_replay_synapse(
syn_id, syn_description, syn_connection_parameters, condition_parameters,
popids=popids, extracellular_calcium=self.circuit_access.config.extracellular_calcium)
popids=popids, extracellular_calcium=self.circuit_access.config.extracellular_calcium,
post_gid=post_gid)
if add_minis:
mini_frequencies = self.circuit_access.fetch_mini_frequencies(cell_id)
logger.debug(f"Adding minis for synapse {syn_id}: syn_description={syn_description}, connection={syn_connection_parameters}, frequency={mini_frequencies}")
Expand Down Expand Up @@ -1007,77 +1021,38 @@ def create_cell_from_circuit(self, cell_id: CellId) -> bluecellulab.Cell:
emodel_properties=cell_kwargs['emodel_properties'])

def global_gid(self, pop: str, gid: int) -> int:
"""Return a globally unique NEURON GID for a (population, node_id)
pair.

NEURON's ParallelContext requires presynaptic sources to be identified by a
single integer GID across all ranks. In SONATA circuits, node ids are only
unique *within* a population, so we combine (population_name, node_id) into a
single integer:

global_gid = pop_index[population] * STRIDE + node_id

Notes
-----
STRIDE must be larger than the maximum node_id in any population to avoid GID
collisions.

Parameters
----------
pop : str
SONATA population name (e.g. "S1nonbarrel_neurons", "POm", ...).
gid : int
Node id within that population.

Returns
-------
int
Globally unique integer GID used with ParallelContext.
"""
if pop not in self._pop_index:
raise KeyError(
f"Population '{pop}' missing from pop index. "
f"Known pops: {sorted(self._pop_index.keys())}"
)
return self._pop_index[pop] * self._gid_stride + int(gid)

def _init_pop_index_mpi(self) -> None:
"""Build a consistent pop->index mapping across ranks."""
assert self.pc is not None

local_pops = set()
for post_gid in self.cells:
local_pops.add(post_gid.population_name)
synapses = getattr(self.cells[post_gid], "synapses", None)
if synapses:
for syn in synapses.values():
sd = syn.syn_description
local_pops.add(sd["source_population_name"])

gathered = self.pc.py_gather(sorted(local_pops), 0)

if int(self.pc.id()) == 0:
all_pops = set([""])
for lst in gathered:
all_pops.update(lst)
pops_sorted = sorted(all_pops)
pop_index = {p: i for i, p in enumerate(pops_sorted)}
else:
pop_index = None
if self.gids is None:
raise RuntimeError("GID namespace not initialized yet.")
return self.gids.global_gid(pop, gid)

pop_index = self.pc.py_broadcast(pop_index, 0)
self._pop_index = pop_index
def _build_gid_namespace(self) -> GidNamespace:
sizes = self.circuit_access.node_population_sizes()
pops_sorted = sorted(sizes.keys())
max_raw = {p: max(0, int(n) - 1) for p, n in sizes.items()}
pop_offset = self._compute_offsets_from_max(pops_sorted, max_raw)
return GidNamespace(pop_offset)

def _register_gids_for_mpi(self) -> None:
assert self.pc is not None
assert self.gids is not None

for cell_id, cell in self.cells.items():
g = self.global_gid(cell_id.population_name, cell_id.id)

self.pc.set_gid2node(g, int(self.pc.id()))
nc = cell.create_netcon_spikedetector(
None,
location=self.spike_location,
threshold=self.spike_threshold
)
nc = cell.create_netcon_spikedetector(None, location=self.spike_location, threshold=self.spike_threshold)
self.pc.cell(g, nc)

def _compute_offsets_from_max(self, pops_sorted: list[str], max_raw: dict[str, int]) -> dict[str, int]:
pop_offset: dict[str, int] = {}
prev: str | None = None

for p in pops_sorted:
if prev is None:
pop_offset[p] = 0
else:
prev_count = int(max_raw[prev]) + 1
end_prev = pop_offset[prev] + prev_count
pop_offset[p] = ((end_prev + 999) // 1000) * 1000
prev = p

return pop_offset
Loading