Skip to content

Commit e4d3203

Browse files
authored
Revert "Migrate elementwise_util callers to the variants with out_dtypes in t…"
This reverts commit b01c7de.
1 parent 0fdc8df commit e4d3203

21 files changed

+118
-201
lines changed

kernels/portable/cpu/op_add.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,17 @@ Tensor& add_out(
5252

5353
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
5454
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
55-
utils::apply_bitensor_elementwise_fn<
56-
CTYPE_COMPUTE,
57-
op_name,
58-
utils::SupportedTensorDtypes::REALHBBF16>(
59-
[val_alpha](const auto val_a, const auto val_b) {
55+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
56+
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
6057
return val_a + val_alpha * val_b;
6158
},
6259
ctx,
6360
a,
6461
utils::SupportedTensorDtypes::REALHBBF16,
6562
b,
6663
utils::SupportedTensorDtypes::REALHBBF16,
67-
out);
64+
out,
65+
utils::SupportedTensorDtypes::REALHBBF16);
6866
});
6967

7068
return out;
@@ -102,19 +100,17 @@ Tensor& add_scalar_out(
102100
static constexpr const char op_name[] = "add.Scalar_out";
103101

104102
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
105-
utils::apply_unitensor_elementwise_fn<
106-
CTYPE_COMPUTE,
107-
op_name,
108-
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
109-
[b, alpha](const auto val_a) {
103+
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
104+
[b, alpha](const CTYPE_COMPUTE val_a) {
110105
CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
111106
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
112107
return val_a + val_alpha * val_b;
113108
},
114109
ctx,
115110
a,
116111
utils::SupportedTensorDtypes::REALHBBF16,
117-
out);
112+
out,
113+
utils::SupportedTensorDtypes::SAME_AS_COMMON);
118114
});
119115

120116
return out;

kernels/portable/cpu/op_addmm.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,19 +88,17 @@ Tensor& addmm_out(
8888
n,
8989
p);
9090

91-
utils::apply_bitensor_elementwise_fn<
92-
CTYPE,
93-
op_name,
94-
utils::SupportedTensorDtypes::REALHBF16>(
95-
[alpha_val, beta_val](const auto val_a, const auto val_b) {
91+
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
92+
[alpha_val, beta_val](const CTYPE val_a, const CTYPE val_b) {
9693
return val_a * alpha_val + val_b * beta_val;
9794
},
9895
ctx,
9996
out,
10097
utils::SupportedTensorDtypes::REALHBF16,
10198
in,
10299
utils::SupportedTensorDtypes::REALHBF16,
103-
out);
100+
out,
101+
utils::SupportedTensorDtypes::REALHBF16);
104102
}
105103
});
106104

kernels/portable/cpu/op_atan2.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,17 @@ Tensor& atan2_out(
5555
static constexpr const char op_name[] = "atan2.out";
5656

5757
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
58-
utils::apply_bitensor_elementwise_fn<
59-
CTYPE_COMPUTE,
60-
op_name,
61-
utils::SupportedTensorDtypes::FLOATHBF16>(
62-
[](const auto val_a, const auto val_b) {
58+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
59+
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
6360
return std::atan2(val_a, val_b);
6461
},
6562
ctx,
6663
a,
6764
utils::SupportedTensorDtypes::REALHBBF16,
6865
b,
6966
utils::SupportedTensorDtypes::REALHBBF16,
70-
out);
67+
out,
68+
utils::SupportedTensorDtypes::FLOATHBF16);
7169
});
7270

7371
return out;

