Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
### Added
- Added automated script for generating type stubs
- Include parameter names in type stubs
- Speed up MatrixExpr.sum(axis=...) via quicksum
- Added pre-commit hook for automatic stub regeneration (see .pre-commit-config.yaml)
- Wrapped isObjIntegral() and test
- Added structured_optimization_trace recipe for structured optimization progress tracking
Expand All @@ -19,8 +18,10 @@
- Fixed segmentation fault when using Variable or Constraint objects after freeTransform() or Model destruction
### Changed
- changed default value of enablepricing flag to True
- Speed up MatrixExpr.sum(axis=...) via quicksum
- Speed up MatrixExpr.add.reduce via quicksum
- Speed up np.ndarray(..., dtype=np.float64) @ MatrixExpr
- Speed up Expr * Expr
- MatrixExpr and MatrixExprCons use `__array_ufunc__` protocol to control all numpy.ufunc inputs and outputs
### Removed

Expand Down
76 changes: 65 additions & 11 deletions src/pyscipopt/expr.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@
import math
from typing import TYPE_CHECKING

from pyscipopt.scip cimport Variable, Solution
from cpython.dict cimport PyDict_Next
from cpython.dict cimport PyDict_Next, PyDict_GetItem
from cpython.tuple cimport PyTuple_GET_ITEM
from cpython.ref cimport PyObject
from pyscipopt.scip cimport Variable, Solution

import numpy as np

Expand Down Expand Up @@ -122,9 +123,39 @@ cdef class Term:
def __len__(self):
return len(self.vartuple)

def __add__(self, other):
Copy link
Contributor Author

@Zeroto521 Zeroto521 Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__mul__ is better. We call this function actually to multiply

both = self.vartuple + other.vartuple
return Term(*both)
def __mul__(self, Term other):
cdef int n1 = len(self)
cdef int n2 = len(other)
if n1 == 0: return other
if n2 == 0: return self

cdef list vartuple = [None] * (n1 + n2)
cdef int i = 0, j = 0, k = 0
cdef Variable var1, var2
while i < n1 and j < n2:
var1 = <Variable>PyTuple_GET_ITEM(self.vartuple, i)
var2 = <Variable>PyTuple_GET_ITEM(other.vartuple, j)
if var1.ptr() <= var2.ptr():
vartuple[k] = var1
i += 1
else:
vartuple[k] = var2
j += 1
k += 1
Comment on lines +135 to +144
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the Merge Sort Algorithm, with a time complexity of $O(n)$.
The time complexity of sorted is $O(nlog(n))$.

while i < n1:
vartuple[k] = <Variable>PyTuple_GET_ITEM(self.vartuple, i)
i += 1
k += 1
while j < n2:
vartuple[k] = <Variable>PyTuple_GET_ITEM(other.vartuple, j)
j += 1
k += 1

cdef Term res = Term.__new__(Term)
res.vartuple = tuple(vartuple)
res.ptrtuple = tuple(v.ptr() for v in res.vartuple)
res.hashval = <Py_ssize_t>hash(res.ptrtuple)
return res

def __repr__(self):
return 'Term(%s)' % ', '.join([str(v) for v in self.vartuple])
Expand Down Expand Up @@ -251,19 +282,42 @@ cdef class Expr:
if isinstance(other, np.ndarray):
return other * self

cdef dict res = {}
cdef Py_ssize_t pos1 = <Py_ssize_t>0, pos2 = <Py_ssize_t>0
cdef PyObject *k1_ptr = NULL
cdef PyObject *v1_ptr = NULL
cdef PyObject *k2_ptr = NULL
cdef PyObject *v2_ptr = NULL
cdef PyObject *old_v_ptr = NULL
cdef Term child
cdef double v1_val, v2_val, prod_v

if _is_number(other):
f = float(other)
return Expr({v:f*c for v,c in self.terms.items()})

elif _is_number(self):
f = float(self)
return Expr({v:f*c for v,c in other.terms.items()})

elif isinstance(other, Expr):
terms = {}
for v1, c1 in self.terms.items():
for v2, c2 in other.terms.items():
v = v1 + v2
terms[v] = terms.get(v, 0.0) + c1 * c2
return Expr(terms)
while PyDict_Next(self.terms, &pos1, &k1_ptr, &v1_ptr):
if (v1_val := <double>(<object>v1_ptr)) == 0:
continue

pos2 = <Py_ssize_t>0
while PyDict_Next(other.terms, &pos2, &k2_ptr, &v2_ptr):
if (v2_val := <double>(<object>v2_ptr)) == 0:
continue

child = (<Term>k1_ptr) * (<Term>k2_ptr)
prod_v = v1_val * v2_val
if (old_v_ptr := PyDict_GetItem(res, child)) != NULL:
res[child] = <double>(<object>old_v_ptr) + prod_v
else:
res[child] = prod_v
return Expr(res)
Comment on lines +304 to +319
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use Cython API to speed up.


elif isinstance(other, GenExpr):
return buildGenExprObj(self) * other
else:
Expand Down
2 changes: 1 addition & 1 deletion src/pyscipopt/scip.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2150,7 +2150,7 @@ class Term:
ptrtuple: Incomplete
vartuple: Incomplete
def __init__(self, *vartuple: Incomplete) -> None: ...
def __add__(self, other: Incomplete) -> Incomplete: ...
def __mul__(self, other: Term) -> Term: ...
def __eq__(self, other: object) -> bool: ...
def __ge__(self, other: object) -> bool: ...
def __getitem__(self, index: Incomplete) -> Incomplete: ...
Expand Down
Loading