@@ -46,17 +46,24 @@ Tensor& copy_out(
46
46
// @lint-ignore CLANGTIDY facebook-hte-CArray
47
47
static constexpr const char op_name[] = " copy.out" ;
48
48
49
- ET_SWITCH_REALHBBF16_TYPES (in.scalar_type (), ctx, " copy.out" , CTYPE, [&]() {
50
- utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
51
- [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
52
- ctx,
53
- in,
54
- utils::SupportedTensorDtypes::REALHBBF16,
55
- src,
56
- utils::SupportedTensorDtypes::REALHBBF16,
57
- out,
58
- utils::SupportedTensorDtypes::REALHBBF16);
59
- });
49
+ // Use direct copy fast path if broadcast is not needed and tensors are
50
+ // non-empty
51
+ if (internal::sizes_match_ignoring_leading_1s (out.sizes (), src.sizes ()) &&
52
+ src.numel () > 0 ) {
53
+ std::memcpy (out.mutable_data_ptr (), src.const_data_ptr (), src.nbytes ());
54
+ } else {
55
+ ET_SWITCH_REALHBBF16_TYPES (in.scalar_type (), ctx, " copy.out" , CTYPE, [&]() {
56
+ utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
57
+ [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
58
+ ctx,
59
+ in,
60
+ utils::SupportedTensorDtypes::REALHBBF16,
61
+ src,
62
+ utils::SupportedTensorDtypes::REALHBBF16,
63
+ out,
64
+ utils::SupportedTensorDtypes::REALHBBF16);
65
+ });
66
+ }
60
67
61
68
return out;
62
69
}
@@ -79,17 +86,24 @@ Tensor& copy_(
79
86
// @lint-ignore CLANGTIDY facebook-hte-CArray
80
87
static constexpr const char op_name[] = " copy_" ;
81
88
82
- ET_SWITCH_REALHBBF16_TYPES (in.scalar_type (), ctx, " copy_" , CTYPE, [&]() {
83
- utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
84
- [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
85
- ctx,
86
- in,
87
- utils::SupportedTensorDtypes::REALHBBF16,
88
- src,
89
- utils::SupportedTensorDtypes::REALHBBF16,
90
- in,
91
- utils::SupportedTensorDtypes::REALHBBF16);
92
- });
89
+ // Use direct copy fast path if broadcast is not needed and tensors are
90
+ // non-empty
91
+ if (internal::sizes_match_ignoring_leading_1s (in.sizes (), src.sizes ()) &&
92
+ src.numel () > 0 ) {
93
+ std::memcpy (in.mutable_data_ptr (), src.const_data_ptr (), in.nbytes ());
94
+ } else {
95
+ ET_SWITCH_REALHBBF16_TYPES (in.scalar_type (), ctx, " copy_" , CTYPE, [&]() {
96
+ utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
97
+ [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
98
+ ctx,
99
+ in,
100
+ utils::SupportedTensorDtypes::REALHBBF16,
101
+ src,
102
+ utils::SupportedTensorDtypes::REALHBBF16,
103
+ in,
104
+ utils::SupportedTensorDtypes::REALHBBF16);
105
+ });
106
+ }
93
107
94
108
return in;
95
109
}
0 commit comments