Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/pylattica/core/runner/asynchronous_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def _run( # pylint: disable=too-many-positional-arguments
self,
_: SimulationState,
result: SimulationResult,
live_state: SimulationState,
controller: BasicController,
num_steps: int,
verbose: bool = False,
Expand All @@ -58,6 +57,7 @@ def _run( # pylint: disable=too-many-positional-arguments
"""

site_queue = deque()
live_state = result.live_state

def _add_sites_to_queue():
next_site = controller.get_random_site(live_state)
Expand All @@ -84,7 +84,6 @@ def _add_sites_to_queue():
state_updates = controller_response

state_updates = merge_updates(state_updates, site_id=site_id)
live_state.batch_update(state_updates)
site_queue.extend(next_sites)

result.add_step(state_updates)
Expand All @@ -95,5 +94,4 @@ def _add_sites_to_queue():
if len(site_queue) == 0:
break

result.set_output(live_state)
return result
4 changes: 1 addition & 3 deletions src/pylattica/core/runner/base_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ def run(

result = controller.instantiate_result(initial_state.copy())
controller.pre_run(initial_state)
live_state = initial_state.copy()

self._run(initial_state, result, live_state, controller, num_steps, verbose)
self._run(initial_state, result, controller, num_steps, verbose)

result.set_output(live_state)
return result
5 changes: 1 addition & 4 deletions src/pylattica/core/runner/synchronous_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def _run( # pylint: disable=too-many-positional-arguments
self,
initial_state: SimulationState,
result: SimulationResult,
live_state: SimulationState,
controller: BasicController,
num_steps: int,
verbose: bool = False,
Expand Down Expand Up @@ -74,11 +73,9 @@ def _run( # pylint: disable=too-many-positional-arguments
else:
printif(verbose, "Running in series.")
for _ in tqdm(range(num_steps)):
updates = self._take_step(live_state, controller)
live_state.batch_update(updates)
updates = self._take_step(result.live_state, controller)
result.add_step(updates)

result.set_output(live_state)
return result

def _take_step_parallel(self, updates: dict, pool, chunk_size) -> SimulationState:
Expand Down
131 changes: 123 additions & 8 deletions src/pylattica/core/simulation_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ class SimulationResult:
Maximum number of diffs to keep in memory. When exceeded, older diffs
are dropped and a checkpoint is created. Set to None for unlimited
history (default, but may cause memory issues for long simulations).
live_compress : bool, optional
If True, store full state snapshots at compress_freq intervals during
simulation instead of diffs. This avoids the expensive O(n) reconstruction
in load_steps() at the cost of more memory per frame. When enabled,
load_steps() becomes a no-op since frames are already stored.
"""

@classmethod
Expand All @@ -29,13 +34,15 @@ def from_file(cls, fpath):

@classmethod
def from_dict(cls, res_dict):
diffs = res_dict["diffs"]
diffs = res_dict.get("diffs", [])
compress_freq = res_dict.get("compress_freq", 1)
max_history = res_dict.get("max_history", None)
live_compress = res_dict.get("live_compress", False)
res = cls(
SimulationState.from_dict(res_dict["initial_state"]),
compress_freq=compress_freq,
max_history=max_history,
live_compress=live_compress,
)
# Restore checkpoint if present
if "checkpoint_state" in res_dict and res_dict["checkpoint_state"] is not None:
Expand All @@ -44,6 +51,11 @@ def from_dict(cls, res_dict):
)
res._checkpoint_step = res_dict.get("checkpoint_step", 0)

# Restore frames if present (for live_compress mode)
if "frames" in res_dict and res_dict["frames"]:
for step_str, state_dict in res_dict["frames"].items():
res._frames[int(step_str)] = SimulationState.from_dict(state_dict)

for diff in diffs:
if SITES in diff:
diff[SITES] = {int(k): v for k, v in diff[SITES].items()}
Expand All @@ -56,13 +68,26 @@ def from_dict(cls, res_dict):
"total_steps", res._checkpoint_step + len(diffs)
)

