diff --git a/CHANGELOG.md b/CHANGELOG.md index 7cfed00f5..a29495758 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,6 @@ ### Added - Added automated script for generating type stubs - Include parameter names in type stubs -- Speed up MatrixExpr.sum(axis=...) via quicksum - Added pre-commit hook for automatic stub regeneration (see .pre-commit-config.yaml) - Wrapped isObjIntegral() and test - Added structured_optimization_trace recipe for structured optimization progress tracking @@ -19,8 +18,10 @@ - Fixed segmentation fault when using Variable or Constraint objects after freeTransform() or Model destruction ### Changed - changed default value of enablepricing flag to True +- Speed up MatrixExpr.sum(axis=...) via quicksum - Speed up MatrixExpr.add.reduce via quicksum - Speed up np.ndarray(..., dtype=np.float64) @ MatrixExpr +- Speed up Expr * Expr - MatrixExpr and MatrixExprCons use `__array_ufunc__` protocol to control all numpy.ufunc inputs and outputs ### Removed diff --git a/src/pyscipopt/expr.pxi b/src/pyscipopt/expr.pxi index 07d6ab031..c4257570d 100644 --- a/src/pyscipopt/expr.pxi +++ b/src/pyscipopt/expr.pxi @@ -45,9 +45,10 @@ import math from typing import TYPE_CHECKING -from pyscipopt.scip cimport Variable, Solution -from cpython.dict cimport PyDict_Next +from cpython.dict cimport PyDict_Next, PyDict_GetItem +from cpython.tuple cimport PyTuple_GET_ITEM from cpython.ref cimport PyObject +from pyscipopt.scip cimport Variable, Solution import numpy as np @@ -122,9 +123,39 @@ cdef class Term: def __len__(self): return len(self.vartuple) - def __add__(self, other): - both = self.vartuple + other.vartuple - return Term(*both) + def __mul__(self, Term other): + cdef int n1 = len(self) + cdef int n2 = len(other) + if n1 == 0: return other + if n2 == 0: return self + + cdef list vartuple = [None] * (n1 + n2) + cdef int i = 0, j = 0, k = 0 + cdef Variable var1, var2 + while i < n1 and j < n2: + var1 = PyTuple_GET_ITEM(self.vartuple, i) + var2 = PyTuple_GET_ITEM(other.vartuple, j) + if var1.ptr() <= var2.ptr(): + vartuple[k] = var1 + i += 1 + else: + vartuple[k] = var2 + j += 1 + k += 1 + while i < n1: + vartuple[k] = PyTuple_GET_ITEM(self.vartuple, i) + i += 1 + k += 1 + while j < n2: + vartuple[k] = PyTuple_GET_ITEM(other.vartuple, j) + j += 1 + k += 1 + + cdef Term res = Term.__new__(Term) + res.vartuple = tuple(vartuple) + res.ptrtuple = tuple(v.ptr() for v in res.vartuple) + res.hashval = hash(res.ptrtuple) + return res def __repr__(self): return 'Term(%s)' % ', '.join([str(v) for v in self.vartuple]) @@ -251,19 +282,42 @@ cdef class Expr: if isinstance(other, np.ndarray): return other * self + cdef dict res = {} + cdef Py_ssize_t pos1 = 0, pos2 = 0 + cdef PyObject *k1_ptr = NULL + cdef PyObject *v1_ptr = NULL + cdef PyObject *k2_ptr = NULL + cdef PyObject *v2_ptr = NULL + cdef PyObject *old_v_ptr = NULL + cdef Term child + cdef double v1_val, v2_val, prod_v + if _is_number(other): f = float(other) return Expr({v:f*c for v,c in self.terms.items()}) + elif _is_number(self): f = float(self) return Expr({v:f*c for v,c in other.terms.items()}) + elif isinstance(other, Expr): - terms = {} - for v1, c1 in self.terms.items(): - for v2, c2 in other.terms.items(): - v = v1 + v2 - terms[v] = terms.get(v, 0.0) + c1 * c2 - return Expr(terms) + while PyDict_Next(self.terms, &pos1, &k1_ptr, &v1_ptr): + if (v1_val := (v1_ptr)) == 0: + continue + + pos2 = 0 + while PyDict_Next(other.terms, &pos2, &k2_ptr, &v2_ptr): + if (v2_val := (v2_ptr)) == 0: + continue + + child = (k1_ptr) * (k2_ptr) + prod_v = v1_val * v2_val + if (old_v_ptr := PyDict_GetItem(res, child)) != NULL: + res[child] = (old_v_ptr) + prod_v + else: + res[child] = prod_v + return Expr(res) + elif isinstance(other, GenExpr): return buildGenExprObj(self) * other else: diff --git a/src/pyscipopt/scip.pyi b/src/pyscipopt/scip.pyi index 61c4ba773..31e4ad577 100644 --- a/src/pyscipopt/scip.pyi +++ b/src/pyscipopt/scip.pyi @@ -2150,7 +2150,7 @@ class Term: ptrtuple: Incomplete vartuple: Incomplete def __init__(self, *vartuple: Incomplete) -> None: ... - def __add__(self, other: Incomplete) -> Incomplete: ... + def __mul__(self, other: Term) -> Term: ... def __eq__(self, other: object) -> bool: ... def __ge__(self, other: object) -> bool: ... def __getitem__(self, index: Incomplete) -> Incomplete: ...