From 1b6b33222ebe74b5fb514ade523b294b966b4e09 Mon Sep 17 00:00:00 2001 From: Max Gallant Date: Tue, 24 Feb 2026 10:57:50 -0800 Subject: [PATCH 1/6] Add a history memory cap to pylattica stored diffs and also resolve O(n) scaling in get random site --- src/pylattica/core/basic_controller.py | 15 +- src/pylattica/core/simulation_result.py | 196 ++++++++++++++++++++---- src/pylattica/core/simulation_state.py | 4 +- tests/core/test_simulation_result.py | 104 +++++++++++++ 4 files changed, 285 insertions(+), 34 deletions(-) diff --git a/src/pylattica/core/basic_controller.py b/src/pylattica/core/basic_controller.py index 33d6e04..ffdd075 100644 --- a/src/pylattica/core/basic_controller.py +++ b/src/pylattica/core/basic_controller.py @@ -15,8 +15,18 @@ class has a single responsibility, which is to implement the update SimulationState will be passed to this method, along with the ID of the site at which the update rule should be applied. It is up to the user to decide what updates should be produced using this information. + + Attributes + ---------- + max_history : int, optional + Maximum number of step diffs to keep in memory during simulation. + Set this to limit memory usage for long simulations. When exceeded, + older steps are checkpointed and dropped. Default is None (unlimited). """ + # Override this in subclasses to limit memory usage + max_history: int = None + @abstractmethod def get_state_update(self, site_id: int, prev_state: SimulationState): pass # pragma: no cover @@ -25,7 +35,8 @@ def pre_run(self, initial_state: SimulationState) -> None: pass def get_random_site(self, state: SimulationState): - return random.randint(0, len(state.site_ids()) - 1) + # Use state.size (O(1)) instead of len(state.site_ids()) which is O(n) + return random.randint(0, state.size - 1) def instantiate_result(self, starting_state: SimulationState): - return SimulationResult(starting_state=starting_state) + return SimulationResult(starting_state=starting_state, max_history=self.max_history) diff --git a/src/pylattica/core/simulation_result.py b/src/pylattica/core/simulation_result.py index ed28d53..20e3b35 100644 --- a/src/pylattica/core/simulation_result.py +++ b/src/pylattica/core/simulation_result.py @@ -17,6 +17,10 @@ class SimulationResult: ---------- initial_state : SimulationState The state with which the simulation started. + max_history : int, optional + Maximum number of diffs to keep in memory. When exceeded, older diffs + are dropped and a checkpoint is created. Set to None for unlimited + history (default, but may cause memory issues for long simulations). """ @classmethod @@ -27,31 +31,54 @@ def from_file(cls, fpath): def from_dict(cls, res_dict): diffs = res_dict["diffs"] compress_freq = res_dict.get("compress_freq", 1) + max_history = res_dict.get("max_history", None) res = cls( SimulationState.from_dict(res_dict["initial_state"]), compress_freq=compress_freq, + max_history=max_history, ) + # Restore checkpoint if present + if "checkpoint_state" in res_dict and res_dict["checkpoint_state"] is not None: + res._checkpoint_state = SimulationState.from_dict(res_dict["checkpoint_state"]) + res._checkpoint_step = res_dict.get("checkpoint_step", 0) + for diff in diffs: if SITES in diff: diff[SITES] = {int(k): v for k, v in diff[SITES].items()} if GENERAL not in diff and SITES not in diff: diff = {int(k): v for k, v in diff.items()} - res.add_step(diff) + res._diffs.append(diff) # Bypass add_step to avoid re-checkpointing + + # Restore total_steps from serialized data, or compute from diffs + checkpoint + res._total_steps = res_dict.get("total_steps", res._checkpoint_step + len(diffs)) return res - def __init__(self, starting_state: SimulationState, compress_freq: int = 1): + def __init__(self, starting_state: SimulationState, compress_freq: int = 1, max_history: int = None): """Initializes a SimulationResult with the specified starting_state. Parameters ---------- starting_state : SimulationState The state with which the simulation started. + compress_freq : int, optional + Compression frequency for sampling, by default 1. + max_history : int, optional + Maximum number of diffs to keep in memory. When exceeded, a + checkpoint is created and old diffs are dropped. This prevents + unbounded memory growth during long simulations. Set to None + (default) for unlimited history. Recommended: 1000-10000 for + long simulations. """ self.initial_state = starting_state self.compress_freq = compress_freq + self.max_history = max_history self._diffs: list[dict] = [] self._stored_states = {} + # Checkpoint support for bounded history + self._checkpoint_state: SimulationState = None + self._checkpoint_step: int = 0 + self._total_steps: int = 0 def add_step(self, updates: Dict[int, Dict]) -> None: """Takes a set of updates as a dictionary mapping site IDs @@ -71,26 +98,79 @@ def add_step(self, updates: Dict[int, Dict]) -> None: The changes associated with a new simulation step. """ self._diffs.append(updates) + self._total_steps += 1 + + # Check if we need to create a checkpoint and drop old diffs + if self.max_history is not None and len(self._diffs) > self.max_history: + self._create_checkpoint() + + def _create_checkpoint(self) -> None: + """Creates a checkpoint by computing the current state and dropping old diffs. + + This is called automatically when max_history is exceeded. It computes + the state at the midpoint of the current diffs, stores it as a checkpoint, + and drops all diffs before that point. + """ + # Compute checkpoint at half the current diffs (keeps half the history) + checkpoint_offset = len(self._diffs) // 2 + + # Compute the state at the checkpoint + if self._checkpoint_state is not None: + state = self._checkpoint_state.copy() + else: + state = self.initial_state.copy() + + for i in range(checkpoint_offset): + state.batch_update(self._diffs[i]) + + # Update checkpoint + self._checkpoint_state = state + self._checkpoint_step += checkpoint_offset + + # Drop old diffs + self._diffs = self._diffs[checkpoint_offset:] + + # Clear stored states cache (indices are now invalid) + self._stored_states.clear() + + @property + def earliest_available_step(self) -> int: + """Returns the earliest step number that can be reconstructed. + + When max_history is set and checkpoints have been created, early + steps are no longer available. + """ + return self._checkpoint_step @property def original_length(self) -> int: return int(len(self) * self.compress_freq) def __len__(self) -> int: - return len(self._diffs) + 1 + # Total steps = checkpoint step + remaining diffs + 1 (for initial state) + return self._total_steps + 1 def steps(self) -> List[SimulationState]: - """Returns a list of all the steps from this simulation. + """Yields all available steps from this simulation. - Returns - ------- - List[SimulationState] - The list of steps + Note: When max_history is set, only steps from the checkpoint onward + are available. Use earliest_available_step to check. + + Yields + ------ + SimulationState + Each step's state (as a copy to avoid mutation issues). """ - live_state = self.initial_state.copy() + # Start from checkpoint or initial state + if self._checkpoint_state is not None: + live_state = self._checkpoint_state.copy() + else: + live_state = self.initial_state.copy() + + yield live_state.copy() # Yield a copy to avoid mutation issues for diff in self._diffs: - yield live_state live_state.batch_update(diff) + yield live_state.copy() @property def last_step(self) -> SimulationState: @@ -111,12 +191,30 @@ def set_output(self, step: SimulationState): self.output = step def load_steps(self, interval=1): - live_state = self.initial_state.copy() - self._stored_states[0] = self.initial_state.copy() + """Pre-loads steps into memory at the specified interval for faster access. + + Parameters + ---------- + interval : int, optional + Store every Nth step in memory, by default 1. + """ + # Clear old cache first + self._stored_states.clear() + + # Start from checkpoint or initial state + if self._checkpoint_state is not None: + live_state = self._checkpoint_state.copy() + start_step = self._checkpoint_step + else: + live_state = self.initial_state.copy() + start_step = 0 + + self._stored_states[start_step] = live_state.copy() + for ud_idx in tqdm.tqdm( range(0, len(self._diffs)), desc="Constructing result from diffs" ): - step_no = ud_idx + 1 + step_no = start_step + ud_idx + 1 live_state.batch_update(self._diffs[ud_idx]) if step_no % interval == 0 and self._stored_states.get(step_no) is None: stored_state = live_state.copy() @@ -134,28 +232,55 @@ def get_step(self, step_no) -> SimulationState: ------- SimulationState The simulation state at the requested step. - """ - # if step_no % self.compress_freq != 0: - # raise ValueError(f"Cannot retrieve step no {step_no} because this result has been compressed with sampling frequency {self.compress_freq}") + Raises + ------ + ValueError + If step_no is before the earliest available step (when using max_history). + """ + if step_no < self._checkpoint_step: + raise ValueError( + f"Cannot retrieve step {step_no}. Earliest available step is " + f"{self._checkpoint_step} (earlier steps were dropped due to max_history={self.max_history})." + ) stored = self._stored_states.get(step_no) if stored is not None: return stored + + # Start from checkpoint (or initial state if no checkpoint) + if self._checkpoint_state is not None: + state = self._checkpoint_state.copy() + start_idx = 0 else: state = self.initial_state.copy() - for ud_idx in range(0, step_no): - state.batch_update(self._diffs[ud_idx]) - return state + start_idx = 0 + + # Apply diffs from checkpoint to requested step + diffs_to_apply = step_no - self._checkpoint_step + for ud_idx in range(start_idx, diffs_to_apply): + state.batch_update(self._diffs[ud_idx]) + + return state def as_dict(self): - return { + result = { "initial_state": self.initial_state.as_dict(), "diffs": self._diffs, "compress_freq": self.compress_freq, + "max_history": self.max_history, + "total_steps": self._total_steps, "@module": self.__class__.__module__, "@class": self.__class__.__name__, } + # Include checkpoint if present + if self._checkpoint_state is not None: + result["checkpoint_state"] = self._checkpoint_state.as_dict() + result["checkpoint_step"] = self._checkpoint_step + else: + result["checkpoint_state"] = None + result["checkpoint_step"] = 0 + return result def to_file(self, fpath: str = None) -> None: """Serializes this result to the specified filepath. @@ -175,29 +300,38 @@ def to_file(self, fpath: str = None) -> None: def compress_result(result: SimulationResult, num_steps: int): - i_state = result.first_step - # total steps is the actual number of diffs stored, not the number of original simulation steps taken - total_steps = len(result) - if num_steps >= total_steps: + """Compress a simulation result by sampling fewer steps. + + Parameters + ---------- + result : SimulationResult + The result to compress. + num_steps : int + Target number of steps in the compressed result. + + Returns + ------- + SimulationResult + A new result with fewer steps. + """ + # Use earliest available step as the starting point + i_state = result.get_step(result.earliest_available_step) + available_steps = len(result) - result.earliest_available_step + if num_steps >= available_steps: raise ValueError( - f"Cannot upsample SimulationResult of length {total_steps} to size {num_steps}." + f"Cannot compress SimulationResult with {available_steps} available steps to {num_steps} steps." ) - exact_sample_freq = total_steps / (num_steps) - # print(total_steps, current_sample_freq) + exact_sample_freq = available_steps / num_steps total_compress_freq = exact_sample_freq * result.compress_freq compressed_result = SimulationResult(i_state, compress_freq=total_compress_freq) live_state = SimulationState(copy.deepcopy(i_state._state)) - added = 0 next_sample_step = exact_sample_freq for i, diff in enumerate(result._diffs): curr_step = i + 1 live_state.batch_update(diff) - # if curr_step % current_sample_freq == 0: if curr_step > next_sample_step: - # print(curr_step) - added += 1 compressed_result.add_step(live_state.as_state_update()) next_sample_step += exact_sample_freq return compressed_result diff --git a/src/pylattica/core/simulation_state.py b/src/pylattica/core/simulation_state.py index d94626b..945a35e 100644 --- a/src/pylattica/core/simulation_state.py +++ b/src/pylattica/core/simulation_state.py @@ -58,12 +58,14 @@ def __init__(self, state: Dict = None): def size(self) -> int: """Gives the number of sites for which state information is stored. + This is O(1) - it does not create a list of site IDs. + Returns ------- int The number of sites for which state information is stored. """ - return len(self.site_ids()) + return len(self._state[SITES]) def site_ids(self) -> List[int]: """A list of site IDs for which some state is stored. diff --git a/tests/core/test_simulation_result.py b/tests/core/test_simulation_result.py index 1b51149..09e47ef 100644 --- a/tests/core/test_simulation_result.py +++ b/tests/core/test_simulation_result.py @@ -102,3 +102,107 @@ def test_write_file_autoname(random_result_small: SimulationResult): def test_diff_storage(random_result_small_ordered: SimulationResult): diff_one = random_result_small_ordered._diffs[0] assert len(diff_one.keys()) == 1 + + +def test_max_history_limits_memory(initial_state): + """Test that max_history limits the number of diffs kept in memory.""" + result = SimulationResult(initial_state, max_history=50) + + # Add 200 steps + for step in range(200): + updates = {0: {"value": step}} + result.add_step(updates) + + # Should have at most max_history diffs in memory + assert len(result._diffs) <= 50 + + # But total steps should still be correct + assert len(result) == 201 # 200 steps + initial state + assert result._total_steps == 200 + + +def test_max_history_creates_checkpoint(initial_state): + """Test that exceeding max_history creates a checkpoint.""" + result = SimulationResult(initial_state, max_history=50) + + # Add 100 steps to trigger checkpointing + for step in range(100): + updates = {0: {"value": step}} + result.add_step(updates) + + # Checkpoint should have been created + assert result._checkpoint_state is not None + assert result._checkpoint_step > 0 + assert result.earliest_available_step == result._checkpoint_step + + +def test_max_history_get_step_recent(initial_state): + """Test that recent steps are still accessible with max_history.""" + result = SimulationResult(initial_state, max_history=50) + + for step in range(100): + updates = {0: {"value": step}} + result.add_step(updates) + + # Should be able to get the last step + last_step = result.get_step(100) + assert last_step.get_site_state(0)["value"] == 99 + + # Should be able to get steps after checkpoint + earliest = result.earliest_available_step + step = result.get_step(earliest + 1) + assert step is not None + + +def test_max_history_get_step_early_raises(initial_state): + """Test that requesting steps before checkpoint raises ValueError.""" + result = SimulationResult(initial_state, max_history=50) + + for step in range(100): + updates = {0: {"value": step}} + result.add_step(updates) + + earliest = result.earliest_available_step + assert earliest > 0 # Checkpoint should exist + + with pytest.raises(ValueError, match="Cannot retrieve step"): + result.get_step(0) + + +def test_max_history_serialization(initial_state): + """Test that results with max_history serialize and deserialize correctly.""" + result = SimulationResult(initial_state, max_history=50) + + for step in range(100): + updates = {0: {"value": step}} + result.add_step(updates) + + # Serialize and deserialize + d = result.as_dict() + rehydrated = SimulationResult.from_dict(d) + + # Check state is preserved + assert rehydrated.max_history == result.max_history + assert rehydrated._checkpoint_step == result._checkpoint_step + assert rehydrated._total_steps == result._total_steps + assert len(rehydrated._diffs) == len(result._diffs) + + # Check we can get the same steps + for step_no in range(result.earliest_available_step, len(result)): + orig = result.get_step(step_no) + rehyd = rehydrated.get_step(step_no) + assert orig.as_dict() == rehyd.as_dict() + + +def test_max_history_none_unlimited(initial_state): + """Test that max_history=None allows unlimited growth (default behavior).""" + result = SimulationResult(initial_state, max_history=None) + + for step in range(500): + updates = {0: {"value": step}} + result.add_step(updates) + + # All diffs should be in memory + assert len(result._diffs) == 500 + assert result._checkpoint_state is None + assert result.earliest_available_step == 0 From 0c682121d553f58a1d51944f60a4a0e5985107ba Mon Sep 17 00:00:00 2001 From: Max Gallant Date: Tue, 24 Feb 2026 14:11:01 -0800 Subject: [PATCH 2/6] vastly faster VN neighborhood building --- .../structures/square_grid/neighborhoods.py | 127 ++++++++++++++++-- 1 file changed, 118 insertions(+), 9 deletions(-) diff --git a/src/pylattica/structures/square_grid/neighborhoods.py b/src/pylattica/structures/square_grid/neighborhoods.py index e4126ce..0911131 100644 --- a/src/pylattica/structures/square_grid/neighborhoods.py +++ b/src/pylattica/structures/square_grid/neighborhoods.py @@ -1,10 +1,14 @@ import numpy as np +import rustworkx as rx from ...core.coordinate_utils import get_points_in_cube from ...core.neighborhood_builders import ( StochasticNeighborhoodBuilder, MotifNeighborhoodBuilder, DistanceNeighborhoodBuilder, + NeighborhoodBuilder, ) +from ...core.neighborhoods import Neighborhood +from ...core.periodic_structure import PeriodicStructure class VonNeumannNbHood2DBuilder(MotifNeighborhoodBuilder): @@ -28,25 +32,130 @@ def __init__(self, size=1): super().__init__(filtered_points) -class VonNeumannNbHood3DBuilder(MotifNeighborhoodBuilder): - """A helper class for generating von Neumann type neighborhoods in square 3D structures.""" +class VonNeumannNbHood3DBuilder(NeighborhoodBuilder): + """Optimized Von Neumann neighborhood builder for simple cubic 3D grids. + + Uses direct index math instead of coordinate lookups, providing + much faster performance for large grids. + + For a cubic grid of size n³: + - Site i is at position (x, y, z) where x = i % n, y = (i // n) % n, z = i // n² + - Neighbor at offset (dx, dy, dz) has ID: ((x+dx) % n) + ((y+dy) % n) * n + ((z+dz) % n) * n² + """ def __init__(self, size: int): - """Constructs the VonNeumannNbHood3D Builder + """Constructs the VonNeumannNbHood3DBuilder. Parameters ---------- size : int - The size of the neighborhood. + The size of the neighborhood (Manhattan distance). """ + # Generate Von Neumann neighborhood offsets (excluding origin) points = get_points_in_cube(-size, size + 1, 3) + self._offsets = [ + tuple(point) for point in points + if sum(np.abs(p) for p in point) <= size and any(p != 0 for p in point) + ] + # Precompute distances for edge weights + self._distances = { + offset: np.sqrt(sum(p**2 for p in offset)) + for offset in self._offsets + } + # Cache for grid size (computed once per structure) + self._cached_n = None + self._cached_n_sites = None - filtered_points = [] - for point in points: - if sum(np.abs(p) for p in point) <= size: - filtered_points.append(point) + def _get_grid_size(self, struct: PeriodicStructure) -> int: + """Infer grid size n from structure (cached).""" + n_sites = len(struct.site_ids) + if n_sites != self._cached_n_sites: + n = int(round(n_sites ** (1/3))) + if n ** 3 != n_sites: + raise ValueError(f"Structure has {n_sites} sites, not a perfect cube.") + self._cached_n = n + self._cached_n_sites = n_sites + return self._cached_n - super().__init__(filtered_points) + def get_neighbors(self, curr_site: dict, struct: PeriodicStructure) -> list: + """Get neighbors of a site using fast index math. + + Parameters + ---------- + curr_site : dict + Site dictionary with 'id' key + struct : PeriodicStructure + The structure (used to infer grid size) + + Returns + ------- + list + List of (neighbor_id, distance) tuples + """ + from ...core.constants import SITE_ID + + n = self._get_grid_size(struct) + site_id = curr_site[SITE_ID] + + # Convert site ID to (x, y, z) coordinates + x = site_id % n + y = (site_id // n) % n + z = site_id // (n * n) + + neighbors = [] + for dx, dy, dz in self._offsets: + # Compute neighbor coordinates with periodic boundary conditions + nx = (x + dx) % n + ny = (y + dy) % n + nz = (z + dz) % n + + # Convert back to site ID + neighbor_id = nx + ny * n + nz * (n * n) + neighbors.append((neighbor_id, self._distances[(dx, dy, dz)])) + + return neighbors + + def get(self, struct: PeriodicStructure, site_class: str = None) -> Neighborhood: + """Build neighborhood graph using vectorized index math. + + This override provides much faster performance than the base class + by computing all edges in bulk using numpy operations. + """ + n_sites = len(struct.site_ids) + n = self._get_grid_size(struct) + + graph = rx.PyDiGraph() + + # Add all nodes at once + graph.add_nodes_from(range(n_sites)) + + # Vectorized computation: create coordinate arrays for all sites + site_ids = np.arange(n_sites, dtype=np.int64) + x = site_ids % n + y = (site_ids // n) % n + z = site_ids // (n * n) + + # Collect all edges across all offsets, then add in one batch + all_edges = [] + for dx, dy, dz in self._offsets: + # Compute neighbor coordinates with periodic boundary conditions + nx = (x + dx) % n + ny = (y + dy) % n + nz = (z + dz) % n + + # Convert to neighbor site IDs + neighbor_ids = nx + ny * n + nz * (n * n) + weight = self._distances[(dx, dy, dz)] + + # Stack source, dest, weight as columns and extend + # Use numpy operations to avoid Python loop overhead + edge_data = np.column_stack([site_ids, neighbor_ids]) + all_edges.extend((int(s), int(d), weight) for s, d in edge_data) + + # Add all edges in one batch + graph.extend_from_weighted_edge_list(all_edges) + + return Neighborhood(graph) class MooreNbHoodBuilder(MotifNeighborhoodBuilder): From 0b9c10991609fb9ffd3b1556feab85ff8bd02856 Mon Sep 17 00:00:00 2001 From: Max Gallant Date: Tue, 24 Feb 2026 14:11:09 -0800 Subject: [PATCH 3/6] tests for compress_result --- tests/core/test_simulation_result.py | 91 ++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/tests/core/test_simulation_result.py b/tests/core/test_simulation_result.py index 09e47ef..b928d48 100644 --- a/tests/core/test_simulation_result.py +++ b/tests/core/test_simulation_result.py @@ -3,6 +3,7 @@ import random import os from pylattica.core import SimulationResult, SimulationState +from pylattica.core.simulation_result import compress_result @pytest.fixture @@ -206,3 +207,93 @@ def test_max_history_none_unlimited(initial_state): assert len(result._diffs) == 500 assert result._checkpoint_state is None assert result.earliest_available_step == 0 + + +def test_max_history_steps_generator(initial_state): + """Test that steps() works correctly with checkpoints.""" + result = SimulationResult(initial_state, max_history=50) + + for step in range(100): + updates = {0: {"value": step}} + result.add_step(updates) + + # Iterate through available steps + steps_list = list(result.steps()) + + # Should have steps from checkpoint onward + expected_count = len(result._diffs) + 1 # diffs + checkpoint state + assert len(steps_list) == expected_count + + # Each step should be a separate object (copies) + assert steps_list[0] is not steps_list[1] + + +def test_max_history_load_steps(initial_state): + """Test that load_steps() works correctly with checkpoints.""" + result = SimulationResult(initial_state, max_history=50) + + for step in range(100): + updates = {0: {"value": step}} + result.add_step(updates) + + # Load steps at interval + result.load_steps(interval=10) + + # Should have cached states + assert len(result._stored_states) > 0 + + # Cached states should be after checkpoint + for step_no in result._stored_states: + assert step_no >= result.earliest_available_step + + +def test_original_length(initial_state): + """Test the original_length property.""" + result = SimulationResult(initial_state, compress_freq=1) + + for step in range(10): + updates = {0: {"value": step}} + result.add_step(updates) + + # With compress_freq=1, original_length should equal len + assert result.original_length == len(result) + + # With compress_freq=2, original_length should be doubled + result_compressed = SimulationResult(initial_state, compress_freq=2) + for step in range(10): + updates = {0: {"value": step}} + result_compressed.add_step(updates) + + assert result_compressed.original_length == len(result_compressed) * 2 + + +def test_compress_result(initial_state): + """Test the compress_result function.""" + result = SimulationResult(initial_state) + + # Add 100 steps with deterministic values + for step in range(100): + updates = {0: {"value": step}} + result.add_step(updates) + + # Compress to 20 steps + compressed = compress_result(result, 20) + + # Should have fewer steps + assert len(compressed) <= 25 # Some margin for sampling + + # compress_freq should be updated + assert compressed.compress_freq > 1 + + +def test_compress_result_invalid_size(initial_state): + """Test that compress_result raises error for invalid target size.""" + result = SimulationResult(initial_state) + + for step in range(10): + updates = {0: {"value": step}} + result.add_step(updates) + + # Can't compress to more steps than we have + with pytest.raises(ValueError, match="Cannot compress"): + compress_result(result, 100) From 114f464fbdc4e08a7de1a3e192fdc909a9c53fc3 Mon Sep 17 00:00:00 2001 From: Max Gallant Date: Tue, 24 Feb 2026 14:36:13 -0800 Subject: [PATCH 4/6] Black linting --- pyproject.toml | 3 ++- src/pylattica/core/basic_controller.py | 4 +++- src/pylattica/core/neighborhood_builders.py | 20 ++++++++----------- src/pylattica/core/simulation_result.py | 15 +++++++++++--- src/pylattica/models/game_of_life/__init__.py | 2 +- .../models/game_of_life/life_phase_set.py | 2 +- .../structures/square_grid/grid_setup.py | 4 +++- .../structures/square_grid/neighborhoods.py | 10 +++++----- .../visualization/square_grid_artist_2D.py | 4 +++- 9 files changed, 38 insertions(+), 26 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 14bd30e..7b2405a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,6 +113,7 @@ enabled = true [dependency-groups] dev = [ + "black>=24.8.0", "pytest>=7.1.3", "pytest-cov>=4.0.0", -] \ No newline at end of file +] diff --git a/src/pylattica/core/basic_controller.py b/src/pylattica/core/basic_controller.py index ffdd075..50e0aa2 100644 --- a/src/pylattica/core/basic_controller.py +++ b/src/pylattica/core/basic_controller.py @@ -39,4 +39,6 @@ def get_random_site(self, state: SimulationState): return random.randint(0, state.size - 1) def instantiate_result(self, starting_state: SimulationState): - return SimulationResult(starting_state=starting_state, max_history=self.max_history) + return SimulationResult( + starting_state=starting_state, max_history=self.max_history + ) diff --git a/src/pylattica/core/neighborhood_builders.py b/src/pylattica/core/neighborhood_builders.py index 9b58c6f..0e1ded9 100644 --- a/src/pylattica/core/neighborhood_builders.py +++ b/src/pylattica/core/neighborhood_builders.py @@ -146,9 +146,9 @@ def get(self, struct: PeriodicStructure, site_class: str = None) -> Neighborhood site_ids = np.array([s[SITE_ID] for s in all_sites]) # Convert to fractional coordinates for periodic KD-tree - frac_coords = np.array([ - struct.lattice.get_fractional_coords(loc) for loc in locations - ]) + frac_coords = np.array( + [struct.lattice.get_fractional_coords(loc) for loc in locations] + ) # Compute the maximum fractional radius that could correspond to # the Cartesian cutoff. For non-orthogonal lattices, we need to use @@ -163,9 +163,7 @@ def get(self, struct: PeriodicStructure, site_class: str = None) -> Neighborhood dim = struct.lattice.dim # Build boxsize array: 1.0 for periodic dimensions, large value for non-periodic - boxsize = np.array([ - 1.0 if periodic[i] else 1e10 for i in range(dim) - ]) + boxsize = np.array([1.0 if periodic[i] else 1e10 for i in range(dim)]) # Wrap fractional coordinates to [0, 1) for periodic dimensions frac_coords_wrapped = frac_coords.copy() @@ -296,9 +294,9 @@ def get(self, struct: PeriodicStructure, site_class: str = None) -> Neighborhood site_ids = np.array([s[SITE_ID] for s in all_sites]) # Convert to fractional coordinates for periodic KD-tree - frac_coords = np.array([ - struct.lattice.get_fractional_coords(loc) for loc in locations - ]) + frac_coords = np.array( + [struct.lattice.get_fractional_coords(loc) for loc in locations] + ) # Compute the maximum fractional radius for the outer cutoff. # Use the maximum stretch factor of the inverse matrix for non-orthogonal lattices. @@ -311,9 +309,7 @@ def get(self, struct: PeriodicStructure, site_class: str = None) -> Neighborhood dim = struct.lattice.dim # Build boxsize array - boxsize = np.array([ - 1.0 if periodic[i] else 1e10 for i in range(dim) - ]) + boxsize = np.array([1.0 if periodic[i] else 1e10 for i in range(dim)]) # Wrap fractional coordinates to [0, 1) for periodic dimensions frac_coords_wrapped = frac_coords.copy() diff --git a/src/pylattica/core/simulation_result.py b/src/pylattica/core/simulation_result.py index 20e3b35..69b3b37 100644 --- a/src/pylattica/core/simulation_result.py +++ b/src/pylattica/core/simulation_result.py @@ -39,7 +39,9 @@ def from_dict(cls, res_dict): ) # Restore checkpoint if present if "checkpoint_state" in res_dict and res_dict["checkpoint_state"] is not None: - res._checkpoint_state = SimulationState.from_dict(res_dict["checkpoint_state"]) + res._checkpoint_state = SimulationState.from_dict( + res_dict["checkpoint_state"] + ) res._checkpoint_step = res_dict.get("checkpoint_step", 0) for diff in diffs: @@ -50,11 +52,18 @@ def from_dict(cls, res_dict): res._diffs.append(diff) # Bypass add_step to avoid re-checkpointing # Restore total_steps from serialized data, or compute from diffs + checkpoint - res._total_steps = res_dict.get("total_steps", res._checkpoint_step + len(diffs)) + res._total_steps = res_dict.get( + "total_steps", res._checkpoint_step + len(diffs) + ) return res - def __init__(self, starting_state: SimulationState, compress_freq: int = 1, max_history: int = None): + def __init__( + self, + starting_state: SimulationState, + compress_freq: int = 1, + max_history: int = None, + ): """Initializes a SimulationResult with the specified starting_state. Parameters diff --git a/src/pylattica/models/game_of_life/__init__.py b/src/pylattica/models/game_of_life/__init__.py index a21bc66..82cbbdc 100644 --- a/src/pylattica/models/game_of_life/__init__.py +++ b/src/pylattica/models/game_of_life/__init__.py @@ -1,2 +1,2 @@ from .controller import GameOfLifeController, Life, Seeds, Anneal, Diamoeba, Maze -from .life_phase_set import LIFE_PHASE_SET \ No newline at end of file +from .life_phase_set import LIFE_PHASE_SET diff --git a/src/pylattica/models/game_of_life/life_phase_set.py b/src/pylattica/models/game_of_life/life_phase_set.py index 0c9e9ee..c807409 100644 --- a/src/pylattica/models/game_of_life/life_phase_set.py +++ b/src/pylattica/models/game_of_life/life_phase_set.py @@ -1,3 +1,3 @@ from ...discrete.phase_set import PhaseSet -LIFE_PHASE_SET = PhaseSet(["alive", "dead"]) \ No newline at end of file +LIFE_PHASE_SET = PhaseSet(["alive", "dead"]) diff --git a/src/pylattica/structures/square_grid/grid_setup.py b/src/pylattica/structures/square_grid/grid_setup.py index 9404306..0a86247 100644 --- a/src/pylattica/structures/square_grid/grid_setup.py +++ b/src/pylattica/structures/square_grid/grid_setup.py @@ -348,7 +348,9 @@ def setup_random_sites( while num_sites_planted < num_sites_desired: if total_attempts > 1000 * num_sites_desired: - print(f"Only able to place {num_sites_planted} in {total_attempts} attempts") + print( + f"Only able to place {num_sites_planted} in {total_attempts} attempts" + ) break rand_site = random.choice(all_sites) diff --git a/src/pylattica/structures/square_grid/neighborhoods.py b/src/pylattica/structures/square_grid/neighborhoods.py index 0911131..a3fefe8 100644 --- a/src/pylattica/structures/square_grid/neighborhoods.py +++ b/src/pylattica/structures/square_grid/neighborhoods.py @@ -54,13 +54,13 @@ def __init__(self, size: int): # Generate Von Neumann neighborhood offsets (excluding origin) points = get_points_in_cube(-size, size + 1, 3) self._offsets = [ - tuple(point) for point in points + tuple(point) + for point in points if sum(np.abs(p) for p in point) <= size and any(p != 0 for p in point) ] # Precompute distances for edge weights self._distances = { - offset: np.sqrt(sum(p**2 for p in offset)) - for offset in self._offsets + offset: np.sqrt(sum(p**2 for p in offset)) for offset in self._offsets } # Cache for grid size (computed once per structure) self._cached_n = None @@ -70,8 +70,8 @@ def _get_grid_size(self, struct: PeriodicStructure) -> int: """Infer grid size n from structure (cached).""" n_sites = len(struct.site_ids) if n_sites != self._cached_n_sites: - n = int(round(n_sites ** (1/3))) - if n ** 3 != n_sites: + n = int(round(n_sites ** (1 / 3))) + if n**3 != n_sites: raise ValueError(f"Structure has {n_sites} sites, not a perfect cube.") self._cached_n = n self._cached_n_sites = n_sites diff --git a/src/pylattica/visualization/square_grid_artist_2D.py b/src/pylattica/visualization/square_grid_artist_2D.py index e47843c..66862f9 100644 --- a/src/pylattica/visualization/square_grid_artist_2D.py +++ b/src/pylattica/visualization/square_grid_artist_2D.py @@ -62,7 +62,9 @@ def _draw_image(self, state: SimulationState, **kwargs): for phase in legend_order: color = legend.get(phase) - p_col_start = state_size * cell_size + legend_border_width + legend_hoffset + p_col_start = ( + state_size * cell_size + legend_border_width + legend_hoffset + ) p_row_start = count * cell_size + legend_voffset for p_x in range(p_col_start, p_col_start + cell_size): for p_y in range(p_row_start, p_row_start + cell_size): From 4034a3ddcbd50e4dffe296a7d116eb482fcc1147 Mon Sep 17 00:00:00 2001 From: Max Gallant Date: Tue, 24 Feb 2026 14:45:35 -0800 Subject: [PATCH 5/6] lint --- .prospector.yaml | 3 +++ pyproject.toml | 1 + src/pylattica/core/neighborhood_builders.py | 5 ----- src/pylattica/core/runner/asynchronous_runner.py | 2 +- src/pylattica/core/runner/synchronous_runner.py | 2 +- src/pylattica/core/simulation_result.py | 14 ++++++++++++-- src/pylattica/core/simulation_state.py | 10 ++++++++++ src/pylattica/structures/square_grid/grid_setup.py | 6 +++--- .../structures/square_grid/growth_setup.py | 2 +- src/pylattica/visualization/result_artist.py | 4 ++-- .../visualization/square_grid_artist_3D.py | 2 +- src/pylattica/visualization/structure_artist.py | 2 +- 12 files changed, 36 insertions(+), 17 deletions(-) diff --git a/.prospector.yaml b/.prospector.yaml index 4876c10..21da665 100644 --- a/.prospector.yaml +++ b/.prospector.yaml @@ -2,6 +2,8 @@ max-line-length: 120 test-warnings: false doc-warnings: false strictness: medium +with: [] +uses: [] ignore-paths: - docs - tests @@ -21,6 +23,7 @@ pycodestyle: pylint: disable: + - django-not-available - unsubscriptable-object - invalid-name - arguments-differ # to account for jobflow diff --git a/pyproject.toml b/pyproject.toml index 7b2405a..393505c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,6 +114,7 @@ enabled = true [dependency-groups] dev = [ "black>=24.8.0", + "prospector>=1.10.3", "pytest>=7.1.3", "pytest-cov>=4.0.0", ] diff --git a/src/pylattica/core/neighborhood_builders.py b/src/pylattica/core/neighborhood_builders.py index 0e1ded9..9e57bed 100644 --- a/src/pylattica/core/neighborhood_builders.py +++ b/src/pylattica/core/neighborhood_builders.py @@ -139,8 +139,6 @@ def get(self, struct: PeriodicStructure, site_class: str = None) -> Neighborhood else: sites_to_process = struct.sites(site_class=site_class) - n_sites = len(all_sites) - # Extract locations and IDs as arrays for vectorized operations locations = np.array([s[LOCATION] for s in all_sites]) site_ids = np.array([s[SITE_ID] for s in all_sites]) @@ -174,9 +172,6 @@ def get(self, struct: PeriodicStructure, site_class: str = None) -> Neighborhood # Build KD-tree with periodic boundary conditions tree = cKDTree(frac_coords_wrapped, boxsize=boxsize) - # Create index mapping from site_id to array index - id_to_idx = {sid: idx for idx, sid in enumerate(site_ids)} - # Process each site sites_to_process_ids = set(s[SITE_ID] for s in sites_to_process) diff --git a/src/pylattica/core/runner/asynchronous_runner.py b/src/pylattica/core/runner/asynchronous_runner.py index e7f985f..1a07207 100644 --- a/src/pylattica/core/runner/asynchronous_runner.py +++ b/src/pylattica/core/runner/asynchronous_runner.py @@ -27,7 +27,7 @@ class AsynchronousRunner(Runner): that this mode should be used with the is_async initialization parameter. """ - def _run( + def _run( # pylint: disable=too-many-positional-arguments self, _: SimulationState, result: SimulationResult, diff --git a/src/pylattica/core/runner/synchronous_runner.py b/src/pylattica/core/runner/synchronous_runner.py index ab49ce7..8f28056 100644 --- a/src/pylattica/core/runner/synchronous_runner.py +++ b/src/pylattica/core/runner/synchronous_runner.py @@ -38,7 +38,7 @@ def __init__(self, parallel: bool = False, workers: int = None) -> None: self.parallel = parallel self.workers = workers - def _run( + def _run( # pylint: disable=too-many-positional-arguments self, initial_state: SimulationState, result: SimulationResult, diff --git a/src/pylattica/core/simulation_result.py b/src/pylattica/core/simulation_result.py index 69b3b37..42c6359 100644 --- a/src/pylattica/core/simulation_result.py +++ b/src/pylattica/core/simulation_result.py @@ -89,6 +89,16 @@ def __init__( self._checkpoint_step: int = 0 self._total_steps: int = 0 + def get_diffs(self) -> list[dict]: + """Returns the list of diffs. + + Returns + ------- + list[dict] + The list of state diffs. + """ + return self._diffs + def add_step(self, updates: Dict[int, Dict]) -> None: """Takes a set of updates as a dictionary mapping site IDs to the new values for various state parameters. For instance, if at the @@ -335,9 +345,9 @@ def compress_result(result: SimulationResult, num_steps: int): total_compress_freq = exact_sample_freq * result.compress_freq compressed_result = SimulationResult(i_state, compress_freq=total_compress_freq) - live_state = SimulationState(copy.deepcopy(i_state._state)) + live_state = SimulationState(copy.deepcopy(i_state.get_state())) next_sample_step = exact_sample_freq - for i, diff in enumerate(result._diffs): + for i, diff in enumerate(result.get_diffs()): curr_step = i + 1 live_state.batch_update(diff) if curr_step > next_sample_step: diff --git a/src/pylattica/core/simulation_state.py b/src/pylattica/core/simulation_state.py index 945a35e..fa36693 100644 --- a/src/pylattica/core/simulation_state.py +++ b/src/pylattica/core/simulation_state.py @@ -183,5 +183,15 @@ def copy(self) -> SimulationState: def as_state_update(self) -> Dict: return copy.deepcopy(self._state) + def get_state(self) -> Dict: + """Returns the internal state dictionary. + + Returns + ------- + Dict + The internal state dictionary. + """ + return self._state + def __eq__(self, other: SimulationState) -> bool: return self._state == other._state diff --git a/src/pylattica/structures/square_grid/grid_setup.py b/src/pylattica/structures/square_grid/grid_setup.py index 0a86247..3c44be8 100644 --- a/src/pylattica/structures/square_grid/grid_setup.py +++ b/src/pylattica/structures/square_grid/grid_setup.py @@ -145,7 +145,7 @@ def setup_particle( ) return Simulation(state, structure) - def setup_random_particles( + def setup_random_particles( # pylint: disable=too-many-positional-arguments self, size: int, radius: int, @@ -188,7 +188,7 @@ def setup_random_particles( return Simulation(state, structure) - def add_particle_to_state( + def add_particle_to_state( # pylint: disable=too-many-positional-arguments self, structure: PeriodicStructure, state: SimulationState, @@ -276,7 +276,7 @@ def setup_noise(self, size: int, phases: typing.List[str]) -> Simulation: ) return Simulation(state, structure) - def setup_random_sites( + def setup_random_sites( # pylint: disable=too-many-positional-arguments self, size: int, num_sites_desired: int, diff --git a/src/pylattica/structures/square_grid/growth_setup.py b/src/pylattica/structures/square_grid/growth_setup.py index 12ff005..9990f3a 100644 --- a/src/pylattica/structures/square_grid/growth_setup.py +++ b/src/pylattica/structures/square_grid/growth_setup.py @@ -31,7 +31,7 @@ def __init__(self, phase_set: PhaseSet, dim=2): self._phases = phase_set self.dim = dim - def grow( + def grow( # pylint: disable=too-many-positional-arguments self, size: int, num_sites_desired: int, diff --git a/src/pylattica/visualization/result_artist.py b/src/pylattica/visualization/result_artist.py index 0ed2451..c42e218 100644 --- a/src/pylattica/visualization/result_artist.py +++ b/src/pylattica/visualization/result_artist.py @@ -13,7 +13,7 @@ _dsr_globals = {} -def default_annotation_builder(step, step_no): +def default_annotation_builder(_step, step_no): return f"Step {step_no}" @@ -97,7 +97,7 @@ def jupyter_play(self, cell_size: int = 20, wait: int = 1, **kwargs): wait : int, optional The time duration between frames in the animation. Defaults to 1., by default 1 """ - from IPython.display import clear_output, display # pragma: no cover + from IPython.display import clear_output, display # pylint: disable=import-error # pragma: no cover imgs = self._get_images(cell_size=cell_size, **kwargs) # pragma: no cover for img in imgs: # pragma: no cover diff --git a/src/pylattica/visualization/square_grid_artist_3D.py b/src/pylattica/visualization/square_grid_artist_3D.py index fcd8981..6eac96c 100644 --- a/src/pylattica/visualization/square_grid_artist_3D.py +++ b/src/pylattica/visualization/square_grid_artist_3D.py @@ -50,7 +50,7 @@ def _draw_image(self, state: SimulationState, **kwargs): colors = list(np.array(color_cache[color]) / 255) ax.voxels(data, facecolors=colors, edgecolor="k", linewidth=0.25) - if kwargs.get("show_legend") == True: + if kwargs.get("show_legend"): legend = self.cell_artist.get_legend(state) legend_handles = [] for phase, color in legend.items(): diff --git a/src/pylattica/visualization/structure_artist.py b/src/pylattica/visualization/structure_artist.py index 6f31a62..2346c0d 100644 --- a/src/pylattica/visualization/structure_artist.py +++ b/src/pylattica/visualization/structure_artist.py @@ -29,7 +29,7 @@ def jupyter_show(self, state: SimulationState, **kwargs): state : SimulationState The simulation state to display. """ - from IPython.display import display # pragma: no cover + from IPython.display import display # pylint: disable=import-error # pragma: no cover img = self.get_img(state, **kwargs) # pragma: no cover display(img) # pragma: no cover From e5b01fc45b7946f8cdaba7ac3c23591c70677faa Mon Sep 17 00:00:00 2001 From: Max Gallant Date: Tue, 24 Feb 2026 14:48:09 -0800 Subject: [PATCH 6/6] more linting --- src/pylattica/visualization/result_artist.py | 5 ++++- src/pylattica/visualization/structure_artist.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/pylattica/visualization/result_artist.py b/src/pylattica/visualization/result_artist.py index c42e218..f8c5582 100644 --- a/src/pylattica/visualization/result_artist.py +++ b/src/pylattica/visualization/result_artist.py @@ -97,7 +97,10 @@ def jupyter_play(self, cell_size: int = 20, wait: int = 1, **kwargs): wait : int, optional The time duration between frames in the animation. Defaults to 1., by default 1 """ - from IPython.display import clear_output, display # pylint: disable=import-error # pragma: no cover + from IPython.display import ( + clear_output, + display, + ) # pylint: disable=import-error # pragma: no cover imgs = self._get_images(cell_size=cell_size, **kwargs) # pragma: no cover for img in imgs: # pragma: no cover diff --git a/src/pylattica/visualization/structure_artist.py b/src/pylattica/visualization/structure_artist.py index 2346c0d..216527e 100644 --- a/src/pylattica/visualization/structure_artist.py +++ b/src/pylattica/visualization/structure_artist.py @@ -29,7 +29,9 @@ def jupyter_show(self, state: SimulationState, **kwargs): state : SimulationState The simulation state to display. """ - from IPython.display import display # pylint: disable=import-error # pragma: no cover + from IPython.display import ( + display, + ) # pylint: disable=import-error # pragma: no cover img = self.get_img(state, **kwargs) # pragma: no cover display(img) # pragma: no cover