55#include < ATen/Dispatch.h>
66#include < ATen/Parallel.h>
77#include < ATen/cpu/vec/vec.h>
8+ #include < ATen/cpu/vec/functional.h>
89#include < ATen/native/Pool.h>
910#include < ATen/native/cpu/utils.h>
1011#include < c10/util/irange.h>
@@ -60,13 +61,15 @@ vec::Vectorized<int64_t> is_nan_vec<int64_t>(vec::Vectorized<int64_t> vec) {
6061 return ret;
6162}
6263
63- template <typename scalar_t , typename accscalar_t >
64- inline void compute_internal (
64+ template <typename scalar_t , typename opmath_t >
65+ inline
66+ typename std::enable_if<std::is_same<scalar_t , opmath_t >::value, void >::type
67+ compute_internal (
6568 scalar_t * input_data,
6669 scalar_t * out_data,
67- accscalar_t * max_ptr,
68- vec::int_same_size_t <accscalar_t >* index_ptr,
69- int64_t * ind,
70+ opmath_t * max_ptr,
71+ vec::int_same_size_t <opmath_t >* index_ptr,
72+ int64_t * ind,
7073 int64_t input_depth, int64_t input_height, int64_t input_width, int64_t channels,
7174 int64_t n,
7275 int64_t len,
@@ -78,7 +81,7 @@ inline void compute_internal(
7881 int64_t dilationH,
7982 int64_t dilationW) {
8083 using Vec = vec::Vectorized<scalar_t >;
81- using integer_t = vec::int_same_size_t <accscalar_t >;
84+ using integer_t = vec::int_same_size_t <opmath_t >;
8285 using iVec = vec::Vectorized<integer_t >;
8386 // Pass I: init out lane
8487 iVec index0_vec = iVec (id0 * input_height * input_width + ih0 * input_width + iw0);
@@ -130,13 +133,16 @@ inline void compute_internal(
130133 }
131134}
132135
133- template <>
134- inline void compute_internal (
135- BFloat16* input_data,
136- BFloat16* out_data,
137- float * max_ptr,
138- int32_t * index_ptr,
139- int64_t * ind,
136+ // std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
137+ template <typename scalar_t , typename opmath_t >
138+ inline
139+ typename std::enable_if<!std::is_same<scalar_t , opmath_t >::value, void >::type
140+ compute_internal (
141+ scalar_t * input_data,
142+ scalar_t * out_data,
143+ opmath_t * max_ptr,
144+ vec::int_same_size_t <opmath_t >* index_ptr,
145+ int64_t * ind,
140146 int64_t input_depth, int64_t input_height, int64_t input_width, int64_t channels,
141147 int64_t n,
142148 int64_t len,
@@ -147,34 +153,34 @@ inline void compute_internal(
147153 int64_t dilationD,
148154 int64_t dilationH,
149155 int64_t dilationW) {
150- using bVec = vec::Vectorized<BFloat16 >;
151- using fVec = vec::Vectorized<float >;
156+ using Vec = vec::Vectorized<scalar_t >;
157+ using fVec = vec::Vectorized<opmath_t >;
152158 using iVec = vec::Vectorized<int32_t >;
153159 // Pass I: init out lane
154160 iVec index0_vec = iVec (id0 * input_height * input_width + ih0 * input_width + iw0);
155- fVec out_vec = fVec (-std::numeric_limits<float >::infinity ());
161+ fVec out_vec = fVec (-std::numeric_limits<opmath_t >::infinity ());
156162 int64_t d1 = 0 ;
157163 for (; d1 < len; d1 += fVec::size ()) {
158164 index0_vec.store (index_ptr + d1);
159165 out_vec.store (max_ptr + d1);
160166 }
161167 for (; d1 < size; d1++) {
162168 ind[d1] = ih0 * input_width + iw0;
163- max_ptr[d1] = -std::numeric_limits<float >::infinity ();
169+ max_ptr[d1] = -std::numeric_limits<opmath_t >::infinity ();
164170 }
165171 // Pass II: compute local max
166172 for (int64_t id = id0; id < id1; id += dilationD) {
167173 for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
168174 for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
169- BFloat16 * in = input_data + (n * input_depth * input_height * input_width +
175+ scalar_t * in = input_data + (n * input_depth * input_height * input_width +
170176 id * input_height * input_width + ih * input_width + iw) * channels;
171177
172178 int64_t d2 = 0 ;
173- for (; d2 < len; d2 += bVec ::size ()) {
179+ for (; d2 < len; d2 += Vec ::size ()) {
174180 iVec index_ivec = iVec (id * input_height * input_width + ih * input_width + iw);
175- bVec val_bvec = bVec ::loadu (in + d2);
181+ Vec val_bvec = Vec ::loadu (in + d2);
176182 fVec val_fvec0, val_fvec1;
177- std::tie (val_fvec0, val_fvec1) = convert_bfloat16_float (val_bvec);
183+ std::tie (val_fvec0, val_fvec1) = convert_to_float< scalar_t > (val_bvec);
178184
179185 iVec maxindex_ivec0 = iVec::loadu (index_ptr + d2);
180186 iVec maxindex_ivec1 = iVec::loadu (index_ptr + d2 + iVec::size ());
@@ -200,9 +206,9 @@ inline void compute_internal(
200206 }
201207 for (; d2 < size; d2++) {
202208 int64_t index = id * input_height * input_width + ih * input_width + iw;
203- float val = float (in[d2]);
209+ opmath_t val = opmath_t (in[d2]);
204210 int64_t maxindex = ind[d2];
205- float maxval = max_ptr[d2];
211+ opmath_t maxval = max_ptr[d2];
206212
207213 bool mask = (val > maxval) || std::isnan (val);
208214 max_ptr[d2] = mask ? val : maxval;
@@ -211,16 +217,16 @@ inline void compute_internal(
211217 }
212218 }
213219 }
214- // Convert max values from float to bfloat16
220+ // Convert max values from float to bfloat16/half
215221 int64_t d3 = 0 ;
216- for (; d3 < len; d3 += bVec ::size ()) {
222+ for (; d3 < len; d3 += Vec ::size ()) {
217223 fVec max_fvec0 = fVec::loadu (max_ptr + d3);
218224 fVec max_fvec1 = fVec::loadu (max_ptr + d3 + fVec::size ());
219- bVec max_bvec = convert_float_bfloat16 (max_fvec0, max_fvec1);
225+ Vec max_bvec = convert_from_float< scalar_t > (max_fvec0, max_fvec1);
220226 max_bvec.store (out_data + d3);
221227 }
222228 for (; d3 < size; d3++) {
223- out_data[d3] = BFloat16 (max_ptr[d3]);
229+ out_data[d3] = scalar_t (max_ptr[d3]);
224230 }
225231}
226232
@@ -281,7 +287,7 @@ void cpu_max_pool(
281287 int64_t output_height = output.size (-2 );
282288 int64_t output_width = output.size (-1 );
283289
284- using accscalar_t = at::opmath_type<scalar_t >;
290+ using opmath_t = at::opmath_type<scalar_t >;
285291 // parallel on dim N, C
286292 at::parallel_for (0 , channels, 0 , [&](int64_t begin, int64_t end) {
287293 for (int64_t c = begin; c < end; c++) {
@@ -306,17 +312,18 @@ void cpu_max_pool(
306312
307313 // compute local max
308314 int64_t maxindex = id0 * input_height * input_width + ih0 * input_width + iw0;
309- accscalar_t maxval;
310- if (std::numeric_limits<accscalar_t >::has_infinity) {
311- maxval = -std::numeric_limits<accscalar_t >::infinity ();
315+ opmath_t maxval;
316+ if (std::numeric_limits<opmath_t >::has_infinity) {
317+ maxval = -std::numeric_limits<opmath_t >::infinity ();
312318 } else {
313- maxval = std::numeric_limits<accscalar_t >::min ();
319+ maxval = std::numeric_limits<opmath_t >::min ();
314320 }
321+
315322 for (int64_t id = id0; id < id1; id += dilationD) {
316323 for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
317324 for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
318325 int64_t index = id * input_height * input_width + ih * input_width + iw;
319- accscalar_t val = input_ptr[index];
326+ opmath_t val = input_ptr[index];
320327 if ((val > maxval) || is_nan (static_cast <double >(val))) {
321328 maxval = val;
322329 maxindex = index;
@@ -396,9 +403,9 @@ void cpu_max_pool_channels_last(
396403 int64_t output_height = output.size (-2 );
397404 int64_t output_width = output.size (-1 );
398405
399- using accscalar_t = at::opmath_type<scalar_t >;
406+ using opmath_t = at::opmath_type<scalar_t >;
400407 using Vec = vec::Vectorized<scalar_t >;
401- using integer_t = vec::int_same_size_t <accscalar_t >;
408+ using integer_t = vec::int_same_size_t <opmath_t >;
402409 // for the convience of vectorization, use integer of the same size of scalar_t,
403410 // e.g. int32_t for float, int64_t for double
404411 // need to make sure doesn't overflow
@@ -418,11 +425,11 @@ void cpu_max_pool_channels_last(
418425 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
419426 std::unique_ptr<integer_t []> index_buffer (new integer_t [len]);
420427 integer_t * index_ptr = index_buffer.get ();
421- // temp buffer holding max value with accscalar_t
422- std::unique_ptr<accscalar_t []> max_arr;
423- accscalar_t * max_ptr = nullptr ;
424- if (!std::is_same<scalar_t , accscalar_t >::value) {
425- max_arr = std::make_unique<accscalar_t []>(size);
428+ // temp buffer holding max value with opmath_t
429+ std::unique_ptr<opmath_t []> max_arr;
430+ opmath_t * max_ptr = nullptr ;
431+ if (!std::is_same<scalar_t , opmath_t >::value) {
432+ max_arr = std::make_unique<opmath_t []>(size);
426433 max_ptr = max_arr.get ();
427434 }
428435
@@ -598,13 +605,13 @@ void max_pool2d_kernel_impl(
598605 int dilationW, int dilationH) {
599606 switch (input.suggest_memory_format ()) {
600607 case at::MemoryFormat::Contiguous: {
601- AT_DISPATCH_ALL_TYPES_AND (ScalarType::BFloat16, input.scalar_type (), " max_pool2d" , [&] {
608+ AT_DISPATCH_ALL_TYPES_AND2 (ScalarType::BFloat16, ScalarType::Half , input.scalar_type (), " max_pool2d" , [&] {
602609 cpu_max_pool<scalar_t , /* is 3d*/ false >(output, indices, input, {kW , kH }, {dW, dH}, {padW, padH}, {dilationW, dilationH});
603610 });
604611 break ;
605612 }
606613 case at::MemoryFormat::ChannelsLast: {
607- AT_DISPATCH_ALL_TYPES_AND (ScalarType::BFloat16, input.scalar_type (), " max_pool2d_channels_last" , [&] {
614+ AT_DISPATCH_ALL_TYPES_AND2 (ScalarType::BFloat16, ScalarType::Half , input.scalar_type (), " max_pool2d_channels_last" , [&] {
608615 cpu_max_pool_channels_last<scalar_t , false >(output, indices, input, {kW , kH }, {dW, dH}, {padW, padH}, {dilationW, dilationH});
609616 });
610617 break ;
@@ -637,7 +644,7 @@ void max_pool3d_kernel_impl(
637644 DimVector indices_sizes (indices.sizes ().begin (), indices.sizes ().end ());
638645 indices_sizes.insert (indices_sizes.begin (), 1 );
639646 indices.resize_ (indices_sizes, at::MemoryFormat::ChannelsLast3d);
640- AT_DISPATCH_ALL_TYPES_AND (ScalarType::BFloat16, input.scalar_type (), " max_pool3d_channels_last" , [&] {
647+ AT_DISPATCH_ALL_TYPES_AND2 (ScalarType::BFloat16, ScalarType::Half , input.scalar_type (), " max_pool3d_channels_last" , [&] {
641648 cpu_max_pool_channels_last<scalar_t , /* is 3d*/ true >(output, indices, input_cl_check,
642649 {kW , kH , kD }, {dW, dH, dD}, {padW, padH, padD}, {dilationW, dilationH, dilationD});
643650 });
@@ -648,14 +655,14 @@ void max_pool3d_kernel_impl(
648655 }
649656 switch (input.suggest_memory_format ()) {
650657 case at::MemoryFormat::Contiguous: {
651- AT_DISPATCH_ALL_TYPES_AND (ScalarType::BFloat16, input.scalar_type (), " max_pool3d" , [&] {
658+ AT_DISPATCH_ALL_TYPES_AND2 (ScalarType::BFloat16, ScalarType::Half , input.scalar_type (), " max_pool3d" , [&] {
652659 cpu_max_pool<scalar_t , /* is 3d*/ true >(output, indices, input,
653660 {kW , kH , kD }, {dW, dH, dD}, {padW, padH, padD}, {dilationW, dilationH, dilationD});
654661 });
655662 break ;
656663 }
657664 case at::MemoryFormat::ChannelsLast3d: {
658- AT_DISPATCH_ALL_TYPES_AND (ScalarType::BFloat16, input.scalar_type (), " max_pool3d_channels_last" , [&] {
665+ AT_DISPATCH_ALL_TYPES_AND2 (ScalarType::BFloat16, ScalarType::Half , input.scalar_type (), " max_pool3d_channels_last" , [&] {
659666 cpu_max_pool_channels_last<scalar_t , true >(output, indices, input,
660667 {kW , kH , kD }, {dW, dH, dD}, {padW, padH, padD}, {dilationW, dilationH, dilationD});
661668 });
@@ -672,13 +679,13 @@ void max_pool2d_backward_kernel_impl(
672679 const Tensor& indices) {
673680 switch (grad_output.suggest_memory_format ()) {
674681 case at::MemoryFormat::Contiguous: {
675- AT_DISPATCH_FLOATING_TYPES_AND (ScalarType::BFloat16, grad_output.scalar_type (), " max_pool2d_backward" , [&] {
682+ AT_DISPATCH_FLOATING_TYPES_AND2 (ScalarType::BFloat16, ScalarType::Half , grad_output.scalar_type (), " max_pool2d_backward" , [&] {
676683 cpu_max_pool_backward<scalar_t , /* is 3d*/ false >(grad_input, grad_output, indices);
677684 });
678685 break ;
679686 }
680687 case at::MemoryFormat::ChannelsLast: {
681- AT_DISPATCH_FLOATING_TYPES_AND (ScalarType::BFloat16, grad_output.scalar_type (), " max_pool2d_backward_channels_last" , [&] {
688+ AT_DISPATCH_FLOATING_TYPES_AND2 (ScalarType::BFloat16, ScalarType::Half , grad_output.scalar_type (), " max_pool2d_backward_channels_last" , [&] {
682689 cpu_max_pool_backward_channels_last<scalar_t , /* is 3d*/ false >(grad_input, grad_output, indices);
683690 });
684691 break ;
@@ -705,7 +712,7 @@ void max_pool3d_backward_kernel_impl(
705712 sizes.insert (sizes.begin (), 1 );
706713 grad_input.resize_ (sizes, at::MemoryFormat::ChannelsLast3d);
707714 auto _indices = indices.unsqueeze (0 ).contiguous (at::MemoryFormat::ChannelsLast3d);
708- AT_DISPATCH_FLOATING_TYPES_AND (ScalarType::BFloat16, grad_output.scalar_type (), " max_pool3d_backward_channels_last" , [&] {
715+ AT_DISPATCH_FLOATING_TYPES_AND2 (ScalarType::BFloat16, ScalarType::Half , grad_output.scalar_type (), " max_pool3d_backward_channels_last" , [&] {
709716 cpu_max_pool_backward_channels_last<scalar_t , /* is_3d*/ true >(grad_input, grad_output_cl_check, _indices);
710717 });
711718 grad_input.squeeze_ (0 );
@@ -714,13 +721,13 @@ void max_pool3d_backward_kernel_impl(
714721 }
715722 switch (grad_output.suggest_memory_format ()) {
716723 case at::MemoryFormat::Contiguous: {
717- AT_DISPATCH_FLOATING_TYPES_AND (ScalarType::BFloat16, grad_output.scalar_type (), " max_pool3d_backward" , [&] {
724+ AT_DISPATCH_FLOATING_TYPES_AND2 (ScalarType::BFloat16, ScalarType::Half , grad_output.scalar_type (), " max_pool3d_backward" , [&] {
718725 cpu_max_pool_backward<scalar_t , /* is_3d*/ true >(grad_input, grad_output, indices);
719726 });
720727 break ;
721728 }
722729 case at::MemoryFormat::ChannelsLast3d: {
723- AT_DISPATCH_FLOATING_TYPES_AND (ScalarType::BFloat16, grad_output.scalar_type (), " max_pool3d_backward_channels_last" , [&] {
730+ AT_DISPATCH_FLOATING_TYPES_AND2 (ScalarType::BFloat16, ScalarType::Half , grad_output.scalar_type (), " max_pool3d_backward_channels_last" , [&] {
724731 cpu_max_pool_backward_channels_last<scalar_t , /* is_3d*/ true >(grad_input, grad_output, indices);
725732 });
726733 break ;
0 commit comments