Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@
{MakeAttribute("to", int64_t(IElemType(0)))})};
}

IMPLEMENT_GRADIENT_BUILDER(GetCastLikeGradient) {
std::vector<NodeDef> result;
result.push_back(
NodeDef("Cast",
{GO(0)}, {GI(0)},
{MakeAttribute("to", int64_t(IElemType(0)))}));

Check warning on line 84 in orttraining/orttraining/core/graph/gradient_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using deprecated casting style. Use static_cast<int64_t>(...) instead [readability/casting] [4] Raw Output: orttraining/orttraining/core/graph/gradient_builder.cc:84: Using deprecated casting style. Use static_cast<int64_t>(...) instead [readability/casting] [4]
result.push_back(ConstantScalarNode(0, "0", IElemType(1)));
result.push_back(NodeDef("Shape", {I(1)}, {ArgDef("gi_shape")}));
result.push_back(NodeDef("Expand", {ArgDef("0"), ArgDef("gi_shape")}, {GI(1)}));
return result;
}

IMPLEMENT_GRADIENT_BUILDER(GetSinGradient) {
std::vector<NodeDef> result;
result.push_back(NodeDef("Cos", {I(0)}, {IA("Cos_O0")}));
Expand Down
1 change: 1 addition & 0 deletions orttraining/orttraining/core/graph/gradient_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace training {
};

DECLARE_GRADIENT_BUILDER(GetCastGradient)
DECLARE_GRADIENT_BUILDER(GetCastLikeGradient)
DECLARE_GRADIENT_BUILDER(GetSinGradient)
DECLARE_GRADIENT_BUILDER(GetCosGradient)
DECLARE_GRADIENT_BUILDER(GetLogGradient)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ GradientDef GetGradientForOp(const GradientGraphConfiguration& gradient_graph_co
void GradientBuilderRegistry::RegisterGradientBuilders() {
// Register gradient builders here.
REGISTER_GRADIENT_BUILDER("Cast", GetCastGradient);
REGISTER_GRADIENT_BUILDER("CastLike", GetCastLikeGradient);
REGISTER_GRADIENT_BUILDER("Sin", GetSinGradient);
REGISTER_GRADIENT_BUILDER("Cos", GetCosGradient);
REGISTER_GRADIENT_BUILDER("Log", GetLogGradient);
Expand Down
62 changes: 62 additions & 0 deletions orttraining/orttraining/test/gradient/gradient_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,68 @@ TEST(GradientCheckerTest, CastGrad) {
}
}

TEST(GradientCheckerTest, CastLikeGrad) {
OpDef op_def{"CastLike", kOnnxDomain, 15};
float error_tolerance = 1e-3f;
// dummy test like CastGrad
{
TensorShape shape({2, 3, 4});
float max_error;
GradientChecker<float, float, float> gradient_checker;

ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {shape, shape}, {shape}, &max_error));
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
}

// float <-> double
{
TensorInfo info_f({2, 3, 4}, true, nullptr, DataTypeImpl::GetTensorType<float>());
TensorInfo info_d({2, 3, 4}, true, nullptr, DataTypeImpl::GetTensorType<double>());
// float -> double
{
float max_error;
GradientChecker<float, float, float> gradient_checker;

ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {info_f, info_d}, {info_d}, &max_error));
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
}
// double -> float
{
float max_error;
GradientChecker<float, float, float> gradient_checker;

ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {info_d, info_f}, {info_f}, &max_error));
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
}
}

// float <-> int32_t
{
TensorInfo info_i({2, 3, 4}, false, nullptr, DataTypeImpl::GetTensorType<int32_t>());
TensorInfo info_f({2, 3, 4}, true, nullptr, DataTypeImpl::GetTensorType<float>());
/*
// float -> int32_t
// This part causes programming error
{
float max_error;
GradientChecker<float, float, float> gradient_checker;
TensorInfo info_i_g({2, 3, 4}, true, nullptr, DataTypeImpl::GetTensorType<int32_t>());

ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {info_f, info_i}, {info_i_g}, &max_error));
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
}
*/
// int32_t -> float
{
float max_error;
GradientChecker<float, float, float> gradient_checker;

ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {info_i, info_f}, {info_f}, &max_error));
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
}
}
}

TEST(GradientCheckerTest, SplitGrad) {
TensorShape shape({9, 5});
float max_error;
Expand Down
Loading