Skip to content

Commit c853b3c

Browse files
David Linfacebook-github-bot
David Lin
authored andcommitted
Add base for sgd optimizer (#3496)
Summary: Pull Request resolved: #3496 This adds the sgd_optimizer header to executorch. would appreciate some thoughts on where to place this file. Reviewed By: JacobSzwejbka Differential Revision: D56888378 fbshipit-source-id: 17d6bb3975ae2d58aee911ee91a3ff07acbc6850
1 parent ebe701e commit c853b3c

File tree

6 files changed

+131
-0
lines changed

6 files changed

+131
-0
lines changed

extension/training/optimizer/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets()

extension/training/optimizer/sgd.h

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
/**
10+
* SGD (stochastic gradient descent) optimizer to perform on-device training.
11+
* This uses the gradients calculated in the backwards pass of the loss function
12+
* and updates the parameters such that it minimizes the loss.
13+
*
14+
* This is similar to the Lite Interpreter implementation of the SGD optimizer,
15+
* but without the dependency on ATen Tensors and autograd.
16+
*/
17+
#pragma once
18+
19+
namespace torch {
20+
namespace executor {
21+
namespace optimizer {
22+
23+
/**
24+
* SGD optimizer state. This keeps track of the state of a given parameter to
25+
* be used in later epochs.
26+
*/
27+
class SGDParamState {};
28+
29+
/**
30+
* SGD optimizer options. This contains options for performing training on a
31+
* param group, such as the learning rate.
32+
*/
33+
class SGDOptions {};
34+
35+
/**
36+
* SGD optimizer param group. This contains the parameters and
37+
* the OptimizerOptions associated to it.
38+
*/
39+
class SGDParamGroup {};
40+
41+
/**
42+
* SGD optimizer class. This is responsible for performing the optimization
43+
* step.
44+
*/
45+
class SGD {};
46+
47+
} // namespace optimizer
48+
} // namespace executor
49+
} // namespace torch
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
"""Defines targets that should be shared between fbcode and xplat.
5+
6+
The directory containing this targets.bzl file should also contain both
7+
TARGETS and BUCK files that call this function.
8+
"""
9+
10+
runtime.cxx_library(
11+
name = "optimizer",
12+
exported_headers = [
13+
"sgd.h",
14+
],
15+
exported_deps = [
16+
],
17+
visibility = [
18+
"@EXECUTORCH_CLIENTS",
19+
],
20+
)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets()
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/training/optimizer/sgd.h>
10+
11+
#include <gtest/gtest.h>
12+
13+
using namespace ::testing;
14+
using namespace torch::executor::optimizer;
15+
16+
class SGDOptimizerTest : public ::testing::Test {};
17+
18+
TEST_F(SGDOptimizerTest, InstantiateTypes) {
19+
SGDParamState state;
20+
SGDOptions options;
21+
SGDParamGroup param_group;
22+
SGD sgd;
23+
24+
EXPECT_TRUE(dynamic_cast<SGDParamState*>(&state) != nullptr);
25+
EXPECT_TRUE(dynamic_cast<SGDOptions*>(&options) != nullptr);
26+
EXPECT_TRUE(dynamic_cast<SGDParamGroup*>(&param_group) != nullptr);
27+
EXPECT_TRUE(dynamic_cast<SGD*>(&sgd) != nullptr);
28+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
"""Defines targets that should be shared between fbcode and xplat.
5+
6+
The directory containing this targets.bzl file should also contain both
7+
TARGETS and BUCK files that call this function.
8+
"""
9+
10+
runtime.cxx_test(
11+
name = "sgd_test",
12+
srcs = [
13+
"sgd_test.cpp",
14+
],
15+
deps = [
16+
"//executorch/extension/training/optimizer:optimizer",
17+
],
18+
)

0 commit comments

Comments
 (0)