Skip to content

Add a lot of optimizers including optax support to sgd#2041

Merged
YigitElma merged 41 commits intomasterfrom
yge/adam
Mar 2, 2026
Merged

Add a lot of optimizers including optax support to sgd#2041
YigitElma merged 41 commits intomasterfrom
yge/adam

Conversation

@YigitElma
Copy link
Copy Markdown
Collaborator

@YigitElma YigitElma commented Dec 16, 2025

  • Unifies all the SGD type optimizers, developers only need to implement the update rule
  • x_scale is now used with SGD methods too
  • Adds wrappers for optax optimizers, and they can be called by optax-name
  • Any custom optax optimizer can be used via
        import optax
        from desc.optimize import Optimizer
        from desc.examples import get

        eq = get("DSHAPE")

        # Optimizer
        opt = optax.chain(
            optax.sgd(learning_rate=1.0),
            optax.scale_by_zoom_linesearch(max_linesearch_steps=15),
        )
        optimizer = Optimizer("optax-custom")
        eq.solve(optimizer=optimizer, options={"optax-options": {"update_rule": opt}})

@YigitElma YigitElma self-assigned this Dec 16, 2025
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Dec 16, 2025

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |    3.00 %    |     3.895e+03      |     4.012e+03      |    117.03    |       39.74        |       36.33        |
  test_proximal_jac_w7x_with_eq_update   |    0.76 %    |     6.585e+03      |     6.634e+03      |    49.79     |       162.70       |       163.71       |
  test_proximal_freeb_jac                |   -0.37 %    |     1.323e+04      |     1.318e+04      |    -48.81    |       85.70        |       84.40        |
  test_proximal_freeb_jac_blocked        |   -0.48 %    |     7.537e+03      |     7.502e+03      |    -35.93    |       75.27        |       74.88        |
  test_proximal_freeb_jac_batched        |    0.32 %    |     7.484e+03      |     7.508e+03      |    24.32     |       73.76        |       75.69        |
  test_proximal_jac_ripple               |    1.47 %    |     3.468e+03      |     3.519e+03      |    51.07     |       65.97        |       68.48        |
  test_proximal_jac_ripple_bounce1d      |    2.40 %    |     3.571e+03      |     3.656e+03      |    85.59     |       77.62        |       80.36        |
  test_eq_solve                          |    4.16 %    |     2.044e+03      |     2.129e+03      |    84.94     |       94.81        |       96.74        |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

@codecov
Copy link
Copy Markdown

codecov Bot commented Dec 16, 2025

Codecov Report

❌ Patch coverage is 98.07692% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 94.51%. Comparing base (5bb92a4) to head (8278279).
⚠️ Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
desc/optimize/stochastic.py 98.03% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##           master    #2041   +/-   ##
=======================================
  Coverage   94.51%   94.51%           
=======================================
  Files         102      102           
  Lines       28640    28675   +35     
=======================================
+ Hits        27068    27103   +35     
  Misses       1572     1572           
Files with missing lines Coverage Δ
desc/optimize/_desc_wrappers.py 91.17% <100.00%> (+0.13%) ⬆️
desc/optimize/stochastic.py 98.01% <98.03%> (+1.00%) ⬆️

... and 3 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@YigitElma YigitElma marked this pull request as ready for review December 18, 2025 01:42
@YigitElma YigitElma requested review from a team, ddudt, dpanici, f0uriest, rahulgaur104 and unalmis and removed request for a team December 18, 2025 01:42
@YigitElma YigitElma changed the title Add ADAM optimizer Add ADAM and RMSProp optimizers Dec 18, 2025
Comment thread desc/optimize/stochastic.py Outdated
Comment thread desc/particles.py
Comment thread desc/optimize/stochastic.py
Copy link
Copy Markdown
Collaborator

@dpanici dpanici left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just the small docstring fix, should be explitict that x_scale='"auto"` does no scaling here

dpanici
dpanici previously approved these changes Dec 22, 2025
Copy link
Copy Markdown
Member

@f0uriest f0uriest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd double check that the x_scale logic is correct

Also, did you look at whether we could just wrap stuff from optax?

From the examples eg https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.adam

it looks like the user could just pass in an optax.solver and then we can just do

opt_state = solver.init(x0)
...
g = grad(x)*x_scale
updates, opt_state = solver.update(g, opt_state, x)
x = optax.apply_updates(x, x_scale*updates)

or something similar. That would give users access to a much wider array of first order optimizers, and save us having to do it all ourselves

Comment thread desc/optimize/stochastic.py Outdated


def sgd(
def generic_sgd(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgd is technically public (https://desc-docs.readthedocs.io/en/stable/_api/optimize/desc.optimize.sgd.html#desc.optimize.sgd) so if we want to change the name we should keep an alias to the old one with a deprecation warning. That said, I'm not sure we really need to change the name. "SGD" is already used pretty generically in the ML community for a bunch of first order stochastic methods like ADAM, ADAGRAD, RMSPROP, etc

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, SGD is in fact the general name. I can revert to the old name, and just emphasize that "sgd" option is with nesterov momentum. I was trying to make a distinction I guess

Comment thread desc/optimize/stochastic.py Outdated
for the update rule chosen.

- ``"alpha"`` : (float > 0) Learning rate. Defaults to
1e-1 * ||x_scaled|| / ||g_scaled||.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems pretty large (steps would be 10% of x), have you checked how robust this is?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to solve eq with these and even though none of them converged 10% was better for a variety of equilibrium. I haven't checked other optimization problems. Reverted the change and added a backguard against 0 and NaNs.

Comment thread desc/optimize/stochastic.py Outdated
Comment thread desc/optimize/stochastic.py Outdated
Comment thread desc/optimize/stochastic.py Outdated
Where alpha is the step size and beta is the momentum parameter.
Update rule for ``'sgd'``:

.. math::
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

personally I prefer unicode for stuff like this. TeX looks nice in the rendered html docs, but is much harder to read as code.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mostly agree and don't have a strong stance either way. My general preference is to use LaTeX for complex equations or public-facing objectives that users will first encounter in the documentation (like guiding center equations, optimization algorithms). For internal development notes or specific compute functions that aren't usually viewed on the web, I’m fine with Unicode since it keeps the source code more readable.

Comment thread desc/particles.py
@unalmis
Copy link
Copy Markdown
Collaborator

unalmis commented Jan 6, 2026

Also, did you look at whether we could just wrap stuff from optax?

or something similar. That would give users access to a much wider array of first order optimizers, and save us having to do it all ourselves

Also optimistix has trust region methods with easy to use linear solvers.., e.g. normal conjugate gradient etc.

dpanici
dpanici previously approved these changes Feb 19, 2026
@YigitElma YigitElma requested review from ddudt, dpanici and f0uriest and removed request for ddudt and f0uriest February 20, 2026 04:18
@YigitElma YigitElma added the easy Short and simple to code or review label Feb 23, 2026
dpanici
dpanici previously approved these changes Feb 24, 2026
ddudt
ddudt previously approved these changes Feb 26, 2026
@YigitElma YigitElma dismissed stale reviews from ddudt and dpanici via 736c34b February 26, 2026 20:07
@YigitElma YigitElma requested review from ddudt and dpanici February 26, 2026 20:08
@YigitElma YigitElma merged commit ba11e9a into master Mar 2, 2026
27 checks passed
@YigitElma YigitElma deleted the yge/adam branch March 2, 2026 19:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

easy Short and simple to code or review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants