Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .prospector.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ max-line-length: 120
test-warnings: false
doc-warnings: false
strictness: medium
with: []
uses: []
ignore-paths:
- docs
- tests
Expand All @@ -21,6 +23,7 @@ pycodestyle:

pylint:
disable:
- django-not-available
- unsubscriptable-object
- invalid-name
- arguments-differ # to account for jobflow
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ enabled = true

[dependency-groups]
dev = [
"black>=24.8.0",
"prospector>=1.10.3",
"pytest>=7.1.3",
"pytest-cov>=4.0.0",
]
]
17 changes: 15 additions & 2 deletions src/pylattica/core/basic_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,18 @@ class has a single responsibility, which is to implement the update
SimulationState will be passed to this method, along with the ID of
the site at which the update rule should be applied. It is up to the
user to decide what updates should be produced using this information.

Attributes
----------
max_history : int, optional
Maximum number of step diffs to keep in memory during simulation.
Set this to limit memory usage for long simulations. When exceeded,
older steps are checkpointed and dropped. Default is None (unlimited).
"""

# Override this in subclasses to limit memory usage
max_history: int = None

@abstractmethod
def get_state_update(self, site_id: int, prev_state: SimulationState):
pass # pragma: no cover
Expand All @@ -25,7 +35,10 @@ def pre_run(self, initial_state: SimulationState) -> None:
pass

def get_random_site(self, state: SimulationState):
return random.randint(0, len(state.site_ids()) - 1)
# Use state.size (O(1)) instead of len(state.site_ids()) which is O(n)
return random.randint(0, state.size - 1)

def instantiate_result(self, starting_state: SimulationState):
return SimulationResult(starting_state=starting_state)
return SimulationResult(
starting_state=starting_state, max_history=self.max_history
)
25 changes: 8 additions & 17 deletions src/pylattica/core/neighborhood_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,14 @@ def get(self, struct: PeriodicStructure, site_class: str = None) -> Neighborhood
else:
sites_to_process = struct.sites(site_class=site_class)

n_sites = len(all_sites)

# Extract locations and IDs as arrays for vectorized operations
locations = np.array([s[LOCATION] for s in all_sites])
site_ids = np.array([s[SITE_ID] for s in all_sites])

# Convert to fractional coordinates for periodic KD-tree
frac_coords = np.array([
struct.lattice.get_fractional_coords(loc) for loc in locations
])
frac_coords = np.array(
[struct.lattice.get_fractional_coords(loc) for loc in locations]
)

# Compute the maximum fractional radius that could correspond to
# the Cartesian cutoff. For non-orthogonal lattices, we need to use
Expand All @@ -163,9 +161,7 @@ def get(self, struct: PeriodicStructure, site_class: str = None) -> Neighborhood
dim = struct.lattice.dim

# Build boxsize array: 1.0 for periodic dimensions, large value for non-periodic
boxsize = np.array([
1.0 if periodic[i] else 1e10 for i in range(dim)
])
boxsize = np.array([1.0 if periodic[i] else 1e10 for i in range(dim)])

# Wrap fractional coordinates to [0, 1) for periodic dimensions
frac_coords_wrapped = frac_coords.copy()
Expand All @@ -176,9 +172,6 @@ def get(self, struct: PeriodicStructure, site_class: str = None) -> Neighborhood
# Build KD-tree with periodic boundary conditions
tree = cKDTree(frac_coords_wrapped, boxsize=boxsize)

# Create index mapping from site_id to array index
id_to_idx = {sid: idx for idx, sid in enumerate(site_ids)}

# Process each site
sites_to_process_ids = set(s[SITE_ID] for s in sites_to_process)

Expand Down Expand Up @@ -296,9 +289,9 @@ def get(self, struct: PeriodicStructure, site_class: str = None) -> Neighborhood
site_ids = np.array([s[SITE_ID] for s in all_sites])

# Convert to fractional coordinates for periodic KD-tree
frac_coords = np.array([
struct.lattice.get_fractional_coords(loc) for loc in locations
])
frac_coords = np.array(
[struct.lattice.get_fractional_coords(loc) for loc in locations]
)

# Compute the maximum fractional radius for the outer cutoff.
# Use the maximum stretch factor of the inverse matrix for non-orthogonal lattices.
Expand All @@ -311,9 +304,7 @@ def get(self, struct: PeriodicStructure, site_class: str = None) -> Neighborhood
dim = struct.lattice.dim

# Build boxsize array
boxsize = np.array([
1.0 if periodic[i] else 1e10 for i in range(dim)
])
boxsize = np.array([1.0 if periodic[i] else 1e10 for i in range(dim)])

# Wrap fractional coordinates to [0, 1) for periodic dimensions
frac_coords_wrapped = frac_coords.copy()
Expand Down
2 changes: 1 addition & 1 deletion src/pylattica/core/runner/asynchronous_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class AsynchronousRunner(Runner):
that this mode should be used with the is_async initialization parameter.
"""

def _run(
def _run( # pylint: disable=too-many-positional-arguments
self,
_: SimulationState,
result: SimulationResult,
Expand Down
2 changes: 1 addition & 1 deletion src/pylattica/core/runner/synchronous_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, parallel: bool = False, workers: int = None) -> None:
self.parallel = parallel
self.workers = workers

def _run(
def _run( # pylint: disable=too-many-positional-arguments
self,
initial_state: SimulationState,
result: SimulationResult,
Expand Down
Loading
Loading