Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion extension/training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ current state.
## Layout
- `examples/` : Example end to end flows from model definition to optimizer.step()
- `module/`: Utility class to provide an improved UX when using ExecuTorch for Training.
- `optimizer/`: Cpp implementations of various optimizers, currently only SGD though Adam is planned.
- `optimizer/`: Cpp implementations of various optimizers, currently SGD and AdamW. Adam is planned.
- `test/`: Tests that cover multiple subdirs.

## Technical Birds Eye view
Expand Down
217 changes: 217 additions & 0 deletions extension/training/optimizer/adamw.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/extension/training/optimizer/adamw.h>

#include <executorch/extension/tensor/tensor_ptr.h>
#include <executorch/runtime/core/error.h>

#include <cmath>
#include <cstring>

using executorch::aten::Tensor;
using executorch::extension::make_tensor_ptr;
using executorch::extension::TensorPtr;
using ::executorch::runtime::Error;

namespace executorch {
namespace extension {
namespace training {
namespace optimizer {

namespace {
// out[i] = a[i] + alpha * b[i]
void add_out_hack(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh my legacy... @manuelcandales do you remember how to call executorch ops outside of the interpreter. I cant recall

const Tensor& a,
const Tensor& b,
const double alpha,
Tensor& out) {
auto a_ptr = a.const_data_ptr<float>();
auto b_ptr = b.const_data_ptr<float>();
auto out_ptr = out.mutable_data_ptr<float>();
for (size_t i = 0; i < a.numel(); ++i) {
out_ptr[i] = a_ptr[i] + b_ptr[i] * alpha;
}
}

// out[i] = a[i] * alpha
void mul_out_hack(const Tensor& a, const double alpha, Tensor& out) {
auto a_ptr = a.const_data_ptr<float>();
auto out_ptr = out.mutable_data_ptr<float>();
for (size_t i = 0; i < a.numel(); ++i) {
out_ptr[i] = a_ptr[i] * alpha;
}
}

// Fused second-moment update: v[i] = beta2 * v[i] + (1 - beta2) * g[i]^2.
// Avoids materializing a separate g^2 tensor.
void addcmul_sq_out_hack(
const Tensor& v,
const Tensor& g,
const double beta2,
Tensor& out) {
auto v_ptr = v.const_data_ptr<float>();
auto g_ptr = g.const_data_ptr<float>();
auto out_ptr = out.mutable_data_ptr<float>();
const double one_minus_beta2 = 1.0 - beta2;
for (size_t i = 0; i < v.numel(); ++i) {
const double gi = static_cast<double>(g_ptr[i]);
out_ptr[i] = static_cast<float>(
static_cast<double>(v_ptr[i]) * beta2 + one_minus_beta2 * gi * gi);
}
}

// Fused AdamW parameter update:
// p[i] -= lr * (m[i] / bias_correction1) /
// (sqrt(v[i] / bias_correction2) + eps)
// Performed in double precision internally to limit accumulated FP error on
// the division-by-sqrt path.
void adamw_update_hack(
Tensor& p,
const Tensor& m,
const Tensor& v,
const double lr,
const double bias_correction1,
const double bias_correction2,
const double eps) {
auto p_ptr = p.mutable_data_ptr<float>();
auto m_ptr = m.const_data_ptr<float>();
auto v_ptr = v.const_data_ptr<float>();
const double inv_bc1 = 1.0 / bias_correction1;
const double inv_sqrt_bc2 = 1.0 / std::sqrt(bias_correction2);
for (size_t i = 0; i < p.numel(); ++i) {
const double m_hat = static_cast<double>(m_ptr[i]) * inv_bc1;
const double v_hat_sqrt =
std::sqrt(static_cast<double>(v_ptr[i])) * inv_sqrt_bc2;
p_ptr[i] = static_cast<float>(
static_cast<double>(p_ptr[i]) - lr * m_hat / (v_hat_sqrt + eps));
}
}
} // namespace

bool AdamWParamGroup::has_options() const {
return options_ != nullptr;
}

AdamWOptions& AdamWParamGroup::options() {
return *options_.get();
}

const AdamWOptions& AdamWParamGroup::options() const {
return *options_.get();
}

void AdamWParamGroup::set_options(std::unique_ptr<AdamWOptions> options) {
options_ = std::move(options);
}

const std::map<std::string_view, executorch::aten::Tensor>&
AdamWParamGroup::named_parameters() const {
return named_parameters_;
}

void AdamW::add_param_group(const AdamWParamGroup& param_group) {
AdamWParamGroup param_group_(param_group.named_parameters());
if (!param_group.has_options()) {
param_group_.set_options(defaults_.clone());
} else {
param_group_.set_options(param_group.options().clone());
}
param_groups_.emplace_back(std::move(param_group_));
}

Error AdamW::step(const std::map<std::string_view, executorch::aten::Tensor>&
named_gradients) {
for (auto& group : param_groups_) {
auto& options = static_cast<AdamWOptions&>(group.options());
const double lr = options.lr();
const double beta1 = options.beta1();
const double beta2 = options.beta2();
const double eps = options.eps();
const double weight_decay = options.weight_decay();

for (auto param_iter = group.named_parameters().begin();
param_iter != group.named_parameters().end();
++param_iter) {
const auto& named_gradient = named_gradients.find(param_iter->first);
if (named_gradient == named_gradients.end()) {
continue;
}
auto g = named_gradient->second;
auto p = param_iter->second;

Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AdamW::step immediately treats both parameters and gradients as float tensors and iterates assuming matching shapes. If a caller passes a non-float tensor or a gradient whose shape/numel differs from the parameter, this can lead to undefined behavior or out-of-bounds reads/writes (e.g., in addcmul_sq_out_hack/adamw_update_hack). Add explicit scalar_type/shape checks (at least dtype==Float and p.numel()==g.numel()) and return Error::InvalidArgument when unsupported.

Suggested change
if (
p.scalar_type() != executorch::aten::ScalarType::Float ||
g.scalar_type() != executorch::aten::ScalarType::Float) {
return Error::InvalidArgument;
}
if (p.numel() != g.numel()) {
return Error::InvalidArgument;
}

Copilot uses AI. Check for mistakes.
// Decoupled weight decay: p <- p - lr * weight_decay * p. Applied to
// the parameter directly, BEFORE the moment-based update, and NOT
// folded into the gradient. This is the defining property of AdamW
// (Loshchilov & Hutter, 2019).
if (weight_decay != 0.0) {
add_out_hack(p, p, -lr * weight_decay, p);
}

// Look up or lazily allocate the per-parameter state (two moment
// buffers sized and shaped like the gradient, plus a step counter).
auto param_state_it = state_.find(p.unsafeGetTensorImpl());
AdamWParamState* state_ptr = nullptr;
if (param_state_it == state_.end()) {
void* m_buf_ptr = malloc(g.nbytes());
void* v_buf_ptr = malloc(g.nbytes());
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The per-parameter state allocation uses malloc() but never checks for allocation failure. On memory-constrained targets this can turn into a null dereference in memset()/from_blob/TensorImpl construction. Check m_buf_ptr/v_buf_ptr for nullptr, free any partially-allocated buffers, and return Error::MemoryAllocationFailed.

Suggested change
void* v_buf_ptr = malloc(g.nbytes());
void* v_buf_ptr = malloc(g.nbytes());
if (m_buf_ptr == nullptr || v_buf_ptr == nullptr) {
free(m_buf_ptr);
free(v_buf_ptr);
return Error::MemoryAllocationFailed;
}

Copilot uses AI. Check for mistakes.
std::memset(m_buf_ptr, 0, g.nbytes());
std::memset(v_buf_ptr, 0, g.nbytes());

std::vector<executorch::aten::SizesType> sizes(
g.sizes().begin(), g.sizes().end());
auto m_ptr = make_tensor_ptr(
sizes,
m_buf_ptr,
g.scalar_type(),
executorch::aten::TensorShapeDynamism::STATIC,
[](void* p) { free(p); });
auto v_ptr = make_tensor_ptr(
sizes,
v_buf_ptr,
g.scalar_type(),
executorch::aten::TensorShapeDynamism::STATIC,
[](void* p) { free(p); });

auto state = std::make_unique<AdamWParamState>(
std::move(m_ptr), std::move(v_ptr));
state_ptr = state.get();
state_[p.unsafeGetTensorImpl()] = std::move(state);
} else {
state_ptr = param_state_it->second.get();
}

state_ptr->increment_step_count();
const int64_t step = state_ptr->step_count();

Tensor& exp_avg = state_ptr->exp_avg();
Tensor& exp_avg_sq = state_ptr->exp_avg_sq();

// First moment: m <- beta1 * m + (1 - beta1) * g
mul_out_hack(exp_avg, beta1, exp_avg);
add_out_hack(exp_avg, g, 1.0 - beta1, exp_avg);

// Second moment: v <- beta2 * v + (1 - beta2) * g^2
addcmul_sq_out_hack(exp_avg_sq, g, beta2, exp_avg_sq);

// Bias-corrected update.
const double bias_correction1 = 1.0 - std::pow(beta1, step);
const double bias_correction2 = 1.0 - std::pow(beta2, step);
adamw_update_hack(
Comment on lines +202 to +205
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AdamWOptions parameters are used in bias-correction divisors (1 - beta^step) and sqrt(bias_correction2). With beta1==1 or beta2==1 (or beta2<0), step() will divide by zero / take sqrt of a negative, producing NaNs/inf. Consider validating lr/beta1/beta2/eps ranges (e.g., lr>=0, 0<=beta<1, eps>0) either in AdamWOptions construction or at the start of step() and return Error::InvalidArgument on invalid values.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGD also has no parameter validation (sgd.h), and PyTorch's C++ AdamW documents 0 <= beta < 1 as a precondition rather than a runtime check. Adding validation only to AdamW would be inconsistent. Happy to add range checks across both optimizers in a follow-up if the maintainer wants.

p, exp_avg, exp_avg_sq, lr, bias_correction1, bias_correction2, eps);
}
}
return Error::Ok;
}

AdamW::~AdamW() = default;

} // namespace optimizer
} // namespace training
} // namespace extension
} // namespace executorch
Loading
Loading