-
Notifications
You must be signed in to change notification settings - Fork 934
Add AdamW optimizer to extension/training #18848
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,217 @@ | ||||||||||||||||||||||||
| /* | ||||||||||||||||||||||||
| * Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||||||||||||||||||||
| * All rights reserved. | ||||||||||||||||||||||||
| * | ||||||||||||||||||||||||
| * This source code is licensed under the BSD-style license found in the | ||||||||||||||||||||||||
| * LICENSE file in the root directory of this source tree. | ||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| #include <executorch/extension/training/optimizer/adamw.h> | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| #include <executorch/extension/tensor/tensor_ptr.h> | ||||||||||||||||||||||||
| #include <executorch/runtime/core/error.h> | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| #include <cmath> | ||||||||||||||||||||||||
| #include <cstring> | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| using executorch::aten::Tensor; | ||||||||||||||||||||||||
| using executorch::extension::make_tensor_ptr; | ||||||||||||||||||||||||
| using executorch::extension::TensorPtr; | ||||||||||||||||||||||||
| using ::executorch::runtime::Error; | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| namespace executorch { | ||||||||||||||||||||||||
| namespace extension { | ||||||||||||||||||||||||
| namespace training { | ||||||||||||||||||||||||
| namespace optimizer { | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| namespace { | ||||||||||||||||||||||||
| // out[i] = a[i] + alpha * b[i] | ||||||||||||||||||||||||
| void add_out_hack( | ||||||||||||||||||||||||
| const Tensor& a, | ||||||||||||||||||||||||
| const Tensor& b, | ||||||||||||||||||||||||
| const double alpha, | ||||||||||||||||||||||||
| Tensor& out) { | ||||||||||||||||||||||||
| auto a_ptr = a.const_data_ptr<float>(); | ||||||||||||||||||||||||
| auto b_ptr = b.const_data_ptr<float>(); | ||||||||||||||||||||||||
| auto out_ptr = out.mutable_data_ptr<float>(); | ||||||||||||||||||||||||
| for (size_t i = 0; i < a.numel(); ++i) { | ||||||||||||||||||||||||
| out_ptr[i] = a_ptr[i] + b_ptr[i] * alpha; | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // out[i] = a[i] * alpha | ||||||||||||||||||||||||
| void mul_out_hack(const Tensor& a, const double alpha, Tensor& out) { | ||||||||||||||||||||||||
| auto a_ptr = a.const_data_ptr<float>(); | ||||||||||||||||||||||||
| auto out_ptr = out.mutable_data_ptr<float>(); | ||||||||||||||||||||||||
| for (size_t i = 0; i < a.numel(); ++i) { | ||||||||||||||||||||||||
| out_ptr[i] = a_ptr[i] * alpha; | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // Fused second-moment update: v[i] = beta2 * v[i] + (1 - beta2) * g[i]^2. | ||||||||||||||||||||||||
| // Avoids materializing a separate g^2 tensor. | ||||||||||||||||||||||||
| void addcmul_sq_out_hack( | ||||||||||||||||||||||||
| const Tensor& v, | ||||||||||||||||||||||||
| const Tensor& g, | ||||||||||||||||||||||||
| const double beta2, | ||||||||||||||||||||||||
| Tensor& out) { | ||||||||||||||||||||||||
| auto v_ptr = v.const_data_ptr<float>(); | ||||||||||||||||||||||||
| auto g_ptr = g.const_data_ptr<float>(); | ||||||||||||||||||||||||
| auto out_ptr = out.mutable_data_ptr<float>(); | ||||||||||||||||||||||||
| const double one_minus_beta2 = 1.0 - beta2; | ||||||||||||||||||||||||
| for (size_t i = 0; i < v.numel(); ++i) { | ||||||||||||||||||||||||
| const double gi = static_cast<double>(g_ptr[i]); | ||||||||||||||||||||||||
| out_ptr[i] = static_cast<float>( | ||||||||||||||||||||||||
| static_cast<double>(v_ptr[i]) * beta2 + one_minus_beta2 * gi * gi); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // Fused AdamW parameter update: | ||||||||||||||||||||||||
| // p[i] -= lr * (m[i] / bias_correction1) / | ||||||||||||||||||||||||
| // (sqrt(v[i] / bias_correction2) + eps) | ||||||||||||||||||||||||
| // Performed in double precision internally to limit accumulated FP error on | ||||||||||||||||||||||||
| // the division-by-sqrt path. | ||||||||||||||||||||||||
| void adamw_update_hack( | ||||||||||||||||||||||||
| Tensor& p, | ||||||||||||||||||||||||
| const Tensor& m, | ||||||||||||||||||||||||
| const Tensor& v, | ||||||||||||||||||||||||
| const double lr, | ||||||||||||||||||||||||
| const double bias_correction1, | ||||||||||||||||||||||||
| const double bias_correction2, | ||||||||||||||||||||||||
| const double eps) { | ||||||||||||||||||||||||
| auto p_ptr = p.mutable_data_ptr<float>(); | ||||||||||||||||||||||||
| auto m_ptr = m.const_data_ptr<float>(); | ||||||||||||||||||||||||
| auto v_ptr = v.const_data_ptr<float>(); | ||||||||||||||||||||||||
| const double inv_bc1 = 1.0 / bias_correction1; | ||||||||||||||||||||||||
| const double inv_sqrt_bc2 = 1.0 / std::sqrt(bias_correction2); | ||||||||||||||||||||||||
| for (size_t i = 0; i < p.numel(); ++i) { | ||||||||||||||||||||||||
| const double m_hat = static_cast<double>(m_ptr[i]) * inv_bc1; | ||||||||||||||||||||||||
| const double v_hat_sqrt = | ||||||||||||||||||||||||
| std::sqrt(static_cast<double>(v_ptr[i])) * inv_sqrt_bc2; | ||||||||||||||||||||||||
| p_ptr[i] = static_cast<float>( | ||||||||||||||||||||||||
| static_cast<double>(p_ptr[i]) - lr * m_hat / (v_hat_sqrt + eps)); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } // namespace | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| bool AdamWParamGroup::has_options() const { | ||||||||||||||||||||||||
| return options_ != nullptr; | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| AdamWOptions& AdamWParamGroup::options() { | ||||||||||||||||||||||||
| return *options_.get(); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| const AdamWOptions& AdamWParamGroup::options() const { | ||||||||||||||||||||||||
| return *options_.get(); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| void AdamWParamGroup::set_options(std::unique_ptr<AdamWOptions> options) { | ||||||||||||||||||||||||
| options_ = std::move(options); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| const std::map<std::string_view, executorch::aten::Tensor>& | ||||||||||||||||||||||||
| AdamWParamGroup::named_parameters() const { | ||||||||||||||||||||||||
| return named_parameters_; | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| void AdamW::add_param_group(const AdamWParamGroup& param_group) { | ||||||||||||||||||||||||
| AdamWParamGroup param_group_(param_group.named_parameters()); | ||||||||||||||||||||||||
| if (!param_group.has_options()) { | ||||||||||||||||||||||||
| param_group_.set_options(defaults_.clone()); | ||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||
| param_group_.set_options(param_group.options().clone()); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| param_groups_.emplace_back(std::move(param_group_)); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Error AdamW::step(const std::map<std::string_view, executorch::aten::Tensor>& | ||||||||||||||||||||||||
| named_gradients) { | ||||||||||||||||||||||||
| for (auto& group : param_groups_) { | ||||||||||||||||||||||||
| auto& options = static_cast<AdamWOptions&>(group.options()); | ||||||||||||||||||||||||
| const double lr = options.lr(); | ||||||||||||||||||||||||
| const double beta1 = options.beta1(); | ||||||||||||||||||||||||
| const double beta2 = options.beta2(); | ||||||||||||||||||||||||
| const double eps = options.eps(); | ||||||||||||||||||||||||
| const double weight_decay = options.weight_decay(); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| for (auto param_iter = group.named_parameters().begin(); | ||||||||||||||||||||||||
| param_iter != group.named_parameters().end(); | ||||||||||||||||||||||||
| ++param_iter) { | ||||||||||||||||||||||||
| const auto& named_gradient = named_gradients.find(param_iter->first); | ||||||||||||||||||||||||
| if (named_gradient == named_gradients.end()) { | ||||||||||||||||||||||||
| continue; | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| auto g = named_gradient->second; | ||||||||||||||||||||||||
| auto p = param_iter->second; | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
| if ( | |
| p.scalar_type() != executorch::aten::ScalarType::Float || | |
| g.scalar_type() != executorch::aten::ScalarType::Float) { | |
| return Error::InvalidArgument; | |
| } | |
| if (p.numel() != g.numel()) { | |
| return Error::InvalidArgument; | |
| } |
Copilot
AI
Apr 13, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The per-parameter state allocation uses malloc() but never checks for allocation failure. On memory-constrained targets this can turn into a null dereference in memset()/from_blob/TensorImpl construction. Check m_buf_ptr/v_buf_ptr for nullptr, free any partially-allocated buffers, and return Error::MemoryAllocationFailed.
| void* v_buf_ptr = malloc(g.nbytes()); | |
| void* v_buf_ptr = malloc(g.nbytes()); | |
| if (m_buf_ptr == nullptr || v_buf_ptr == nullptr) { | |
| free(m_buf_ptr); | |
| free(v_buf_ptr); | |
| return Error::MemoryAllocationFailed; | |
| } |
Copilot
AI
Apr 13, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AdamWOptions parameters are used in bias-correction divisors (1 - beta^step) and sqrt(bias_correction2). With beta1==1 or beta2==1 (or beta2<0), step() will divide by zero / take sqrt of a negative, producing NaNs/inf. Consider validating lr/beta1/beta2/eps ranges (e.g., lr>=0, 0<=beta<1, eps>0) either in AdamWOptions construction or at the start of step() and return Error::InvalidArgument on invalid values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGD also has no parameter validation (sgd.h), and PyTorch's C++ AdamW documents 0 <= beta < 1 as a precondition rather than a runtime check. Adding validation only to AdamW would be inconsistent. Happy to add range checks across both optimizers in a follow-up if the maintainer wants.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahh my legacy... @manuelcandales do you remember how to call executorch ops outside of the interpreter. I cant recall