Add AdamW optimizer to extension/training#18848
Add AdamW optimizer to extension/training#18848BryanBradfo wants to merge 2 commits intopytorch:mainfrom
Conversation
Ports AdamW alongside the existing SGD implementation, following the pattern in extension/training/optimizer/sgd.{h,cpp}. Weight decay is decoupled (applied to the parameter directly, not folded into the gradient) per Loshchilov & Hutter 2019, this is the property that distinguishes AdamW from Adam-with-L2.
Fixes pytorch#18766
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18848
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 Awaiting Approval, 3 New Failures, 3 Pending, 3 Unrelated FailuresAs of commit bda405f with merge base c11ba1b ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label "release notes: training" |
There was a problem hiding this comment.
Pull request overview
Adds an AdamW optimizer implementation to ExecuTorch’s on-device training extension, aligning behavior with torch.optim.AdamW (decoupled weight decay) and integrating it into the existing C++ optimizer build/test setup.
Changes:
- Introduces
AdamWoptimizer implementation (adamw.{h,cpp}) and exposes it as a training optimizer target. - Adds new gtests for AdamW and wires them into the optimizer test targets.
- Updates training extension build source lists and documentation to reflect AdamW availability.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| shim_et/xplat/executorch/build/build_variables.bzl | Adds AdamW source to extension training sources list. |
| extension/training/optimizer/targets.bzl | Defines a new adamw C++ library target. |
| extension/training/optimizer/adamw.h | Declares AdamW API, options, param group, and state types. |
| extension/training/optimizer/adamw.cpp | Implements AdamW step logic and state allocation/freeing. |
| extension/training/optimizer/test/targets.bzl | Adds adamw_test target. |
| extension/training/optimizer/test/adamw_test.cpp | New unit tests for AdamW behavior and defaults. |
| extension/training/README.md | Updates optimizer list to include AdamW. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| TensorImpl* m_impl = new TensorImpl( | ||
| g.scalar_type(), | ||
| g.sizes().size(), | ||
| const_cast<TensorImpl::SizesType*>(g.sizes().data()), | ||
| m_buf_ptr, | ||
| const_cast<TensorImpl::DimOrderType*>(g.dim_order().data())); | ||
| TensorImpl* v_impl = new TensorImpl( | ||
| g.scalar_type(), | ||
| g.sizes().size(), | ||
| const_cast<TensorImpl::SizesType*>(g.sizes().data()), | ||
| v_buf_ptr, | ||
| const_cast<TensorImpl::DimOrderType*>(g.dim_order().data())); |
There was a problem hiding this comment.
In portable (non-USE_ATEN_LIB) mode, the state tensors' TensorImpl are constructed with sizes/dim_order pointers taken from the gradient tensor (g.sizes().data()/g.dim_order().data()). TensorImpl does not own these arrays, and the optimizer state outlives the gradient tensor, so this can become a use-after-free if the gradient's metadata storage doesn't live for the lifetime of the optimizer. Allocate and store your own sizes/dim_order (and strides if needed) with the state, or use a construction path that ensures the metadata buffers are owned/stable for the optimizer lifetime.
There was a problem hiding this comment.
Good catch in principle. This matches the existing pattern in SGD (sgd.cpp:119-125) which constructs TensorImpl with the same d_p.sizes().data() / d_p.dim_order().data() pointers from the gradient. In practice, parameter shapes are stable across training steps in ExecuTorch's captured graph model. Happy to address both optimizers together in a follow-up if the maintainer prefers.
| // Bias-corrected update. | ||
| const double bias_correction1 = 1.0 - std::pow(beta1, step); | ||
| const double bias_correction2 = 1.0 - std::pow(beta2, step); | ||
| adamw_update_hack( |
There was a problem hiding this comment.
AdamWOptions parameters are used in bias-correction divisors (1 - beta^step) and sqrt(bias_correction2). With beta1==1 or beta2==1 (or beta2<0), step() will divide by zero / take sqrt of a negative, producing NaNs/inf. Consider validating lr/beta1/beta2/eps ranges (e.g., lr>=0, 0<=beta<1, eps>0) either in AdamWOptions construction or at the start of step() and return Error::InvalidArgument on invalid values.
There was a problem hiding this comment.
SGD also has no parameter validation (sgd.h), and PyTorch's C++ AdamW documents 0 <= beta < 1 as a precondition rather than a runtime check. Adding validation only to AdamW would be inconsistent. Happy to add range checks across both optimizers in a follow-up if the maintainer wants.
| } | ||
| auto g = named_gradient->second; | ||
| auto p = param_iter->second; | ||
|
|
There was a problem hiding this comment.
AdamW::step immediately treats both parameters and gradients as float tensors and iterates assuming matching shapes. If a caller passes a non-float tensor or a gradient whose shape/numel differs from the parameter, this can lead to undefined behavior or out-of-bounds reads/writes (e.g., in addcmul_sq_out_hack/adamw_update_hack). Add explicit scalar_type/shape checks (at least dtype==Float and p.numel()==g.numel()) and return Error::InvalidArgument when unsupported.
| if ( | |
| p.scalar_type() != executorch::aten::ScalarType::Float || | |
| g.scalar_type() != executorch::aten::ScalarType::Float) { | |
| return Error::InvalidArgument; | |
| } | |
| if (p.numel() != g.numel()) { | |
| return Error::InvalidArgument; | |
| } |
| AdamWParamState* state_ptr = nullptr; | ||
| if (param_state_it == state_.end()) { | ||
| void* m_buf_ptr = malloc(g.nbytes()); | ||
| void* v_buf_ptr = malloc(g.nbytes()); |
There was a problem hiding this comment.
The per-parameter state allocation uses malloc() but never checks for allocation failure. On memory-constrained targets this can turn into a null dereference in memset()/from_blob/TensorImpl construction. Check m_buf_ptr/v_buf_ptr for nullptr, free any partially-allocated buffers, and return Error::MemoryAllocationFailed.
| void* v_buf_ptr = malloc(g.nbytes()); | |
| void* v_buf_ptr = malloc(g.nbytes()); | |
| if (m_buf_ptr == nullptr || v_buf_ptr == nullptr) { | |
| free(m_buf_ptr); | |
| free(v_buf_ptr); | |
| return Error::MemoryAllocationFailed; | |
| } |
|
You can ignore copilot probably its noise ratio is pretty bad. Ive been trying to figure out how to turn it off for the repo. |
|
|
||
| namespace { | ||
| // out[i] = a[i] + alpha * b[i] | ||
| void add_out_hack( |
There was a problem hiding this comment.
ahh my legacy... @manuelcandales do you remember how to call executorch ops outside of the interpreter. I cant recall
Replace manual TensorImpl construction with make_tensor_ptr from extension/tensor, removing the #ifdef USE_ATEN_LIB block and simplifying the destructor. Store defaults_ by value since it is always initialized.
Adds AdamW to the training optimizer extension. It's a port of the existing SGD implementation at
extension/training/optimizer/sgd.{h,cpp}, with the main algorithmic difference being decoupled weight decay (the parameter gets decayed directly instead of mixing the decay into the gradient). Matchestorch.optim.AdamWwith default settings.Fixes #18766
Scope
C++ only for this PR. Python bindings are left out on purpose: the pybindings file has a TODO to build a generic optimizer interface first, so copying
PySGDtoPyAdamWnow would just add duplication. Happy to follow up with that.amsgradandmaximizeare also left out, both rarely used and easy to add later if needed.Test plan
Six new gtests pass, and the SGD regression stays green:
Output was also cross-checked against
torch.optim.AdamWon four small cases (simple convergence, decoupled weight decay, multi-parameter). All four match to six decimal places.cc @JacobSzwejbka