Skip to content

Commit 9ea9313

Browse files
authored
Add direct copy fast path for portable copy op
Differential Revision: D73656456 Pull Request resolved: #10487
1 parent 9cc9f82 commit 9ea9313

File tree

1 file changed

+36
-22
lines changed

1 file changed

+36
-22
lines changed

kernels/portable/cpu/op_copy.cpp

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,24 @@ Tensor& copy_out(
4646
// @lint-ignore CLANGTIDY facebook-hte-CArray
4747
static constexpr const char op_name[] = "copy.out";
4848

49-
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() {
50-
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
51-
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
52-
ctx,
53-
in,
54-
utils::SupportedTensorDtypes::REALHBBF16,
55-
src,
56-
utils::SupportedTensorDtypes::REALHBBF16,
57-
out,
58-
utils::SupportedTensorDtypes::REALHBBF16);
59-
});
49+
// Use direct copy fast path if broadcast is not needed and tensors are
50+
// non-empty
51+
if (internal::sizes_match_ignoring_leading_1s(out.sizes(), src.sizes()) &&
52+
src.numel() > 0) {
53+
std::memcpy(out.mutable_data_ptr(), src.const_data_ptr(), src.nbytes());
54+
} else {
55+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() {
56+
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
57+
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
58+
ctx,
59+
in,
60+
utils::SupportedTensorDtypes::REALHBBF16,
61+
src,
62+
utils::SupportedTensorDtypes::REALHBBF16,
63+
out,
64+
utils::SupportedTensorDtypes::REALHBBF16);
65+
});
66+
}
6067

6168
return out;
6269
}
@@ -79,17 +86,24 @@ Tensor& copy_(
7986
// @lint-ignore CLANGTIDY facebook-hte-CArray
8087
static constexpr const char op_name[] = "copy_";
8188

82-
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() {
83-
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
84-
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
85-
ctx,
86-
in,
87-
utils::SupportedTensorDtypes::REALHBBF16,
88-
src,
89-
utils::SupportedTensorDtypes::REALHBBF16,
90-
in,
91-
utils::SupportedTensorDtypes::REALHBBF16);
92-
});
89+
// Use direct copy fast path if broadcast is not needed and tensors are
90+
// non-empty
91+
if (internal::sizes_match_ignoring_leading_1s(in.sizes(), src.sizes()) &&
92+
src.numel() > 0) {
93+
std::memcpy(in.mutable_data_ptr(), src.const_data_ptr(), in.nbytes());
94+
} else {
95+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() {
96+
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
97+
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
98+
ctx,
99+
in,
100+
utils::SupportedTensorDtypes::REALHBBF16,
101+
src,
102+
utils::SupportedTensorDtypes::REALHBBF16,
103+
in,
104+
utils::SupportedTensorDtypes::REALHBBF16);
105+
});
106+
}
93107

94108
return in;
95109
}

0 commit comments

Comments
 (0)