kernels/portable/cpu/op_clamp.cpp

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,8 @@ Tensor& clamp_out(
134134
static constexpr const char op_name[] = "clamp.out";
135135

136136
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
137-
utils::apply_unitensor_elementwise_fn<
138-
CTYPE_COMPUTE,
139-
op_name,
140-
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
137+
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
141138
[has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) {
142-
// TODO: rewrite this to be vectorization-capable.
143139
CTYPE_COMPUTE val_out = val_in;
144140
if (has_min) {
145141
val_out = utils::max_override(
@@ -154,7 +150,8 @@ Tensor& clamp_out(
154150
ctx,
155151
in,
156152
utils::SupportedTensorDtypes::REALHBBF16,
157-
out);
153+
out,
154+
utils::SupportedTensorDtypes::SAME_AS_COMMON);
158155
});
159156

160157
return out;
@@ -213,15 +210,11 @@ Tensor& clamp_tensor_out(
213210
static constexpr const char op_name[] = "clamp.Tensor_out";
214211

215212
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
216-
utils::apply_tritensor_elementwise_fn<
217-
CTYPE_COMPUTE,
218-
op_name,
219-
utils::SupportedTensorDtypes::REALHBBF16>(
213+
utils::apply_tritensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
220214
[has_min, has_max](
221215
const CTYPE_COMPUTE val_in,
222216
const CTYPE_COMPUTE val_min,
223217
const CTYPE_COMPUTE val_max) {
224-
// TODO: rewrite this to be vectorization-capable.
225218
CTYPE_COMPUTE val_out = val_in;
226219
if (has_min) {
227220
val_out = utils::max_override(val_out, val_min);
@@ -238,7 +231,8 @@ Tensor& clamp_tensor_out(
238231
utils::SupportedTensorDtypes::REALHBBF16,
239232
max,
240233
utils::SupportedTensorDtypes::REALHBBF16,
241-
out);
234+
out,
235+
utils::SupportedTensorDtypes::REALHBBF16);
242236
});
243237

244238
return out;

kernels/portable/cpu/op_copy.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,15 @@ Tensor& copy_out(
4747
static constexpr const char op_name[] = "copy.out";
4848

4949
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() {
50-
utils::apply_bitensor_elementwise_fn<
51-
CTYPE,
52-
op_name,
53-
utils::SupportedTensorDtypes::REALHBBF16>(
54-
[](ET_UNUSED const auto _, const auto val_src) { return val_src; },
50+
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
51+
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
5552
ctx,
5653
in,
5754
utils::SupportedTensorDtypes::REALHBBF16,
5855
src,
5956
utils::SupportedTensorDtypes::REALHBBF16,
60-
out);
57+
out,
58+
utils::SupportedTensorDtypes::REALHBBF16);
6159
});
6260

6361
return out;
@@ -82,17 +80,15 @@ Tensor& copy_(
8280
static constexpr const char op_name[] = "copy_";
8381

8482
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() {
85-
utils::apply_bitensor_elementwise_fn<
86-
CTYPE,
87-
op_name,
88-
utils::SupportedTensorDtypes::REALHBBF16>(
89-
[](ET_UNUSED const auto _, const auto val_src) { return val_src; },
83+
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
84+
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
9085
ctx,
9186
in,
9287
utils::SupportedTensorDtypes::REALHBBF16,
9388
src,
9489
utils::SupportedTensorDtypes::REALHBBF16,
95-
in);
90+
in,
91+
utils::SupportedTensorDtypes::REALHBBF16);
9692
});
9793

9894
return in;

kernels/portable/cpu/op_div.cpp

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,17 @@ Tensor& div_out(
5858
static constexpr const char op_name[] = "div.out";
5959

6060
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
61-
utils::apply_bitensor_elementwise_fn<
62-
CTYPE_COMPUTE,
63-
op_name,
64-
utils::SupportedTensorDtypes::FLOATHBF16>(
65-
[](const auto val_a, const auto val_b) { return val_a / val_b; },
61+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
62+
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
63+
return val_a / val_b;
64+
},
6665
ctx,
6766
a,
6867
utils::SupportedTensorDtypes::REALHBBF16,
6968
b,
7069
utils::SupportedTensorDtypes::REALHBBF16,
71-
out);
70+
out,
71+
utils::SupportedTensorDtypes::FLOATHBF16);
7272
});
7373

7474
return out;
@@ -122,13 +122,9 @@ Tensor& div_out_mode(
122122
bool div_by_zero_error = false;
123123

124124
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
125-
utils::apply_bitensor_elementwise_fn<
126-
CTYPE_COMPUTE,
127-
op_name,
128-
utils::SupportedTensorDtypes::REALHBF16>(
125+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
129126
[mode_is_trunc, &div_by_zero_error](
130127
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
131-
// TODO: rewrite this to be vectorization-capable.
132128
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
133129
if (val_b == 0) {
134130
div_by_zero_error = true;
@@ -150,7 +146,8 @@ Tensor& div_out_mode(
150146
utils::SupportedTensorDtypes::REALHBBF16,
151147
b,
152148
utils::SupportedTensorDtypes::REALHBBF16,
153-
out);
149+
out,
150+
utils::SupportedTensorDtypes::REALHBF16);
154151
});
155152

156153
ET_KERNEL_CHECK_MSG(
@@ -191,15 +188,13 @@ Tensor& div_scalar_out(
191188

192189
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
193190
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
194-
utils::apply_unitensor_elementwise_fn<
195-
CTYPE_COMPUTE,
196-
op_name,
197-
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
198-
[val_b](const auto val_a) { return val_a / val_b; },
191+
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
192+
[val_b](const CTYPE_COMPUTE val_a) { return val_a / val_b; },
199193
ctx,
200194
a,
201195
utils::SupportedTensorDtypes::REALHBBF16,
202-
out);
196+
out,
197+
utils::SupportedTensorDtypes::SAME_AS_COMMON);
203198
});
204199

