diff --git a/extension/training/README.md b/extension/training/README.md index ed2d65ef343..1f2f23d7290 100644 --- a/extension/training/README.md +++ b/extension/training/README.md @@ -8,7 +8,7 @@ current state. ## Layout - `examples/` : Example end to end flows from model definition to optimizer.step() - `module/`: Utility class to provide an improved UX when using ExecuTorch for Training. -- `optimizer/`: Cpp implementations of various optimizers, currently only SGD though Adam is planned. +- `optimizer/`: Cpp implementations of various optimizers, currently SGD and AdamW. Adam is planned. - `test/`: Tests that cover multiple subdirs. ## Technical Birds Eye view diff --git a/extension/training/optimizer/adamw.cpp b/extension/training/optimizer/adamw.cpp new file mode 100644 index 00000000000..7f63aa4946c --- /dev/null +++ b/extension/training/optimizer/adamw.cpp @@ -0,0 +1,217 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include + +using executorch::aten::Tensor; +using executorch::extension::make_tensor_ptr; +using executorch::extension::TensorPtr; +using ::executorch::runtime::Error; + +namespace executorch { +namespace extension { +namespace training { +namespace optimizer { + +namespace { +// out[i] = a[i] + alpha * b[i] +void add_out_hack( + const Tensor& a, + const Tensor& b, + const double alpha, + Tensor& out) { + auto a_ptr = a.const_data_ptr(); + auto b_ptr = b.const_data_ptr(); + auto out_ptr = out.mutable_data_ptr(); + for (size_t i = 0; i < a.numel(); ++i) { + out_ptr[i] = a_ptr[i] + b_ptr[i] * alpha; + } +} + +// out[i] = a[i] * alpha +void mul_out_hack(const Tensor& a, const double alpha, Tensor& out) { + auto a_ptr = a.const_data_ptr(); + auto out_ptr = out.mutable_data_ptr(); + for (size_t i = 0; i < a.numel(); ++i) { + out_ptr[i] = a_ptr[i] * alpha; + } +} + +// Fused second-moment update: v[i] = beta2 * v[i] + (1 - beta2) * g[i]^2. +// Avoids materializing a separate g^2 tensor. +void addcmul_sq_out_hack( + const Tensor& v, + const Tensor& g, + const double beta2, + Tensor& out) { + auto v_ptr = v.const_data_ptr(); + auto g_ptr = g.const_data_ptr(); + auto out_ptr = out.mutable_data_ptr(); + const double one_minus_beta2 = 1.0 - beta2; + for (size_t i = 0; i < v.numel(); ++i) { + const double gi = static_cast(g_ptr[i]); + out_ptr[i] = static_cast( + static_cast(v_ptr[i]) * beta2 + one_minus_beta2 * gi * gi); + } +} + +// Fused AdamW parameter update: +// p[i] -= lr * (m[i] / bias_correction1) / +// (sqrt(v[i] / bias_correction2) + eps) +// Performed in double precision internally to limit accumulated FP error on +// the division-by-sqrt path. +void adamw_update_hack( + Tensor& p, + const Tensor& m, + const Tensor& v, + const double lr, + const double bias_correction1, + const double bias_correction2, + const double eps) { + auto p_ptr = p.mutable_data_ptr(); + auto m_ptr = m.const_data_ptr(); + auto v_ptr = v.const_data_ptr(); + const double inv_bc1 = 1.0 / bias_correction1; + const double inv_sqrt_bc2 = 1.0 / std::sqrt(bias_correction2); + for (size_t i = 0; i < p.numel(); ++i) { + const double m_hat = static_cast(m_ptr[i]) * inv_bc1; + const double v_hat_sqrt = + std::sqrt(static_cast(v_ptr[i])) * inv_sqrt_bc2; + p_ptr[i] = static_cast( + static_cast(p_ptr[i]) - lr * m_hat / (v_hat_sqrt + eps)); + } +} +} // namespace + +bool AdamWParamGroup::has_options() const { + return options_ != nullptr; +} + +AdamWOptions& AdamWParamGroup::options() { + return *options_.get(); +} + +const AdamWOptions& AdamWParamGroup::options() const { + return *options_.get(); +} + +void AdamWParamGroup::set_options(std::unique_ptr options) { + options_ = std::move(options); +} + +const std::map& +AdamWParamGroup::named_parameters() const { + return named_parameters_; +} + +void AdamW::add_param_group(const AdamWParamGroup& param_group) { + AdamWParamGroup param_group_(param_group.named_parameters()); + if (!param_group.has_options()) { + param_group_.set_options(defaults_.clone()); + } else { + param_group_.set_options(param_group.options().clone()); + } + param_groups_.emplace_back(std::move(param_group_)); +} + +Error AdamW::step(const std::map& + named_gradients) { + for (auto& group : param_groups_) { + auto& options = static_cast(group.options()); + const double lr = options.lr(); + const double beta1 = options.beta1(); + const double beta2 = options.beta2(); + const double eps = options.eps(); + const double weight_decay = options.weight_decay(); + + for (auto param_iter = group.named_parameters().begin(); + param_iter != group.named_parameters().end(); + ++param_iter) { + const auto& named_gradient = named_gradients.find(param_iter->first); + if (named_gradient == named_gradients.end()) { + continue; + } + auto g = named_gradient->second; + auto p = param_iter->second; + + // Decoupled weight decay: p <- p - lr * weight_decay * p. Applied to + // the parameter directly, BEFORE the moment-based update, and NOT + // folded into the gradient. This is the defining property of AdamW + // (Loshchilov & Hutter, 2019). + if (weight_decay != 0.0) { + add_out_hack(p, p, -lr * weight_decay, p); + } + + // Look up or lazily allocate the per-parameter state (two moment + // buffers sized and shaped like the gradient, plus a step counter). + auto param_state_it = state_.find(p.unsafeGetTensorImpl()); + AdamWParamState* state_ptr = nullptr; + if (param_state_it == state_.end()) { + void* m_buf_ptr = malloc(g.nbytes()); + void* v_buf_ptr = malloc(g.nbytes()); + std::memset(m_buf_ptr, 0, g.nbytes()); + std::memset(v_buf_ptr, 0, g.nbytes()); + + std::vector sizes( + g.sizes().begin(), g.sizes().end()); + auto m_ptr = make_tensor_ptr( + sizes, + m_buf_ptr, + g.scalar_type(), + executorch::aten::TensorShapeDynamism::STATIC, + [](void* p) { free(p); }); + auto v_ptr = make_tensor_ptr( + sizes, + v_buf_ptr, + g.scalar_type(), + executorch::aten::TensorShapeDynamism::STATIC, + [](void* p) { free(p); }); + + auto state = std::make_unique( + std::move(m_ptr), std::move(v_ptr)); + state_ptr = state.get(); + state_[p.unsafeGetTensorImpl()] = std::move(state); + } else { + state_ptr = param_state_it->second.get(); + } + + state_ptr->increment_step_count(); + const int64_t step = state_ptr->step_count(); + + Tensor& exp_avg = state_ptr->exp_avg(); + Tensor& exp_avg_sq = state_ptr->exp_avg_sq(); + + // First moment: m <- beta1 * m + (1 - beta1) * g + mul_out_hack(exp_avg, beta1, exp_avg); + add_out_hack(exp_avg, g, 1.0 - beta1, exp_avg); + + // Second moment: v <- beta2 * v + (1 - beta2) * g^2 + addcmul_sq_out_hack(exp_avg_sq, g, beta2, exp_avg_sq); + + // Bias-corrected update. + const double bias_correction1 = 1.0 - std::pow(beta1, step); + const double bias_correction2 = 1.0 - std::pow(beta2, step); + adamw_update_hack( + p, exp_avg, exp_avg_sq, lr, bias_correction1, bias_correction2, eps); + } + } + return Error::Ok; +} + +AdamW::~AdamW() = default; + +} // namespace optimizer +} // namespace training +} // namespace extension +} // namespace executorch diff --git a/extension/training/optimizer/adamw.h b/extension/training/optimizer/adamw.h new file mode 100644 index 00000000000..131e752b9b4 --- /dev/null +++ b/extension/training/optimizer/adamw.h @@ -0,0 +1,221 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * AdamW optimizer to perform on-device training. This is an adaptation of the + * PyTorch AdamW implementation (Loshchilov & Hutter, 2019) that decouples + * weight decay from the gradient-based update. Per-parameter state consists of + * first and second moment running averages and a scalar step counter used for + * bias correction. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace extension { +namespace training { +namespace optimizer { + +/** + * AdamW optimizer state. Holds the two moment buffers and the step counter for + * a single parameter, to be reused across optimizer steps. + */ +class ET_EXPERIMENTAL AdamWParamState { + public: + /** + * Constructs a new AdamW param state. + * + * @param[in] exp_avg The first moment (EMA of gradients) buffer. + * @param[in] exp_avg_sq The second moment (EMA of squared gradients) buffer. + */ + AdamWParamState( + executorch::extension::TensorPtr exp_avg, + executorch::extension::TensorPtr exp_avg_sq) + : exp_avg_(std::move(exp_avg)), + exp_avg_sq_(std::move(exp_avg_sq)), + step_count_(0) {} + + executorch::aten::Tensor& exp_avg() { + return *exp_avg_; + } + + executorch::aten::Tensor& exp_avg_sq() { + return *exp_avg_sq_; + } + + int64_t step_count() const { + return step_count_; + } + + void increment_step_count() { + ++step_count_; + } + + private: + executorch::extension::TensorPtr exp_avg_; + executorch::extension::TensorPtr exp_avg_sq_; + int64_t step_count_; +}; + +/** + * AdamW optimizer options. Hyperparameters for a given parameter group. + */ +class ET_EXPERIMENTAL AdamWOptions { + public: + /** + * Constructs a new AdamW optimizer options. + * + * @param[in] lr The learning rate. + * @param[in] beta1 Exponential decay rate for the first moment estimate. + * @param[in] beta2 Exponential decay rate for the second moment estimate. + * @param[in] eps Small constant added to the denominator for numerical + * stability. + * @param[in] weight_decay Decoupled weight decay coefficient. Applied + * directly to the parameter (not folded into the gradient) per the AdamW + * formulation. + */ + explicit AdamWOptions( + double lr = 1e-3, + double beta1 = 0.9, + double beta2 = 0.999, + double eps = 1e-8, + double weight_decay = 1e-2) + : lr_(lr), + beta1_(beta1), + beta2_(beta2), + eps_(eps), + weight_decay_(weight_decay) {} + + std::unique_ptr clone() const { + return std::make_unique( + static_cast(*this)); + } + + double lr() const { + return lr_; + } + + double beta1() const { + return beta1_; + } + + double beta2() const { + return beta2_; + } + + double eps() const { + return eps_; + } + + double weight_decay() const { + return weight_decay_; + } + + private: + double lr_; + double beta1_; + double beta2_; + double eps_; + double weight_decay_; +}; + +/** + * AdamW optimizer param group. Holds a set of named parameters and the options + * governing their update. + */ +class ET_EXPERIMENTAL AdamWParamGroup { + public: + // NOTE: In order to store `AdamWParamGroup` in a `std::vector`, it has + // to be copy-constructible. + AdamWParamGroup(const AdamWParamGroup& param_group) + : named_parameters_(param_group.named_parameters()), + options_( + param_group.has_options() ? param_group.options().clone() + : nullptr) {} + AdamWParamGroup& operator=(const AdamWParamGroup& param_group) { + this->named_parameters_ = param_group.named_parameters_; + this->options_ = + param_group.has_options() ? param_group.options().clone() : nullptr; + return *this; + } + + /* implicit */ AdamWParamGroup( + const std::map& + named_parameters) + : named_parameters_(named_parameters) {} + AdamWParamGroup( + const std::map& + named_parameters, + std::unique_ptr options) + : named_parameters_(named_parameters), options_(std::move(options)) {} + + bool has_options() const; + AdamWOptions& options(); + const AdamWOptions& options() const; + void set_options(std::unique_ptr options); + const std::map& named_parameters() + const; + + private: + std::map named_parameters_; + std::unique_ptr options_; +}; + +/** + * AdamW optimizer class. Performs the optimization step. + */ +class ET_EXPERIMENTAL AdamW { + public: + explicit AdamW( + const std::vector& param_groups, + AdamWOptions defaults) + : defaults_(defaults) { + for (const auto& param_group : param_groups) { + add_param_group(param_group); + } + } + + explicit AdamW( + const std::map& + named_parameters, + AdamWOptions defaults) + : AdamW({AdamWParamGroup(named_parameters)}, defaults) {} + + // Adds the given param_group to the optimizer's param_group list. + void add_param_group(const AdamWParamGroup& param_group); + + ~AdamW(); + + /** + * Performs the optimization step. + * + * @param[in] named_gradients The gradients of the tensors specified by the + * fully qualified name. + */ + ::executorch::runtime::Error step( + const std::map& + named_gradients); + + private: + std::vector param_groups_; + std::unordered_map> state_; + AdamWOptions defaults_; +}; + +} // namespace optimizer +} // namespace training +} // namespace extension +} // namespace executorch diff --git a/extension/training/optimizer/targets.bzl b/extension/training/optimizer/targets.bzl index c99ae2a360d..3bfac5aa172 100644 --- a/extension/training/optimizer/targets.bzl +++ b/extension/training/optimizer/targets.bzl @@ -38,3 +38,19 @@ def define_common_targets(): ], # + kernel_deps, visibility = ["PUBLIC"], ) + + runtime.cxx_library( + name = "adamw" + aten_suffix, + srcs = [ + "adamw.cpp", + ], + exported_headers = [ + "adamw.h", + ], + exported_deps = [ + "//executorch/extension/tensor:tensor" + aten_suffix, + "//executorch/runtime/core:core", + "//executorch/runtime/core/exec_aten:lib" + aten_suffix, + ], + visibility = ["PUBLIC"], + ) diff --git a/extension/training/optimizer/test/adamw_test.cpp b/extension/training/optimizer/test/adamw_test.cpp new file mode 100644 index 00000000000..cbc86f5b246 --- /dev/null +++ b/extension/training/optimizer/test/adamw_test.cpp @@ -0,0 +1,137 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +#include + +// @lint-ignore-every CLANGTIDY facebook-hte-CArray + +using namespace ::testing; +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using ::executorch::extension::training::optimizer::AdamW; +using ::executorch::extension::training::optimizer::AdamWOptions; +using ::executorch::extension::training::optimizer::AdamWParamState; +using ::executorch::runtime::Error; +using ::executorch::runtime::testing::TensorFactory; + +class AdamWOptimizerTest : public ::testing::Test { + protected: + void SetUp() override { + torch::executor::runtime_init(); + } +}; + +TEST_F(AdamWOptimizerTest, AdamWParamStateTest) { + auto exp_avg = + executorch::extension::make_tensor_ptr({2, 2}, {0.f, 0.f, 0.f, 0.f}); + auto exp_avg_sq = + executorch::extension::make_tensor_ptr({2, 2}, {0.f, 0.f, 0.f, 0.f}); + AdamWParamState state(std::move(exp_avg), std::move(exp_avg_sq)); + + EXPECT_EQ(state.step_count(), 0); + state.increment_step_count(); + EXPECT_EQ(state.step_count(), 1); +} + +TEST_F(AdamWOptimizerTest, AdamWOptionsDefaultValuesTest) { + AdamWOptions options; + + EXPECT_DOUBLE_EQ(options.lr(), 1e-3); + EXPECT_DOUBLE_EQ(options.beta1(), 0.9); + EXPECT_DOUBLE_EQ(options.beta2(), 0.999); + EXPECT_DOUBLE_EQ(options.eps(), 1e-8); + EXPECT_DOUBLE_EQ(options.weight_decay(), 1e-2); +} + +TEST_F(AdamWOptimizerTest, AdamWOptionsNonDefaultValuesTest) { + AdamWOptions options(0.1, 0.8, 0.99, 1e-6, 0.5); + + EXPECT_DOUBLE_EQ(options.lr(), 0.1); + EXPECT_DOUBLE_EQ(options.beta1(), 0.8); + EXPECT_DOUBLE_EQ(options.beta2(), 0.99); + EXPECT_DOUBLE_EQ(options.eps(), 1e-6); + EXPECT_DOUBLE_EQ(options.weight_decay(), 0.5); +} + +TEST_F(AdamWOptimizerTest, AdamWOptimizerSimple) { + TensorFactory tf; + + std::map named_parameters; + named_parameters.insert({"param1", tf.make({1, 1}, {1.0})}); + + // lr=0.1, defaults otherwise, wd=0 to isolate the moment-based update. + AdamW optimizer(named_parameters, AdamWOptions{0.1, 0.9, 0.999, 1e-8, 0.0}); + + for (int i = 0; i < 10; ++i) { + std::map named_gradients; + named_gradients.insert({"param1", tf.make({1, 1}, {-1.0})}); + optimizer.step(named_gradients); + } + + auto p1 = + static_cast(named_parameters.at("param1").const_data_ptr()); + // With a constant gradient of -1 and no weight decay, the bias-corrected + // m_hat / sqrt(v_hat) is ~= -1 at every step, so each step shifts p by + // +lr. After 10 steps of lr=0.1, p should be near 2.0. + EXPECT_NEAR(p1[0], 2.0, 0.1); +} + +TEST_F(AdamWOptimizerTest, AdamWOptimizerDecoupledWeightDecay) { + TensorFactory tf; + + std::map named_parameters; + named_parameters.insert({"param1", tf.make({1, 1}, {1.0})}); + + // lr=0.1, wd=0.5. With a ZERO gradient, the moment update contributes + // nothing (m stays 0, v stays 0 -> m_hat/sqrt(v_hat+eps) ~= 0), so only + // the decoupled weight-decay term moves the parameter: + // p <- p * (1 - lr * wd) = 1.0 * (1 - 0.05) = 0.95 + // This is the test that distinguishes AdamW from Adam-with-L2. + AdamW optimizer(named_parameters, AdamWOptions{0.1, 0.9, 0.999, 1e-8, 0.5}); + + std::map named_gradients; + named_gradients.insert({"param1", tf.make({1, 1}, {0.0})}); + optimizer.step(named_gradients); + + auto p1 = + static_cast(named_parameters.at("param1").const_data_ptr()); + EXPECT_NEAR(p1[0], 0.95, 1e-5); +} + +TEST_F(AdamWOptimizerTest, AdamWOptimizerMultipleParams) { + TensorFactory tf; + + std::map named_parameters; + named_parameters.insert({"param1", tf.make({1, 1}, {1.0})}); + named_parameters.insert({"param2", tf.make({1, 1}, {2.0})}); + + AdamW optimizer(named_parameters, AdamWOptions{0.1, 0.9, 0.999, 1e-8, 0.0}); + + for (int i = 0; i < 5; ++i) { + std::map named_gradients; + named_gradients.insert({"param1", tf.make({1, 1}, {-1.0})}); + named_gradients.insert({"param2", tf.make({1, 1}, {1.0})}); + optimizer.step(named_gradients); + } + + auto p1 = + static_cast(named_parameters.at("param1").const_data_ptr()); + auto p2 = + static_cast(named_parameters.at("param2").const_data_ptr()); + // Each param sees a constant gradient of +/- 1 for 5 steps -> p shifts by + // roughly +/- 5 * lr = +/- 0.5. State is tracked independently per param. + EXPECT_NEAR(p1[0], 1.5, 0.1); + EXPECT_NEAR(p2[0], 1.5, 0.1); +} diff --git a/extension/training/optimizer/test/targets.bzl b/extension/training/optimizer/test/targets.bzl index 7a93337a379..ea5bd94e6f0 100644 --- a/extension/training/optimizer/test/targets.bzl +++ b/extension/training/optimizer/test/targets.bzl @@ -20,3 +20,15 @@ def define_common_targets(): "//executorch/runtime/core/exec_aten/testing_util:tensor_util", ], ) + + runtime.cxx_test( + name = "adamw_test" + aten_suffix, + srcs = [ + "adamw_test.cpp", + ], + deps = [ + "//executorch/extension/training/optimizer:adamw" + aten_suffix, + "//executorch/runtime/core:core", + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + ], + ) diff --git a/shim_et/xplat/executorch/build/build_variables.bzl b/shim_et/xplat/executorch/build/build_variables.bzl index edddc1da916..e2e9b330c47 100644 --- a/shim_et/xplat/executorch/build/build_variables.bzl +++ b/shim_et/xplat/executorch/build/build_variables.bzl @@ -379,6 +379,7 @@ EXTENSION_THREADPOOL_SRCS = ["extension/threadpool/" + x for x in THREADPOOL_SRC EXTENSION_TRAINING_SRCS = [ "extension/training/module/training_module.cpp", + "extension/training/optimizer/adamw.cpp", "extension/training/optimizer/sgd.cpp", ]