# Reconstruct live_state to reflect the final state
if res._frames:
# In live_compress mode, use the last frame
last_step = max(res._frames.keys())
res._live_state = res._frames[last_step].copy()
elif res._diffs:
# Replay all diffs to get final state
if res._checkpoint_state is not None:
res._live_state = res._checkpoint_state.copy()
for diff in res._diffs:
res._live_state.batch_update(diff)

return res

def __init__(
self,
starting_state: SimulationState,
compress_freq: int = 1,
max_history: int = None,
live_compress: bool = False,
):
"""Initializes a SimulationResult with the specified starting_state.

Expand All @@ -71,24 +96,39 @@ def __init__(
starting_state : SimulationState
The state with which the simulation started.
compress_freq : int, optional
Compression frequency for sampling, by default 1.
Compression frequency for sampling, by default 1. When live_compress
is True, this controls how often full state snapshots are stored.
max_history : int, optional
Maximum number of diffs to keep in memory. When exceeded, a
checkpoint is created and old diffs are dropped. This prevents
unbounded memory growth during long simulations. Set to None
(default) for unlimited history. Recommended: 1000-10000 for
long simulations.
long simulations. Ignored when live_compress is True.
live_compress : bool, optional
If True, store full state snapshots at compress_freq intervals
during simulation instead of storing diffs. This avoids the O(n)
reconstruction cost of load_steps() but uses more memory per stored
frame. Default is False (store diffs, reconstruct post-hoc).
"""
self.initial_state = starting_state
self.compress_freq = compress_freq
self.max_history = max_history
self.live_compress = live_compress
self._diffs: list[dict] = []
self._stored_states = {}
self._frames: Dict[int, SimulationState] = {} # For live_compress mode
# Checkpoint support for bounded history
self._checkpoint_state: SimulationState = None
self._checkpoint_step: int = 0
self._total_steps: int = 0

# Live state that gets updated with each step
self._live_state: SimulationState = starting_state.copy()

# Store initial state as frame 0 if live_compress is enabled
if self.live_compress:
self._frames[0] = starting_state.copy()

def get_diffs(self) -> list[dict]:
"""Returns the list of diffs.

Expand All @@ -99,6 +139,15 @@ def get_diffs(self) -> list[dict]:
"""
return self._diffs

@property
def live_state(self) -> SimulationState:
"""The current live state of the simulation.

This state is updated with each call to add_step(). Use this to access
the current simulation state during a run.
"""
return self._live_state

def add_step(self, updates: Dict[int, Dict]) -> None:
"""Takes a set of updates as a dictionary mapping site IDs
to the new values for various state parameters. For instance, if at the
Expand All @@ -111,14 +160,30 @@ def add_step(self, updates: Dict[int, Dict]) -> None:
}
}

This method:
1. Applies the updates to the internal live_state
2. Increments the step counter
3. In live_compress mode: stores frames at compress_freq intervals
4. In normal mode: stores the diff for later reconstruction

Parameters
----------
updates : dict
The changes associated with a new simulation step.
"""
self._diffs.append(updates)
# Update the live state
self._live_state.batch_update(updates)
self._total_steps += 1

# In live_compress mode, store frames at intervals instead of diffs
if self.live_compress:
if self._total_steps % self.compress_freq == 0:
self._frames[self._total_steps] = self._live_state.copy()
return

# Normal mode: store diffs
self._diffs.append(updates)

# Check if we need to create a checkpoint and drop old diffs
if self.max_history is not None and len(self._diffs) > self.max_history:
self._create_checkpoint()
Expand Down Expand Up @@ -173,13 +238,20 @@ def steps(self) -> List[SimulationState]:
"""Yields all available steps from this simulation.

Note: When max_history is set, only steps from the checkpoint onward
are available. Use earliest_available_step to check.
are available. When live_compress is set, only frames at compress_freq
intervals are available. Use earliest_available_step to check.

Yields
------
SimulationState
Each step's state (as a copy to avoid mutation issues).
"""
# If frames exist (live_compress mode), yield them in order
if self._frames:
for step_no in sorted(self._frames.keys()):
yield self._frames[step_no].copy()
return

