diff --git a/CHANGELOG.md b/CHANGELOG.md index 7cfed00f5..d5fc978d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ - Speed up MatrixExpr.add.reduce via quicksum - Speed up np.ndarray(..., dtype=np.float64) @ MatrixExpr - MatrixExpr and MatrixExprCons use `__array_ufunc__` protocol to control all numpy.ufunc inputs and outputs +- Return itself for abs to UnaryExpr(Operator.fabs) ### Removed ## 6.0.0 - 2025.xx.yy diff --git a/src/pyscipopt/expr.pxi b/src/pyscipopt/expr.pxi index 07d6ab031..aac14950e 100644 --- a/src/pyscipopt/expr.pxi +++ b/src/pyscipopt/expr.pxi @@ -45,9 +45,10 @@ import math from typing import TYPE_CHECKING -from pyscipopt.scip cimport Variable, Solution from cpython.dict cimport PyDict_Next +from cpython.object cimport Py_TYPE from cpython.ref cimport PyObject +from pyscipopt.scip cimport Variable, Solution import numpy as np @@ -647,6 +648,20 @@ cdef class GenExpr: '''returns operator of GenExpr''' return self._op + cdef GenExpr copy(self, bool copy = True): + cdef object cls = Py_TYPE(self) + cdef GenExpr res = cls.__new__(cls) + res._op = self._op + res.children = self.children.copy() if copy else self.children + if cls is SumExpr: + (res).constant = (self).constant + (res).coefs = (self).coefs.copy() if copy else (self).coefs + if cls is ProdExpr: + (res).constant = (self).constant + elif cls is PowExpr: + (res).expo = (self).expo + return res + # Sum Expressions cdef class SumExpr(GenExpr): @@ -736,6 +751,11 @@ cdef class UnaryExpr(GenExpr): self.children.append(expr) self._op = op + def __abs__(self) -> UnaryExpr: + if self._op == "abs": + return self.copy() + return UnaryExpr(Operator.fabs, self) + def __repr__(self): return self._op + "(" + self.children[0].__repr__() + ")" diff --git a/src/pyscipopt/scip.pyi b/src/pyscipopt/scip.pyi index 61c4ba773..527fa64b9 100644 --- a/src/pyscipopt/scip.pyi +++ b/src/pyscipopt/scip.pyi @@ -377,7 +377,7 @@ class GenExpr: def __init__(self) -> None: ... def degree(self) -> Incomplete: ... def getOp(self) -> Incomplete: ... - def __abs__(self) -> Incomplete: ... + def __abs__(self) -> GenExpr: ... def __add__(self, other: Incomplete) -> Incomplete: ... def __eq__(self, other: object) -> bool: ... def __ge__(self, other: object) -> bool: ... @@ -2161,9 +2161,9 @@ class Term: def __lt__(self, other: object) -> bool: ... def __ne__(self, other: object) -> bool: ... -@disjoint_base class UnaryExpr(GenExpr): def __init__(self, *args: Incomplete, **kwargs: Incomplete) -> None: ... + def __abs__(self) -> GenExpr: ... @disjoint_base class VarExpr(GenExpr): diff --git a/tests/test_expr.py b/tests/test_expr.py index c9135d2fa..fc6c9c193 100644 --- a/tests/test_expr.py +++ b/tests/test_expr.py @@ -218,3 +218,11 @@ def test_getVal_with_GenExpr(): with pytest.raises(ZeroDivisionError): m.getVal(1 / z) + + +def test_abs_abs_expr(): + m = Model() + x = m.addVar(name="x") + + # should print abs(x) not abs(abs(x)) + assert str(abs(abs(x))) == str(abs(x))