-
Notifications
You must be signed in to change notification settings - Fork 276
Speed up Expr * Expr
#1175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Speed up Expr * Expr
#1175
Conversation
Replaces Term.__add__ with Term.__mul__ and updates Expr.__mul__ to use more efficient Cython dict iteration and item access. This improves performance and correctness when multiplying expressions, especially for large term dictionaries.
Replaces the simple concatenation in Term.__mul__ with an efficient merge that maintains variable order based on pointer values. This improves performance and correctness when multiplying Term objects.
Moved the 'Speed up MatrixExpr.sum(axis=...) via quicksum' entry from the Added section to the Changed section for better categorization and clarity.
Added a new entry to the changelog noting the performance improvement for Expr * Expr operations.
| def __len__(self): | ||
| return len(self.vartuple) | ||
|
|
||
| def __add__(self, other): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__mul__ is better. We call this function actually to multiply
| while i < n1 and j < n2: | ||
| var1 = <Variable>PyTuple_GET_ITEM(self.vartuple, i) | ||
| var2 = <Variable>PyTuple_GET_ITEM(other.vartuple, j) | ||
| if var1.ptr() <= var2.ptr(): | ||
| vartuple[k] = var1 | ||
| i += 1 | ||
| else: | ||
| vartuple[k] = var2 | ||
| j += 1 | ||
| k += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the Merge Sort Algorithm, with a time complexity of
The time complexity of sorted is
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR optimizes the multiplication of Expr objects (polynomial expressions) by using low-level C API calls and introducing a more efficient Term.__mul__ method. According to the benchmarks provided, this results in at least 1.2x speedup for matrix multiplication operations.
Changes:
- Introduced
Term.__mul__method using an efficient merge algorithm for combining sorted variable tuples - Optimized
Expr.__mul__to use C-level Python dict iteration APIs (PyDict_Next, PyDict_GetItem) and skip zero coefficients - Updated CHANGELOG to document the performance improvement
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| src/pyscipopt/expr.pxi | Implements optimized Term.__mul__ method and refactors Expr.__mul__ to use low-level C APIs for faster dictionary iteration and term multiplication |
| CHANGELOG.md | Adds entry documenting the Expr * Expr performance improvement in the Changed section |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| while PyDict_Next(self.terms, &pos1, &k1_ptr, &v1_ptr): | ||
| if (v1_val := <double>(<object>v1_ptr)) == 0: | ||
| continue | ||
|
|
||
| pos2 = <Py_ssize_t>0 | ||
| while PyDict_Next(other.terms, &pos2, &k2_ptr, &v2_ptr): | ||
| if (v2_val := <double>(<object>v2_ptr)) == 0: | ||
| continue | ||
|
|
||
| child = (<Term>k1_ptr) * (<Term>k2_ptr) | ||
| prod_v = v1_val * v2_val | ||
| if (old_v_ptr := PyDict_GetItem(res, child)) != NULL: | ||
| res[child] = <double>(<object>old_v_ptr) + prod_v | ||
| else: | ||
| res[child] = prod_v | ||
| return Expr(res) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use Cython API to speed up.
Corrects the Term class in scip.pyi to define __mul__ instead of __add__, updating the method signature to accept and return Term objects.
This PR is at least faster 1.2x than before.
Expr.__mul__(Expr):Expr.__mul__(Expr)andTerm.__mul__(Term)