diff --git a/bluecellulab/cell/point_process.py b/bluecellulab/cell/point_process.py new file mode 100644 index 00000000..aa4c7ed8 --- /dev/null +++ b/bluecellulab/cell/point_process.py @@ -0,0 +1,246 @@ +from __future__ import annotations + +from dataclasses import dataclass +import logging +from pathlib import Path +from typing import Any, Mapping, Optional + +from bluecellulab.circuit.simulation_access import get_synapse_replay_spikes +from bluecellulab.exceptions import BluecellulabError +from neuron import h +import numpy as np + +from bluecellulab.circuit.node_id import CellId + +logger = logging.getLogger(__name__) + +class BasePointProcessCell: + """Base class for NEURON artificial point processes (IntFire1/2/...).""" + + def __init__(self, cell_id: Optional[CellId]) -> None: + self.cell_id = cell_id + + self._spike_times = h.Vector() + self._spike_detector: Optional[h.NetCon] = None + self.pointcell = None # type: ignore[assignment] + self.synapses: dict = {} + self.connections: dict = {} + + @property + def hoc_cell(self): + return self.pointcell + + + def init_callbacks(self): + pass + + def connect_to_circuit(self, proxy) -> None: + self._circuit_proxy = proxy + + def delete(self) -> None: + # Stop recording + if self._spike_detector is not None: + # NetCon will be GC'd when no Python refs remain + self._spike_detector = None + if self._spike_times is not None: + self._spike_times = None + + # Drop pointer to underlying NEURON object + self.pointcell = None + + + def get_spike_times(self) -> list[float]: + return list(self._spike_times) + + def create_netcon_spikedetector( + self, + sec, # ignored for artificial cells + location=None, # ignored for artificial cells + threshold: float = 0.0, + ) -> h.NetCon: + nc = h.NetCon(self.pointcell, None) + nc.threshold = threshold # harmless for artificial cells + return nc + + def is_recording_spikes(self, location=None, threshold: float | None = None) -> bool: + return self._spike_detector is not None + + def start_recording_spikes(self, sec, location=None, threshold: float = 0.0) -> None: + if self._spike_detector is not None: + return + self._spike_times = h.Vector() + self._spike_detector = h.NetCon(self.pointcell, None) + self._spike_detector.threshold = threshold + self._spike_detector.record(self._spike_times) + + + def connect2target(self, target_pp=None) -> h.NetCon: + """Neurodamus-like helper: NetCon from this cell to a target point process.""" + return h.NetCon(self.pointcell, target_pp) + + +class HocPointProcessCell(BasePointProcessCell): + """Point process that wraps an arbitrary HOC/mod artificial mechanism. + """ + + def __init__( + self, + cell_id: Optional[CellId], + mechanism_name: str, + param_overrides: Optional[Mapping[str, Any]] = None, + spike_threshold: float = 1.0, + ) -> None: + super().__init__(cell_id) + + try: + mech_cls = getattr(h, mechanism_name) + except AttributeError as exc: + raise BluecellulabError( + f"Point mechanism '{mechanism_name}' not found in NEURON. " + "Make sure the mod/hoc files are compiled and loaded." + ) from exc + + point = mech_cls() + if param_overrides: + for name, value in param_overrides.items(): + if hasattr(point, name): + setattr(point, name, value) + + self.pointcell = point + self.start_recording_spikes(None, None, threshold=spike_threshold) + + def add_synapse_replay(self, stimulus, spike_threshold: float, spike_location: str) -> None: + """SONATA-style spike replay for point processes. + + This is a simplified analogue of Cell.add_synapse_replay, but instead of + mapping spikes to individual synapses, we directly connect each presynaptic + node_id's spike train to this artificial cell via VecStim → NetCon. + """ + file_path = Path(stimulus.spike_file).expanduser() + + if not file_path.is_absolute(): + config_dir = stimulus.config_dir + if config_dir is not None: + file_path = Path(config_dir) / file_path + + file_path = file_path.resolve() + + if not file_path.exists(): + raise FileNotFoundError(f"Spike file not found: {str(file_path)}") + + synapse_spikes = get_synapse_replay_spikes(str(file_path)) + + if not hasattr(self, "_replay_vecs"): + self._replay_vecs: list[h.Vector] = [] + if not hasattr(self, "_replay_vecstims"): + self._replay_vecstims: list[h.VecStim] = [] + if not hasattr(self, "_replay_netcons"): + self._replay_netcons: list[h.NetCon] = [] + + for pre_node_id, spikes in synapse_spikes.items(): + delay = getattr(stimulus, "delay", 0.0) or 0.0 + duration = getattr(stimulus, "duration", np.inf) + + spikes_of_interest = spikes[ + (spikes >= delay) & (spikes <= duration) + ] + if spikes_of_interest.size == 0: + continue + + vec = h.Vector(spikes_of_interest) + vs = h.VecStim() + vs.play(vec) + + nc = h.NetCon(vs, self.pointcell) + # Use stimulus weight if available, otherwise default to 1.0 + weight = getattr(stimulus, "weight", 1.0) + nc.weight[0] = weight + nc.delay = 0.0 # delay already baked into spike times + + self._replay_vecs.append(vec) + self._replay_vecstims.append(vs) + self._replay_netcons.append(nc) + + logger.debug( + f"Added replay connection from pre_node_id={pre_node_id} " + f"to point neuron {self.cell_id}" + ) + +def mechanism_name_from_model_template(model_template: str) -> str: + """Translate SONATA model_template into a NEURON mechanism name. + + Examples: + 'hoc:AllenPointCell' -> 'AllenPointCell' + 'nrn:IntFire1' -> 'IntFire1' + 'AllenPointCell' -> 'AllenPointCell' + """ + mt = str(model_template).strip() + if ":" in mt: + prefix, name = mt.split(":", 1) + prefix = prefix.lower() + if prefix in ("hoc", "nrn"): + return name + return mt + +@dataclass +class IntFire1Params: + tau: float = 10.0 + refrac: float = 2.0 + + +class IntFire1Cell(BasePointProcessCell): + def __init__( + self, + cell_id: Optional[CellId] = None, + tau: float = 10.0, + refrac: float = 2.0, + ) -> None: + super().__init__(cell_id) + point = h.IntFire1() + point.tau = tau + point.refrac = refrac + self.pointcell = point + + self.start_recording_spikes(None, None, threshold=1.0) + + +@dataclass +class IntFire2Params: + taum: float = 10.0 + taus: float = 20.0 + ib: float = 0.0 + + +class IntFire2Cell(BasePointProcessCell): + def __init__( + self, + cell_id: Optional[CellId] = None, + taum: float = 10.0, + taus: float = 20.0, + ib: float = 0.0, + ) -> None: + super().__init__(cell_id) + point = h.IntFire2() + point.taum = taum + point.taus = taus + point.ib = ib + self.pointcell = point + + self.start_recording_spikes(None, None, threshold=1.0) + + +def create_intfire1_cell( + tau: float = 10.0, + refrac: float = 2.0, + cell_id: Optional[CellId] = None, +) -> IntFire1Cell: + return IntFire1Cell(cell_id=cell_id, tau=tau, refrac=refrac) + + +def create_intfire2_cell( + taum: float = 10.0, + taus: float = 20.0, + ib: float = 0.0, + cell_id: Optional[CellId] = None, +) -> IntFire2Cell: + return IntFire2Cell(cell_id=cell_id, taum=taum, taus=taus, ib=ib) diff --git a/bluecellulab/circuit_simulation.py b/bluecellulab/circuit_simulation.py index 0d202dd8..b80b1405 100644 --- a/bluecellulab/circuit_simulation.py +++ b/bluecellulab/circuit_simulation.py @@ -65,6 +65,8 @@ from bluecellulab.simulation.modifications import apply_modifications from bluecellulab.synapse.synapse_types import SynapseID +from bluecellulab.cell.point_process import BasePointProcessCell, HocPointProcessCell, mechanism_name_from_model_template + logger = logging.getLogger(__name__) @@ -413,7 +415,32 @@ def _add_stimuli( except ValueError: pass + all_point_processes = all( + isinstance(cell, BasePointProcessCell) for cell in self.cells.values() + ) + for stimulus in stimuli_entries: + + # 1) SynapseReplay: works for both morpho cells and point processes + if isinstance(stimulus, circuit_stimulus_definitions.SynapseReplay): + for cell_id, cell in self.cells.items(): + if self.circuit_access.target_contains_cell(stimulus.target, cell_id): + if hasattr(cell, "add_synapse_replay"): + print("Adding SynapseReplay to cell", cell_id) + cell.add_synapse_replay( + stimulus, self.spike_threshold, self.spike_location + ) + logger.debug( + f"Added SynapseReplay {stimulus} to point/morpho cell {cell_id}" + ) + # No section/compartment logic needed for SynapseReplay + continue + + # 2) Other stimuli: require morphology + # If all cells are point processes, skip these stimuli entirely. + if all_point_processes: + continue + # Build a unified list of (cell_id, section, segx, section_name) targets targets: list[tuple] = [] @@ -424,6 +451,9 @@ def _add_stimuli( stimulus.node_set ) for cell_id in self.cells: + # Skip point processes: they have no soma + if isinstance(self.cells[cell_id], BasePointProcessCell): + continue if cell_id not in gids_of_target: continue sec = self.cells[cell_id].soma @@ -1167,6 +1197,26 @@ def fetch_cell_kwargs(self, cell_id: CellId) -> dict: def create_cell_from_circuit(self, cell_id: CellId) -> bluecellulab.Cell: """Create a Cell object from the circuit.""" + if self.circuit_format == CircuitFormat.SONATA: + try: + info = self.circuit_access.fetch_cell_info(cell_id) # type: ignore[attr-defined] + except AttributeError: + info = pd.Series() + + model_type = str(info.get("model_type", "")).lower() + model_template = str(info.get("model_template", "")) + + if model_type == "point_process": + mech_name = mechanism_name_from_model_template(model_template) + + # TODO (later): parse dynamics_params and feed param_overrides + return HocPointProcessCell( + cell_id=cell_id, + mechanism_name=mech_name, + param_overrides=None, + spike_threshold=self.spike_threshold, + ) + cell_kwargs = self.fetch_cell_kwargs(cell_id) return bluecellulab.Cell( template_path=cell_kwargs["template_path"], diff --git a/bluecellulab/point/__init__.py b/bluecellulab/point/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bluecellulab/point/connection_params.py b/bluecellulab/point/connection_params.py new file mode 100644 index 00000000..3cf62d22 --- /dev/null +++ b/bluecellulab/point/connection_params.py @@ -0,0 +1,15 @@ +from __future__ import annotations +from dataclasses import dataclass + + +@dataclass +class PointProcessConnParameters: + """Point-neuron connection parameters (Allen-style / Neurodamus mirror).""" + + sgid: int # source gid + delay: float # ms + weight: float # NetCon weight + + # isec: int = -1 + # ipt: int = -1 + # offset: float = 0.5 \ No newline at end of file diff --git a/bluecellulab/point/point_connection.py b/bluecellulab/point/point_connection.py new file mode 100644 index 00000000..6501107a --- /dev/null +++ b/bluecellulab/point/point_connection.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from typing import Iterable, List, Optional + +from neuron import h + +from bluecellulab.point.connection_params import PointProcessConnParameters +from bluecellulab.cell.point_process import BasePointProcessCell + +pc = h.ParallelContext() + + +DEFAULT_SPIKE_THRESHOLD = 0.0 + + +class PointProcessConnection: + """Allen-style point connection: sgid -> PointNeuronCell.pointcell. + + Mirrors Neurodamus PointConnection: + - at most one synapse per connection + - uses pc.gid_connect(sgid, cell.pointcell) + - can later be extended with replay (VecStim) if needed. + """ + + def __init__( + self, + synapse_params: Iterable[PointProcessConnParameters], + weight_factor: float = 1.0, + syndelay_override: Optional[float] = None, + attach_src_cell: bool = True, + replay=None, # placeholder for future replay object + ) -> None: + self.synapse_params = list(synapse_params) + assert len(self.synapse_params) <= 1, ( + "PointProcessConnection supports max. one synapse per connection" + ) + + self.weight_factor = weight_factor + self.syndelay_override = syndelay_override + self.attach_src_cell = attach_src_cell + self._replay = replay + + self._netcons: List[h.NetCon] = [] + + @property + def netcons(self) -> list[h.NetCon]: + return self._netcons + + def finalize(self, cell: BasePointProcessCell) -> int: + """Create NetCon(s) onto the given point neuron cell. + + Returns + ------- + int + Number of synapses (0 or 1). + """ + n_syns = 0 + + for params in self.synapse_params: + n_syns += 1 + + if self.attach_src_cell: + # --- main path: presyn cell with sgid --- + nc = pc.gid_connect(params.sgid, cell.pointcell) + nc.delay = self.syndelay_override or float(params.delay) + nc.weight[0] = float(params.weight) * self.weight_factor + nc.threshold = DEFAULT_SPIKE_THRESHOLD + self._netcons.append(nc) + + # --- replay path (optional, stubbed) --- + if self._replay is not None and getattr(self._replay, "has_data", lambda: False)(): + vecstim = h.VecStim() + vecstim.play(self._replay.time_vec) + nc = h.NetCon( + vecstim, + cell.pointcell, + 10.0, + self.syndelay_override or float(params.delay), + float(params.weight), + ) + nc.weight[0] = float(params.weight) * self.weight_factor + self._replay._store(vecstim, nc) + + return n_syns \ No newline at end of file