# Start from checkpoint or initial state
if self._checkpoint_state is not None:
live_state = self._checkpoint_state.copy()
Expand All @@ -206,17 +278,38 @@ def last_step(self) -> SimulationState:
def first_step(self):
return self.get_step(0)

def set_output(self, step: SimulationState):
self.output = step
@property
def output(self) -> SimulationState:
"""The final output state of the simulation (alias for live_state)."""
return self._live_state

def load_steps(self, interval=1):
"""Pre-loads steps into memory at the specified interval for faster access.

When live_compress is enabled, this is a no-op since frames are already
stored during simulation. If a different interval is requested than what
was used during simulation (compress_freq), an error is raised.

Parameters
----------
interval : int, optional
Store every Nth step in memory, by default 1.

Raises
------
ValueError
If live_compress was used and requested interval doesn't match compress_freq.
"""
# If frames already exist (live_compress mode), no reconstruction needed
if self._frames:
if interval != self.compress_freq:
raise ValueError(
f"Cannot load steps with interval={interval}. This result was "
f"created with live_compress=True and compress_freq={self.compress_freq}. "
f"Only interval={self.compress_freq} is available."
)
return

# Clear old cache first
self._stored_states.clear()

Expand Down Expand Up @@ -255,8 +348,20 @@ def get_step(self, step_no) -> SimulationState:
Raises
------
ValueError
If step_no is before the earliest available step (when using max_history).
If step_no is before the earliest available step (when using max_history),
or if step_no is not available in live_compress mode.
"""
# Check frames first (live_compress mode)
if self._frames:
if step_no in self._frames:
return self._frames[step_no]
# In live_compress mode, only frames at compress_freq intervals exist
raise ValueError(
f"Cannot retrieve step {step_no}. This result was created with "
f"live_compress=True and compress_freq={self.compress_freq}. "
f"Available steps: {sorted(self._frames.keys())}"
)

if step_no < self._checkpoint_step:
raise ValueError(
f"Cannot retrieve step {step_no}. Earliest available step is "
Expand Down Expand Up @@ -288,6 +393,7 @@ def as_dict(self):
"diffs": self._diffs,
"compress_freq": self.compress_freq,
"max_history": self.max_history,
"live_compress": self.live_compress,
"total_steps": self._total_steps,
"@module": self.__class__.__module__,
"@class": self.__class__.__name__,
Expand All @@ -299,6 +405,15 @@ def as_dict(self):
else:
result["checkpoint_state"] = None
result["checkpoint_step"] = 0

# Include frames if present (live_compress mode)
if self._frames:
result["frames"] = {
str(step): state.as_dict() for step, state in self._frames.items()
}
else:
result["frames"] = {}

return result

def to_file(self, fpath: str = None) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/pylattica/visualization/result_artist.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ def jupyter_play(self, cell_size: int = 20, wait: int = 1, **kwargs):
wait : int, optional
The time duration between frames in the animation. Defaults to 1., by default 1
"""
from IPython.display import (
from IPython.display import ( # pylint: disable=import-error
clear_output,
display,
) # pylint: disable=import-error # pragma: no cover
) # pragma: no cover

imgs = self._get_images(cell_size=cell_size, **kwargs) # pragma: no cover
for img in imgs: # pragma: no cover
Expand Down
4 changes: 2 additions & 2 deletions src/pylattica/visualization/structure_artist.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def jupyter_show(self, state: SimulationState, **kwargs):
state : SimulationState
The simulation state to display.
"""
from IPython.display import (
from IPython.display import ( # pylint: disable=import-error
display,
) # pylint: disable=import-error # pragma: no cover
) # pragma: no cover

img = self.get_img(state, **kwargs) # pragma: no cover
display(img) # pragma: no cover
Expand Down
Loading
Loading