File tree 6 files changed +131
-0
lines changed
extension/training/optimizer 6 files changed +131
-0
lines changed Original file line number Diff line number Diff line change
1
+ # Any targets that should be shared between fbcode and xplat must be defined in
2
+ # targets.bzl. This file can contain fbcode-only targets.
3
+
4
+ load(":targets.bzl", "define_common_targets")
5
+
6
+ oncall("executorch")
7
+
8
+ define_common_targets()
Original file line number Diff line number Diff line change
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ /* *
10
+ * SGD (stochastic gradient descent) optimizer to perform on-device training.
11
+ * This uses the gradients calculated in the backwards pass of the loss function
12
+ * and updates the parameters such that it minimizes the loss.
13
+ *
14
+ * This is similar to the Lite Interpreter implementation of the SGD optimizer,
15
+ * but without the dependency on ATen Tensors and autograd.
16
+ */
17
+ #pragma once
18
+
19
+ namespace torch {
20
+ namespace executor {
21
+ namespace optimizer {
22
+
23
+ /* *
24
+ * SGD optimizer state. This keeps track of the state of a given parameter to
25
+ * be used in later epochs.
26
+ */
27
+ class SGDParamState {};
28
+
29
+ /* *
30
+ * SGD optimizer options. This contains options for performing training on a
31
+ * param group, such as the learning rate.
32
+ */
33
+ class SGDOptions {};
34
+
35
+ /* *
36
+ * SGD optimizer param group. This contains the parameters and
37
+ * the OptimizerOptions associated to it.
38
+ */
39
+ class SGDParamGroup {};
40
+
41
+ /* *
42
+ * SGD optimizer class. This is responsible for performing the optimization
43
+ * step.
44
+ */
45
+ class SGD {};
46
+
47
+ } // namespace optimizer
48
+ } // namespace executor
49
+ } // namespace torch
Original file line number Diff line number Diff line change
1
+ load ("@fbsource//xplat/executorch/build:runtime_wrapper.bzl" , "runtime" )
2
+
3
+ def define_common_targets ():
4
+ """Defines targets that should be shared between fbcode and xplat.
5
+
6
+ The directory containing this targets.bzl file should also contain both
7
+ TARGETS and BUCK files that call this function.
8
+ """
9
+
10
+ runtime .cxx_library (
11
+ name = "optimizer" ,
12
+ exported_headers = [
13
+ "sgd.h" ,
14
+ ],
15
+ exported_deps = [
16
+ ],
17
+ visibility = [
18
+ "@EXECUTORCH_CLIENTS" ,
19
+ ],
20
+ )
Original file line number Diff line number Diff line change
1
+ # Any targets that should be shared between fbcode and xplat must be defined in
2
+ # targets.bzl. This file can contain fbcode-only targets.
3
+
4
+ load(":targets.bzl", "define_common_targets")
5
+
6
+ oncall("executorch")
7
+
8
+ define_common_targets()
Original file line number Diff line number Diff line change
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #include < executorch/extension/training/optimizer/sgd.h>
10
+
11
+ #include < gtest/gtest.h>
12
+
13
+ using namespace ::testing;
14
+ using namespace torch ::executor::optimizer;
15
+
16
+ class SGDOptimizerTest : public ::testing::Test {};
17
+
18
+ TEST_F (SGDOptimizerTest, InstantiateTypes) {
19
+ SGDParamState state;
20
+ SGDOptions options;
21
+ SGDParamGroup param_group;
22
+ SGD sgd;
23
+
24
+ EXPECT_TRUE (dynamic_cast <SGDParamState*>(&state) != nullptr );
25
+ EXPECT_TRUE (dynamic_cast <SGDOptions*>(&options) != nullptr );
26
+ EXPECT_TRUE (dynamic_cast <SGDParamGroup*>(¶m_group) != nullptr );
27
+ EXPECT_TRUE (dynamic_cast <SGD*>(&sgd) != nullptr );
28
+ }
Original file line number Diff line number Diff line change
1
+ load ("@fbsource//xplat/executorch/build:runtime_wrapper.bzl" , "runtime" )
2
+
3
+ def define_common_targets ():
4
+ """Defines targets that should be shared between fbcode and xplat.
5
+
6
+ The directory containing this targets.bzl file should also contain both
7
+ TARGETS and BUCK files that call this function.
8
+ """
9
+
10
+ runtime .cxx_test (
11
+ name = "sgd_test" ,
12
+ srcs = [
13
+ "sgd_test.cpp" ,
14
+ ],
15
+ deps = [
16
+ "//executorch/extension/training/optimizer:optimizer" ,
17
+ ],
18
+ )
You can’t perform that action at this time.
0 commit comments