205200
return out;

kernels/portable/cpu/op_elu.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,17 @@ Tensor& elu_out(
4444
ET_EXTRACT_SCALAR(scale, math_scale);
4545
ET_EXTRACT_SCALAR(input_scale, math_input_scale);
4646
const auto negcoef = math_alpha * math_scale;
47-
utils::apply_unitensor_elementwise_fn<
48-
CTYPE,
49-
op_name,
50-
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
51-
[negcoef, math_scale, math_input_scale](const auto x) {
52-
// TODO: rewrite this to be vectorization-capable.
47+
utils::apply_unitensor_elementwise_fn<CTYPE, op_name>(
48+
[negcoef, math_scale, math_input_scale](auto x) {
5349
return MathT(x) <= MathT(0)
5450
? std::expm1(MathT(x) * math_input_scale) * negcoef
5551
: MathT(x) * math_scale;
5652
},
5753
ctx,
5854
in,
5955
utils::SupportedTensorDtypes::FLOATHBF16,
60-
out);
56+
out,
57+
utils::SupportedTensorDtypes::SAME_AS_COMMON);
6158
});
6259
return out;
6360
}

kernels/portable/cpu/op_floor_divide.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,9 @@ Tensor& floor_divide_out(
5353
bool div_by_zero_error = false;
5454

5555
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
56-
utils::apply_bitensor_elementwise_fn<
57-
CTYPE_COMPUTE,
58-
op_name,
59-
utils::SupportedTensorDtypes::REALHBF16>(
56+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
6057
[&div_by_zero_error](
6158
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
62-
// TODO: rewrite this to be vectorization-capable.
6359
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
6460
if (val_b == 0) {
6561
div_by_zero_error = true;
@@ -73,7 +69,8 @@ Tensor& floor_divide_out(
7369
utils::SupportedTensorDtypes::REALHBBF16,
7470
b,
7571
utils::SupportedTensorDtypes::REALHBBF16,
76-
out);
72+
out,
73+
utils::SupportedTensorDtypes::REALHBF16);
7774
});
7875

7976
ET_KERNEL_CHECK_MSG(

kernels/portable/cpu/op_fmod.cpp

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,9 @@ Tensor& fmod_Tensor_out(
5555
bool div_by_zero_error = false;
5656

5757
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
58-
utils::apply_bitensor_elementwise_fn<
59-
CTYPE_COMPUTE,
60-
op_name,
61-
utils::SupportedTensorDtypes::REALHBF16>(
58+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
6259
[&div_by_zero_error](
6360
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
64-
// TODO: rewrite this to be vectorization-capable.
6561
CTYPE_COMPUTE value = 0;
6662
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
6763
if (val_b == 0) {
@@ -77,7 +73,8 @@ Tensor& fmod_Tensor_out(
7773
utils::SupportedTensorDtypes::REALHBBF16,
7874
b,
7975
utils::SupportedTensorDtypes::REALHBBF16,
80-
out);
76+
out,
77+
utils::SupportedTensorDtypes::REALHBF16);
8178
});
8279

8380
ET_KERNEL_CHECK_MSG(
@@ -134,19 +131,16 @@ Tensor& fmod_Scalar_out(
134131

135132
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
136133
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
137-
utils::apply_unitensor_elementwise_fn<
138-
CTYPE_COMPUTE,
139-
op_name,
140-
utils::SupportedTensorDtypes::REALHBF16>(
134+
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
141135
[val_b](const CTYPE_COMPUTE val_a) {
142-
// TODO: rewrite this to be vectorization-capable.
143136
CTYPE_COMPUTE value = std::fmod(val_a, val_b);
144137
return value;
145138
},
146139
ctx,
147140
a,
148141
utils::SupportedTensorDtypes::REALHBBF16,
149-
out);
142+
out,
143+
utils::SupportedTensorDtypes::REALHBF16);
150144
});
151145

152146
return out;

kernels/portable/cpu/op_maximum.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,7 @@ Tensor& maximum_out(
4545
static constexpr const char op_name[] = "maximum.out";
4646

4747
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
48-
utils::apply_bitensor_elementwise_fn<
49-
CTYPE_COMPUTE,
50-
op_name,
51-
utils::SupportedTensorDtypes::REALHBBF16>(
48+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
5249
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
5350
return utils::max_override(val_a, val_b);
5451
},
@@ -57,7 +54,8 @@ Tensor& maximum_out(
5754
utils::SupportedTensorDtypes::REALHBBF16,
5855
b,
5956
utils::SupportedTensorDtypes::REALHBBF16,
60-
out);
57+
out,
58+
utils::SupportedTensorDtypes::REALHBBF16);
6159
});
6260

6361
return out;

0 commit comments

Comments
 (0)