Skip to content

Commit b98c3ab

Browse files
authored
Save some size in pattern/{bitwise,comparison}_op.h (#10489)
bloaty told me that we were paying a noticeable size cost for the ::value members of these structs (at least after the PR in this stack that reapplies #9841) and now we're not. Test Plan: bash test/build_optimized_size_test.sh ``` before: adopt functionref ========== ExecuTorch with no ops binary size, unstripped: -rwxr-xr-x 1 swolchok staff 153928 Apr 25 11:08 cmake-out/test/size_test ExecuTorch with portable ops binary size, unstripped: -rwxr-xr-x 1 swolchok staff 2150960 Apr 25 11:08 cmake-out/test/size_test_all_ops ExecuTorch with optimized ops binary size, unstripped: -rwxr-xr-x 1 swolchok staff 5927336 Apr 25 11:08 cmake-out/test/size_test_all_optimized_ops (.venv) swolchok@swolchok-mac ~/src/executorch> size cmake-out/test/size_test* __TEXT __DATA __OBJC others dec hex 81920 81920 0 4295049216 4295213056 10003c000 cmake-out/test/size_test 1474560 81920 0 4295655424 4297211904 100224000 cmake-out/test/size_test_all_ops 4505600 98304 0 4296376320 4300980224 1005bc000 cmake-out/test/size_test_all_optimized_ops after: ExecuTorch with no ops binary size, unstripped: -rwxr-xr-x 1 swolchok staff 153928 Apr 25 12:24 cmake-out/test/size_test ExecuTorch with portable ops binary size, unstripped: -rwxr-xr-x 1 swolchok staff 2150960 Apr 25 12:24 cmake-out/test/size_test_all_ops ExecuTorch with optimized ops binary size, unstripped: -rwxr-xr-x 1 swolchok staff 5887368 Apr 25 12:24 cmake-out/test/size_test_all_optimized_ops (.venv) swolchok@swolchok-mac ~/src/executorch> size cmake-out/test/size_test* __TEXT __DATA __OBJC others dec hex 81920 81920 0 4295049216 4295213056 10003c000 cmake-out/test/size_test 1474560 81920 0 4295655424 4297211904 100224000 cmake-out/test/size_test_all_ops 4489216 98304 0 4296359936 4300947456 1005b4000 cmake-out/test/size_test_all_optimized_ops ``` (yes it's neutral; improves size results for further diffs)
1 parent f7c906f commit b98c3ab

12 files changed

+63
-75
lines changed

.lintrunner.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,13 @@ exclude_patterns = [
220220
'extension/**',
221221
'kernels/optimized/**',
222222
# Justified <functional> include.
223+
'kernels/portable/cpu/op_bitwise*.cpp',
224+
'kernels/portable/cpu/op_eq.cpp',
225+
'kernels/portable/cpu/op_ge.cpp',
226+
'kernels/portable/cpu/op_gt.cpp',
227+
'kernels/portable/cpu/op_le.cpp',
228+
'kernels/portable/cpu/op_lt.cpp',
229+
'kernels/portable/cpu/op_ne.cpp',
223230
'runtime/kernel/thread_parallel_interface.h',
224231
'scripts/**',
225232
'third-party/**',

kernels/portable/cpu/op_bitwise_and.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include <executorch/kernels/portable/cpu/pattern/bitwise_op.h>
1010

11+
#include <functional>
12+
1113
namespace torch {
1214
namespace executor {
1315
namespace native {
@@ -19,7 +21,7 @@ Tensor& bitwise_and_Tensor_out(
1921
Tensor& out) {
2022
// @lint-ignore CLANGTIDY facebook-hte-CArray
2123
static constexpr const char op_name[] = "bitwise_and.Tensor_out";
22-
return internal::bitwise_tensor_out<op_name>(ctx, a, b, out);
24+
return internal::bitwise_tensor_out<std::bit_and, op_name>(ctx, a, b, out);
2325
}
2426

2527
Tensor& bitwise_and_Scalar_out(
@@ -29,7 +31,7 @@ Tensor& bitwise_and_Scalar_out(
2931
Tensor& out) {
3032
// @lint-ignore CLANGTIDY facebook-hte-CArray
3133
static constexpr const char op_name[] = "bitwise_and.Scalar_out";
32-
return internal::bitwise_scalar_out<op_name>(ctx, a, b, out);
34+
return internal::bitwise_scalar_out<std::bit_and, op_name>(ctx, a, b, out);
3335
}
3436

3537
} // namespace native

kernels/portable/cpu/op_bitwise_or.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include <executorch/kernels/portable/cpu/pattern/bitwise_op.h>
1010

11+
#include <functional>
12+
1113
namespace torch {
1214
namespace executor {
1315
namespace native {
@@ -19,7 +21,7 @@ Tensor& bitwise_or_Tensor_out(
1921
Tensor& out) {
2022
// @lint-ignore CLANGTIDY facebook-hte-CArray
2123
static constexpr const char op_name[] = "bitwise_or.Tensor_out";
22-
return internal::bitwise_tensor_out<op_name>(ctx, a, b, out);
24+
return internal::bitwise_tensor_out<std::bit_or, op_name>(ctx, a, b, out);
2325
}
2426

2527
Tensor& bitwise_or_Scalar_out(
@@ -29,7 +31,7 @@ Tensor& bitwise_or_Scalar_out(
2931
Tensor& out) {
3032
// @lint-ignore CLANGTIDY facebook-hte-CArray
3133
static constexpr const char op_name[] = "bitwise_or.Scalar_out";
32-
return internal::bitwise_scalar_out<op_name>(ctx, a, b, out);
34+
return internal::bitwise_scalar_out<std::bit_or, op_name>(ctx, a, b, out);
3335
}
3436

3537
} // namespace native

kernels/portable/cpu/op_bitwise_xor.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include <executorch/kernels/portable/cpu/pattern/bitwise_op.h>
1010

11+
#include <functional>
12+
1113
namespace torch {
1214
namespace executor {
1315
namespace native {
@@ -19,7 +21,7 @@ Tensor& bitwise_xor_Tensor_out(
1921
Tensor& out) {
2022
// @lint-ignore CLANGTIDY facebook-hte-CArray
2123
static constexpr const char op_name[] = "bitwise_xor.Tensor_out";
22-
return internal::bitwise_tensor_out<op_name>(ctx, a, b, out);
24+
return internal::bitwise_tensor_out<std::bit_xor, op_name>(ctx, a, b, out);
2325
}
2426

2527
Tensor& bitwise_xor_Scalar_out(
@@ -29,7 +31,7 @@ Tensor& bitwise_xor_Scalar_out(
2931
Tensor& out) {
3032
// @lint-ignore CLANGTIDY facebook-hte-CArray
3133
static constexpr const char op_name[] = "bitwise_xor.Scalar_out";
32-
return internal::bitwise_scalar_out<op_name>(ctx, a, b, out);
34+
return internal::bitwise_scalar_out<std::bit_xor, op_name>(ctx, a, b, out);
3335
}
3436

3537
} // namespace native

kernels/portable/cpu/op_eq.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include <executorch/kernels/portable/cpu/pattern/comparison_op.h>
1010

11+
#include <functional>
12+
1113
namespace torch {
1214
namespace executor {
1315
namespace native {
@@ -19,7 +21,8 @@ Tensor& eq_tensor_out(
1921
Tensor& out) {
2022
// @lint-ignore CLANGTIDY facebook-hte-CArray
2123
static constexpr const char op_name[] = "eq.Tensor_out";
22-
return internal::comparison_tensor_out<op_name>(ctx, a, b, out);
24+
return internal::comparison_tensor_out<std::equal_to, op_name>(
25+
ctx, a, b, out);
2326
}
2427

2528
Tensor& eq_scalar_out(
@@ -29,7 +32,8 @@ Tensor& eq_scalar_out(
2932
Tensor& out) {
3033
// @lint-ignore CLANGTIDY facebook-hte-CArray
3134
static constexpr const char op_name[] = "eq.Scalar_out";
32-
return internal::comparison_scalar_out<op_name>(ctx, a, b, out);
35+
return internal::comparison_scalar_out<std::equal_to, op_name>(
36+
ctx, a, b, out);
3337
}
3438

3539
} // namespace native

kernels/portable/cpu/op_ge.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include <executorch/kernels/portable/cpu/pattern/comparison_op.h>
1010

11+
#include <functional>
12+
1113
namespace torch {
1214
namespace executor {
1315
namespace native {
@@ -19,7 +21,8 @@ Tensor& ge_tensor_out(
1921
Tensor& out) {
2022
// @lint-ignore CLANGTIDY facebook-hte-CArray
2123
static constexpr const char op_name[] = "ge.Tensor_out";
22-
return internal::comparison_tensor_out<op_name>(ctx, a, b, out);
24+
return internal::comparison_tensor_out<std::greater_equal, op_name>(
25+
ctx, a, b, out);
2326
}
2427

2528
Tensor& ge_scalar_out(
@@ -29,7 +32,8 @@ Tensor& ge_scalar_out(
2932
Tensor& out) {
3033
// @lint-ignore CLANGTIDY facebook-hte-CArray
3134
static constexpr const char op_name[] = "ge.Scalar_out";
32-
return internal::comparison_scalar_out<op_name>(ctx, a, b, out);
35+
return internal::comparison_scalar_out<std::greater_equal, op_name>(
36+
ctx, a, b, out);
3337
}
3438

3539
} // namespace native

kernels/portable/cpu/op_gt.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include <executorch/kernels/portable/cpu/pattern/comparison_op.h>
1010

11+
#include <functional>
12+
1113
namespace torch {
1214
namespace executor {
1315
namespace native {
@@ -19,7 +21,7 @@ Tensor& gt_tensor_out(
1921
Tensor& out) {
2022
// @lint-ignore CLANGTIDY facebook-hte-CArray
2123
static constexpr const char op_name[] = "gt.Tensor_out";
22-
return internal::comparison_tensor_out<op_name>(ctx, a, b, out);
24+
return internal::comparison_tensor_out<std::greater, op_name>(ctx, a, b, out);
2325
}
2426

2527
Tensor& gt_scalar_out(
@@ -29,7 +31,7 @@ Tensor& gt_scalar_out(
2931
Tensor& out) {
3032
// @lint-ignore CLANGTIDY facebook-hte-CArray
3133
static constexpr const char op_name[] = "gt.Scalar_out";
32-
return internal::comparison_scalar_out<op_name>(ctx, a, b, out);
34+
return internal::comparison_scalar_out<std::greater, op_name>(ctx, a, b, out);
3335
}
3436

3537
} // namespace native

kernels/portable/cpu/op_le.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include <executorch/kernels/portable/cpu/pattern/comparison_op.h>
1010

11+
#include <functional>
12+
1113
namespace torch {
1214
namespace executor {
1315
namespace native {
@@ -19,7 +21,8 @@ Tensor& le_tensor_out(
1921
Tensor& out) {
2022
// @lint-ignore CLANGTIDY facebook-hte-CArray
2123
static constexpr const char op_name[] = "le.Tensor_out";
22-
return internal::comparison_tensor_out<op_name>(ctx, a, b, out);
24+
return internal::comparison_tensor_out<std::less_equal, op_name>(
25+
ctx, a, b, out);
2326
}
2427

2528
Tensor& le_scalar_out(
@@ -29,7 +32,8 @@ Tensor& le_scalar_out(
2932
Tensor& out) {
3033
// @lint-ignore CLANGTIDY facebook-hte-CArray
3134
static constexpr const char op_name[] = "le.Scalar_out";
32-
return internal::comparison_scalar_out<op_name>(ctx, a, b, out);
35+
return internal::comparison_scalar_out<std::less_equal, op_name>(
36+
ctx, a, b, out);
3337
}
3438

3539
} // namespace native

kernels/portable/cpu/op_lt.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include <executorch/kernels/portable/cpu/pattern/comparison_op.h>
1010

11+
#include <functional>
12+
1113
namespace torch {
1214
namespace executor {
1315
namespace native {
@@ -19,7 +21,7 @@ Tensor& lt_tensor_out(
1921
Tensor& out) {
2022
// @lint-ignore CLANGTIDY facebook-hte-CArray
2123
static constexpr const char op_name[] = "lt.Tensor_out";
22-
return internal::comparison_tensor_out<op_name>(ctx, a, b, out);
24+
return internal::comparison_tensor_out<std::less, op_name>(ctx, a, b, out);
2325
}
2426

2527
Tensor& lt_scalar_out(
@@ -29,7 +31,7 @@ Tensor& lt_scalar_out(
2931
Tensor& out) {
3032
// @lint-ignore CLANGTIDY facebook-hte-CArray
3133
static constexpr const char op_name[] = "lt.Scalar_out";
32-
return internal::comparison_scalar_out<op_name>(ctx, a, b, out);
34+
return internal::comparison_scalar_out<std::less, op_name>(ctx, a, b, out);
3335
}
3436

3537
} // namespace native

kernels/portable/cpu/op_ne.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include <executorch/kernels/portable/cpu/pattern/comparison_op.h>
1010

11+
#include <functional>
12+
1113
namespace torch {
1214
namespace executor {
1315
namespace native {
@@ -19,7 +21,8 @@ Tensor& ne_tensor_out(
1921
Tensor& out) {
2022
// @lint-ignore CLANGTIDY facebook-hte-CArray
2123
static constexpr const char op_name[] = "ne.Tensor_out";
22-
return internal::comparison_tensor_out<op_name>(ctx, a, b, out);
24+
return internal::comparison_tensor_out<std::not_equal_to, op_name>(
25+
ctx, a, b, out);
2326
}
2427

2528
Tensor& ne_scalar_out(
@@ -29,7 +32,8 @@ Tensor& ne_scalar_out(
2932
Tensor& out) {
3033
// @lint-ignore CLANGTIDY facebook-hte-CArray
3134
static constexpr const char op_name[] = "ne.Scalar_out";
32-
return internal::comparison_scalar_out<op_name>(ctx, a, b, out);
35+
return internal::comparison_scalar_out<std::not_equal_to, op_name>(
36+
ctx, a, b, out);
3337
}
3438

3539
} // namespace native

kernels/portable/cpu/pattern/bitwise_op.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,13 @@ constexpr bitwise_fn<T> get_bitwise_fn() {
4747

4848
template <typename T, const char* op_name>
4949
struct BitwiseFnForOp {
50-
static constexpr auto value = get_bitwise_fn<T, op_name>();
51-
static_assert(value != nullptr, "unknown op_name!");
50+
static constexpr auto get_value() {
51+
return get_bitwise_fn<T, op_name>();
52+
}
53+
static_assert(get_value() != nullptr, "unknown op_name!");
5254
};
5355

54-
template <const char* op_name>
56+
template <template <typename> class BitOp, const char* op_name>
5557
Tensor& bitwise_tensor_out(
5658
RuntimeContext& ctx,
5759
const Tensor& a,
@@ -81,7 +83,7 @@ Tensor& bitwise_tensor_out(
8183
ET_SWITCH_INT_TYPES_AND(
8284
Bool, compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
8385
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
84-
BitwiseFnForOp<CTYPE_COMPUTE, op_name>::value,
86+
BitOp<CTYPE_COMPUTE>(),
8587
ctx,
8688
a,
8789
utils::SupportedTensorDtypes::INTB,
@@ -94,7 +96,7 @@ Tensor& bitwise_tensor_out(
9496
return out;
9597
}
9698

97-
template <const char* op_name>
99+
template <template <typename> class BitOp, const char* op_name>
98100
Tensor& bitwise_scalar_out(
99101
RuntimeContext& ctx,
100102
const Tensor& a,
@@ -123,8 +125,7 @@ Tensor& bitwise_scalar_out(
123125
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
124126
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
125127
[val_b](const CTYPE_COMPUTE val_a) {
126-
return BitwiseFnForOp<CTYPE_COMPUTE, op_name>::value(
127-
val_a, val_b);
128+
return BitOp()(val_a, val_b);
128129
},
129130
ctx,
130131
a,

kernels/portable/cpu/pattern/comparison_op.h

Lines changed: 4 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -17,53 +17,7 @@ namespace executor {
1717
namespace native {
1818
namespace internal {
1919

20-
#define DEFINE_BINARY_OPERATOR_TEMPLATE(name, op) \
21-
template <typename T> \
22-
T name(const T val_a, const T val_b) { \
23-
return val_a op val_b; \
24-
}
25-
26-
DEFINE_BINARY_OPERATOR_TEMPLATE(eq, ==)
27-
DEFINE_BINARY_OPERATOR_TEMPLATE(ne, !=)
28-
DEFINE_BINARY_OPERATOR_TEMPLATE(ge, >=)
29-
DEFINE_BINARY_OPERATOR_TEMPLATE(le, <=)
30-
DEFINE_BINARY_OPERATOR_TEMPLATE(gt, >)
31-
DEFINE_BINARY_OPERATOR_TEMPLATE(lt, <)
32-
33-
template <typename T>
34-
using comparison_fn = T (*)(const T, const T);
35-
36-
template <typename T, const char* op_name>
37-
constexpr comparison_fn<T> get_comparison_fn() {
38-
std::string_view op = op_name;
39-
if (op == "eq.Tensor_out" || op == "eq.Scalar_out") {
40-
return eq;
41-
}
42-
if (op == "ne.Tensor_out" || op == "ne.Scalar_out") {
43-
return ne;
44-
}
45-
if (op == "ge.Tensor_out" || op == "ge.Scalar_out") {
46-
return ge;
47-
}
48-
if (op == "le.Tensor_out" || op == "le.Scalar_out") {
49-
return le;
50-
}
51-
if (op == "gt.Tensor_out" || op == "gt.Scalar_out") {
52-
return gt;
53-
}
54-
if (op == "lt.Tensor_out" || op == "lt.Scalar_out") {
55-
return lt;
56-
}
57-
return nullptr;
58-
};
59-
60-
template <typename T, const char* op_name>
61-
struct ComparisonFnForOp {
62-
static constexpr auto value = get_comparison_fn<T, op_name>();
63-
static_assert(value != nullptr, "unknown op_name!");
64-
};
65-
66-
template <const char* op_name>
20+
template <template <typename> class Comparison, const char* op_name>
6721
Tensor& comparison_tensor_out(
6822
KernelRuntimeContext& ctx,
6923
const Tensor& a,
@@ -92,7 +46,7 @@ Tensor& comparison_tensor_out(
9246

9347
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
9448
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
95-
ComparisonFnForOp<CTYPE_COMPUTE, op_name>::value,
49+
Comparison<CTYPE_COMPUTE>(),
9650
ctx,
9751
a,
9852
utils::SupportedTensorDtypes::REALHBBF16,
@@ -105,7 +59,7 @@ Tensor& comparison_tensor_out(
10559
return out;
10660
}
10761

108-
template <const char* op_name>
62+
template <template <typename> class Comparison, const char* op_name>
10963
Tensor& comparison_scalar_out(
11064
KernelRuntimeContext& ctx,
11165
const Tensor& a,
@@ -129,7 +83,7 @@ Tensor& comparison_scalar_out(
12983
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
13084
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
13185
[val_b](const CTYPE_COMPUTE val_a) {
132-
return ComparisonFnForOp<CTYPE_COMPUTE, op_name>::value(val_a, val_b);
86+
return Comparison<CTYPE_COMPUTE>()(val_a, val_b);
13387
},
13488
ctx,
13589
a,

0 commit comments

Comments
 (0)