diff --git a/doc/api.rst b/doc/api.rst index a7612927..6fb3434f 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -136,9 +136,14 @@ Attributes Modification ------------ +``Variable.update`` is the canonical mutation API. The legacy ``lower`` / +``upper`` setters still forward to ``update`` but emit a +``DeprecationWarning`` and will be removed in a future release. + .. autosummary:: :toctree: generated/ + variables.Variable.update variables.Variable.fix variables.Variable.unfix variables.Variable.relax @@ -332,6 +337,19 @@ Structure constraints.Constraint.coeffs constraints.Constraint.vars +Modification +------------ + +``Constraint.update`` is the canonical mutation API. The legacy ``lhs`` / +``sign`` / ``rhs`` / ``coeffs`` / ``vars`` setters still forward to +``update`` but emit a ``DeprecationWarning`` and will be removed in a +future release. + +.. autosummary:: + :toctree: generated/ + + constraints.Constraint.update + Post-solve access ----------------- diff --git a/doc/release_notes.rst b/doc/release_notes.rst index 9b1ecbd8..e581a64e 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -4,6 +4,19 @@ Release Notes Upcoming Version ---------------- +**Features** + +*In-place solver updates (persistent re-solve)* + +* A built solver can now be re-solved against a mutated ``Model`` without a full rebuild. Construct with ``Solver.from_name(..., track_updates=True)`` and re-call ``solver.solve(model)`` after edits — the diff against the previous build is applied in place when the backend supports it, falling back to a rebuild otherwise. Supported on HiGHS, Gurobi, Xpress, and Mosek (``io_api="direct"``). +* Pass ``disallow_rebuild=True`` to ``solve(model, ...)`` to guarantee an in-place update or raise ``RebuildRequiredError``. Inspect ``solver._last_rebuild_reason`` (a ``RebuildReason``, or ``None`` after an in-place update) to understand why a rebuild was triggered. +* New ``linopy.persistent`` module exposes ``ModelSnapshot``, ``ModelDiff``, and ``RebuildReason`` for users who want to introspect or build the diff themselves. ``ModelDiff.from_snapshot`` / ``from_models`` return the ``RebuildReason`` directly when the change cannot be applied in place. + +**Deprecations** + +* Mutation via assignment to ``Variable.lower`` / ``Variable.upper`` / ``Constraint.coeffs`` / ``Constraint.vars`` / ``Constraint.lhs`` / ``Constraint.sign`` / ``Constraint.rhs`` is deprecated and emits a ``DeprecationWarning``. Use ``Variable.update(...)`` / ``Constraint.update(...)`` instead — the canonical mutation API with one validation path and one place that flips the persistent-solver dirty flag. Read access to these properties is unchanged. The setters will be removed in a future release. +* Passing a raw ``DataArray`` of integer labels to ``Constraint.vars = ...`` setter is deprecated and emits a ``FutureWarning``. Pass a ``Variable`` to ``Constraint.update()`` instead — it is the supported input. The ``DataArray`` path will be removed in a future release. + Version 0.8.0 ------------- diff --git a/examples/creating-constraints.ipynb b/examples/creating-constraints.ipynb index 1b792b14..d504deb3 100644 --- a/examples/creating-constraints.ipynb +++ b/examples/creating-constraints.ipynb @@ -348,7 +348,7 @@ "\n", "`CSRConstraint` deliberately exposes a narrower API than the xarray-backed `Constraint`:\n", "\n", - "- **No in-place mutation.** Setters such as `con.coeffs = ...`, `con.vars = ...`, `con.sign = ...`, `con.rhs = ...`, and `con.lhs = ...` are only available on `Constraint`.\n", + "- **No in-place mutation.** `Constraint.update(...)` is only available on `Constraint`. (The legacy setters — `con.coeffs = ...`, `con.vars = ...`, `con.sign = ...`, `con.rhs = ...`, `con.lhs = ...` — still forward to `update` on `Constraint` but emit a `DeprecationWarning` and will be removed in a future release.)\n", "- **No label-based indexing.** `con.loc[...]` is only available on `Constraint`.\n", "- **Accessing `.coeffs` / `.vars` triggers reconstruction.** On a `CSRConstraint` these properties rebuild the full xarray `Dataset` on demand and emit a `PerformanceWarning`. For solver-oriented workflows prefer `con.to_matrix()` or work with the CSR data directly.\n", "\n", @@ -356,8 +356,8 @@ "\n", "```python\n", "con = m.constraints[\"my_constraint\"].mutable()\n", - "con.loc[{\"time\": 0}] # label-based indexing now available\n", - "con.rhs = 5 # mutation now available\n", + "con.loc[{\"time\": 0}] # label-based indexing now available\n", + "con.update(rhs=5) # mutation now available\n", "```" ] }, diff --git a/examples/manipulating-models.ipynb b/examples/manipulating-models.ipynb index 5762eda0..bd86399e 100644 --- a/examples/manipulating-models.ipynb +++ b/examples/manipulating-models.ipynb @@ -74,7 +74,7 @@ "metadata": {}, "outputs": [], "source": [ - "x.lower = 1" + "x.update(lower=1)" ] }, { @@ -83,7 +83,10 @@ "metadata": {}, "source": [ ".. note::\n", - " The same could have been achieved by calling `m.variables.x.lower = 1`\n", + " Assignment via the ``x.lower = 1`` setter still works but is\n", + " deprecated and will be removed in a future release. Use\n", + " ``Variable.update`` instead — it is the canonical mutation API\n", + " with a single validation path.\n", "\n", "Let's solve it again!" ] @@ -127,7 +130,7 @@ "metadata": {}, "outputs": [], "source": [ - "x.lower = xr.DataArray(range(10, 0, -1), coords=(time,))" + "x.update(lower=xr.DataArray(range(10, 0, -1), coords=(time,)))" ] }, { @@ -157,9 +160,12 @@ "source": [ "## Varying Constraints\n", "\n", - "A similar functionality is implemented for constraints. Here we can modify the left-hand-side, the sign and the right-hand-side.\n", + "A similar functionality is implemented for constraints. We use\n", + "``Constraint.update`` to change the left-hand-side, the sign,\n", + "and the right-hand-side.\n", "\n", - "Assume we want to relax the right-hand-side of the first constraint `con1` to `8 * factor`. This would translate to:" + "Assume we want to relax the right-hand-side of the first constraint\n", + "``con1`` to ``8 * factor``. This translates to:" ] }, { @@ -169,7 +175,7 @@ "metadata": {}, "outputs": [], "source": [ - "con1.rhs = 8 * factor" + "con1.update(rhs=8 * factor)" ] }, { @@ -178,7 +184,10 @@ "metadata": {}, "source": [ ".. note::\n", - " The same could have been achieved by calling `m.constraints.con1.rhs = 8 * factor`\n", + " Assignment via the ``con1.rhs = 8 * factor`` setter still works\n", + " but is deprecated and will be removed in a future release. Use\n", + " ``Constraint.update`` instead — it is the canonical mutation API\n", + " with a single validation path.\n", "\n", "Let's solve it again!" ] @@ -212,7 +221,7 @@ "metadata": {}, "outputs": [], "source": [ - "con1.lhs = 3 * x + 8 * y" + "con1.update(lhs=3 * x + 8 * y)" ] }, { @@ -221,9 +230,15 @@ "metadata": {}, "source": [ "**Note:**\n", - "The same could have been achieved by calling \n", - "```python \n", - "m.constraints['con1'].lhs = 3 * x + 8 * y\n", + "Assignment via the ``con1.lhs = 3 * x + 8 * y`` setter still works\n", + "but is deprecated and will be removed in a future release. Use\n", + "``Constraint.update`` instead — it is the canonical mutation API\n", + "with a single validation path.\n", + "\n", + "``Constraint.update`` also accepts a full constraint expression in one call:\n", + "\n", + "```python\n", + "con1.update(3 * x + 8 * y <= 8 * factor) # replaces lhs / sign / rhs at once\n", "```" ] }, diff --git a/linopy/constraints.py b/linopy/constraints.py index 96e2a843..ed104d06 100644 --- a/linopy/constraints.py +++ b/linopy/constraints.py @@ -55,7 +55,6 @@ maybe_group_terms_polars, maybe_replace_signs, replace_by_map, - require_constant, save_join, to_dataframe, to_polars, @@ -72,6 +71,7 @@ ) from linopy.types import ( ConstantLike, + ConstraintLike, CoordsLike, ExpressionLike, SignLike, @@ -956,9 +956,19 @@ def active_labels(self) -> np.ndarray: return self._con_labels def sanitize_zeros(self) -> CSRConstraint: - """Remove terms with zero or near-zero coefficients (mutates in-place).""" - self._csr.data[np.abs(self._csr.data) <= 1e-10] = 0 - self._csr.eliminate_zeros() + """ + Remove terms with zero or near-zero coefficients. + + Copy-on-write: rebinds ``_csr`` instead of mutating its arrays, so + external holders of the previous arrays (e.g. a ModelSnapshot + sharing them) keep a valid baseline. + """ + zeros = np.abs(self._csr.data) <= 1e-10 + if zeros.any(): + csr = self._csr.copy() + csr.data[zeros] = 0 + csr.eliminate_zeros() + self._csr = csr return self def sanitize_missings(self) -> CSRConstraint: @@ -1128,7 +1138,7 @@ class Constraint(ConstraintBase): Supports setters, xarray operations via conwrap, and from_rule construction. """ - __slots__ = ("_data", "_model", "_assigned") + __slots__ = ("_data", "_model", "_assigned", "_coef_dirty") def __init__( self, @@ -1157,6 +1167,7 @@ def __init__( self._assigned = "labels" in data self._data = data self._model = model + self._coef_dirty = False @property def data(self) -> Dataset: @@ -1204,8 +1215,14 @@ def coeffs(self) -> DataArray: @coeffs.setter def coeffs(self, value: ConstantLike) -> None: - value = DataArray(value).broadcast_like(self.vars, exclude=[self.term_dim]) - self._data = assign_multiindex_safe(self.data, coeffs=value) + """Syntactic sugar for :meth:`Constraint.update`. Do not add logic here; mutate via ``update`` so the contract stays single-sourced.""" + warn( + "Constraint.coeffs setter is deprecated and will be removed in a " + "future release; use Constraint.update(coeffs=...) instead.", + DeprecationWarning, + stacklevel=2, + ) + self.update(coeffs=value) @property def vars(self) -> DataArray: @@ -1213,34 +1230,44 @@ def vars(self) -> DataArray: @vars.setter def vars(self, value: variables.Variable | DataArray) -> None: - if isinstance(value, variables.Variable): - value = value.labels - if not isinstance(value, DataArray): - raise TypeError("Expected value to be of type DataArray or Variable") - value = value.broadcast_like(self.coeffs, exclude=[self.term_dim]) - self._data = assign_multiindex_safe(self.data, vars=value) + """Syntactic sugar for :meth:`Constraint.update`. Do not add logic here; mutate via ``update`` so the contract stays single-sourced.""" + warn( + "Constraint.vars setter is deprecated and will be removed in a " + "future release; use Constraint.update(variables=...) instead.", + DeprecationWarning, + stacklevel=2, + ) + self.update(variables=value) @property def sign(self) -> DataArray: return self.data.sign @sign.setter - @require_constant def sign(self, value: SignLike) -> None: - value = maybe_replace_signs(DataArray(value)).broadcast_like(self.sign) - self._data = assign_multiindex_safe(self.data, sign=value) + """Syntactic sugar for :meth:`Constraint.update`. Do not add logic here; mutate via ``update`` so the contract stays single-sourced.""" + warn( + "Constraint.sign setter is deprecated and will be removed in a " + "future release; use Constraint.update(sign=...) instead.", + DeprecationWarning, + stacklevel=2, + ) + self.update(sign=value) @property def rhs(self) -> DataArray: return self.data.rhs @rhs.setter - def rhs(self, value: ExpressionLike) -> None: - value = expressions.as_expression( - value, self.model, coords=self.coords, dims=self.coord_dims + def rhs(self, value: ExpressionLike | VariableLike | ConstantLike) -> None: + """Syntactic sugar for :meth:`Constraint.update`. Do not add logic here; mutate via ``update`` so the contract stays single-sourced.""" + warn( + "Constraint.rhs setter is deprecated and will be removed in a " + "future release; use Constraint.update(rhs=...) instead.", + DeprecationWarning, + stacklevel=2, ) - self.lhs = self.lhs - value.reset_const() - self._data = assign_multiindex_safe(self.data, rhs=value.const) + self.update(rhs=value) @property def is_indicator(self) -> bool: @@ -1263,12 +1290,213 @@ def lhs(self) -> expressions.LinearExpression: @lhs.setter def lhs(self, value: ExpressionLike | VariableLike | ConstantLike) -> None: - value = expressions.as_expression( - value, self.model, coords=self.coords, dims=self.coord_dims + """Syntactic sugar for :meth:`Constraint.update`. Do not add logic here; mutate via ``update`` so the contract stays single-sourced.""" + warn( + "Constraint.lhs setter is deprecated and will be removed in a " + "future release; use Constraint.update(lhs=...) instead.", + DeprecationWarning, + stacklevel=2, ) + self.update(lhs=value) + + def _assign_lhs( + self, expr: expressions.LinearExpression, rhs: DataArray | None = None + ) -> None: + """ + Internal: replace coeffs/vars from ``expr``, adjusting rhs for + the expression's constant part. Sets ``_coef_dirty``. + """ + base_rhs = self.rhs if rhs is None else rhs self._data = self.data.drop_vars(["coeffs", "vars"]).assign( - coeffs=value.coeffs, vars=value.vars, rhs=self.rhs - value.const + coeffs=expr.coeffs, + vars=expr.vars, + rhs=base_rhs - expr.const, ) + self._coef_dirty = True + + def _update_data(self, **fields: Any) -> None: + """ + Internal: write ``fields`` into ``self._data`` and update dirty bookkeeping. + + Writes that touch the lhs structure (``coeffs``, ``vars``) flip + ``_coef_dirty``. Other fields (``rhs``, ``sign``, …) leave it alone. + """ + self._data = assign_multiindex_safe(self.data, **fields) + if "coeffs" in fields or "vars" in fields: + self._coef_dirty = True + + def update( + self, + constraint: ConstraintLike | None = None, + *, + lhs: ExpressionLike | VariableLike | ConstantLike | None = None, + rhs: ExpressionLike | VariableLike | ConstantLike | None = None, + sign: SignLike | None = None, + coeffs: ConstantLike | None = None, + variables: variables.Variable | DataArray | None = None, + ) -> Constraint: + """ + Update the constraint in place. + + The only mutation API; setters forward here. Two call shapes: + + * ``c.update(x + 5 <= 3)`` — pass a complete constraint + expression (mirroring ``add_constraints``). Replaces lhs, + sign, and rhs at once. + * ``c.update(lhs=, rhs=, sign=, coeffs=, variables=)`` — pass + only what you want to change. + + Use the keyword form for targeted changes — it skips the + unchanged attributes entirely. The positional form always + rewrites lhs / sign / rhs (and flips ``_coef_dirty``), so it + is the wrong shape for hot loops that only touch one part: + + .. code-block:: python + + # Hot loop, rhs is the only thing changing per iteration: + for k in scenarios: + c.update(rhs=rhs_k) # ← targeted, cheap + + # Same loop written positionally rebuilds lhs every + # iteration even though it never changes: + for k in scenarios: + c.update(big_lhs_expr <= rhs_k) # ← avoid + + Parameters + ---------- + constraint : ConstraintLike, optional + A complete constraint expression (e.g. ``x + 5 <= 3``). + Mutually exclusive with the keyword arguments below. + lhs : ExpressionLike / VariableLike / ConstantLike, optional + Replace the LHS expression. Any constant part is moved to + ``rhs`` so ``c.lhs`` stays pure-variable. Cannot be combined + with ``coeffs`` / ``variables``. Sets the internal + ``_coef_dirty`` flag. + rhs : ExpressionLike / VariableLike / ConstantLike, optional + New right-hand side. + + * Constant rhs (scalar, array, DataArray) → assigned directly + to ``c.rhs``; ``c.lhs`` is untouched. + * Variable / Expression rhs → rearranged onto the lhs to + preserve the invariant that ``c.rhs`` is constant-only, + matching ``add_constraints``. **This rewrites ``c.lhs``.** + + Example — the two calls below produce the same final state:: + + # Form A: explicit, only changes rhs + c.update(rhs=5) + + # Form B: rhs carries a variable, so lhs is rewritten too. + # Starting from `2*x <= 3`, this gives `2*x - y <= 5`: + c.update(rhs=y + 5) + + If you want the rewrite to be loud, use the positional form + (``c.update(2*x - y <= 5)``) which makes both sides explicit. + sign : SignLike, optional + New sign. One of ``"<=" / "==" / ">="`` (or their ``< > =`` + aliases). + coeffs : ConstantLike, optional + Replace coefficient values (same sparsity / term structure). + Lower-level than ``lhs=``; sets ``_coef_dirty``. + variables : Variable, optional + Replace variable label array (same sparsity / term + structure). Lower-level than ``lhs=``; sets ``_coef_dirty``. + + A raw ``DataArray`` of integer labels is still accepted + for back-compat but emits a ``FutureWarning`` — pass a + ``Variable`` instead. The DataArray path will be removed + in a future release. + + Returns + ------- + Constraint + ``self`` for chaining. + """ + if constraint is not None: + if any(x is not None for x in (lhs, rhs, sign, coeffs, variables)): + raise TypeError( + "Constraint.update: positional `constraint` argument " + "cannot be combined with keyword arguments." + ) + con: ConstraintBase + if isinstance(constraint, AnonymousScalarConstraint): + con = constraint.to_constraint() + elif isinstance(constraint, ConstraintBase): + con = constraint + else: + raise TypeError( + "Constraint.update: positional argument must be a " + "ConstraintLike (e.g. `x + 5 <= 3`); got " + f"{type(constraint).__name__}." + ) + lhs, sign, rhs = con.lhs, con.sign, con.rhs + + if all(v is None for v in (lhs, rhs, sign, coeffs, variables)): + return self + + if lhs is not None and (coeffs is not None or variables is not None): + raise TypeError( + "Constraint.update: pass either `lhs=` (replace the whole " + "expression) or `coeffs=` / `variables=` (partial array " + "replacement), not both." + ) + + # 1. lhs replacement first so subsequent rhs= rearrangement sees the new lhs. + if lhs is not None: + expr = expressions.as_expression( + lhs, self.model, coords=self.coords, dims=self.coord_dims + ) + if isinstance(expr, expressions.QuadraticExpression): + raise TypeError( + "Constraint.update: lhs must be linear; got a quadratic expression." + ) + self._assign_lhs(expr) + + # 2. rhs (rearranges non-constant part onto lhs). + if rhs is not None: + expr = expressions.as_expression( + rhs, self.model, coords=self.coords, dims=self.coord_dims + ) + residual = expr.reset_const() + if residual.nterm != 0: + self._assign_lhs(self.lhs - residual, rhs=expr.const) + else: + self._update_data(rhs=expr.const) + + # 3. coeffs / variables partial updates (only valid without lhs=). + if coeffs is not None: + new_coeffs = DataArray(coeffs).broadcast_like( + self.vars, exclude=[self.term_dim] + ) + self._update_data(coeffs=new_coeffs) + if variables is not None: + from linopy.variables import Variable as _Variable + + if isinstance(variables, _Variable): + v = variables.labels + elif isinstance(variables, DataArray): + warnings.warn( + "Passing a DataArray to Constraint.update(variables=...) " + "is deprecated and will be removed in a future release; " + "pass a Variable instead.", + FutureWarning, + stacklevel=2, + ) + v = variables + else: + raise TypeError( + "Constraint.update(variables=...) expects a Variable; " + f"got {type(variables).__name__}." + ) + new_vars = v.broadcast_like(self.coeffs, exclude=[self.term_dim]) + self._update_data(vars=new_vars) + + # 4. sign last so it composes cleanly with the rest. + if sign is not None: + new_sign = maybe_replace_signs(DataArray(sign)).broadcast_like(self.sign) + self._update_data(sign=new_sign) + + return self @property @has_optimized_model @@ -1372,8 +1600,10 @@ def to_matrix_with_rhs( def sanitize_zeros(self) -> Constraint: """Remove terms with zero or near-zero coefficients.""" not_zero = abs(self.coeffs) > 1e-10 - self.vars = self.vars.where(not_zero, -1) - self.coeffs = self.coeffs.where(not_zero) + self._update_data( + vars=self.vars.where(not_zero, -1), + coeffs=self.coeffs.where(not_zero), + ) return self def sanitize_missings(self) -> Constraint: @@ -1603,14 +1833,14 @@ def __repr__(self) -> str: return r @overload - def __getitem__(self, names: str) -> ConstraintBase: ... + def __getitem__(self, names: str) -> Constraint: ... @overload def __getitem__(self, names: list[str]) -> Constraints: ... - def __getitem__(self, names: str | list[str]) -> ConstraintBase | Constraints: + def __getitem__(self, names: str | list[str]) -> Constraint | Constraints: if isinstance(names, str): - return self.data[names] + return self.data[names] # type: ignore[return-value] return Constraints({name: self.data[name] for name in names}, self.model) def __getattr__(self, name: str) -> ConstraintBase: diff --git a/linopy/model.py b/linopy/model.py index 884d59db..dcadd95c 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -1986,7 +1986,7 @@ def assign_result( for _, var in self.variables.items(): start, end = var.range var.solution = xr.DataArray( - primal[start:end].reshape(var.shape), var.coords + primal[start:end].reshape(var.shape), var.coords, dims=var.dims ) if len(result.solution.dual): diff --git a/linopy/persistent/__init__.py b/linopy/persistent/__init__.py new file mode 100644 index 00000000..1058fce4 --- /dev/null +++ b/linopy/persistent/__init__.py @@ -0,0 +1,39 @@ +"""Persistent-solver snapshot and diff primitives.""" + +from __future__ import annotations + +from linopy.persistent.diff import ( + ConSlice, + ModelDiff, + RebuildReason, + VarSlice, +) +from linopy.persistent.errors import ( + RebuildRequiredError, + UnsupportedUpdate, + UpdatesDisabledError, +) +from linopy.persistent.snapshot import ( + ContainerConBuffers, + ContainerVarBuffers, + ModelSnapshot, + StructuralKey, + VarKind, + clear_coef_dirty, +) + +__all__ = [ + "ConSlice", + "ContainerConBuffers", + "ContainerVarBuffers", + "ModelDiff", + "ModelSnapshot", + "RebuildReason", + "RebuildRequiredError", + "StructuralKey", + "UnsupportedUpdate", + "UpdatesDisabledError", + "VarKind", + "VarSlice", + "clear_coef_dirty", +] diff --git a/linopy/persistent/diff.py b/linopy/persistent/diff.py new file mode 100644 index 00000000..71ba7e70 --- /dev/null +++ b/linopy/persistent/diff.py @@ -0,0 +1,627 @@ +from __future__ import annotations + +import enum +from collections.abc import Iterable +from dataclasses import dataclass +from functools import cached_property +from typing import TYPE_CHECKING + +import numpy as np + +from linopy.constants import short_GREATER_EQUAL, short_LESS_EQUAL +from linopy.constraints import Constraint +from linopy.persistent.snapshot import ( + ContainerConBuffers, + ContainerVarBuffers, + ModelSnapshot, + StructuralKey, + _coord_snapshot, + _extract_con_buffers, + _extract_var_buffers, + _objective_linear_vector, +) + +if TYPE_CHECKING: + from numpy.typing import DTypeLike + + from linopy.common import ConstraintLabelIndex, VariableLabelIndex + from linopy.constraints import ConstraintBase + from linopy.model import Model + from linopy.variables import Variable + + +class RebuildReason(enum.Enum): + STRUCTURAL_LABELS = "vlabels/clabels mismatch" + STRUCTURAL_CONTAINERS = "container set changed" + COORD_REINDEX = "coordinates changed" + SPARSITY = "coefficient sparsity changed" + QUAD_OBJ = "quadratic objective changed" + BACKEND_REJECTED = "backend raised UnsupportedUpdate" + + +@dataclass(frozen=True) +class VarSlice: + bounds: slice + type: slice + + +@dataclass(frozen=True) +class ConSlice: + coef: slice + rhs: slice + sign: slice + + +def _cat(parts: list[np.ndarray], dtype: DTypeLike) -> np.ndarray: + if not parts: + return np.empty(0, dtype=dtype) + return np.concatenate(parts).astype(dtype, copy=False) + + +def _same(a: np.ndarray, b: np.ndarray) -> bool: + return a is b or np.array_equal(a, b) + + +def _coords_equal( + a: dict[str, np.ndarray], b: dict[str, np.ndarray], ignored: frozenset[str] +) -> bool: + keys = a.keys() - ignored + if keys != b.keys() - ignored: + return False + return all(np.array_equal(a[k], b[k]) for k in keys) + + +def _structural_reason(base: StructuralKey, model: Model) -> RebuildReason | None: + if base.var_container_names != tuple( + model.variables + ) or base.con_container_names != tuple(model.constraints): + return RebuildReason.STRUCTURAL_CONTAINERS + if not np.array_equal(base.vlabels, model.variables.label_index.vlabels): + return RebuildReason.STRUCTURAL_LABELS + if not np.array_equal(base.clabels, model.constraints.label_index.clabels): + return RebuildReason.STRUCTURAL_LABELS + return None + + +@dataclass(frozen=True) +class _CoefDelta: + """Coefficient changes of one container, expanded to COO lazily.""" + + buf: ContainerConBuffers + changed_rows: np.ndarray + row_positions: np.ndarray + nnz: int + + +@dataclass +class ModelDiff: + """ + Flat-native delta between two structurally identical model states. + + Instances are produced by :meth:`from_snapshot` / :meth:`from_models`; + any condition that cannot be expressed as an in-place delta is returned + as a :class:`RebuildReason` instead of a diff. + + Coefficient changes are stored per container as ``coef_deltas`` + (changed rows referencing the container's CSR buffers) and expanded to + COO triplets — ``con_coef_rows`` / ``con_coef_cols`` / ``con_coef_vals`` + — on first access. + """ + + var_bounds_indices: np.ndarray + var_bounds_lower: np.ndarray + var_bounds_upper: np.ndarray + var_type_positions: np.ndarray + var_type_kinds: np.ndarray + + coef_deltas: list[_CoefDelta] + n_coef_updates: int + + con_rhs_indices: np.ndarray + con_rhs_values: np.ndarray + con_rhs_signs: np.ndarray + + con_sign_indices: np.ndarray + con_sign_values: np.ndarray + + obj_c_indices: np.ndarray | None + obj_c_values: np.ndarray | None + obj_sense: str | None + + var_slices: dict[str, VarSlice] + con_slices: dict[str, ConSlice] + + #: Snapshot of the diffed (target) model state, assembled from the + #: buffers the diff walk already extracted — adopting it after a + #: successful apply replaces a full re-capture. Note: holding a diff + #: therefore pins all container buffers for its lifetime. + snapshot: ModelSnapshot + + @property + def is_empty(self) -> bool: + return ( + self.var_bounds_indices.size == 0 + and self.var_type_positions.size == 0 + and self.n_coef_updates == 0 + and self.con_rhs_indices.size == 0 + and self.con_sign_indices.size == 0 + and self.obj_c_indices is None + and self.obj_sense is None + ) + + @property + def changed_variables(self) -> set[str]: + return set(self.var_slices) + + @property + def changed_constraints(self) -> set[str]: + return set(self.con_slices) + + @cached_property + def _coef_coo(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + rows = np.empty(self.n_coef_updates, dtype=np.int32) + cols = np.empty(self.n_coef_updates, dtype=np.int32) + vals = np.empty(self.n_coef_updates, dtype=np.float64) + cursor = 0 + for delta in self.coef_deltas: + indptr = delta.buf.indptr + starts = indptr[delta.changed_rows] + counts = indptr[delta.changed_rows + 1] - starts + run_offsets = np.repeat(np.cumsum(counts) - counts, counts) + flat = np.repeat(starts, counts) + np.arange(delta.nnz) - run_offsets + sl = slice(cursor, cursor + delta.nnz) + rows[sl] = np.repeat(delta.row_positions, counts) + cols[sl] = delta.buf.indices[flat] + vals[sl] = delta.buf.data[flat] + cursor += delta.nnz + return rows, cols, vals + + @property + def con_coef_rows(self) -> np.ndarray: + return self._coef_coo[0] + + @property + def con_coef_cols(self) -> np.ndarray: + return self._coef_coo[1] + + @property + def con_coef_vals(self) -> np.ndarray: + return self._coef_coo[2] + + def con_rhs_as_bounds(self) -> tuple[np.ndarray, np.ndarray]: + """Return (lower, upper) row-bounds form of the RHS updates.""" + vals = self.con_rhs_values + signs = self.con_rhs_signs + lower = np.where(signs == short_LESS_EQUAL, -np.inf, vals) + upper = np.where(signs == short_GREATER_EQUAL, np.inf, vals) + return lower, upper + + def summary(self) -> dict[str, int | bool | str | None]: + return { + "var_bounds": int(self.var_bounds_indices.size), + "var_type": int(self.var_type_positions.size), + "con_rhs": int(self.con_rhs_indices.size), + "con_sign": int(self.con_sign_indices.size), + "con_coef_updates": self.n_coef_updates, + "obj_linear_changed": self.obj_c_indices is not None, + "obj_sense_changed_to": self.obj_sense, + } + + def inspect_variable(self, name: str) -> dict[str, object]: + sl = self.var_slices.get(name) + if sl is None: + return {} + entry: dict[str, object] = {} + if sl.bounds.stop > sl.bounds.start: + entry["bounds_indices"] = self.var_bounds_indices[sl.bounds] + entry["lower"] = self.var_bounds_lower[sl.bounds] + entry["upper"] = self.var_bounds_upper[sl.bounds] + if sl.type.stop > sl.type.start: + entry["type_positions"] = self.var_type_positions[sl.type] + entry["type_kinds"] = self.var_type_kinds[sl.type] + return entry + + def inspect_constraint(self, name: str) -> dict[str, object]: + sl = self.con_slices.get(name) + if sl is None: + return {} + entry: dict[str, object] = {} + if sl.coef.stop > sl.coef.start: + entry["coef_rows"] = self.con_coef_rows[sl.coef] + entry["coef_cols"] = self.con_coef_cols[sl.coef] + entry["coef_vals"] = self.con_coef_vals[sl.coef] + if sl.rhs.stop > sl.rhs.start: + entry["rhs_indices"] = self.con_rhs_indices[sl.rhs] + entry["rhs_values"] = self.con_rhs_values[sl.rhs] + entry["rhs_signs"] = self.con_rhs_signs[sl.rhs] + if sl.sign.stop > sl.sign.start: + entry["sign_indices"] = self.con_sign_indices[sl.sign] + entry["sign_values"] = self.con_sign_values[sl.sign] + return entry + + def __repr__(self) -> str: + if self.is_empty: + return "ModelDiff(empty)" + parts = [ + f"{k}={v}" for k, v in self.summary().items() if v not in (0, False, None) + ] + return "ModelDiff(" + ", ".join(parts) + ")" + + @classmethod + def from_snapshot( + cls, + snapshot: ModelSnapshot, + model: Model, + same_model: bool = False, + ignore_dims: Iterable[str] = (), + ) -> ModelDiff | RebuildReason: + """ + Diff ``model`` against a captured ``snapshot``. + + Returns a :class:`ModelDiff` when the change is expressible in + place, or the :class:`RebuildReason` that prevents it. + + Coordinate values are compared on every dim *not* in + ``ignore_dims``; a mismatch triggers + ``RebuildReason.COORD_REINDEX``. Pass ``ignore_dims={"snapshot"}`` + for rolling-horizon use cases where the snapshot coord + legitimately shifts between solves. + + ``same_model`` is a perf hint, **default False**. When True, the + diff trusts ``Constraint._coef_dirty`` to short-circuit the CSR + walk for unchanged containers. That's only safe if every + coefficient mutation went through ``Constraint.update`` (or the + setters that forward there) — direct ``c.coeffs.values[...]`` + writes bypass the flag and would silently miss changes. Pass + ``same_model=True`` only when you own the mutation contract. + """ + reason = _structural_reason(snapshot.structural_key, model) + if reason is not None: + return reason + + builder = _DiffBuilder( + model.variables.label_index, + model.constraints.label_index, + frozenset(ignore_dims), + structural_key=snapshot.structural_key, + ) + + for name, var in model.variables.items(): + reason = builder.diff_var( + name, var, snapshot.var_buffers[name], snapshot.var_coords[name] + ) + if reason is not None: + return reason + + for name, con in model.constraints.items(): + skip = same_model and isinstance(con, Constraint) and not con._coef_dirty + reason = builder.diff_con( + name, + con, + snapshot.con_buffers[name], + snapshot.con_coords[name], + skip_coef_compare=skip, + ) + if reason is not None: + return reason + + reason = builder.diff_objective( + model, snapshot.obj_c, snapshot.obj_quad_present, snapshot.obj_sense + ) + if reason is not None: + return reason + + return builder.finalize() + + @classmethod + def from_models( + cls, + model_a: Model, + model_b: Model, + ignore_dims: Iterable[str] = (), + ) -> ModelDiff | RebuildReason: + """ + Diff two linopy models directly, without capturing a snapshot. + + ``model_a`` is the baseline, ``model_b`` is the target. The + coefficient comparison runs unconditionally — no ``_coef_dirty`` + shortcut applies between independently-built models. Returns a + :class:`ModelDiff` or the :class:`RebuildReason` preventing an + in-place update. + """ + var_idx_a = model_a.variables.label_index + key_a = StructuralKey( + var_container_names=tuple(model_a.variables), + con_container_names=tuple(model_a.constraints), + vlabels=var_idx_a.vlabels, + clabels=model_a.constraints.label_index.clabels, + ) + reason = _structural_reason(key_a, model_b) + if reason is not None: + return reason + + var_idx_b = model_b.variables.label_index + con_idx_b = model_b.constraints.label_index + key_b = StructuralKey( + var_container_names=tuple(model_b.variables), + con_container_names=tuple(model_b.constraints), + vlabels=var_idx_b.vlabels, + clabels=con_idx_b.clabels, + ) + builder = _DiffBuilder( + var_idx_b, + con_idx_b, + frozenset(ignore_dims), + structural_key=key_b, + ) + + for name, var_b in model_b.variables.items(): + var_a = model_a.variables[name] + reason = builder.diff_var( + name, var_b, _extract_var_buffers(var_a), _coord_snapshot(var_a) + ) + if reason is not None: + return reason + + for name, con_b in model_b.constraints.items(): + con_a = model_a.constraints[name] + reason = builder.diff_con( + name, + con_b, + _extract_con_buffers(con_a, var_idx_a), + _coord_snapshot(con_a), + skip_coef_compare=False, + ) + if reason is not None: + return reason + + reason = builder.diff_objective( + model_b, + _objective_linear_vector(model_a), + model_a.objective.is_quadratic, + model_a.objective.sense, + ) + if reason is not None: + return reason + + return builder.finalize() + + +class _DiffBuilder: + """Accumulates per-container deltas and finalizes them into a ModelDiff.""" + + def __init__( + self, + var_label_index: VariableLabelIndex, + con_label_index: ConstraintLabelIndex, + ignored: frozenset[str], + structural_key: StructuralKey, + ) -> None: + self.var_label_index = var_label_index + self.var_l2p = var_label_index.label_to_pos + self.con_l2p = con_label_index.label_to_pos + self.ignored = ignored + self.structural_key = structural_key + + # Target-state material for the snapshot assembled in finalize(). + self.var_buffers: dict[str, ContainerVarBuffers] = {} + self.con_buffers: dict[str, ContainerConBuffers] = {} + self.var_coords: dict[str, dict[str, np.ndarray]] = {} + self.con_coords: dict[str, dict[str, np.ndarray]] = {} + self._snap_obj_c: np.ndarray | None = None + self._snap_obj_sense: str | None = None + + self.var_bounds_idx: list[np.ndarray] = [] + self.var_bounds_lo: list[np.ndarray] = [] + self.var_bounds_up: list[np.ndarray] = [] + self.var_type_pos: list[np.ndarray] = [] + self.var_type_kinds: list[np.ndarray] = [] + + self.coef_deltas: list[_CoefDelta] = [] + self.con_rhs_idx: list[np.ndarray] = [] + self.con_rhs_vals: list[np.ndarray] = [] + self.con_rhs_signs: list[np.ndarray] = [] + self.con_sign_idx: list[np.ndarray] = [] + self.con_sign_vals: list[np.ndarray] = [] + + self.var_slices: dict[str, VarSlice] = {} + self.con_slices: dict[str, ConSlice] = {} + + self.obj_c_indices: np.ndarray | None = None + self.obj_c_values: np.ndarray | None = None + self.obj_sense: str | None = None + + self._vb_cur = 0 + self._vt_cur = 0 + self._cc_cur = 0 + self._cr_cur = 0 + self._cs_cur = 0 + + def diff_var( + self, + name: str, + var: Variable, + base_buf: ContainerVarBuffers, + base_coords: dict[str, np.ndarray], + ) -> RebuildReason | None: + new_buf = _extract_var_buffers(var) + new_coords = _coord_snapshot(var) + self.var_buffers[name] = new_buf + self.var_coords[name] = new_coords + if new_buf.lower.shape != base_buf.lower.shape: + return RebuildReason.COORD_REINDEX + if not _same(new_buf.active_labels, base_buf.active_labels): + return RebuildReason.STRUCTURAL_LABELS + if not _coords_equal(base_coords, new_coords, self.ignored): + return RebuildReason.COORD_REINDEX + + bound_mask = (new_buf.lower != base_buf.lower) | ( + new_buf.upper != base_buf.upper + ) + bounds_changed = bool(bound_mask.any()) + type_changed = new_buf.type != base_buf.type + if not (bounds_changed or type_changed): + return None + + b_start, t_start = self._vb_cur, self._vt_cur + if bounds_changed: + local_idx = np.flatnonzero(bound_mask) + positions = self.var_l2p[new_buf.active_labels[local_idx]] + self.var_bounds_idx.append(positions.astype(np.int32, copy=False)) + self.var_bounds_lo.append( + new_buf.lower[local_idx].astype(np.float64, copy=False) + ) + self.var_bounds_up.append( + new_buf.upper[local_idx].astype(np.float64, copy=False) + ) + self._vb_cur += local_idx.size + if type_changed: + positions = self.var_l2p[new_buf.active_labels].astype(np.int32, copy=False) + self.var_type_pos.append(positions) + self.var_type_kinds.append( + np.full(positions.size, new_buf.type, dtype=object) + ) + self._vt_cur += positions.size + self.var_slices[name] = VarSlice( + bounds=slice(b_start, self._vb_cur), + type=slice(t_start, self._vt_cur), + ) + return None + + def diff_con( + self, + name: str, + con: ConstraintBase, + base_buf: ContainerConBuffers, + base_coords: dict[str, np.ndarray], + skip_coef_compare: bool, + ) -> RebuildReason | None: + new_buf = _extract_con_buffers(con, self.var_label_index) + new_coords = _coord_snapshot(con) + self.con_buffers[name] = new_buf + self.con_coords[name] = new_coords + if new_buf.indptr.shape != base_buf.indptr.shape: + return RebuildReason.COORD_REINDEX + if not _same(new_buf.active_labels, base_buf.active_labels): + return RebuildReason.STRUCTURAL_LABELS + if not _coords_equal(base_coords, new_coords, self.ignored): + return RebuildReason.COORD_REINDEX + if not _same(new_buf.indptr, base_buf.indptr): + return RebuildReason.SPARSITY + if not _same(new_buf.indices, base_buf.indices): + return RebuildReason.SPARSITY + + n_rows = new_buf.active_labels.size + if n_rows == 0: + return None + + changed_rows = None + if not (skip_coef_compare or new_buf.data is base_buf.data): + data_diff = new_buf.data != base_buf.data + if data_diff.any(): + nnz_per_row = np.diff(new_buf.indptr) + row_of_nnz = np.repeat(np.arange(n_rows), nnz_per_row) + changed_rows = np.unique(row_of_nnz[data_diff]) + + rhs_idx = None + if new_buf.rhs is not base_buf.rhs: + rhs_idx = np.flatnonzero(new_buf.rhs != base_buf.rhs) + if rhs_idx.size == 0: + rhs_idx = None + sign_idx = None + if new_buf.sign is not base_buf.sign: + sign_idx = np.flatnonzero(new_buf.sign != base_buf.sign) + if sign_idx.size == 0: + sign_idx = None + + if changed_rows is None and rhs_idx is None and sign_idx is None: + return None + + c_start, r_start, s_start = self._cc_cur, self._cr_cur, self._cs_cur + if changed_rows is not None: + row_positions = self.con_l2p[new_buf.active_labels[changed_rows]].astype( + np.int32, copy=False + ) + indptr = new_buf.indptr + nnz = int((indptr[changed_rows + 1] - indptr[changed_rows]).sum()) + self.coef_deltas.append( + _CoefDelta(new_buf, changed_rows, row_positions, nnz) + ) + self._cc_cur += nnz + if rhs_idx is not None: + positions = self.con_l2p[new_buf.active_labels[rhs_idx]] + self.con_rhs_idx.append(positions.astype(np.int32, copy=False)) + self.con_rhs_vals.append( + new_buf.rhs[rhs_idx].astype(np.float64, copy=False) + ) + self.con_rhs_signs.append(new_buf.sign[rhs_idx]) + self._cr_cur += rhs_idx.size + if sign_idx is not None: + positions = self.con_l2p[new_buf.active_labels[sign_idx]] + self.con_sign_idx.append(positions.astype(np.int32, copy=False)) + self.con_sign_vals.append(new_buf.sign[sign_idx]) + self._cs_cur += sign_idx.size + self.con_slices[name] = ConSlice( + coef=slice(c_start, self._cc_cur), + rhs=slice(r_start, self._cr_cur), + sign=slice(s_start, self._cs_cur), + ) + return None + + def diff_objective( + self, + model: Model, + base_obj_c: np.ndarray, + base_obj_quad: bool, + base_obj_sense: str, + ) -> RebuildReason | None: + if model.objective.is_quadratic or base_obj_quad: + return RebuildReason.QUAD_OBJ + + obj_c = _objective_linear_vector(model) + self._snap_obj_c = obj_c + self._snap_obj_sense = model.objective.sense + if obj_c.shape != base_obj_c.shape: + return RebuildReason.COORD_REINDEX + obj_diff_mask = obj_c != base_obj_c + if obj_diff_mask.any(): + self.obj_c_indices = np.flatnonzero(obj_diff_mask).astype( + np.int32, copy=False + ) + self.obj_c_values = obj_c[self.obj_c_indices].astype(np.float64, copy=False) + if model.objective.sense != base_obj_sense: + self.obj_sense = model.objective.sense + return None + + def finalize(self) -> ModelDiff: + assert self._snap_obj_c is not None and self._snap_obj_sense is not None + snapshot = ModelSnapshot( + structural_key=self.structural_key, + var_buffers=self.var_buffers, + con_buffers=self.con_buffers, + var_coords=self.var_coords, + con_coords=self.con_coords, + obj_c=self._snap_obj_c, + obj_quad_present=False, + obj_sense=self._snap_obj_sense, + ) + return ModelDiff( + snapshot=snapshot, + var_bounds_indices=_cat(self.var_bounds_idx, np.int32), + var_bounds_lower=_cat(self.var_bounds_lo, np.float64), + var_bounds_upper=_cat(self.var_bounds_up, np.float64), + var_type_positions=_cat(self.var_type_pos, np.int32), + var_type_kinds=_cat(self.var_type_kinds, object), + coef_deltas=self.coef_deltas, + n_coef_updates=self._cc_cur, + con_rhs_indices=_cat(self.con_rhs_idx, np.int32), + con_rhs_values=_cat(self.con_rhs_vals, np.float64), + con_rhs_signs=_cat(self.con_rhs_signs, "U1"), + con_sign_indices=_cat(self.con_sign_idx, np.int32), + con_sign_values=_cat(self.con_sign_vals, "U1"), + obj_c_indices=self.obj_c_indices, + obj_c_values=self.obj_c_values, + obj_sense=self.obj_sense, + var_slices=self.var_slices, + con_slices=self.con_slices, + ) diff --git a/linopy/persistent/errors.py b/linopy/persistent/errors.py new file mode 100644 index 00000000..c6159207 --- /dev/null +++ b/linopy/persistent/errors.py @@ -0,0 +1,25 @@ +from __future__ import annotations + + +class UnsupportedUpdate(Exception): + pass + + +class RebuildRequiredError(RuntimeError): + """ + Raised when an in-place update is required but a rebuild is needed. + + Carries the :class:`RebuildReason` that forced the rebuild attempt. + """ + + def __init__(self, reason: object, message: str | None = None) -> None: + self.reason = reason + super().__init__(message or f"rebuild required: {reason}") + + +class UpdatesDisabledError(RuntimeError): + """ + Raised when an in-place update is requested on a solver built with + ``track_updates=False``. Reconstruct the solver with ``track_updates=True`` + to enable diff-based updates. + """ diff --git a/linopy/persistent/snapshot.py b/linopy/persistent/snapshot.py new file mode 100644 index 00000000..fd758ea3 --- /dev/null +++ b/linopy/persistent/snapshot.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +import enum +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +import numpy as np + +from linopy import expressions +from linopy.constraints import Constraint + +if TYPE_CHECKING: + from linopy.constraints import ConstraintBase + from linopy.model import Model + from linopy.variables import Variable, VariableLabelIndex + + +class VarKind(enum.Enum): + CONTINUOUS = "continuous" + BINARY = "binary" + INTEGER = "integer" + SEMI_CONTINUOUS = "semi_continuous" + + +def _variable_type(var: Variable) -> VarKind: + attrs = var.attrs + if attrs.get("binary"): + return VarKind.BINARY + if attrs.get("integer"): + return VarKind.INTEGER + if attrs.get("semi_continuous"): + return VarKind.SEMI_CONTINUOUS + return VarKind.CONTINUOUS + + +def _objective_linear_vector(model: Model) -> np.ndarray: + vlabels = model.variables.label_index.vlabels + label_to_pos = model.variables.label_index.label_to_pos + result = np.zeros(len(vlabels), dtype=np.float64) + expr = model.objective.expression + if isinstance(expr, expressions.QuadraticExpression): + vars_2d = expr.data.vars.values + coeffs_all = expr.data.coeffs.values.ravel() + vars1, vars2 = vars_2d[0], vars_2d[1] + linear = (vars1 == -1) | (vars2 == -1) + var_labels = np.where(vars1[linear] != -1, vars1[linear], vars2[linear]) + coeffs = coeffs_all[linear] + else: + var_labels = expr.data.vars.values.ravel() + coeffs = expr.data.coeffs.values.ravel() + mask = var_labels != -1 + np.add.at(result, label_to_pos[var_labels[mask]], coeffs[mask]) + return result + + +def _extract_var_buffers(var: Variable) -> ContainerVarBuffers: + # Boolean masking copies, so the buffers never alias the live model + # arrays — the snapshot stays a valid baseline even after in-place + # ``.values[...]`` mutations. + labels_flat = var.labels.values.ravel() + mask = labels_flat != -1 + return ContainerVarBuffers( + lower=var.lower.values.ravel()[mask].astype(np.float64, copy=False), + upper=var.upper.values.ravel()[mask].astype(np.float64, copy=False), + type=_variable_type(var), + active_labels=labels_flat[mask].astype(np.int64, copy=False), + ) + + +def _extract_con_buffers( + con: ConstraintBase, var_label_index: VariableLabelIndex +) -> ContainerConBuffers: + """ + Extract flat constraint buffers without copying. + + Mutable ``Constraint`` objects build fresh arrays in + ``to_matrix_with_rhs``, so the buffers are exclusively owned. + ``CSRConstraint`` returns its stored arrays — the buffers share memory + with the constraint, every mutation path rebinds whole arrays + (copy-on-write), and the diff uses object identity to skip comparisons + on untouched containers. + """ + csr, con_labels, b, sense = con.to_matrix_with_rhs(var_label_index) + return ContainerConBuffers( + indptr=csr.indptr, + indices=csr.indices, + data=np.asarray(csr.data, dtype=np.float64), + rhs=np.asarray(b, dtype=np.float64), + sign=np.asarray(sense, dtype="U1"), + active_labels=np.asarray(con_labels, dtype=np.int64), + ) + + +@dataclass(frozen=True) +class StructuralKey: + var_container_names: tuple[str, ...] + con_container_names: tuple[str, ...] + vlabels: np.ndarray + clabels: np.ndarray + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, StructuralKey) + and self.var_container_names == other.var_container_names + and self.con_container_names == other.con_container_names + and np.array_equal(self.vlabels, other.vlabels) + and np.array_equal(self.clabels, other.clabels) + ) + + __hash__ = None # type: ignore[assignment] + + +@dataclass(frozen=True) +class ContainerVarBuffers: + lower: np.ndarray + upper: np.ndarray + type: VarKind + active_labels: np.ndarray + + +@dataclass(frozen=True) +class ContainerConBuffers: + indptr: np.ndarray + indices: np.ndarray + data: np.ndarray + rhs: np.ndarray + sign: np.ndarray + active_labels: np.ndarray + + +def _coord_snapshot(obj: Variable | ConstraintBase) -> dict[str, np.ndarray]: + return {str(name): np.asarray(idx) for name, idx in obj.indexes.items()} + + +def clear_coef_dirty(model: Model) -> None: + """ + Reset ``Constraint._coef_dirty`` on every constraint of ``model``. + + Must be called exactly when a snapshot reflecting the model's current + state is adopted by a tracking solver — clearing without adopting makes + a later ``same_model=True`` diff silently skip changed coefficients. + """ + for con in model.constraints.data.values(): + if isinstance(con, Constraint): + con._coef_dirty = False + + +@dataclass +class ModelSnapshot: + structural_key: StructuralKey + var_buffers: dict[str, ContainerVarBuffers] = field(default_factory=dict) + con_buffers: dict[str, ContainerConBuffers] = field(default_factory=dict) + var_coords: dict[str, dict[str, np.ndarray]] = field(default_factory=dict) + con_coords: dict[str, dict[str, np.ndarray]] = field(default_factory=dict) + obj_c: np.ndarray = field(default_factory=lambda: np.zeros(0, dtype=np.float64)) + obj_quad_present: bool = False + obj_sense: str = "min" + + @classmethod + def capture(cls, model: Model) -> ModelSnapshot: + var_label_index = model.variables.label_index + con_label_index = model.constraints.label_index + + structural_key = StructuralKey( + var_container_names=tuple(model.variables), + con_container_names=tuple(model.constraints), + vlabels=var_label_index.vlabels, + clabels=con_label_index.clabels, + ) + + var_buffers = { + name: _extract_var_buffers(var) for name, var in model.variables.items() + } + con_buffers = { + name: _extract_con_buffers(con, var_label_index) + for name, con in model.constraints.items() + } + var_coords = { + name: _coord_snapshot(var) for name, var in model.variables.items() + } + con_coords = { + name: _coord_snapshot(con) for name, con in model.constraints.items() + } + + return cls( + structural_key=structural_key, + var_buffers=var_buffers, + con_buffers=con_buffers, + var_coords=var_coords, + con_coords=con_coords, + obj_c=_objective_linear_vector(model), + obj_quad_present=model.objective.is_quadratic, + obj_sense=model.objective.sense, + ) diff --git a/linopy/solvers.py b/linopy/solvers.py index b71d0c2c..d154dd35 100644 --- a/linopy/solvers.py +++ b/linopy/solvers.py @@ -19,7 +19,7 @@ import warnings from abc import ABC from collections import namedtuple -from collections.abc import Callable, Generator, Iterator, Sequence +from collections.abc import Callable, Generator, Iterable, Iterator, Sequence from dataclasses import dataclass, field from enum import Enum, auto from importlib.metadata import PackageNotFoundError @@ -36,6 +36,7 @@ import linopy.io from linopy.common import count_initial_letters, values_to_lookup_array from linopy.constants import ( + EQUAL, SOS_DIM_ATTR, SOS_TYPE_ATTR, Result, @@ -44,7 +45,27 @@ SolverStatus, Status, TerminationCondition, + short_GREATER_EQUAL, + short_LESS_EQUAL, ) +from linopy.persistent import ( + ModelDiff, + ModelSnapshot, + RebuildReason, + RebuildRequiredError, + UnsupportedUpdate, + UpdatesDisabledError, + VarKind, + clear_coef_dirty, +) + + +def _int_list(arr: np.ndarray, dtype: type = np.int64) -> list[int]: + return arr.astype(dtype, copy=False).tolist() + + +def _float_list(arr: np.ndarray) -> list[float]: + return arr.astype(float, copy=False).tolist() def _parse_int_label(name: str) -> int: @@ -400,11 +421,22 @@ class Solver(ABC, Generic[EnvType]): Subclasses provide ``_build_direct`` / ``_run_direct`` (when supporting the direct API) and ``_run_file`` (when supporting LP/MPS files). Construction goes via :meth:`Solver.from_name` or :meth:`Solver.from_model`. + + ``track_updates`` toggles persistent-update support: + + * ``False`` (default) — one-shot mode. No :class:`ModelSnapshot` is + captured at build time; any later ``solve(model=...)`` or + ``update(model)`` raises :class:`UpdatesDisabledError`. Use for + throw-away solver instances and high-level ``Model.solve(...)``. + * ``True`` — long-lived mode. A snapshot is captured at build time and + re-captured after each successful in-place update, enabling + diff-based ``solve(model=...)`` / ``update(model)`` across iterations. """ model: Model | None = None io_api: str | None = None options: dict[str, Any] = field(default_factory=dict) + track_updates: bool = False # Runtime state — never set via constructor. status: Status | None = field(init=False, default=None, repr=False) @@ -422,9 +454,18 @@ class Solver(ABC, Generic[EnvType]): _n_cons: int = field(init=False, default=0, repr=False) _problem_fn: Path | None = field(init=False, default=None, repr=False) + snapshot: ModelSnapshot | None = field(init=False, default=None, repr=False) + _rebuilds: int = field(init=False, default=0, repr=False) + _in_place_updates: int = field(init=False, default=0, repr=False) + _last_rebuild_reason: RebuildReason | None = field( + init=False, default=None, repr=False + ) + display_name: ClassVar[str] = "" features: ClassVar[frozenset[SolverFeature]] = frozenset() accepted_io_apis: ClassVar[frozenset[str]] = frozenset() + supports_persistent_update: ClassVar[bool] = False + supports_sign_update: ClassVar[bool] = False def __post_init__(self) -> None: if type(self) is Solver: @@ -437,10 +478,116 @@ def __post_init__(self) -> None: "Please install first to initialize solver instance." ) raise ImportError(msg) + self._lock: threading.Lock = threading.Lock() + + def apply_update( + self, + diff: ModelDiff, + var_label_index: Any, + con_label_index: Any, + ) -> None: + """ + Apply an in-place :class:`ModelDiff` to the built native model. + + Template method: validates the diff up front (a rejected update + leaves the native model untouched), then walks the sections in a + fixed order, dispatching to the per-backend ``_apply_*`` hooks. + """ + if not type(self).supports_persistent_update: + raise UnsupportedUpdate(type(self).__name__) + self._validate_update(diff) + ctx = self._apply_begin(var_label_index, con_label_index) + if diff.var_bounds_indices.size: + self._apply_var_bounds( + ctx, + diff.var_bounds_indices, + diff.var_bounds_lower, + diff.var_bounds_upper, + ) + if diff.var_type_positions.size: + self._apply_var_types(ctx, diff.var_type_positions, diff.var_type_kinds) + self._reclamp_binary_bounds( + ctx, diff.var_type_positions, diff.var_type_kinds + ) + if diff.con_rhs_indices.size: + self._apply_con_rhs(ctx, diff) + if diff.con_sign_indices.size: + self._apply_con_signs(ctx, diff.con_sign_indices, diff.con_sign_values) + if diff.n_coef_updates: + self._apply_con_coefs( + ctx, diff.con_coef_rows, diff.con_coef_cols, diff.con_coef_vals + ) + if diff.obj_c_indices is not None: + assert diff.obj_c_values is not None + self._apply_obj_linear(ctx, diff.obj_c_indices, diff.obj_c_values) + if diff.obj_sense is not None: + self._apply_obj_sense(ctx, diff.obj_sense) + self.sense = diff.obj_sense + self._apply_end(ctx) + + def _validate_update(self, diff: ModelDiff) -> None: + """Reject unsupported diff content before any native mutation.""" + if diff.con_sign_indices.size and not type(self).supports_sign_update: + raise UnsupportedUpdate( + f"{self.display_name} does not support in-place constraint sign change" + ) + + def _apply_begin(self, var_label_index: Any, con_label_index: Any) -> Any: + """Backend prep + validation; the return value is passed to every hook.""" + return self.solver_model + + def _apply_end(self, ctx: Any) -> None: + return None + + def _apply_var_bounds( + self, ctx: Any, indices: np.ndarray, lower: np.ndarray, upper: np.ndarray + ) -> None: + raise NotImplementedError + + def _apply_var_types( + self, ctx: Any, positions: np.ndarray, kinds: np.ndarray + ) -> None: + raise NotImplementedError + + def _reclamp_binary_bounds( + self, ctx: Any, positions: np.ndarray, kinds: np.ndarray + ) -> None: + """ + Re-clamp variables switched to BINARY to [0, 1]. + + Compensates for backends whose native type system only has a generic + integer kind; backends where the binary type implies the bounds + (Gurobi) override with a no-op. + """ + binary_mask = kinds == VarKind.BINARY + if binary_mask.any(): + bin_positions = positions[binary_mask] + n = bin_positions.size + self._apply_var_bounds(ctx, bin_positions, np.zeros(n), np.ones(n)) + + def _apply_con_rhs(self, ctx: Any, diff: ModelDiff) -> None: + raise NotImplementedError + + def _apply_con_signs( + self, ctx: Any, indices: np.ndarray, signs: np.ndarray + ) -> None: + raise NotImplementedError + + def _apply_con_coefs( + self, ctx: Any, rows: np.ndarray, cols: np.ndarray, vals: np.ndarray + ) -> None: + raise NotImplementedError + + def _apply_obj_linear( + self, ctx: Any, indices: np.ndarray, values: np.ndarray + ) -> None: + raise NotImplementedError + + def _apply_obj_sense(self, ctx: Any, sense: str) -> None: + raise NotImplementedError @property def solver_options(self) -> dict[str, Any]: - """Back-compat alias for ``self.options``.""" return self.options @classmethod @@ -497,17 +644,35 @@ def supports(cls, feature: SolverFeature) -> bool: @staticmethod def from_name( name: str, - model: Model, + model: Model | None = None, io_api: str | None = None, options: dict[str, Any] | None = None, + track_updates: bool = False, **build_kwargs: Any, ) -> Solver: - """Construct and build the solver subclass registered as ``name``.""" + """ + Construct the solver subclass registered as ``name``. + + With ``model`` supplied, the solver is built immediately. Without it, + an unbuilt instance is returned and the first ``solve(model, ...)`` + call performs the build. See :class:`Solver` for ``track_updates``. + """ cls = _solver_class_for(name) if cls is None: raise ValueError(f"unknown solver: {name}") + if model is None: + return cls( + model=None, + io_api=io_api, + options=options or {}, + track_updates=track_updates, + ) return cls.from_model( - model, io_api=io_api, options=options or {}, **build_kwargs + model, + io_api=io_api, + options=options or {}, + track_updates=track_updates, + **build_kwargs, ) @classmethod @@ -516,10 +681,16 @@ def from_model( model: Model, io_api: str | None = None, options: dict[str, Any] | None = None, + track_updates: bool = False, **build_kwargs: Any, ) -> Solver: """Instantiate and build the solver against ``model``.""" - instance = cls(model=model, io_api=io_api, options=options or {}) + instance = cls( + model=model, + io_api=io_api, + options=options or {}, + track_updates=track_updates, + ) instance._build(**build_kwargs) return instance @@ -539,6 +710,9 @@ def _build(self, **build_kwargs: Any) -> None: self._validate_model() if self.io_api == "direct": self._build_direct(**build_kwargs) + if self.track_updates: + self.snapshot = ModelSnapshot.capture(self.model) + clear_coef_dirty(self.model) else: self._build_file(**build_kwargs) @@ -616,38 +790,196 @@ def _build_file(self, **build_kwargs: Any) -> None: self.io_api = read_io_api_from_problem_file(problem_fn) self._cache_model_sizes(model) - def solve(self, **run_kwargs: Any) -> Result: + def solve( + self, + model: Model | None = None, + assign: bool = False, + ignore_dims: Iterable[str] = (), + disallow_rebuild: bool = False, + **run_kwargs: Any, + ) -> Result: """ Run the prepared solver and return a :class:`Result`. - The canonical low-level pattern is:: + With ``model`` supplied, diff against the previous build and either + apply in place or rebuild before running. Requires ``io_api='direct'``. + With ``assign=True`` the Result is written back to the target Model + via :meth:`Model.assign_result`. + + Coordinate alignment is checked on every dim by default. Pass + ``ignore_dims`` to exclude dims whose coord values legitimately shift + between solves. + + Pass ``disallow_rebuild=True`` to guarantee that an existing solver + model is updated in place — any condition that would force a rebuild + (structural change, sparsity change, backend rejection, …) raises + :class:`RebuildRequiredError` instead. The initial build on the first + ``solve(model, ...)`` is still allowed. + + Thread safety: the solver lock is held for the entire call, + including the native run. This is deliberate — diff/apply and the + run must be atomic (otherwise a concurrent apply would change the + problem between apply and run), and native solver handles are not + thread-safe. Concurrent solves therefore serialize per Solver + instance; use separate instances for parallelism. Pure diff + computation (``update(model, apply=False)``) does not take the lock. + """ + if model is not None and self.io_api != "direct": + raise ValueError("solve(model=...) requires io_api='direct'") + + with self._lock: + if model is not None: + if self.solver_model is None: + self.model = model + self._build() + else: + if not self.track_updates and model is self.model: + raise UpdatesDisabledError( + "Solver was constructed with track_updates=False; " + "in-place mutations of the build-time Model cannot " + "be detected without a snapshot. Pass a freshly " + "built Model instance, or reconstruct the solver " + "with Solver.from_name(..., track_updates=True)." + ) + self._update_locked( + model, + apply=True, + ignore_dims=ignore_dims, + disallow_rebuild=disallow_rebuild, + ) + target = model + else: + target = self.model # type: ignore[assignment] - solver = Solver.from_name("gurobi", model, io_api="direct") - result = solver.solve() - model.assign_result(result, solver=solver) + if self.model is not None and self.model.objective.expression.empty: + raise ValueError( + "No objective has been set on the model. Use `m.add_objective(...)` " + "first (e.g. `m.add_objective(0 * x)` for a pure feasibility problem)." + ) + if self.io_api == "direct" or self.solver_model is not None: + result = self._run_direct(**run_kwargs) + elif self._problem_fn is not None: + result = self._run_file(**run_kwargs) + else: + raise RuntimeError( + "Solver has not been built; call Solver.from_name(...) or _build() first." + ) - Passing ``solver=`` to :meth:`Model.assign_result` wires - ``model.solver`` so post-solve helpers like - :meth:`Model.compute_infeasibilities` keep working. + if assign and target is not None: + target.assign_result(result, solver=self) + return result - Raises - ------ - ValueError - If the attached model has no objective set. Submit-time check - shared by both ``Model.solve()`` and direct-Solver callers. + def update( + self, + model: Model, + apply: bool = True, + ignore_dims: Iterable[str] = (), + ) -> ModelDiff | RebuildReason: """ - if self.model is not None and self.model.objective.expression.empty: - raise ValueError( - "No objective has been set on the model. Use `m.add_objective(...)` " - "first (e.g. `m.add_objective(0 * x)` for a pure feasibility problem)." + Diff ``model`` against the solver state and optionally apply it. + + With ``apply=False`` the diff is computed without taking the solver + lock, so it can overlap a concurrently running solve. The preview + always runs a full comparison (no ``_coef_dirty`` shortcut — a + concurrent apply may clear the flag against a newer snapshot), so it + can report raw in-place ``.values[...]`` mutations that the apply + path, which trusts the flag for the build-time model, would miss. + """ + if self.io_api != "direct": + raise ValueError("update requires io_api='direct'") + if self.solver_model is None: + raise RuntimeError("Solver has not been built") + if not self.track_updates and model is self.model: + raise UpdatesDisabledError( + "Solver was constructed with track_updates=False; " + "in-place mutations of the build-time Model cannot be " + "detected without a snapshot. Pass a freshly built Model " + "instance, or reconstruct the solver with " + "Solver.from_name(..., track_updates=True)." ) - if self.io_api == "direct" or self.solver_model is not None: - return self._run_direct(**run_kwargs) - if self._problem_fn is not None: - return self._run_file(**run_kwargs) - raise RuntimeError( - "Solver has not been built; call Solver.from_name(...) or _build() first." - ) + if not apply: + return self._diff_unlocked(model, ignore_dims) + with self._lock: + return self._update_locked(model, apply=apply, ignore_dims=ignore_dims) + + def _diff_unlocked( + self, model: Model, ignore_dims: Iterable[str] + ) -> ModelDiff | RebuildReason: + """ + Compute a diff without the solver lock. + + Snapshot and baseline refs are read once; snapshot buffers are + immutable after capture, so the walk is consistent even while a + concurrent apply swaps ``self.snapshot``. The fallback baseline + (``from_models``) is only consistent if no thread concurrently + mutates either Model. + """ + snapshot = self.snapshot + if snapshot is not None: + return ModelDiff.from_snapshot( + snapshot, model, same_model=False, ignore_dims=ignore_dims + ) + baseline = self.model + assert baseline is not None + return ModelDiff.from_models(baseline, model, ignore_dims=ignore_dims) + + def _update_locked( + self, + model: Model, + apply: bool, + ignore_dims: Iterable[str] = (), + disallow_rebuild: bool = False, + ) -> ModelDiff | RebuildReason: + if apply and not type(self).supports_persistent_update: + if disallow_rebuild: + raise RebuildRequiredError(RebuildReason.BACKEND_REJECTED) + self._rebuild(model, RebuildReason.BACKEND_REJECTED) + return RebuildReason.BACKEND_REJECTED + if self.snapshot is not None: + same_model = model is self.model + diff = ModelDiff.from_snapshot( + self.snapshot, model, same_model=same_model, ignore_dims=ignore_dims + ) + else: + assert self.model is not None + diff = ModelDiff.from_models(self.model, model, ignore_dims=ignore_dims) + if isinstance(diff, RebuildReason): + if not apply: + return diff + if disallow_rebuild: + raise RebuildRequiredError(diff) + self._rebuild(model, diff) + return diff + if not apply: + return diff + try: + self.apply_update( + diff, + model.variables.label_index, + model.constraints.label_index, + ) + except Exception as exc: + if disallow_rebuild: + raise RebuildRequiredError( + RebuildReason.BACKEND_REJECTED, str(exc) + ) from exc + self._last_rebuild_reason = RebuildReason.BACKEND_REJECTED + self._rebuild(model, RebuildReason.BACKEND_REJECTED) + return diff + self.model = model + if self.track_updates: + self.snapshot = diff.snapshot + clear_coef_dirty(model) + self._in_place_updates += 1 + self._last_rebuild_reason = None + return diff + + def _rebuild(self, model: Model, reason: RebuildReason) -> None: + self.close() + self.model = model + self._build() + self._rebuilds += 1 + self._last_rebuild_reason = reason def _run_direct(self, **run_kwargs: Any) -> Result: """Run the pre-built native solver model. Override per-solver.""" @@ -801,6 +1133,18 @@ def __del__(self) -> None: with contextlib.suppress(Exception): self.close() + def __getstate__(self) -> dict[str, Any]: + drop = {"solver_model", "env", "_env_stack", "snapshot", "_lock"} + return {k: v for k, v in self.__dict__.items() if k not in drop} + + def __setstate__(self, state: dict[str, Any]) -> None: + self.__dict__.update(state) + self.solver_model = None + self.env = None + self._env_stack = None + self.snapshot = None + self._lock = threading.Lock() + def __repr__(self) -> str: status = self.status.status.value if self.status is not None else "unsolved" parts = [f"name={self.solver_name.value!r}", f"status={status!r}"] @@ -1209,12 +1553,61 @@ class Highs(Solver[None]): SolverFeature.MIP_DUAL_BOUND_REPORT, } ) + supports_persistent_update: ClassVar[bool] = True @classmethod @functools.cache def is_available(cls) -> bool: return _has_module("highspy") + @classmethod + @functools.cache + def _vtype_map(cls) -> dict[VarKind, Any]: + return { + VarKind.CONTINUOUS: highspy.HighsVarType.kContinuous, + VarKind.BINARY: highspy.HighsVarType.kInteger, + VarKind.INTEGER: highspy.HighsVarType.kInteger, + VarKind.SEMI_CONTINUOUS: highspy.HighsVarType.kSemiContinuous, + } + + def _apply_var_bounds( + self, ctx: Any, indices: np.ndarray, lower: np.ndarray, upper: np.ndarray + ) -> None: + ctx.changeColsBounds(indices.size, indices, lower, upper) + + def _apply_var_types( + self, ctx: Any, positions: np.ndarray, kinds: np.ndarray + ) -> None: + type_map = self._vtype_map() + integrality = np.fromiter( + (int(type_map[k]) for k in kinds), + dtype=np.uint8, + count=positions.size, + ) + ctx.changeColsIntegrality(positions.size, positions, integrality) + + def _apply_con_rhs(self, ctx: Any, diff: ModelDiff) -> None: + lower, upper = diff.con_rhs_as_bounds() + for pos, lo, up in zip(diff.con_rhs_indices, lower, upper): + ctx.changeRowBounds(int(pos), float(lo), float(up)) + + def _apply_con_coefs( + self, ctx: Any, rows: np.ndarray, cols: np.ndarray, vals: np.ndarray + ) -> None: + for i in range(rows.size): + ctx.changeCoeff(int(rows[i]), int(cols[i]), float(vals[i])) + + def _apply_obj_linear( + self, ctx: Any, indices: np.ndarray, values: np.ndarray + ) -> None: + ctx.changeColsCost(indices.size, indices, values) + + def _apply_obj_sense(self, ctx: Any, sense: str) -> None: + native = ( + highspy.ObjSense.kMaximize if sense == "max" else highspy.ObjSense.kMinimize + ) + ctx.changeObjectiveSense(native) + def _build_direct( self, explicit_coordinate_names: bool = False, @@ -1515,6 +1908,8 @@ class Gurobi(Solver["gurobipy.Env | dict[str, Any] | None"]): SolverFeature.MIP_DUAL_BOUND_REPORT, } ) + supports_persistent_update: ClassVar[bool] = True + supports_sign_update: ClassVar[bool] = True @classmethod @functools.cache @@ -1628,6 +2023,91 @@ def _build_solver_model( gm.update() return gm + _GUROBI_VTYPE_MAP: ClassVar[dict[VarKind, str]] = { + VarKind.CONTINUOUS: "C", + VarKind.BINARY: "B", + VarKind.INTEGER: "I", + VarKind.SEMI_CONTINUOUS: "S", + } + _GUROBI_SIGN_MAP: ClassVar[dict[str, str]] = { + short_LESS_EQUAL: "<", + short_GREATER_EQUAL: ">", + EQUAL: "=", + } + _GUROBI_SENSE_MAP: ClassVar[dict[str, int]] = {"min": 1, "max": -1} + + def _apply_begin(self, var_label_index: Any, con_label_index: Any) -> Any: + gm = self.solver_model + gurobi_vars = gm.getVars() + gurobi_cons = gm.getConstrs() + if len(gurobi_vars) != var_label_index.n_active_vars: + raise UnsupportedUpdate("gurobi var count mismatch") + if len(gurobi_cons) != con_label_index.n_active_cons: + raise UnsupportedUpdate("gurobi con count mismatch") + return (gm, gurobi_vars, gurobi_cons) + + def _apply_end(self, ctx: Any) -> None: + ctx[0].update() + + def _apply_var_bounds( + self, ctx: Any, indices: np.ndarray, lower: np.ndarray, upper: np.ndarray + ) -> None: + gm, gvars, _ = ctx + subset = [gvars[int(i)] for i in indices] + gm.setAttr("LB", subset, lower.tolist()) + gm.setAttr("UB", subset, upper.tolist()) + + def _apply_var_types( + self, ctx: Any, positions: np.ndarray, kinds: np.ndarray + ) -> None: + gm, gvars, _ = ctx + subset = [gvars[int(p)] for p in positions] + vtypes = [self._GUROBI_VTYPE_MAP[k] for k in kinds] + gm.setAttr("VType", subset, vtypes) + + def _reclamp_binary_bounds( + self, ctx: Any, positions: np.ndarray, kinds: np.ndarray + ) -> None: + # Gurobi's VType 'B' natively implies [0, 1]; no bound writes needed. + return None + + def _apply_con_rhs(self, ctx: Any, diff: ModelDiff) -> None: + gm, _, gcons = ctx + subset = [gcons[int(r)] for r in diff.con_rhs_indices] + gm.setAttr("RHS", subset, diff.con_rhs_values.tolist()) + + def _apply_con_signs( + self, ctx: Any, indices: np.ndarray, signs: np.ndarray + ) -> None: + gm, _, gcons = ctx + senses = [] + for s in signs: + s_str = str(s) + if s_str not in self._GUROBI_SIGN_MAP: + raise UnsupportedUpdate(f"unknown sign {s_str!r}") + senses.append(self._GUROBI_SIGN_MAP[s_str]) + subset = [gcons[int(r)] for r in indices] + gm.setAttr("Sense", subset, senses) + + def _apply_con_coefs( + self, ctx: Any, rows: np.ndarray, cols: np.ndarray, vals: np.ndarray + ) -> None: + gm, gvars, gcons = ctx + for i in range(rows.size): + gm.chgCoeff(gcons[int(rows[i])], gvars[int(cols[i])], float(vals[i])) + + def _apply_obj_linear( + self, ctx: Any, indices: np.ndarray, values: np.ndarray + ) -> None: + gm, gvars, _ = ctx + subset = [gvars[int(i)] for i in indices] + gm.setAttr("Obj", subset, values.tolist()) + + def _apply_obj_sense(self, ctx: Any, sense: str) -> None: + if sense not in self._GUROBI_SENSE_MAP: + raise UnsupportedUpdate(f"unknown obj sense {sense!r}") + ctx[0].ModelSense = self._GUROBI_SENSE_MAP[sense] + def _run_direct( self, solution_fn: Path | None = None, @@ -2108,12 +2588,74 @@ class Xpress(Solver[None]): SolverFeature.SOS_CONSTRAINTS, } ) + supports_persistent_update: ClassVar[bool] = True + supports_sign_update: ClassVar[bool] = True + + _XPRESS_VTYPE_MAP: ClassVar[dict[VarKind, str]] = { + VarKind.CONTINUOUS: "C", + VarKind.BINARY: "B", + VarKind.INTEGER: "I", + VarKind.SEMI_CONTINUOUS: "S", + } + _XPRESS_ROWTYPE_MAP: ClassVar[dict[str, str]] = { + short_LESS_EQUAL: "L", + short_GREATER_EQUAL: "G", + EQUAL: "E", + } @classmethod @functools.cache def is_available(cls) -> bool: return _has_module("xpress") + def _apply_var_bounds( + self, ctx: Any, indices: np.ndarray, lower: np.ndarray, upper: np.ndarray + ) -> None: + cols = np.concatenate([indices, indices]).astype(np.int64, copy=False) + btypes = ["L"] * indices.size + ["U"] * indices.size + lb = np.where(np.isneginf(lower), -xpress.infinity, lower) + ub = np.where(np.isposinf(upper), xpress.infinity, upper) + vals = np.concatenate([lb, ub]).astype(float, copy=False) + ctx.chgbounds(cols.tolist(), btypes, vals.tolist()) + + def _apply_var_types( + self, ctx: Any, positions: np.ndarray, kinds: np.ndarray + ) -> None: + coltypes = [self._XPRESS_VTYPE_MAP[k] for k in kinds] + ctx.chgcoltype(positions.tolist(), coltypes) + + def _apply_con_rhs(self, ctx: Any, diff: ModelDiff) -> None: + ctx.chgrhs(_int_list(diff.con_rhs_indices), _float_list(diff.con_rhs_values)) + + def _apply_con_signs( + self, ctx: Any, indices: np.ndarray, signs: np.ndarray + ) -> None: + rowtypes = [] + for s in signs: + s_str = str(s) + if s_str not in self._XPRESS_ROWTYPE_MAP: + raise UnsupportedUpdate(f"unknown sign {s_str!r}") + rowtypes.append(self._XPRESS_ROWTYPE_MAP[s_str]) + ctx.chgrowtype(_int_list(indices), rowtypes) + + def _apply_con_coefs( + self, ctx: Any, rows: np.ndarray, cols: np.ndarray, vals: np.ndarray + ) -> None: + ctx.chgmcoef(_int_list(rows), _int_list(cols), _float_list(vals)) + + def _apply_obj_linear( + self, ctx: Any, indices: np.ndarray, values: np.ndarray + ) -> None: + ctx.chgobj(_int_list(indices), _float_list(values)) + + def _apply_obj_sense(self, ctx: Any, sense: str) -> None: + if sense == "max": + ctx.chgobjsense(xpress.maximize) + elif sense == "min": + ctx.chgobjsense(xpress.minimize) + else: + raise UnsupportedUpdate(f"unknown obj sense {sense!r}") + def _build_direct( self, explicit_coordinate_names: bool = False, @@ -2750,6 +3292,7 @@ class Mosek(Solver[None]): SolverFeature.SOLUTION_FILE_NOT_NEEDED, } ) + supports_persistent_update: ClassVar[bool] = True @classmethod @functools.cache @@ -2761,6 +3304,60 @@ def _license_probe(cls) -> None: with mosek.Env() as env, env.Task(0, 0) as task: task.optimize() + def _validate_update(self, diff: ModelDiff) -> None: + super()._validate_update(diff) + if (diff.var_type_kinds == VarKind.SEMI_CONTINUOUS).any(): + raise UnsupportedUpdate("MOSEK does not support semi-continuous variables") + + def _apply_var_bounds( + self, ctx: Any, indices: np.ndarray, lower: np.ndarray, upper: np.ndarray + ) -> None: + for k in range(indices.size): + j = int(indices[k]) + lb = float(lower[k]) + ub = float(upper[k]) + ctx.chgvarbound(j, 1, int(np.isfinite(lb)), lb) + ctx.chgvarbound(j, 0, int(np.isfinite(ub)), ub) + + def _apply_var_types( + self, ctx: Any, positions: np.ndarray, kinds: np.ndarray + ) -> None: + integer_mask = (kinds == VarKind.BINARY) | (kinds == VarKind.INTEGER) + vartypes = np.where( + integer_mask, + mosek.variabletype.type_int, + mosek.variabletype.type_cont, + ).tolist() + ctx.putvartypelist(_int_list(positions, np.int32), vartypes) + + def _apply_con_rhs(self, ctx: Any, diff: ModelDiff) -> None: + lower, upper = diff.con_rhs_as_bounds() + for k, i in enumerate(diff.con_rhs_indices): + lo = float(lower[k]) + up = float(upper[k]) + ctx.chgconbound(int(i), 1, int(np.isfinite(lo)), lo) + ctx.chgconbound(int(i), 0, int(np.isfinite(up)), up) + + def _apply_con_coefs( + self, ctx: Any, rows: np.ndarray, cols: np.ndarray, vals: np.ndarray + ) -> None: + ctx.putaijlist( + _int_list(rows, np.int32), _int_list(cols, np.int32), _float_list(vals) + ) + + def _apply_obj_linear( + self, ctx: Any, indices: np.ndarray, values: np.ndarray + ) -> None: + ctx.putclist(_int_list(indices, np.int32), _float_list(values)) + + def _apply_obj_sense(self, ctx: Any, sense: str) -> None: + if sense == "max": + ctx.putobjsense(mosek.objsense.maximize) + elif sense == "min": + ctx.putobjsense(mosek.objsense.minimize) + else: + raise UnsupportedUpdate(f"unknown obj sense {sense!r}") + def _run_direct( self, solution_fn: Path | None = None, diff --git a/linopy/variables.py b/linopy/variables.py index 0eacfdc4..772284cd 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -48,7 +48,6 @@ get_label_position, has_optimized_model, iterate_slices, - require_constant, save_join, set_int_index, to_dataframe, @@ -895,18 +894,18 @@ def upper(self) -> DataArray: return self.data.upper @upper.setter - @require_constant def upper(self, value: ConstantLike) -> None: """ - Set the upper bounds of the variables. - - The function raises an error in case no model is set as a - reference. + Syntactic sugar for :meth:`Variable.update`. Do not add logic + here; mutate via ``update`` so the contract stays single-sourced. """ - value = DataArray(value).broadcast_like(self.upper) - if not set(value.dims).issubset(self.model.variables[self.name].dims): - raise ValueError("Cannot assign new dimensions to existing variable.") - self._data = assign_multiindex_safe(self.data, upper=value) + warn( + "Variable.upper setter is deprecated and will be removed in a " + "future release; use Variable.update(upper=...) instead.", + DeprecationWarning, + stacklevel=2, + ) + self.update(upper=value) @property def lower(self) -> DataArray: @@ -919,18 +918,100 @@ def lower(self) -> DataArray: return self.data.lower @lower.setter - @require_constant def lower(self, value: ConstantLike) -> None: """ - Set the lower bounds of the variables. + Syntactic sugar for :meth:`Variable.update`. Do not add logic + here; mutate via ``update`` so the contract stays single-sourced. + """ + warn( + "Variable.lower setter is deprecated and will be removed in a " + "future release; use Variable.update(lower=...) instead.", + DeprecationWarning, + stacklevel=2, + ) + self.update(lower=value) - The function raises an error in case no model is set as a - reference. + def update( + self, + *, + lower: ConstantLike | None = None, + upper: ConstantLike | None = None, + ) -> Variable: + """ + Update variable bounds in place. + + Canonical mutation API. Validation and coord alignment live here. + Single-attribute setters (`var.lower = …`) forward to this method. + + Parameters + ---------- + lower : ConstantLike, optional + New lower bound. Accepts any constant — scalars, numpy + arrays, pandas Series / DataFrame, xarray DataArray (e.g. + time-varying bounds). Aligned via xarray broadcast against + the variable's existing shape; new dims are rejected. + Decision variables / linear expressions are not accepted. + upper : ConstantLike, optional + New upper bound. Same. + + Returns + ------- + Variable + ``self`` for chaining. + + Raises + ------ + TypeError + If either bound is a Variable / Expression (bounds must be + numeric, not symbolic). + ValueError + If the new bound introduces dimensions not in the variable's + coords, or if the resulting ``lower > upper`` anywhere. """ - value = DataArray(value).broadcast_like(self.lower) - if not set(value.dims).issubset(self.model.variables[self.name].dims): - raise ValueError("Cannot assign new dimensions to existing variable.") - self._data = assign_multiindex_safe(self.data, lower=value) + if lower is None and upper is None: + return self + + updates = self._validate_update(lower=lower, upper=upper) + self._data = assign_multiindex_safe(self.data, **updates) + return self + + def _validate_update( + self, + *, + lower: ConstantLike | None = None, + upper: ConstantLike | None = None, + ) -> dict[str, DataArray]: + """ + Validate, broadcast, and cross-check update inputs. + + Returns the broadcasted DataArrays ready for assignment. Raises + before any mutation if any input is wrong. + """ + updates: dict[str, DataArray] = {} + own_dims = self.model.variables[self.name].dims + for name, val, ref in ( + ("lower", lower, self.lower), + ("upper", upper, self.upper), + ): + if val is None: + continue + if not isinstance(val, ConstantLike): + raise TypeError( + f"Variable.update({name}=...) must be a constant; " + f"got {type(val).__name__}." + ) + new_val = DataArray(val).broadcast_like(ref) + if not set(new_val.dims).issubset(own_dims): + raise ValueError("Cannot assign new dimensions to existing variable.") + updates[name] = new_val + + final_lower = updates.get("lower", self.lower) + final_upper = updates.get("upper", self.upper) + if bool((final_lower > final_upper).any()): + raise ValueError( + "Variable.update would leave lower > upper at one or more coordinates." + ) + return updates @property @has_optimized_model @@ -1424,8 +1505,7 @@ def fix( **{STASHED_LOWER: lower, STASHED_UPPER: upper}, ) - self.lower = value - self.upper = value + self.update(lower=value, upper=value) def unfix(self) -> None: """ @@ -1434,8 +1514,7 @@ def unfix(self) -> None: if not self.fixed: return - self.lower = self.data[STASHED_LOWER] - self.upper = self.data[STASHED_UPPER] + self.update(lower=self.data[STASHED_LOWER], upper=self.data[STASHED_UPPER]) self._data = self.data.drop_vars(STASHED_ATTRS) @property diff --git a/test/test_constraint.py b/test/test_constraint.py index a684b966..225b76c1 100644 --- a/test/test_constraint.py +++ b/test/test_constraint.py @@ -357,7 +357,9 @@ def test_constraint_vars_setter( def test_constraint_vars_setter_with_array( mc: linopy.constraints.Constraint, x: linopy.Variable ) -> None: - mc.vars = x.labels + """Passing a raw DataArray is deprecated but still works for back-compat.""" + with pytest.warns(FutureWarning, match="DataArray"): + mc.vars = x.labels assert_equal(mc.vars, x.labels) @@ -421,15 +423,124 @@ def test_constraint_sign_setter_invalid( def test_constraint_rhs_setter(mc: linopy.constraints.Constraint) -> None: sizes = mc.sizes - mc.rhs = 2 # type: ignore + mc.rhs = 2 assert (mc.rhs == 2).all() assert mc.sizes == sizes +def test_constraint_update_rhs_and_sign(mc: linopy.constraints.Constraint) -> None: + mc.update(rhs=5, sign=EQUAL) + assert (mc.rhs == 5).all() + assert (mc.sign == EQUAL).all() + + +def test_constraint_update_no_kwargs_is_noop( + mc: linopy.constraints.Constraint, +) -> None: + old_rhs = mc.rhs.copy() + old_sign = mc.sign.copy() + mc.update() + assert (mc.rhs == old_rhs).all() + assert (mc.sign == old_sign).all() + + +def test_constraint_update_rearranges_variable_rhs( + mc: linopy.constraints.Constraint, x: linopy.Variable +) -> None: + """ + Variable / Expression rhs is moved onto lhs; only the constant + part lands on rhs (mirrors add_constraints and the .rhs setter). + """ + mc.update(rhs=x + 3) + assert (mc.rhs == 3).all() + assert mc.lhs.nterm == 2 # original term + the rearranged -x + + +def test_constraint_update_returns_self( + mc: linopy.constraints.Constraint, +) -> None: + out = mc.update(rhs=7) + assert out is mc + + +def test_constraint_update_positional_constraint_expression( + mc: linopy.constraints.Constraint, x: linopy.Variable, y: linopy.Variable +) -> None: + """``c.update(x + 5 <= 3)`` replaces lhs / sign / rhs in one call.""" + mc.update(x + y <= 7) + assert (mc.rhs == 7).all() + assert (mc.sign == LESS_EQUAL).all() + assert mc.lhs.nterm == 2 + + +def test_constraint_update_positional_rejects_mixing_kwargs( + mc: linopy.constraints.Constraint, x: linopy.Variable +) -> None: + """Positional constraint can't be combined with keyword updates.""" + with pytest.raises(TypeError, match="cannot be combined with keyword"): + mc.update(x <= 3, sign=EQUAL) + + +def test_constraint_update_positional_rejects_non_constraint( + mc: linopy.constraints.Constraint, +) -> None: + """Random objects are rejected with a clear error.""" + with pytest.raises(TypeError, match="must be a ConstraintLike"): + mc.update("not a constraint") # type: ignore + + +def test_constraint_update_lhs_only( + mc: linopy.constraints.Constraint, x: linopy.Variable, y: linopy.Variable +) -> None: + """lhs= alone replaces the expression; rhs and sign untouched.""" + old_rhs = mc.rhs.copy() + old_sign = mc.sign.copy() + mc.update(lhs=5 * x + 7 * y) + assert (mc.rhs == old_rhs).all() + assert (mc.sign == old_sign).all() + assert mc.lhs.nterm == 2 + + +def test_constraint_update_coeffs_only_keeps_values( + mc: linopy.constraints.Constraint, +) -> None: + """coeffs= alone replaces the coef array element-wise; vars untouched.""" + old_vars = mc.vars.copy() + mc.update(coeffs=mc.coeffs * 10) + assert (mc.vars == old_vars).all() + # original was mc.lhs with leading coeff; *10 → all coeffs *10 + assert mc.coeffs.max() >= 10 + + +def test_constraint_update_lhs_and_sign_together( + mc: linopy.constraints.Constraint, x: linopy.Variable +) -> None: + """Compound updates compose: lhs replacement + sign flip in one call.""" + mc.update(lhs=2 * x, sign=EQUAL) + assert (mc.sign == EQUAL).all() + assert mc.lhs.nterm == 1 + + +def test_constraint_update_lhs_and_coeffs_rejected( + mc: linopy.constraints.Constraint, x: linopy.Variable +) -> None: + """lhs= (full replacement) and coeffs= (partial) are mutually exclusive.""" + with pytest.raises(TypeError, match="lhs.*coeffs.*variables"): + mc.update(lhs=2 * x, coeffs=mc.coeffs * 2) + + +def test_constraint_update_lhs_and_variables_rejected( + mc: linopy.constraints.Constraint, x: linopy.Variable +) -> None: + """lhs= (full replacement) and variables= (partial) are mutually exclusive.""" + with pytest.raises(TypeError, match="lhs.*coeffs.*variables"): + mc.update(lhs=2 * x, variables=mc.vars) + + def test_constraint_rhs_setter_with_variable( mc: linopy.constraints.Constraint, x: linopy.Variable ) -> None: - mc.rhs = x # type: ignore + mc.rhs = x assert (mc.rhs == 0).all() assert (mc.coeffs.isel({mc.term_dim: -1}) == -1).all() assert mc.lhs.nterm == 2 @@ -461,7 +572,7 @@ def test_constraint_rhs_setter_broadcasts_missing_dim() -> None: ) con = m.add_constraints(1 * x >= 0, name="con") - con.rhs = xr.DataArray([1.0, 2.0], dims=["i"], coords={"i": [0, 1]}) # type: ignore + con.rhs = xr.DataArray([1.0, 2.0], dims=["i"], coords={"i": [0, 1]}) assert dict(con.rhs.sizes) == {"i": 2, "j": 3} assert (con.rhs.sel(i=1) == 2.0).all() @@ -486,7 +597,7 @@ def test_constraint_rhs_setter_projects_multiindex_level() -> None: [10.0, 20.0], coords={"level1": [1, 2]}, dims=["level1"] ) with pytest.warns(linopy.EvolvingAPIWarning, match="broadcasting level subset"): - con.rhs = rhs_by_level # type: ignore + con.rhs = rhs_by_level assert con.rhs.sel(dim_3=(1, "b")).item() == 10.0 assert con.rhs.sel(dim_3=(2, "a")).item() == 20.0 diff --git a/test/test_constraint_coef_dirty.py b/test/test_constraint_coef_dirty.py new file mode 100644 index 00000000..6e32217b --- /dev/null +++ b/test/test_constraint_coef_dirty.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import pytest + +from linopy import Model + + +@pytest.fixture +def m_with_c() -> tuple[Model, str]: + m = Model() + x = m.add_variables(0, 10, coords=[range(3)], name="x") + y = m.add_variables(0, 10, coords=[range(3)], name="y") + m.add_constraints(2 * x + y >= 5, name="c") + return m, "c" + + +def test_initial_coef_dirty_false(m_with_c: tuple[Model, str]) -> None: + m, name = m_with_c + assert m.constraints[name]._coef_dirty is False + + +def test_update_coeffs_sets_dirty(m_with_c: tuple[Model, str]) -> None: + m, name = m_with_c + c = m.constraints[name] + c.update(coeffs=c.coeffs * 2) + assert c._coef_dirty is True + + +def test_update_variables_sets_dirty(m_with_c: tuple[Model, str]) -> None: + m, name = m_with_c + c = m.constraints[name] + x = m.variables["x"] + c.update(variables=x) + assert c._coef_dirty is True + + +def test_update_lhs_sets_dirty(m_with_c: tuple[Model, str]) -> None: + m, name = m_with_c + c = m.constraints[name] + x = m.variables["x"] + c.update(lhs=3 * x) + assert c._coef_dirty is True + + +def test_update_pure_constant_rhs_short_circuits(m_with_c: tuple[Model, str]) -> None: + m, name = m_with_c + c = m.constraints[name] + coeffs_buf = c.data["coeffs"].values + vars_buf = c.data["vars"].values + c.update(rhs=9) + assert c._coef_dirty is False + assert c.data["coeffs"].values is coeffs_buf + assert c.data["vars"].values is vars_buf + + +def test_update_rhs_with_variable_sets_dirty(m_with_c: tuple[Model, str]) -> None: + m, name = m_with_c + c = m.constraints[name] + x = m.variables["x"] + c.update(rhs=x + 3) + assert c._coef_dirty is True + + +def test_update_sign_does_not_set_dirty(m_with_c: tuple[Model, str]) -> None: + m, name = m_with_c + c = m.constraints[name] + c.update(sign="<=") + assert c._coef_dirty is False + + +def test_flag_persists_across_container_access(m_with_c: tuple[Model, str]) -> None: + m, name = m_with_c + m.constraints[name].update(coeffs=m.constraints[name].coeffs * 2) + assert m.constraints[name]._coef_dirty is True + + +def test_update_positional_constraint_sets_dirty(m_with_c: tuple[Model, str]) -> None: + """Positional ``c.update(expr <= rhs)`` rewrites lhs and must flip the flag.""" + m, name = m_with_c + c = m.constraints[name] + x = m.variables["x"] + c.update(4 * x >= 7) + assert c._coef_dirty is True + + +def test_update_noop_does_not_set_dirty(m_with_c: tuple[Model, str]) -> None: + """``c.update()`` with no args is a no-op and must not flip the flag.""" + m, name = m_with_c + c = m.constraints[name] + c.update() + assert c._coef_dirty is False diff --git a/test/test_persistent_apply_update.py b/test/test_persistent_apply_update.py new file mode 100644 index 00000000..41663dab --- /dev/null +++ b/test/test_persistent_apply_update.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np +import pytest + +from linopy import Model +from linopy.persistent import RebuildReason +from linopy.solvers import Gurobi, Highs, Mosek, Solver, Xpress + +_BACKENDS: dict[str, tuple[type[Solver], dict[str, Any]]] = { + "gurobi": (Gurobi, {"OutputFlag": 0}), + "highs": (Highs, {"output_flag": False}), + "xpress": (Xpress, {"OUTPUTLOG": 0}), + "mosek": (Mosek, {"MSK_IPAR_LOG": 0}), +} + +_SIGN_CHANGE_IN_PLACE: dict[str, bool] = { + "gurobi": True, + "highs": False, + "xpress": True, + "mosek": False, +} + + +def _have(name: str) -> bool: + cls = _BACKENDS[name][0] + if not cls.is_available(): + return False + try: + cls._license_probe() + except Exception: + return False + if name == "xpress": + try: + import xpress + + xpress.problem() + except Exception: + return False + return True + + +SOLVER_PARAMS = [ + pytest.param( + name, + marks=pytest.mark.skipif(not _have(name), reason=f"{name} not installed"), + ) + for name in _BACKENDS +] + + +def _base_model() -> Model: + m = Model() + x = m.add_variables(0, 10, coords=[range(3)], name="x") + y = m.add_variables(0, 10, coords=[range(3)], name="y") + m.add_constraints(x + y >= 4, name="c1") + m.add_constraints(2 * x + y <= 20, name="c2") + m.add_objective(x.sum() + 2 * y.sum()) + return m + + +def _built(solver_name: str, model: Model) -> Solver: + cls, opts = _BACKENDS[solver_name] + s = cls(model=model, io_api="direct", track_updates=True) + s.options = opts + s._build() + return s + + +def _solve(solver: Solver, model: Model) -> float: + result = solver.solve(model, assign=True) + assert result.solution is not None + return float(result.solution.objective) + + +def _obj(model: Model) -> float: + value = model.objective.value + assert value is not None + return float(value) + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_var_lb_in_place(solver_name: str) -> None: + m = _base_model() + s = _built(solver_name, m) + s.solve(assign=True) + base_obj = _obj(m) + + m.variables["x"].lower.values[...] = 5.0 + obj = _solve(s, m) + assert s._in_place_updates == 1 + assert s._rebuilds == 0 + assert s._last_rebuild_reason is None + assert obj > base_obj + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_var_ub_in_place(solver_name: str) -> None: + m = _base_model() + s = _built(solver_name, m) + s.solve(assign=True) + + m.variables["x"].upper.values[...] = 1.0 + _solve(s, m) + assert s._in_place_updates == 1 + assert s._rebuilds == 0 + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_rhs_only_in_place(solver_name: str) -> None: + m = _base_model() + s = _built(solver_name, m) + s.solve(assign=True) + base_obj = _obj(m) + + c = m.constraints["c1"] + c.rhs = 8.0 + assert c._coef_dirty is False + obj = _solve(s, m) + assert s._in_place_updates == 1 + assert s._rebuilds == 0 + assert obj > base_obj + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_constraint_coef_change_in_place(solver_name: str) -> None: + m = _base_model() + s = _built(solver_name, m) + s.solve(assign=True) + base_obj = _obj(m) + + c = m.constraints["c1"] + c.coeffs = c.coeffs * 2 + obj = _solve(s, m) + assert s._in_place_updates == 1 + assert s._rebuilds == 0 + assert not np.isclose(obj, base_obj) + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_objective_linear_change_in_place(solver_name: str) -> None: + m = _base_model() + s = _built(solver_name, m) + s.solve(assign=True) + base_obj = _obj(m) + + x = m.variables["x"] + y = m.variables["y"] + m.objective.expression = 5 * x.sum() + 3 * y.sum() + obj = _solve(s, m) + assert s._in_place_updates == 1 + assert s._rebuilds == 0 + assert not np.isclose(obj, base_obj) + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_objective_sense_flip_in_place(solver_name: str) -> None: + m = _base_model() + s = _built(solver_name, m) + s.solve(assign=True) + min_obj = _obj(m) + + m.objective.sense = "max" + max_obj = _solve(s, m) + assert s._in_place_updates == 1 + assert s._rebuilds == 0 + assert max_obj > min_obj + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_sparsity_change_triggers_rebuild(solver_name: str) -> None: + m = _base_model() + s = _built(solver_name, m) + s.solve(assign=True) + + x = m.variables["x"] + m.add_constraints(x <= 5, name="c3") + s.solve(m, assign=True) + assert s._rebuilds == 1 + assert s._last_rebuild_reason in { + RebuildReason.STRUCTURAL_LABELS, + RebuildReason.STRUCTURAL_CONTAINERS, + } + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_cross_model_in_place(solver_name: str) -> None: + m1 = _base_model() + s = _built(solver_name, m1) + s.solve(assign=True) + + m2 = _base_model() + m2.constraints["c1"].rhs = 8.0 + + s.solve(m2, assign=True) + assert s._in_place_updates == 1 + assert s._rebuilds == 0 + + cross_obj = _obj(m2) + m3 = _base_model() + m3.constraints["c1"].rhs = 8.0 + s_fresh = _built(solver_name, m3) + s_fresh.solve(assign=True) + assert np.isclose(cross_obj, _obj(m3)) + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_sign_flip(solver_name: str) -> None: + m = _base_model() + s = _built(solver_name, m) + s.solve(assign=True) + + m.constraints["c1"].sign = "<=" + s.solve(m, assign=True) + if _SIGN_CHANGE_IN_PLACE[solver_name]: + assert s._in_place_updates == 1 + assert s._rebuilds == 0 + else: + assert s._rebuilds == 1 + assert s._last_rebuild_reason is RebuildReason.BACKEND_REJECTED diff --git a/test/test_persistent_snapshot_buffers.py b/test/test_persistent_snapshot_buffers.py new file mode 100644 index 00000000..bb801ecf --- /dev/null +++ b/test/test_persistent_snapshot_buffers.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import numpy as np +import pytest + +from linopy import Model +from linopy.persistent import ModelDiff, ModelSnapshot, RebuildReason +from linopy.persistent.snapshot import _extract_con_buffers + + +def _build_permuted_pair() -> tuple[Model, Model]: + m1 = Model() + x1 = m1.add_variables(0, 10, coords=[range(3)], name="x") + y1 = m1.add_variables(0, 5, coords=[range(2)], name="y") + m1.add_constraints(2 * x1 + 3 * y1.sum() >= 4, name="c1") + m1.add_objective(x1.sum()) + + m2 = Model() + x2 = m2.add_variables(0, 10, coords=[range(3)], name="x") + y2 = m2.add_variables(0, 5, coords=[range(2)], name="y") + m2.add_constraints(3 * y2.sum() + 2 * x2 >= 4, name="c1") + m2.add_objective(x2.sum()) + return m1, m2 + + +def test_permuted_term_order_produces_equal_buffers() -> None: + m1, m2 = _build_permuted_pair() + s1 = ModelSnapshot.capture(m1) + s2 = ModelSnapshot.capture(m2) + b1 = s1.con_buffers["c1"] + b2 = s2.con_buffers["c1"] + np.testing.assert_array_equal(b1.indptr, b2.indptr) + np.testing.assert_array_equal(b1.indices, b2.indices) + np.testing.assert_array_equal(b1.data, b2.data) + + +def test_active_labels_match_label_index(baseline_model: Model) -> None: + snap = ModelSnapshot.capture(baseline_model) + expected = baseline_model.constraints.label_index.clabels + concatenated = np.concatenate( + [buf.active_labels for buf in snap.con_buffers.values()] + ) + np.testing.assert_array_equal(concatenated, expected) + + +@pytest.fixture +def baseline_model() -> Model: + m = Model() + x = m.add_variables(0, 10, coords=[range(3)], name="x") + y = m.add_variables(0, 5, coords=[range(2)], name="y") + m.add_constraints(2 * x >= 4, name="c1") + m.add_constraints(x.sum() + y.sum() <= 20, name="c2") + m.add_objective(x.sum()) + return m + + +def test_shape_mismatch_triggers_sparsity_rebuild(baseline_model: Model) -> None: + snap = ModelSnapshot.capture(baseline_model) + x = baseline_model.variables["x"] + y = baseline_model.variables["y"] + baseline_model.constraints["c1"].lhs = 2 * x + 0 * y.sum() + diff = ModelDiff.from_snapshot(snap, baseline_model) + assert diff in { + RebuildReason.SPARSITY, + RebuildReason.STRUCTURAL_LABELS, + } + + +def test_zero_row_container_capture() -> None: + m = Model() + m.add_variables(0, 10, coords=[range(2)], name="x") + m.add_objective(0.0 * m.variables["x"].sum()) + snap = ModelSnapshot.capture(m) + assert snap.con_buffers == {} + diff = ModelDiff.from_snapshot(snap, m) + assert isinstance(diff, ModelDiff) + assert diff.is_empty + + +def test_con_buffers_dtypes(baseline_model: Model) -> None: + snap = ModelSnapshot.capture(baseline_model) + buf = snap.con_buffers["c1"] + assert buf.rhs.dtype == np.float64 + assert buf.sign.dtype == np.dtype("U1") + assert buf.data.dtype == np.float64 + assert np.issubdtype(buf.indices.dtype, np.integer) + assert np.issubdtype(buf.indptr.dtype, np.integer) + + +def test_masked_rows_excluded_from_active_labels() -> None: + m = Model() + x = m.add_variables(0, 10, coords=[range(4)], name="x") + mask = np.array([True, False, True, True]) + m.add_constraints(2 * x >= 1, mask=mask, name="c1") + m.add_objective(x.sum()) + snap = ModelSnapshot.capture(m) + buf = snap.con_buffers["c1"] + assert buf.active_labels.size == 3 + rebuilt = _extract_con_buffers(m.constraints["c1"], m.variables.label_index) + np.testing.assert_array_equal(rebuilt.active_labels, buf.active_labels) + + +def test_csr_capture_deterministic(baseline_model: Model) -> None: + s1 = ModelSnapshot.capture(baseline_model) + s2 = ModelSnapshot.capture(baseline_model) + for name in s1.con_buffers: + b1, b2 = s1.con_buffers[name], s2.con_buffers[name] + np.testing.assert_array_equal(b1.indptr, b2.indptr) + np.testing.assert_array_equal(b1.indices, b2.indices) + np.testing.assert_array_equal(b1.data, b2.data) + + +def test_duplicate_variable_terms_summed() -> None: + m1 = Model() + x1 = m1.add_variables(0, 10, coords=[range(3)], name="x") + m1.add_constraints(2 * x1 + 3 * x1 >= 1, name="c1") + m1.add_objective(x1.sum()) + + m2 = Model() + x2 = m2.add_variables(0, 10, coords=[range(3)], name="x") + m2.add_constraints(5 * x2 >= 1, name="c1") + m2.add_objective(x2.sum()) + + diff = ModelDiff.from_models(m1, m2) + assert isinstance(diff, ModelDiff) + assert diff.is_empty diff --git a/test/test_persistent_snapshot_diff.py b/test/test_persistent_snapshot_diff.py new file mode 100644 index 00000000..929d0575 --- /dev/null +++ b/test/test_persistent_snapshot_diff.py @@ -0,0 +1,294 @@ +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest + +from linopy import Model +from linopy.persistent import ( + ContainerConBuffers, + ContainerVarBuffers, + ModelDiff, + ModelSnapshot, + RebuildReason, + StructuralKey, +) + + +@pytest.fixture +def baseline() -> Model: + m = Model() + x = m.add_variables(0, 10, coords=[range(3)], name="x") + y = m.add_variables(0, 5, coords=[range(2)], name="y") + m.add_constraints(2 * x + 1 >= 4, name="c1") + m.add_constraints(x.sum() + y.sum() <= 20, name="c2") + m.add_objective(x.sum() + 2 * y.sum()) + return m + + +def test_capture_structural_key(baseline: Model) -> None: + snap = ModelSnapshot.capture(baseline) + assert isinstance(snap, ModelSnapshot) + assert isinstance(snap.structural_key, StructuralKey) + assert snap.structural_key.var_container_names == ("x", "y") + assert snap.structural_key.con_container_names == ("c1", "c2") + np.testing.assert_array_equal( + snap.structural_key.vlabels, baseline.variables.label_index.vlabels + ) + np.testing.assert_array_equal( + snap.structural_key.clabels, baseline.constraints.label_index.clabels + ) + assert isinstance(snap.var_buffers["x"], ContainerVarBuffers) + assert isinstance(snap.con_buffers["c1"], ContainerConBuffers) + + +def test_is_empty_on_unmutated(baseline: Model) -> None: + snap = ModelSnapshot.capture(baseline) + diff = ModelDiff.from_snapshot(snap, baseline) + assert isinstance(diff, ModelDiff) + assert diff.is_empty + + +def test_bounds_only_mutation(baseline: Model) -> None: + snap = ModelSnapshot.capture(baseline) + baseline.variables["x"].lower = 1 + diff = ModelDiff.from_snapshot(snap, baseline) + assert isinstance(diff, ModelDiff) + assert "x" in diff.changed_variables + assert "y" not in diff.changed_variables + sl = diff.var_slices["x"].bounds + np.testing.assert_array_equal(diff.var_bounds_lower[sl], np.ones(3)) + + +def test_rhs_only_mutation(baseline: Model) -> None: + snap = ModelSnapshot.capture(baseline) + baseline.constraints["c1"].rhs = 9 + diff = ModelDiff.from_snapshot(snap, baseline) + assert isinstance(diff, ModelDiff) + assert "c1" in diff.changed_constraints + sl = diff.con_slices["c1"] + assert sl.rhs.stop > sl.rhs.start + assert sl.coef.stop == sl.coef.start + + +def test_objective_linear_change(baseline: Model) -> None: + snap = ModelSnapshot.capture(baseline) + x = baseline.variables["x"] + y = baseline.variables["y"] + baseline.add_objective(3 * x.sum() + 2 * y.sum(), overwrite=True) + diff = ModelDiff.from_snapshot(snap, baseline) + assert isinstance(diff, ModelDiff) + assert diff.obj_c_indices is not None + assert diff.obj_c_values is not None + + +def test_objective_sense_flip(baseline: Model) -> None: + snap = ModelSnapshot.capture(baseline) + baseline.objective.sense = "max" + diff = ModelDiff.from_snapshot(snap, baseline) + assert isinstance(diff, ModelDiff) + assert diff.obj_sense == "max" + + +def test_add_constraints_is_structural(baseline: Model) -> None: + snap = ModelSnapshot.capture(baseline) + x = baseline.variables["x"] + baseline.add_constraints(x.sum() <= 99, name="c3") + diff = ModelDiff.from_snapshot(snap, baseline) + assert diff in ( + RebuildReason.STRUCTURAL_LABELS, + RebuildReason.STRUCTURAL_CONTAINERS, + ) + + +def test_remove_variables_is_structural(baseline: Model) -> None: + snap = ModelSnapshot.capture(baseline) + baseline.remove_variables("y") + diff = ModelDiff.from_snapshot(snap, baseline) + assert diff in ( + RebuildReason.STRUCTURAL_LABELS, + RebuildReason.STRUCTURAL_CONTAINERS, + ) + + +def test_coef_value_change_same_sparsity(baseline: Model) -> None: + snap = ModelSnapshot.capture(baseline) + c = baseline.constraints["c1"] + c.coeffs = c.coeffs * 3 + diff = ModelDiff.from_snapshot(snap, baseline) + assert isinstance(diff, ModelDiff) + assert "c1" in diff.changed_constraints + sl = diff.con_slices["c1"].coef + vals = diff.con_coef_vals[sl] + np.testing.assert_array_equal(vals, np.full(vals.size, 6.0)) + + +def test_coef_changes_across_containers(baseline: Model) -> None: + snap = ModelSnapshot.capture(baseline) + c1 = baseline.constraints["c1"] + c2 = baseline.constraints["c2"] + c1.update(coeffs=c1.coeffs * 3) + c2.update(coeffs=c2.coeffs * 2) + diff = ModelDiff.from_snapshot(snap, baseline) + assert isinstance(diff, ModelDiff) + sl1 = diff.con_slices["c1"].coef + sl2 = diff.con_slices["c2"].coef + assert diff.n_coef_updates == (sl1.stop - sl1.start) + (sl2.stop - sl2.start) + np.testing.assert_array_equal( + diff.con_coef_vals[sl1], np.full(sl1.stop - sl1.start, 6.0) + ) + np.testing.assert_array_equal( + diff.con_coef_vals[sl2], np.full(sl2.stop - sl2.start, 2.0) + ) + + +def test_coef_sparsity_change(baseline: Model) -> None: + snap = ModelSnapshot.capture(baseline) + x = baseline.variables["x"] + baseline.constraints["c2"].lhs = 2 * x.sum() + diff = ModelDiff.from_snapshot(snap, baseline) + assert diff is RebuildReason.SPARSITY + + +def test_deep_copy_invariant(baseline: Model) -> None: + snap = ModelSnapshot.capture(baseline) + baseline.variables["x"].lower.values[...] = 99 + diff = ModelDiff.from_snapshot(snap, baseline) + assert isinstance(diff, ModelDiff) + assert "x" in diff.changed_variables + + +def test_same_model_false_ignores_dirty_flag(baseline: Model) -> None: + snap = ModelSnapshot.capture(baseline) + c = baseline.constraints["c1"] + c.coeffs = c.coeffs * 5 + c._coef_dirty = False + diff_fast = ModelDiff.from_snapshot(snap, baseline, same_model=True) + assert isinstance(diff_fast, ModelDiff) + fast_coef = diff_fast.con_slices.get("c1") + assert fast_coef is None or fast_coef.coef.stop == fast_coef.coef.start + diff_full = ModelDiff.from_snapshot(snap, baseline, same_model=False) + assert isinstance(diff_full, ModelDiff) + full_coef = diff_full.con_slices["c1"].coef + assert full_coef.stop > full_coef.start + + +def test_from_models_diffs_two_models() -> None: + m1 = Model() + x1 = m1.add_variables(0, 10, coords=[range(3)], name="x") + m1.add_constraints(2 * x1 >= 4, name="c1") + m1.add_objective(x1.sum()) + + m2 = Model() + x2 = m2.add_variables(0, 10, coords=[range(3)], name="x") + m2.add_constraints(2 * x2 >= 7, name="c1") + m2.add_objective(x2.sum()) + + diff = ModelDiff.from_models(m1, m2) + assert isinstance(diff, ModelDiff) + assert "c1" in diff.changed_constraints + sl = diff.con_slices["c1"].rhs + np.testing.assert_array_equal(diff.con_rhs_values[sl], np.full(3, 7.0)) + + +def test_ignore_dims_detects_coord_change() -> None: + m1 = Model() + m1.add_variables(0, 10, coords=[pd.Index([0, 1, 2], name="t")], name="x") + m1.add_constraints(m1.variables["x"] >= 0, name="c1") + m1.add_objective(m1.variables["x"].sum()) + snap = ModelSnapshot.capture(m1) + + m2 = Model() + m2.add_variables(0, 10, coords=[pd.Index([10, 11, 12], name="t")], name="x") + m2.add_constraints(m2.variables["x"] >= 0, name="c1") + m2.add_objective(m2.variables["x"].sum()) + + assert ModelDiff.from_snapshot(snap, m2) is RebuildReason.COORD_REINDEX + assert isinstance(ModelDiff.from_snapshot(snap, m2, ignore_dims={"t"}), ModelDiff) + + +def _assert_snapshot_equal(a: ModelSnapshot, b: ModelSnapshot) -> None: + assert a.structural_key == b.structural_key + assert a.var_buffers.keys() == b.var_buffers.keys() + assert a.con_buffers.keys() == b.con_buffers.keys() + for name, va in a.var_buffers.items(): + vb = b.var_buffers[name] + np.testing.assert_array_equal(va.lower, vb.lower) + np.testing.assert_array_equal(va.upper, vb.upper) + np.testing.assert_array_equal(va.active_labels, vb.active_labels) + assert va.type is vb.type + for name, ca in a.con_buffers.items(): + cb = b.con_buffers[name] + for attr in ("indptr", "indices", "data", "rhs", "sign", "active_labels"): + np.testing.assert_array_equal(getattr(ca, attr), getattr(cb, attr)) + for coords_a, coords_b in ( + (a.var_coords, b.var_coords), + (a.con_coords, b.con_coords), + ): + assert coords_a.keys() == coords_b.keys() + for name in coords_a: + assert coords_a[name].keys() == coords_b[name].keys() + for dim in coords_a[name]: + np.testing.assert_array_equal(coords_a[name][dim], coords_b[name][dim]) + np.testing.assert_array_equal(a.obj_c, b.obj_c) + assert a.obj_quad_present == b.obj_quad_present + assert a.obj_sense == b.obj_sense + + +def test_capture_is_pure(baseline: Model) -> None: + c = baseline.constraints["c1"] + c.update(coeffs=c.coeffs * 2) + assert c._coef_dirty is True + ModelSnapshot.capture(baseline) + assert c._coef_dirty is True + + +@pytest.mark.parametrize( + "mutate", ["none", "rhs", "bounds", "coeffs", "objective", "combined"] +) +def test_diff_snapshot_matches_capture(baseline: Model, mutate: str) -> None: + snap = ModelSnapshot.capture(baseline) + x = baseline.variables["x"] + y = baseline.variables["y"] + if mutate in ("rhs", "combined"): + baseline.constraints["c1"].update(rhs=9) + if mutate in ("bounds", "combined"): + x.update(lower=1) + if mutate in ("coeffs", "combined"): + c2 = baseline.constraints["c2"] + c2.update(coeffs=c2.coeffs * 3) + if mutate in ("objective", "combined"): + baseline.add_objective(3 * x.sum() + 2 * y.sum(), overwrite=True) + diff = ModelDiff.from_snapshot(snap, baseline) + assert isinstance(diff, ModelDiff) + _assert_snapshot_equal(diff.snapshot, ModelSnapshot.capture(baseline)) + + +def test_diff_snapshot_matches_capture_under_ignore_dims() -> None: + def build(t0: int) -> Model: + m = Model() + t = pd.Index(range(t0, t0 + 3), name="t") + m.add_variables(0, 10, coords=[t], name="x") + m.add_constraints(m.variables["x"] >= 0, name="c1") + m.add_objective(m.variables["x"].sum()) + return m + + m1, m2 = build(0), build(10) + snap = ModelSnapshot.capture(m1) + diff = ModelDiff.from_snapshot(snap, m2, ignore_dims={"t"}) + assert isinstance(diff, ModelDiff) + _assert_snapshot_equal(diff.snapshot, ModelSnapshot.capture(m2)) + + +def test_from_models_snapshot_matches_capture() -> None: + def build(rhs: float) -> Model: + m = Model() + x = m.add_variables(0, 10, coords=[range(3)], name="x") + m.add_constraints(2 * x >= rhs, name="c1") + m.add_objective(x.sum()) + return m + + m1, m2 = build(4.0), build(7.0) + diff = ModelDiff.from_models(m1, m2) + assert isinstance(diff, ModelDiff) + _assert_snapshot_equal(diff.snapshot, ModelSnapshot.capture(m2)) diff --git a/test/test_persistent_solver_extras.py b/test/test_persistent_solver_extras.py new file mode 100644 index 00000000..3fc28d18 --- /dev/null +++ b/test/test_persistent_solver_extras.py @@ -0,0 +1,466 @@ +from __future__ import annotations + +import pickle +import threading +from typing import Any + +import numpy as np +import pytest + +from linopy import Model +from linopy.persistent import ModelDiff, RebuildReason, UpdatesDisabledError +from linopy.solvers import Gurobi, Highs, Solver + +_BACKENDS: dict[str, tuple[type[Solver], dict[str, Any]]] = { + "gurobi": (Gurobi, {"OutputFlag": 0}), + "highs": (Highs, {"output_flag": False}), +} + + +def _have(name: str) -> bool: + try: + if name == "gurobi": + import gurobipy # noqa: F401 + elif name == "highs": + import highspy # noqa: F401 + return True + except ImportError: + return False + + +SOLVER_PARAMS = [ + pytest.param( + "gurobi", + marks=pytest.mark.skipif(not _have("gurobi"), reason="gurobipy not installed"), + ), + pytest.param( + "highs", + marks=pytest.mark.skipif(not _have("highs"), reason="highspy not installed"), + ), +] + + +def _base_model() -> Model: + m = Model() + x = m.add_variables(0, 10, coords=[range(3)], name="x") + y = m.add_variables(0, 10, coords=[range(3)], name="y") + m.add_constraints(x + y >= 4, name="c1") + m.add_constraints(2 * x + y <= 20, name="c2") + m.add_objective(x.sum() + 2 * y.sum()) + return m + + +def _built(solver_name: str, model: Model) -> Solver: + cls, opts = _BACKENDS[solver_name] + s = cls(model=model, io_api="direct", track_updates=True) + s.options = opts + s._build() + return s + + +def _obj(model: Model) -> float: + value = model.objective.value + assert value is not None + return float(value) + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_noop_resolve_increments_in_place(solver_name: str) -> None: + m = _base_model() + s = _built(solver_name, m) + s.solve(assign=True) + first_obj = _obj(m) + + s.solve(m, assign=True) + assert s._in_place_updates == 1 + assert s._rebuilds == 0 + assert np.isclose(_obj(m), first_obj) + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_two_consecutive_solves_no_stale_state(solver_name: str) -> None: + m = _base_model() + s = _built(solver_name, m) + s.solve(assign=True) + first_status = s.status + + m.variables["x"].lower.values[...] = 5.0 + s.solve(m, assign=True) + assert s.status is not first_status + assert s.solution is not None + assert np.isclose(float(s.solution.objective), _obj(m)) + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_cross_model_scenario_sweep(solver_name: str) -> None: + m1 = _base_model() + m2 = _base_model() + m2.constraints["c1"].rhs = 6.0 + m3 = _base_model() + m3.variables["x"].lower.values[...] = 2.0 + + s = _built(solver_name, m1) + s.solve(assign=True) + obj1 = _obj(m1) + sol1 = m1.solution + + s.solve(m2, assign=True) + s.solve(m3, assign=True) + + assert s._rebuilds == 0 + assert s._in_place_updates >= 2 + + assert m1.objective._value == obj1 + np.testing.assert_array_equal(m1.solution.x.values, sol1.x.values) + assert m2.objective._value is not None + assert m3.objective._value is not None + + for mk in (m2, m3): + fresh = _base_model() + if mk is m2: + fresh.constraints["c1"].rhs = 6.0 + else: + fresh.variables["x"].lower.values[...] = 2.0 + s_fresh = _built(solver_name, fresh) + s_fresh.solve(assign=True) + assert np.isclose(_obj(mk), _obj(fresh)) + s_fresh.close() + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_cross_model_sparsity_change_rebuilds(solver_name: str) -> None: + def build(include_y_in_c1: bool) -> Model: + m = Model() + x = m.add_variables(0, 10, coords=[range(3)], name="x") + y = m.add_variables(0, 10, coords=[range(3)], name="y") + if include_y_in_c1: + m.add_constraints(x + y >= 4, name="c1") + else: + m.add_constraints(2 * x >= 4, name="c1") + m.add_constraints(2 * x + y <= 20, name="c2") + m.add_objective(x.sum() + 2 * y.sum()) + return m + + m1 = build(include_y_in_c1=True) + s = _built(solver_name, m1) + s.solve(assign=True) + + m2 = build(include_y_in_c1=False) + + s.solve(m2, assign=True) + assert s._rebuilds == 1 + assert s._last_rebuild_reason in { + RebuildReason.SPARSITY, + RebuildReason.STRUCTURAL_LABELS, + RebuildReason.STRUCTURAL_CONTAINERS, + } + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_cross_model_structural_mismatch_rebuilds(solver_name: str) -> None: + m1 = _base_model() + s = _built(solver_name, m1) + s.solve(assign=True) + + m2 = _base_model() + m2.add_variables(0, 5, coords=[range(3)], name="z") + + s.solve(m2, assign=True) + assert s._rebuilds == 1 + assert s.model is m2 + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_dirty_flag_ignored_across_models(solver_name: str) -> None: + m1 = _base_model() + s = _built(solver_name, m1) + s.solve(assign=True) + + m2 = _base_model() + c = m2.constraints["c1"] + c.coeffs = c.coeffs * 3 + c._coef_dirty = False + + s.solve(m2, assign=True) + assert s._rebuilds == 0 + assert s._in_place_updates == 1 + + fresh = _base_model() + cf = fresh.constraints["c1"] + cf.coeffs = cf.coeffs * 3 + s_fresh = _built(solver_name, fresh) + s_fresh.solve(assign=True) + assert np.isclose(_obj(m2), _obj(fresh)) + s_fresh.close() + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_solver_pickle_round_trip_drops_native(solver_name: str) -> None: + m = _base_model() + s = _built(solver_name, m) + s.solve(assign=True) + + state = s.__getstate__() + for key in ("solver_model", "env", "_env_stack", "snapshot", "_lock"): + assert key not in state + + restored = pickle.loads(pickle.dumps(s)) + assert restored.solver_model is None + assert restored.snapshot is None + assert restored._env_stack is None + assert isinstance(restored._lock, type(threading.Lock())) + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_model_pickle_round_trip_no_native_handle(solver_name: str) -> None: + m = _base_model() + s = _built(solver_name, m) + s.solve(assign=True) + + m2 = pickle.loads(pickle.dumps(m)) + s2 = _built(solver_name, m2) + assert s2.solver_model is not None + s2.solve(assign=True) + assert s2._rebuilds == 0 + assert np.isclose(_obj(m), _obj(m2)) + s2.close() + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_backend_exception_during_apply_rebuilds( + solver_name: str, monkeypatch: pytest.MonkeyPatch +) -> None: + m = _base_model() + s = _built(solver_name, m) + s.solve(assign=True) + + c = m.constraints["c1"] + c.coeffs = c.coeffs * 2 + assert c._coef_dirty is True + + def _boom(*args: Any, **kwargs: Any) -> None: + raise RuntimeError("simulated backend failure") + + monkeypatch.setattr(s, "apply_update", _boom) + + dirty_at_rebuild: list[bool] = [] + original_build = s._build + + def _spy_build(**kwargs: Any) -> None: + dirty_at_rebuild.append(m.constraints["c1"]._coef_dirty) + original_build(**kwargs) + + monkeypatch.setattr(s, "_build", _spy_build) + + s.solve(m, assign=True) + assert s._rebuilds == 1 + assert s._last_rebuild_reason is RebuildReason.BACKEND_REJECTED + assert dirty_at_rebuild == [True] + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_concurrent_solves_serialize(solver_name: str) -> None: + m = _base_model() + s = _built(solver_name, m) + s.solve(assign=True) + expected = _obj(m) + + barrier = threading.Barrier(2) + results: list[float] = [] + errors: list[BaseException] = [] + + def _run() -> None: + try: + barrier.wait() + res = s.solve(m, assign=True) + assert res.solution is not None + results.append(float(res.solution.objective)) + except BaseException as e: + errors.append(e) + + threads = [threading.Thread(target=_run) for _ in range(2)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors, errors + assert len(results) == 2 + for r in results: + assert np.isclose(r, expected) + + +_SCENARIO_PARAMS = [ + "bound_only", + "rhs_only", + "single_cell_coef", + "multi_row_coef", + "mixed", +] + + +def _apply_scenario(model: Model, scenario: str) -> None: + if scenario == "bound_only": + model.variables["x"].lower.values[...] = 3.0 + elif scenario == "rhs_only": + model.constraints["c1"].rhs = 7.0 + elif scenario == "single_cell_coef": + c = model.constraints["c1"] + new = c.coeffs.copy() + new.values[0, 0] = 5.0 + c.coeffs = new + elif scenario == "multi_row_coef": + c = model.constraints["c2"] + c.coeffs = c.coeffs * 2 + elif scenario == "mixed": + model.variables["x"].lower.values[...] = 1.0 + model.constraints["c1"].rhs = 6.0 + c = model.constraints["c2"] + new = c.coeffs.copy() + new.values[0, 0] = 4.0 + c.coeffs = new + else: + raise ValueError(scenario) + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +@pytest.mark.parametrize("scenario", _SCENARIO_PARAMS) +@pytest.mark.parametrize("same_model", [True, False]) +def test_scenario_sweep_in_place( + solver_name: str, scenario: str, same_model: bool +) -> None: + m = _base_model() + s = _built(solver_name, m) + s.solve(assign=True) + + target = m if same_model else _base_model() + _apply_scenario(target, scenario) + s.solve(target, assign=True) + + assert s._rebuilds == 0 + assert s._in_place_updates == 1 + assert s._last_rebuild_reason is None + + fresh = _base_model() + _apply_scenario(fresh, scenario) + s_fresh = _built(solver_name, fresh) + s_fresh.solve(assign=True) + assert np.isclose(_obj(target), _obj(fresh)) + s_fresh.close() + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_disallow_rebuild_raises_on_structural_change(solver_name: str) -> None: + from linopy.persistent import RebuildRequiredError + + m = _base_model() + s = _built(solver_name, m) + s.solve(assign=True) + + m2 = _base_model() + m2.add_variables(0, 5, coords=[range(3)], name="z") + + with pytest.raises(RebuildRequiredError): + s.solve(m2, disallow_rebuild=True, assign=True) + assert s._rebuilds == 0 + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_disallow_rebuild_passes_when_update_works(solver_name: str) -> None: + m = _base_model() + s = _built(solver_name, m) + s.solve(assign=True) + + m.constraints["c1"].rhs = 6.0 + s.solve(m, disallow_rebuild=True, assign=True) + assert s._in_place_updates == 1 + assert s._rebuilds == 0 + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_solve_without_assign_does_not_mutate_model(solver_name: str) -> None: + m = _base_model() + s = _built(solver_name, m) + + assert m.objective._value is None + s.solve() + assert m.objective._value is None + + s.solve(assign=True) + assert m.objective._value is not None + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_track_updates_false_skips_snapshot(solver_name: str) -> None: + cls, opts = _BACKENDS[solver_name] + m = _base_model() + s = cls(model=m, io_api="direct", track_updates=False) + s.options = opts + s._build() + assert s.snapshot is None + s.solve(assign=True) + assert s.snapshot is None + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_track_updates_false_rejects_resolve_with_model(solver_name: str) -> None: + cls, opts = _BACKENDS[solver_name] + m = _base_model() + s = cls(model=m, io_api="direct", track_updates=False) + s.options = opts + s._build() + s.solve(assign=True) + + m.variables["x"].lower.values[...] = 6.0 + with pytest.raises(UpdatesDisabledError, match="track_updates=False"): + s.solve(m, assign=True) + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_track_updates_false_rejects_update(solver_name: str) -> None: + cls, opts = _BACKENDS[solver_name] + m = _base_model() + s = cls(model=m, io_api="direct", track_updates=False) + s.options = opts + s._build() + with pytest.raises(UpdatesDisabledError, match="track_updates=False"): + s.update(m) + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_track_updates_false_cross_instance_resolve(solver_name: str) -> None: + cls, opts = _BACKENDS[solver_name] + m1 = _base_model() + s = cls(model=m1, io_api="direct", track_updates=False) + s.options = opts + s._build() + s.solve(assign=True) + base_obj = _obj(m1) + + m2 = _base_model() + m2.constraints["c1"].rhs = 8.0 + result = s.solve(m2, assign=True) + assert s._in_place_updates == 1 + assert s._rebuilds == 0 + assert s.snapshot is None + assert s.model is m2 + assert result.solution is not None + assert float(result.solution.objective) > base_obj + + +@pytest.mark.parametrize("solver_name", SOLVER_PARAMS) +def test_track_updates_false_cross_instance_update(solver_name: str) -> None: + cls, opts = _BACKENDS[solver_name] + m1 = _base_model() + s = cls(model=m1, io_api="direct", track_updates=False) + s.options = opts + s._build() + s.solve(assign=True) + + m2 = _base_model() + m2.constraints["c1"].rhs = 8.0 + diff = s.update(m2, apply=False) + assert isinstance(diff, ModelDiff) + assert diff.summary()["con_rhs"] == 3 + assert "c1" in diff.changed_constraints + assert s.snapshot is None diff --git a/test/test_persistent_solver_orchestrator.py b/test/test_persistent_solver_orchestrator.py new file mode 100644 index 00000000..9495e9ea --- /dev/null +++ b/test/test_persistent_solver_orchestrator.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +import pickle +import threading +from typing import Any + +import pytest + +from linopy import Model +from linopy.constants import ( + Result, + Solution, + SolverStatus, + Status, + TerminationCondition, +) +from linopy.persistent import ModelDiff, RebuildReason +from linopy.solvers import Solver, SolverFeature + + +class FakeSolver(Solver[None]): + display_name = "Fake" + features = frozenset({SolverFeature.DIRECT_API}) + accepted_io_apis = frozenset({"direct"}) + supports_persistent_update = False + + @classmethod + def is_available(cls) -> bool: # type: ignore[override] + return True + + @property + def solver_name(self) -> Any: + class _N: + value = "fake" + + return _N() + + def _validate_model(self) -> None: + return None + + def _build_direct(self, **kwargs: Any) -> None: + self.solver_model = object() + + def _run_direct(self, **kwargs: Any) -> Result: + status = Status(SolverStatus.ok, TerminationCondition.optimal) + return Result( + status=status, solution=Solution(objective=0.0), solver_name="fake" + ) + + +@pytest.fixture +def model() -> Model: + m = Model() + x = m.add_variables(0, 10, coords=[range(3)], name="x") + m.add_constraints(2 * x >= 4, name="c1") + m.add_objective(x.sum()) + return m + + +@pytest.fixture +def other_model() -> Model: + m = Model() + x = m.add_variables(0, 10, coords=[range(3)], name="x") + m.add_constraints(2 * x >= 4, name="c1") + m.add_objective(x.sum()) + return m + + +def _built(model: Model) -> FakeSolver: + s = FakeSolver(model=model, io_api="direct", track_updates=True) + s._build() + return s + + +def test_unsupported_falls_through_to_rebuild(model: Model, other_model: Model) -> None: + s = _built(model) + assert s._rebuilds == 0 + s.solve(other_model) + assert s._rebuilds == 1 + assert s._last_rebuild_reason is RebuildReason.BACKEND_REJECTED + assert s.model is other_model + + +def test_update_apply_false_returns_diff(model: Model) -> None: + s = _built(model) + diff = s.update(model, apply=False) + assert isinstance(diff, ModelDiff) + assert s._in_place_updates == 0 + assert s._rebuilds == 0 + + +def test_solve_no_model_still_works(model: Model) -> None: + s = _built(model) + result = s.solve() + assert result.status.status is SolverStatus.ok + + +def test_getstate_drops_native_fields(model: Model) -> None: + s = _built(model) + state = s.__getstate__() + for k in ("solver_model", "env", "_env_stack", "snapshot", "_lock"): + assert k not in state + restored = pickle.loads(pickle.dumps(s)) + assert restored.solver_model is None + assert restored.snapshot is None + + +def test_update_without_snapshot_raises(model: Model) -> None: + s = FakeSolver(model=model, io_api="direct") + with pytest.raises(RuntimeError, match="not been built"): + s.update(model) + + +def test_unmutated_resolve_diff_is_empty(model: Model) -> None: + s = _built(model) + diff = s.update(model, apply=False) + assert isinstance(diff, ModelDiff) + assert diff.is_empty + + +class FakePersistentSolver(FakeSolver): + supports_persistent_update = True + + def apply_update( + self, diff: ModelDiff, var_label_index: Any, con_label_index: Any + ) -> None: + return None + + +def _built_persistent(model: Model) -> FakePersistentSolver: + s = FakePersistentSolver(model=model, io_api="direct", track_updates=True) + s._build() + return s + + +def test_build_clears_coef_dirty(model: Model) -> None: + c = model.constraints["c1"] + c.update(coeffs=c.coeffs * 2) + assert c._coef_dirty is True + _built_persistent(model) + assert c._coef_dirty is False + + +def test_in_place_update_adopts_diff_snapshot(model: Model) -> None: + s = _built_persistent(model) + c = model.constraints["c1"] + c.update(coeffs=c.coeffs * 2) + diff = s.update(model) + assert isinstance(diff, ModelDiff) + assert s.snapshot is diff.snapshot + assert c._coef_dirty is False + rediff = s.update(model, apply=False) + assert isinstance(rediff, ModelDiff) + assert rediff.is_empty + + +def test_update_apply_false_leaves_state_untouched(model: Model) -> None: + s = _built_persistent(model) + snap_before = s.snapshot + c = model.constraints["c1"] + c.update(coeffs=c.coeffs * 2) + diff = s.update(model, apply=False) + assert isinstance(diff, ModelDiff) + assert c._coef_dirty is True + assert s.snapshot is snap_before + + +def test_update_apply_false_does_not_block_running_solve( + model: Model, monkeypatch: pytest.MonkeyPatch +) -> None: + s = _built_persistent(model) + solve_entered = threading.Event() + release_solve = threading.Event() + original_run = s._run_direct + + def _gated_run(**kwargs: Any) -> Result: + solve_entered.set() + assert release_solve.wait(timeout=5) + return original_run(**kwargs) + + monkeypatch.setattr(s, "_run_direct", _gated_run) + + solver_thread = threading.Thread(target=s.solve) + solver_thread.start() + try: + assert solve_entered.wait(timeout=5) + + result: list[ModelDiff | RebuildReason] = [] + preview_thread = threading.Thread( + target=lambda: result.append(s.update(model, apply=False)) + ) + preview_thread.start() + preview_thread.join(timeout=2) + assert not preview_thread.is_alive(), "preview blocked on a running solve" + assert isinstance(result[0], ModelDiff) + finally: + release_solve.set() + solver_thread.join(timeout=5) + + +def test_preview_detects_raw_mutation_apply_skips_it(model: Model) -> None: + """ + Pins the documented preview/apply asymmetry for unsupported raw + ``.values[...]`` coefficient mutations on the build-time model. + """ + s = _built_persistent(model) + c = model.constraints["c1"] + c.coeffs.values[...] = c.coeffs.values * 2 + assert c._coef_dirty is False + + preview = s.update(model, apply=False) + assert isinstance(preview, ModelDiff) + assert "c1" in preview.changed_constraints + + applied = s.update(model) + assert isinstance(applied, ModelDiff) + assert "c1" not in applied.changed_constraints diff --git a/test/test_variable.py b/test/test_variable.py index 18640821..4a0da9d6 100644 --- a/test/test_variable.py +++ b/test/test_variable.py @@ -186,6 +186,60 @@ def test_variable_lower_setter_with_array_invalid_dim(x: linopy.Variable) -> Non x.lower = lower +def test_variable_update_bounds(z: linopy.Variable) -> None: + z.update(lower=2, upper=20) + assert z.lower.item() == 2 + assert z.upper.item() == 20 + + +def test_variable_update_lower_only(z: linopy.Variable) -> None: + z.update(lower=3) + assert z.lower.item() == 3 + assert z.upper.item() == 10 # unchanged from fixture default + + +def test_variable_update_no_kwargs_is_noop(z: linopy.Variable) -> None: + old_lower, old_upper = z.lower.item(), z.upper.item() + z.update() + assert z.lower.item() == old_lower + assert z.upper.item() == old_upper + + +def test_variable_update_rejects_inverted_bounds(z: linopy.Variable) -> None: + with pytest.raises(ValueError, match="lower > upper"): + z.update(lower=20, upper=5) + + +def test_variable_update_rejects_non_constant(z: linopy.Variable) -> None: + with pytest.raises(TypeError, match="must be a constant"): + z.update(upper=z) + + +def test_variable_update_returns_self(z: linopy.Variable) -> None: + out = z.update(lower=1) + assert out is z + + +def test_variable_update_array_invalid_dim(x: linopy.Variable) -> None: + with pytest.raises(ValueError): + x.update(lower=pd.Series(range(15, 25))) + + +def test_variable_update_upper_only(z: linopy.Variable) -> None: + """upper= alone changes upper; lower untouched.""" + old_lower = z.lower.copy() + z.update(upper=25) + assert (z.upper == 25).all() + assert (z.lower == old_lower).all() + + +def test_variable_update_with_array(x: linopy.Variable) -> None: + """Array bound that aligns on the variable's coord is accepted.""" + lower = pd.Series(range(10, 20), index=pd.RangeIndex(10, name="first")) + x.update(lower=lower) + np.testing.assert_array_equal(x.lower.values, lower.values) + + def test_variable_sum(x: linopy.Variable) -> None: res = x.sum() assert res.nterm == 10