Skip to content
20 changes: 12 additions & 8 deletions kernels/portable/cpu/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,19 @@ Tensor& add_out(

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[val_alpha](const auto val_a, const auto val_b) {
return val_a + val_alpha * val_b;
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
out);
});

return out;
Expand Down Expand Up @@ -100,17 +102,19 @@ Tensor& add_scalar_out(
static constexpr const char op_name[] = "add.Scalar_out";

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[b, alpha](const CTYPE_COMPUTE val_a) {
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
[b, alpha](const auto val_a) {
CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
return val_a + val_alpha * val_b;
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::SAME_AS_COMMON);
out);
});

return out;
Expand Down
10 changes: 6 additions & 4 deletions kernels/portable/cpu/op_addmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,19 @@ Tensor& addmm_out(
n,
p);

utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
[alpha_val, beta_val](const CTYPE val_a, const CTYPE val_b) {
utils::apply_bitensor_elementwise_fn<
CTYPE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
[alpha_val, beta_val](const auto val_a, const auto val_b) {
return val_a * alpha_val + val_b * beta_val;
},
ctx,
out,
utils::SupportedTensorDtypes::REALHBF16,
in,
utils::SupportedTensorDtypes::REALHBF16,
out,
utils::SupportedTensorDtypes::REALHBF16);
out);
}
});

Expand Down
10 changes: 6 additions & 4 deletions kernels/portable/cpu/op_atan2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,19 @@ Tensor& atan2_out(
static constexpr const char op_name[] = "atan2.out";

ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::FLOATHBF16>(
[](const auto val_a, const auto val_b) {
return std::atan2(val_a, val_b);
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::FLOATHBF16);
out);
});

return out;
Expand Down
18 changes: 12 additions & 6 deletions kernels/portable/cpu/op_clamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,12 @@ Tensor& clamp_out(
static constexpr const char op_name[] = "clamp.out";

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
[has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) {
// TODO: rewrite this to be vectorization-capable.
CTYPE_COMPUTE val_out = val_in;
if (has_min) {
val_out = utils::max_override(
Expand All @@ -150,8 +154,7 @@ Tensor& clamp_out(
ctx,
in,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::SAME_AS_COMMON);
out);
});

return out;
Expand Down Expand Up @@ -210,11 +213,15 @@ Tensor& clamp_tensor_out(
static constexpr const char op_name[] = "clamp.Tensor_out";

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_tritensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_tritensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[has_min, has_max](
const CTYPE_COMPUTE val_in,
const CTYPE_COMPUTE val_min,
const CTYPE_COMPUTE val_max) {
// TODO: rewrite this to be vectorization-capable.
CTYPE_COMPUTE val_out = val_in;
if (has_min) {
val_out = utils::max_override(val_out, val_min);
Expand All @@ -231,8 +238,7 @@ Tensor& clamp_tensor_out(
utils::SupportedTensorDtypes::REALHBBF16,
max,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
out);
});

return out;
Expand Down
16 changes: 10 additions & 6 deletions kernels/portable/cpu/op_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,17 @@ Tensor& copy_out(
std::memcpy(out.mutable_data_ptr(), src.const_data_ptr(), src.nbytes());
} else {
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
ctx,
in,
utils::SupportedTensorDtypes::REALHBBF16,
src,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
out);
});
}

Expand Down Expand Up @@ -93,15 +95,17 @@ Tensor& copy_(
std::memcpy(in.mutable_data_ptr(), src.const_data_ptr(), in.nbytes());
} else {
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
ctx,
in,
utils::SupportedTensorDtypes::REALHBBF16,
src,
utils::SupportedTensorDtypes::REALHBBF16,
in,
utils::SupportedTensorDtypes::REALHBBF16);
in);
});
}

Expand Down
31 changes: 18 additions & 13 deletions kernels/portable/cpu/op_div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,17 @@ Tensor& div_out(
static constexpr const char op_name[] = "div.out";

ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
return val_a / val_b;
},
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::FLOATHBF16>(
[](const auto val_a, const auto val_b) { return val_a / val_b; },
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::FLOATHBF16);
out);
});

return out;
Expand Down Expand Up @@ -122,9 +122,13 @@ Tensor& div_out_mode(
bool div_by_zero_error = false;

ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
[mode_is_trunc, &div_by_zero_error](
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
// TODO: rewrite this to be vectorization-capable.
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
if (val_b == 0) {
div_by_zero_error = true;
Expand All @@ -146,8 +150,7 @@ Tensor& div_out_mode(
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBF16);
out);
});

ET_KERNEL_CHECK_MSG(
Expand Down Expand Up @@ -188,13 +191,15 @@ Tensor& div_scalar_out(

ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[val_b](const CTYPE_COMPUTE val_a) { return val_a / val_b; },
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
[val_b](const auto val_a) { return val_a / val_b; },
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::SAME_AS_COMMON);
out);
});

return out;
Expand Down
11 changes: 7 additions & 4 deletions kernels/portable/cpu/op_elu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,20 @@ Tensor& elu_out(
ET_EXTRACT_SCALAR(scale, math_scale);
ET_EXTRACT_SCALAR(input_scale, math_input_scale);
const auto negcoef = math_alpha * math_scale;
utils::apply_unitensor_elementwise_fn<CTYPE, op_name>(
[negcoef, math_scale, math_input_scale](auto x) {
utils::apply_unitensor_elementwise_fn<
CTYPE,
op_name,
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
[negcoef, math_scale, math_input_scale](const auto x) {
// TODO: rewrite this to be vectorization-capable.
return MathT(x) <= MathT(0)
? std::expm1(MathT(x) * math_input_scale) * negcoef
: MathT(x) * math_scale;
},
ctx,
in,
utils::SupportedTensorDtypes::FLOATHBF16,
out,
utils::SupportedTensorDtypes::SAME_AS_COMMON);
out);
});
return out;
}
Expand Down
9 changes: 6 additions & 3 deletions kernels/portable/cpu/op_floor_divide.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,13 @@ Tensor& floor_divide_out(
bool div_by_zero_error = false;

ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
[&div_by_zero_error](
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
// TODO: rewrite this to be vectorization-capable.
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
if (val_b == 0) {
div_by_zero_error = true;
Expand All @@ -69,8 +73,7 @@ Tensor& floor_divide_out(
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBF16);
out);
});

ET_KERNEL_CHECK_MSG(
Expand Down
18 changes: 12 additions & 6 deletions kernels/portable/cpu/op_fmod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,13 @@ Tensor& fmod_Tensor_out(
bool div_by_zero_error = false;

ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
[&div_by_zero_error](
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
// TODO: rewrite this to be vectorization-capable.
CTYPE_COMPUTE value = 0;
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
if (val_b == 0) {
Expand All @@ -73,8 +77,7 @@ Tensor& fmod_Tensor_out(
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBF16);
out);
});

ET_KERNEL_CHECK_MSG(
Expand Down Expand Up @@ -131,16 +134,19 @@ Tensor& fmod_Scalar_out(

ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
[val_b](const CTYPE_COMPUTE val_a) {
// TODO: rewrite this to be vectorization-capable.
CTYPE_COMPUTE value = std::fmod(val_a, val_b);
return value;
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBF16);
out);
});

return out;
Expand Down
8 changes: 5 additions & 3 deletions kernels/portable/cpu/op_maximum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ Tensor& maximum_out(
static constexpr const char op_name[] = "maximum.out";

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
return utils::max_override(val_a, val_b);
},
Expand All @@ -54,8 +57,7 @@ Tensor& maximum_out(
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
out);
});

return out;
Expand Down
Loading
Loading