From 0f123cf486f9013961fec0850a58b432d5c46cb3 Mon Sep 17 00:00:00 2001 From: Max Gallant Date: Tue, 24 Feb 2026 17:24:38 -0800 Subject: [PATCH 1/3] Remove the need for posthoc compression by accumulating compressed frames live --- .../core/runner/asynchronous_runner.py | 4 +- src/pylattica/core/runner/base_runner.py | 4 +- .../core/runner/synchronous_runner.py | 5 +- src/pylattica/core/simulation_result.py | 131 ++++++++++++++++-- tests/core/test_simulation_result.py | 121 ++++++++++++++++ 5 files changed, 247 insertions(+), 18 deletions(-) diff --git a/src/pylattica/core/runner/asynchronous_runner.py b/src/pylattica/core/runner/asynchronous_runner.py index 1a07207..73e4519 100644 --- a/src/pylattica/core/runner/asynchronous_runner.py +++ b/src/pylattica/core/runner/asynchronous_runner.py @@ -31,7 +31,6 @@ def _run( # pylint: disable=too-many-positional-arguments self, _: SimulationState, result: SimulationResult, - live_state: SimulationState, controller: BasicController, num_steps: int, verbose: bool = False, @@ -58,6 +57,7 @@ def _run( # pylint: disable=too-many-positional-arguments """ site_queue = deque() + live_state = result.live_state def _add_sites_to_queue(): next_site = controller.get_random_site(live_state) @@ -84,7 +84,6 @@ def _add_sites_to_queue(): state_updates = controller_response state_updates = merge_updates(state_updates, site_id=site_id) - live_state.batch_update(state_updates) site_queue.extend(next_sites) result.add_step(state_updates) @@ -95,5 +94,4 @@ def _add_sites_to_queue(): if len(site_queue) == 0: break - result.set_output(live_state) return result diff --git a/src/pylattica/core/runner/base_runner.py b/src/pylattica/core/runner/base_runner.py index 268db57..5727a0d 100644 --- a/src/pylattica/core/runner/base_runner.py +++ b/src/pylattica/core/runner/base_runner.py @@ -62,9 +62,7 @@ def run( result = controller.instantiate_result(initial_state.copy()) controller.pre_run(initial_state) - live_state = initial_state.copy() - self._run(initial_state, result, live_state, controller, num_steps, verbose) + self._run(initial_state, result, controller, num_steps, verbose) - result.set_output(live_state) return result diff --git a/src/pylattica/core/runner/synchronous_runner.py b/src/pylattica/core/runner/synchronous_runner.py index 8f28056..b755f27 100644 --- a/src/pylattica/core/runner/synchronous_runner.py +++ b/src/pylattica/core/runner/synchronous_runner.py @@ -42,7 +42,6 @@ def _run( # pylint: disable=too-many-positional-arguments self, initial_state: SimulationState, result: SimulationResult, - live_state: SimulationState, controller: BasicController, num_steps: int, verbose: bool = False, @@ -74,11 +73,9 @@ def _run( # pylint: disable=too-many-positional-arguments else: printif(verbose, "Running in series.") for _ in tqdm(range(num_steps)): - updates = self._take_step(live_state, controller) - live_state.batch_update(updates) + updates = self._take_step(result.live_state, controller) result.add_step(updates) - result.set_output(live_state) return result def _take_step_parallel(self, updates: dict, pool, chunk_size) -> SimulationState: diff --git a/src/pylattica/core/simulation_result.py b/src/pylattica/core/simulation_result.py index 42c6359..a9986ce 100644 --- a/src/pylattica/core/simulation_result.py +++ b/src/pylattica/core/simulation_result.py @@ -21,6 +21,11 @@ class SimulationResult: Maximum number of diffs to keep in memory. When exceeded, older diffs are dropped and a checkpoint is created. Set to None for unlimited history (default, but may cause memory issues for long simulations). + live_compress : bool, optional + If True, store full state snapshots at compress_freq intervals during + simulation instead of diffs. This avoids the expensive O(n) reconstruction + in load_steps() at the cost of more memory per frame. When enabled, + load_steps() becomes a no-op since frames are already stored. """ @classmethod @@ -29,13 +34,15 @@ def from_file(cls, fpath): @classmethod def from_dict(cls, res_dict): - diffs = res_dict["diffs"] + diffs = res_dict.get("diffs", []) compress_freq = res_dict.get("compress_freq", 1) max_history = res_dict.get("max_history", None) + live_compress = res_dict.get("live_compress", False) res = cls( SimulationState.from_dict(res_dict["initial_state"]), compress_freq=compress_freq, max_history=max_history, + live_compress=live_compress, ) # Restore checkpoint if present if "checkpoint_state" in res_dict and res_dict["checkpoint_state"] is not None: @@ -44,6 +51,11 @@ def from_dict(cls, res_dict): ) res._checkpoint_step = res_dict.get("checkpoint_step", 0) + # Restore frames if present (for live_compress mode) + if "frames" in res_dict and res_dict["frames"]: + for step_str, state_dict in res_dict["frames"].items(): + res._frames[int(step_str)] = SimulationState.from_dict(state_dict) + for diff in diffs: if SITES in diff: diff[SITES] = {int(k): v for k, v in diff[SITES].items()} @@ -56,6 +68,18 @@ def from_dict(cls, res_dict): "total_steps", res._checkpoint_step + len(diffs) ) + # Reconstruct live_state to reflect the final state + if res._frames: + # In live_compress mode, use the last frame + last_step = max(res._frames.keys()) + res._live_state = res._frames[last_step].copy() + elif res._diffs: + # Replay all diffs to get final state + if res._checkpoint_state is not None: + res._live_state = res._checkpoint_state.copy() + for diff in res._diffs: + res._live_state.batch_update(diff) + return res def __init__( @@ -63,6 +87,7 @@ def __init__( starting_state: SimulationState, compress_freq: int = 1, max_history: int = None, + live_compress: bool = False, ): """Initializes a SimulationResult with the specified starting_state. @@ -71,24 +96,39 @@ def __init__( starting_state : SimulationState The state with which the simulation started. compress_freq : int, optional - Compression frequency for sampling, by default 1. + Compression frequency for sampling, by default 1. When live_compress + is True, this controls how often full state snapshots are stored. max_history : int, optional Maximum number of diffs to keep in memory. When exceeded, a checkpoint is created and old diffs are dropped. This prevents unbounded memory growth during long simulations. Set to None (default) for unlimited history. Recommended: 1000-10000 for - long simulations. + long simulations. Ignored when live_compress is True. + live_compress : bool, optional + If True, store full state snapshots at compress_freq intervals + during simulation instead of storing diffs. This avoids the O(n) + reconstruction cost of load_steps() but uses more memory per stored + frame. Default is False (store diffs, reconstruct post-hoc). """ self.initial_state = starting_state self.compress_freq = compress_freq self.max_history = max_history + self.live_compress = live_compress self._diffs: list[dict] = [] self._stored_states = {} + self._frames: Dict[int, SimulationState] = {} # For live_compress mode # Checkpoint support for bounded history self._checkpoint_state: SimulationState = None self._checkpoint_step: int = 0 self._total_steps: int = 0 + # Live state that gets updated with each step + self._live_state: SimulationState = starting_state.copy() + + # Store initial state as frame 0 if live_compress is enabled + if self.live_compress: + self._frames[0] = starting_state.copy() + def get_diffs(self) -> list[dict]: """Returns the list of diffs. @@ -99,6 +139,15 @@ def get_diffs(self) -> list[dict]: """ return self._diffs + @property + def live_state(self) -> SimulationState: + """The current live state of the simulation. + + This state is updated with each call to add_step(). Use this to access + the current simulation state during a run. + """ + return self._live_state + def add_step(self, updates: Dict[int, Dict]) -> None: """Takes a set of updates as a dictionary mapping site IDs to the new values for various state parameters. For instance, if at the @@ -111,14 +160,30 @@ def add_step(self, updates: Dict[int, Dict]) -> None: } } + This method: + 1. Applies the updates to the internal live_state + 2. Increments the step counter + 3. In live_compress mode: stores frames at compress_freq intervals + 4. In normal mode: stores the diff for later reconstruction + Parameters ---------- updates : dict The changes associated with a new simulation step. """ - self._diffs.append(updates) + # Update the live state + self._live_state.batch_update(updates) self._total_steps += 1 + # In live_compress mode, store frames at intervals instead of diffs + if self.live_compress: + if self._total_steps % self.compress_freq == 0: + self._frames[self._total_steps] = self._live_state.copy() + return + + # Normal mode: store diffs + self._diffs.append(updates) + # Check if we need to create a checkpoint and drop old diffs if self.max_history is not None and len(self._diffs) > self.max_history: self._create_checkpoint() @@ -173,13 +238,20 @@ def steps(self) -> List[SimulationState]: """Yields all available steps from this simulation. Note: When max_history is set, only steps from the checkpoint onward - are available. Use earliest_available_step to check. + are available. When live_compress is set, only frames at compress_freq + intervals are available. Use earliest_available_step to check. Yields ------ SimulationState Each step's state (as a copy to avoid mutation issues). """ + # If frames exist (live_compress mode), yield them in order + if self._frames: + for step_no in sorted(self._frames.keys()): + yield self._frames[step_no].copy() + return + # Start from checkpoint or initial state if self._checkpoint_state is not None: live_state = self._checkpoint_state.copy() @@ -206,17 +278,38 @@ def last_step(self) -> SimulationState: def first_step(self): return self.get_step(0) - def set_output(self, step: SimulationState): - self.output = step + @property + def output(self) -> SimulationState: + """The final output state of the simulation (alias for live_state).""" + return self._live_state def load_steps(self, interval=1): """Pre-loads steps into memory at the specified interval for faster access. + When live_compress is enabled, this is a no-op since frames are already + stored during simulation. If a different interval is requested than what + was used during simulation (compress_freq), an error is raised. + Parameters ---------- interval : int, optional Store every Nth step in memory, by default 1. + + Raises + ------ + ValueError + If live_compress was used and requested interval doesn't match compress_freq. """ + # If frames already exist (live_compress mode), no reconstruction needed + if self._frames: + if interval != self.compress_freq: + raise ValueError( + f"Cannot load steps with interval={interval}. This result was " + f"created with live_compress=True and compress_freq={self.compress_freq}. " + f"Only interval={self.compress_freq} is available." + ) + return + # Clear old cache first self._stored_states.clear() @@ -255,8 +348,20 @@ def get_step(self, step_no) -> SimulationState: Raises ------ ValueError - If step_no is before the earliest available step (when using max_history). + If step_no is before the earliest available step (when using max_history), + or if step_no is not available in live_compress mode. """ + # Check frames first (live_compress mode) + if self._frames: + if step_no in self._frames: + return self._frames[step_no] + # In live_compress mode, only frames at compress_freq intervals exist + raise ValueError( + f"Cannot retrieve step {step_no}. This result was created with " + f"live_compress=True and compress_freq={self.compress_freq}. " + f"Available steps: {sorted(self._frames.keys())}" + ) + if step_no < self._checkpoint_step: raise ValueError( f"Cannot retrieve step {step_no}. Earliest available step is " @@ -288,6 +393,7 @@ def as_dict(self): "diffs": self._diffs, "compress_freq": self.compress_freq, "max_history": self.max_history, + "live_compress": self.live_compress, "total_steps": self._total_steps, "@module": self.__class__.__module__, "@class": self.__class__.__name__, @@ -299,6 +405,15 @@ def as_dict(self): else: result["checkpoint_state"] = None result["checkpoint_step"] = 0 + + # Include frames if present (live_compress mode) + if self._frames: + result["frames"] = { + str(step): state.as_dict() for step, state in self._frames.items() + } + else: + result["frames"] = {} + return result def to_file(self, fpath: str = None) -> None: diff --git a/tests/core/test_simulation_result.py b/tests/core/test_simulation_result.py index b928d48..5fb2108 100644 --- a/tests/core/test_simulation_result.py +++ b/tests/core/test_simulation_result.py @@ -297,3 +297,124 @@ def test_compress_result_invalid_size(initial_state): # Can't compress to more steps than we have with pytest.raises(ValueError, match="Cannot compress"): compress_result(result, 100) + + +def test_live_compress_stores_frames(initial_state): + """Test that live_compress stores frames at compress_freq intervals.""" + result = SimulationResult(initial_state, compress_freq=10, live_compress=True) + + # Add 25 steps + for step in range(25): + updates = {0: {"value": step}} + result.add_step(updates) + + # Should have frames at 0, 10, 20 (initial + steps 10 and 20) + assert 0 in result._frames + assert 10 in result._frames + assert 20 in result._frames + assert 25 not in result._frames # Not a multiple of 10 + + # Diffs should be empty in live_compress mode + assert len(result._diffs) == 0 + + +def test_live_compress_get_step(initial_state): + """Test that get_step works with live_compress mode.""" + result = SimulationResult(initial_state, compress_freq=5, live_compress=True) + + for step in range(10): + updates = {0: {"value": step}} + result.add_step(updates) + + # Can get steps at frame intervals + state_5 = result.get_step(5) + assert state_5.get_site_state(0)["value"] == 4 # 0-indexed, step 5 has value 4 + + state_10 = result.get_step(10) + assert state_10.get_site_state(0)["value"] == 9 + + # Cannot get steps that aren't at frame intervals + with pytest.raises(ValueError, match="live_compress"): + result.get_step(3) + + +def test_live_compress_load_steps_noop(initial_state): + """Test that load_steps is a no-op in live_compress mode.""" + result = SimulationResult(initial_state, compress_freq=5, live_compress=True) + + for step in range(10): + updates = {0: {"value": step}} + result.add_step(updates) + + # load_steps with matching interval should be a no-op + result.load_steps(interval=5) # Should not raise + + # load_steps with non-matching interval should raise + with pytest.raises(ValueError, match="interval"): + result.load_steps(interval=1) + + +def test_live_compress_steps_generator(initial_state): + """Test that steps() yields frames in live_compress mode.""" + result = SimulationResult(initial_state, compress_freq=5, live_compress=True) + + for step in range(10): + updates = {0: {"value": step}} + result.add_step(updates) + + # Should yield frames in order: 0, 5, 10 + steps = list(result.steps()) + assert len(steps) == 3 # Frames at 0, 5, 10 + + # First frame is initial (no value set yet) + # Frame at step 5 has value 4 (last update before step 5 frame is taken) + assert steps[1].get_site_state(0)["value"] == 4 + # Frame at step 10 has value 9 + assert steps[2].get_site_state(0)["value"] == 9 + + +def test_live_compress_serialization(initial_state): + """Test that live_compress results serialize and deserialize correctly.""" + result = SimulationResult(initial_state, compress_freq=5, live_compress=True) + + for step in range(10): + updates = {0: {"value": step}} + result.add_step(updates) + + # Serialize and deserialize + result_dict = result.as_dict() + restored = SimulationResult.from_dict(result_dict) + + # Check properties are preserved + assert restored.live_compress is True + assert restored.compress_freq == 5 + assert len(restored._frames) == 3 # 0, 5, 10 + assert len(restored._diffs) == 0 + + # Check live_state is correctly restored + assert restored.live_state.get_site_state(0)["value"] == 9 + + +def test_live_state_property(initial_state): + """Test that live_state property reflects current state.""" + result = SimulationResult(initial_state) + + # Initial live_state - site 0 doesn't exist yet + assert result.live_state.get_site_state(0) is None + + # After adding steps, live_state is updated + result.add_step({0: {"value": 42}}) + assert result.live_state.get_site_state(0)["value"] == 42 + + result.add_step({0: {"value": 100}}) + assert result.live_state.get_site_state(0)["value"] == 100 + + +def test_output_property(initial_state): + """Test that output property is an alias for live_state.""" + result = SimulationResult(initial_state) + + result.add_step({0: {"value": 42}}) + + # output should be same as live_state + assert result.output is result.live_state From 2b4a425c75645acb1af28c35157f8e790ce3c7d8 Mon Sep 17 00:00:00 2001 From: Max Gallant Date: Tue, 24 Feb 2026 17:28:00 -0800 Subject: [PATCH 2/3] lint --- src/pylattica/visualization/result_artist.py | 4 ++-- src/pylattica/visualization/structure_artist.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pylattica/visualization/result_artist.py b/src/pylattica/visualization/result_artist.py index f8c5582..cb5e6dc 100644 --- a/src/pylattica/visualization/result_artist.py +++ b/src/pylattica/visualization/result_artist.py @@ -97,10 +97,10 @@ def jupyter_play(self, cell_size: int = 20, wait: int = 1, **kwargs): wait : int, optional The time duration between frames in the animation. Defaults to 1., by default 1 """ - from IPython.display import ( + from IPython.display import ( # pylint: disable=import-error clear_output, display, - ) # pylint: disable=import-error # pragma: no cover + ) # pragma: no cover imgs = self._get_images(cell_size=cell_size, **kwargs) # pragma: no cover for img in imgs: # pragma: no cover diff --git a/src/pylattica/visualization/structure_artist.py b/src/pylattica/visualization/structure_artist.py index 216527e..2e63bf4 100644 --- a/src/pylattica/visualization/structure_artist.py +++ b/src/pylattica/visualization/structure_artist.py @@ -29,9 +29,9 @@ def jupyter_show(self, state: SimulationState, **kwargs): state : SimulationState The simulation state to display. """ - from IPython.display import ( + from IPython.display import ( # pylint: disable=import-error display, - ) # pylint: disable=import-error # pragma: no cover + ) # pragma: no cover img = self.get_img(state, **kwargs) # pragma: no cover display(img) # pragma: no cover From 57a404eb523a12cb063c68de3f5df55c14b8a17f Mon Sep 17 00:00:00 2001 From: Max Gallant Date: Tue, 24 Feb 2026 17:30:40 -0800 Subject: [PATCH 3/3] add some tests --- tests/core/test_simulation_result.py | 37 ++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/core/test_simulation_result.py b/tests/core/test_simulation_result.py index 5fb2108..9b3ad5b 100644 --- a/tests/core/test_simulation_result.py +++ b/tests/core/test_simulation_result.py @@ -410,6 +410,43 @@ def test_live_state_property(initial_state): assert result.live_state.get_site_state(0)["value"] == 100 +def test_from_dict_restores_live_state_from_diffs(initial_state): + """Test that from_dict replays diffs to restore live_state.""" + result = SimulationResult(initial_state) + + for step in range(5): + result.add_step({0: {"value": step}}) + + # Serialize and deserialize + result_dict = result.as_dict() + restored = SimulationResult.from_dict(result_dict) + + # live_state should be restored by replaying diffs + assert restored.live_state.get_site_state(0)["value"] == 4 + + +def test_from_dict_restores_live_state_from_checkpoint(initial_state): + """Test that from_dict uses checkpoint when restoring live_state.""" + # Use max_history to trigger checkpoint creation + result = SimulationResult(initial_state, max_history=5) + + # Add enough steps to trigger checkpoint + for step in range(10): + result.add_step({0: {"value": step}}) + + # Should have a checkpoint now + assert result._checkpoint_state is not None + + # Serialize and deserialize + result_dict = result.as_dict() + restored = SimulationResult.from_dict(result_dict) + + # live_state should be restored correctly (final value is 9) + assert restored.live_state.get_site_state(0)["value"] == 9 + # Checkpoint should be restored + assert restored._checkpoint_state is not None + + def test_output_property(initial_state): """Test that output property is an alias for live_state.""" result = SimulationResult(initial_state)