Skip to content

Commit 42f94d7

Browse files
CaoEpytorchmergebot
authored andcommitted
add Half support for maxpool on CPU (pytorch#98819)
### Testing Single socket (28 cores): shape | fp32 forward / ms | fp16 forward / ms | bf16 forward / ms | fp32 backward / ms | fp16 backward / ms | bf16 backward / ms -- | -- | -- | -- | -- | -- | -- size: (1, 56, 264, 264), kernel: 3, stride: 1, mem_format: contig | 4.12895 | 6.9669 | 5.30297 | 0.55775 | 1.98917 | 0.72233 size: (1, 56, 264, 264), kernel: 3, stride: 1, mem_format: CL | 0.85093 | 1.88813 | 1.38063 | 5.5742 | 36.5086 | 10.58552 size: (32, 16, 200, 200), kernel: 3, stride: 1, mem_format: contig | 22.37212 | 37.90383 | 30.94482 | 6.85868 | 10.6116 | 3.9993 size: (32, 16, 200, 200), kernel: 3, stride: 1, mem_format: CL | 5.41658 | 4.71098 | 4.66578 | 6.69875 | 14.7171 | 5.1167 size: (32, 32, 100, 100), kernel: 3, stride: 1, mem_format: contig | 10.69831 | 18.0468 | 13.71657 | 2.61192 | 4.96172 | 1.68635 size: (32, 32, 100, 100), kernel: 3, stride: 1, mem_format: CL | 2.52637 | 2.0096 | 2.0055 | 2.60314 | 7.2093 | 2.49843 size: (4, 19, 10, 16, 16), kernel: 3, stride: 1, mem_format: contig | 0.47605 | 0.88398 | 0.65326 | 0.06525 | 0.115489 | 0.0674 size: (4, 19, 10, 16, 16), kernel: 3, stride: 1, mem_format: CL3d | 0.10902 | 0.25293 | 0.157475 | 0.11386 | 0.53319 | 0.17836 Single core: shape | fp32 forward / ms | fp16 forward / ms | bf16 forward / ms | fp32 backward / ms | fp16 backward / ms | bf16 backward / ms -- | -- | -- | -- | -- | -- | -- size: (1, 56, 264, 264), kernel: 3, stride: 1, mem_format: contig | 90.9809 | 163.473 | 126.1276 | 6.57721 | 41.40833 | 11.82505 size: (1, 56, 264, 264), kernel: 3, stride: 1, mem_format: CL | 9.88405 | 38.39137 | 29.62069 | 7.10636 | 36.97535 | 11.0525 size: (32, 16, 200, 200), kernel: 3, stride: 1, mem_format: contig | 476.782 | 855.4769 | 648.2248 | 46.6488 | 219.2586 | 67.10599 size: (32, 16, 200, 200), kernel: 3, stride: 1, mem_format: CL | 80.29271 | 91.33854 | 87.80345 | 48.81692 | 203.9974 | 63.39004 size: (32, 32, 100, 100), kernel: 3, stride: 1, mem_format: contig | 235.2113 | 419.0799 | 315.4284 | 20.6049 | 107.1524 | 32.39169 size: (32, 32, 100, 100), kernel: 3, stride: 1, mem_format: CL | 29.47653 | 33.54905 | 32.82823 | 22.59674 | 98.5586 | 30.05763 size: (4, 19, 10, 16, 16), kernel: 3, stride: 1, mem_format: contig | 7.90684 | 13.9208 | 10.03272 | 0.23725 | 1.35269 | 0.41728 size: (4, 19, 10, 16, 16), kernel: 3, stride: 1, mem_format: CL3d | 2.33638 | 3.36894 | 2.64635 | 0.26535 | 1.244 | 0.38895 Pull Request resolved: pytorch#98819 Approved by: https://github.com/mingfeima, https://github.com/mikaylagawarecki
1 parent 1e0e55c commit 42f94d7

File tree

5 files changed

+110
-99
lines changed

5 files changed

+110
-99
lines changed

aten/src/ATen/native/cpu/MaxPoolKernel.cpp

Lines changed: 57 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <ATen/Dispatch.h>
66
#include <ATen/Parallel.h>
77
#include <ATen/cpu/vec/vec.h>
8+
#include <ATen/cpu/vec/functional.h>
89
#include <ATen/native/Pool.h>
910
#include <ATen/native/cpu/utils.h>
1011
#include <c10/util/irange.h>
@@ -60,13 +61,15 @@ vec::Vectorized<int64_t> is_nan_vec<int64_t>(vec::Vectorized<int64_t> vec) {
6061
return ret;
6162
}
6263

63-
template <typename scalar_t, typename accscalar_t>
64-
inline void compute_internal(
64+
template <typename scalar_t, typename opmath_t>
65+
inline
66+
typename std::enable_if<std::is_same<scalar_t, opmath_t>::value, void>::type
67+
compute_internal(
6568
scalar_t* input_data,
6669
scalar_t* out_data,
67-
accscalar_t* max_ptr,
68-
vec::int_same_size_t<accscalar_t>* index_ptr,
69-
int64_t* ind,
70+
opmath_t* max_ptr,
71+
vec::int_same_size_t<opmath_t>* index_ptr,
72+
int64_t* ind,
7073
int64_t input_depth, int64_t input_height, int64_t input_width, int64_t channels,
7174
int64_t n,
7275
int64_t len,
@@ -78,7 +81,7 @@ inline void compute_internal(
7881
int64_t dilationH,
7982
int64_t dilationW) {
8083
using Vec = vec::Vectorized<scalar_t>;
81-
using integer_t = vec::int_same_size_t<accscalar_t>;
84+
using integer_t = vec::int_same_size_t<opmath_t>;
8285
using iVec = vec::Vectorized<integer_t>;
8386
// Pass I: init out lane
8487
iVec index0_vec = iVec(id0 * input_height * input_width + ih0 * input_width + iw0);
@@ -130,13 +133,16 @@ inline void compute_internal(
130133
}
131134
}
132135

133-
template <>
134-
inline void compute_internal(
135-
BFloat16* input_data,
136-
BFloat16* out_data,
137-
float* max_ptr,
138-
int32_t* index_ptr,
139-
int64_t* ind,
136+
// std::is_same<scalar_t, at::BFloat16> || std::is_same<scalar_t, at::Half>
137+
template <typename scalar_t, typename opmath_t>
138+
inline
139+
typename std::enable_if<!std::is_same<scalar_t, opmath_t>::value, void>::type
140+
compute_internal(
141+
scalar_t* input_data,
142+
scalar_t* out_data,
143+
opmath_t* max_ptr,
144+
vec::int_same_size_t<opmath_t>* index_ptr,
145+
int64_t* ind,
140146
int64_t input_depth, int64_t input_height, int64_t input_width, int64_t channels,
141147
int64_t n,
142148
int64_t len,
@@ -147,34 +153,34 @@ inline void compute_internal(
147153
int64_t dilationD,
148154
int64_t dilationH,
149155
int64_t dilationW) {
150-
using bVec = vec::Vectorized<BFloat16>;
151-
using fVec = vec::Vectorized<float>;
156+
using Vec = vec::Vectorized<scalar_t>;
157+
using fVec = vec::Vectorized<opmath_t>;
152158
using iVec = vec::Vectorized<int32_t>;
153159
// Pass I: init out lane
154160
iVec index0_vec = iVec(id0 * input_height * input_width + ih0 * input_width + iw0);
155-
fVec out_vec = fVec(-std::numeric_limits<float>::infinity());
161+
fVec out_vec = fVec(-std::numeric_limits<opmath_t>::infinity());
156162
int64_t d1 = 0;
157163
for (; d1 < len; d1 += fVec::size()) {
158164
index0_vec.store(index_ptr + d1);
159165
out_vec.store(max_ptr + d1);
160166
}
161167
for (; d1 < size; d1++) {
162168
ind[d1] = ih0 * input_width + iw0;
163-
max_ptr[d1] = -std::numeric_limits<float>::infinity();
169+
max_ptr[d1] = -std::numeric_limits<opmath_t>::infinity();
164170
}
165171
// Pass II: compute local max
166172
for (int64_t id = id0; id < id1; id += dilationD) {
167173
for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
168174
for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
169-
BFloat16* in = input_data + (n * input_depth * input_height * input_width +
175+
scalar_t* in = input_data + (n * input_depth * input_height * input_width +
170176
id * input_height * input_width + ih * input_width + iw) * channels;
171177

172178
int64_t d2 = 0;
173-
for (; d2 < len; d2 += bVec::size()) {
179+
for (; d2 < len; d2 += Vec::size()) {
174180
iVec index_ivec = iVec(id * input_height * input_width + ih * input_width + iw);
175-
bVec val_bvec = bVec::loadu(in + d2);
181+
Vec val_bvec = Vec::loadu(in + d2);
176182
fVec val_fvec0, val_fvec1;
177-
std::tie(val_fvec0, val_fvec1) = convert_bfloat16_float(val_bvec);
183+
std::tie(val_fvec0, val_fvec1) = convert_to_float<scalar_t>(val_bvec);
178184

179185
iVec maxindex_ivec0 = iVec::loadu(index_ptr + d2);
180186
iVec maxindex_ivec1 = iVec::loadu(index_ptr + d2 + iVec::size());
@@ -200,9 +206,9 @@ inline void compute_internal(
200206
}
201207
for (; d2 < size; d2++) {
202208
int64_t index = id * input_height * input_width + ih * input_width + iw;
203-
float val = float(in[d2]);
209+
opmath_t val = opmath_t(in[d2]);
204210
int64_t maxindex = ind[d2];
205-
float maxval = max_ptr[d2];
211+
opmath_t maxval = max_ptr[d2];
206212

207213
bool mask = (val > maxval) || std::isnan(val);
208214
max_ptr[d2] = mask ? val : maxval;
@@ -211,16 +217,16 @@ inline void compute_internal(
211217
}
212218
}
213219
}
214-
// Convert max values from float to bfloat16
220+
// Convert max values from float to bfloat16/half
215221
int64_t d3 = 0;
216-
for (; d3 < len; d3 += bVec::size()) {
222+
for (; d3 < len; d3 += Vec::size()) {
217223
fVec max_fvec0 = fVec::loadu(max_ptr + d3);
218224
fVec max_fvec1 = fVec::loadu(max_ptr + d3 + fVec::size());
219-
bVec max_bvec = convert_float_bfloat16(max_fvec0, max_fvec1);
225+
Vec max_bvec = convert_from_float<scalar_t>(max_fvec0, max_fvec1);
220226
max_bvec.store(out_data + d3);
221227
}
222228
for (; d3 < size; d3++) {
223-
out_data[d3] = BFloat16(max_ptr[d3]);
229+
out_data[d3] = scalar_t(max_ptr[d3]);
224230
}
225231
}
226232

@@ -281,7 +287,7 @@ void cpu_max_pool(
281287
int64_t output_height = output.size(-2);
282288
int64_t output_width = output.size(-1);
283289

284-
using accscalar_t = at::opmath_type<scalar_t>;
290+
using opmath_t = at::opmath_type<scalar_t>;
285291
// parallel on dim N, C
286292
at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
287293
for (int64_t c = begin; c < end; c++) {
@@ -306,17 +312,18 @@ void cpu_max_pool(
306312

307313
// compute local max
308314
int64_t maxindex = id0 * input_height * input_width + ih0 * input_width + iw0;
309-
accscalar_t maxval;
310-
if (std::numeric_limits<accscalar_t>::has_infinity) {
311-
maxval = -std::numeric_limits<accscalar_t>::infinity();
315+
opmath_t maxval;
316+
if (std::numeric_limits<opmath_t>::has_infinity) {
317+
maxval = -std::numeric_limits<opmath_t>::infinity();
312318
} else {
313-
maxval = std::numeric_limits<accscalar_t>::min();
319+
maxval = std::numeric_limits<opmath_t>::min();
314320
}
321+
315322
for (int64_t id = id0; id < id1; id += dilationD) {
316323
for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
317324
for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
318325
int64_t index = id * input_height * input_width + ih * input_width + iw;
319-
accscalar_t val = input_ptr[index];
326+
opmath_t val = input_ptr[index];
320327
if ((val > maxval) || is_nan(static_cast<double>(val))) {
321328
maxval = val;
322329
maxindex = index;
@@ -396,9 +403,9 @@ void cpu_max_pool_channels_last(
396403
int64_t output_height = output.size(-2);
397404
int64_t output_width = output.size(-1);
398405

399-
using accscalar_t = at::opmath_type<scalar_t>;
406+
using opmath_t = at::opmath_type<scalar_t>;
400407
using Vec = vec::Vectorized<scalar_t>;
401-
using integer_t = vec::int_same_size_t<accscalar_t>;
408+
using integer_t = vec::int_same_size_t<opmath_t>;
402409
// for the convience of vectorization, use integer of the same size of scalar_t,
403410
// e.g. int32_t for float, int64_t for double
404411
// need to make sure doesn't overflow
@@ -418,11 +425,11 @@ void cpu_max_pool_channels_last(
418425
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
419426
std::unique_ptr<integer_t []> index_buffer(new integer_t[len]);
420427
integer_t * index_ptr = index_buffer.get();
421-
// temp buffer holding max value with accscalar_t
422-
std::unique_ptr<accscalar_t []> max_arr;
423-
accscalar_t* max_ptr = nullptr;
424-
if (!std::is_same<scalar_t, accscalar_t>::value) {
425-
max_arr = std::make_unique<accscalar_t[]>(size);
428+
// temp buffer holding max value with opmath_t
429+
std::unique_ptr<opmath_t []> max_arr;
430+
opmath_t* max_ptr = nullptr;
431+
if (!std::is_same<scalar_t, opmath_t>::value) {
432+
max_arr = std::make_unique<opmath_t[]>(size);
426433
max_ptr = max_arr.get();
427434
}
428435

@@ -598,13 +605,13 @@ void max_pool2d_kernel_impl(
598605
int dilationW, int dilationH) {
599606
switch (input.suggest_memory_format()) {
600607
case at::MemoryFormat::Contiguous: {
601-
AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool2d", [&] {
608+
AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool2d", [&] {
602609
cpu_max_pool<scalar_t, /*is 3d*/false>(output, indices, input, {kW, kH}, {dW, dH}, {padW, padH}, {dilationW, dilationH});
603610
});
604611
break;
605612
}
606613
case at::MemoryFormat::ChannelsLast: {
607-
AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool2d_channels_last", [&] {
614+
AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool2d_channels_last", [&] {
608615
cpu_max_pool_channels_last<scalar_t, false>(output, indices, input, {kW, kH}, {dW, dH}, {padW, padH}, {dilationW, dilationH});
609616
});
610617
break;
@@ -637,7 +644,7 @@ void max_pool3d_kernel_impl(
637644
DimVector indices_sizes(indices.sizes().begin(), indices.sizes().end());
638645
indices_sizes.insert(indices_sizes.begin(), 1);
639646
indices.resize_(indices_sizes, at::MemoryFormat::ChannelsLast3d);
640-
AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool3d_channels_last", [&] {
647+
AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool3d_channels_last", [&] {
641648
cpu_max_pool_channels_last<scalar_t, /*is 3d*/true>(output, indices, input_cl_check,
642649
{kW, kH, kD}, {dW, dH, dD}, {padW, padH, padD}, {dilationW, dilationH, dilationD});
643650
});
@@ -648,14 +655,14 @@ void max_pool3d_kernel_impl(
648655
}
649656
switch (input.suggest_memory_format()) {
650657
case at::MemoryFormat::Contiguous: {
651-
AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool3d", [&] {
658+
AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool3d", [&] {
652659
cpu_max_pool<scalar_t, /*is 3d*/true>(output, indices, input,
653660
{kW, kH, kD}, {dW, dH, dD}, {padW, padH, padD}, {dilationW, dilationH, dilationD});
654661
});
655662
break;
656663
}
657664
case at::MemoryFormat::ChannelsLast3d: {
658-
AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool3d_channels_last", [&] {
665+
AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "max_pool3d_channels_last", [&] {
659666
cpu_max_pool_channels_last<scalar_t, true>(output, indices, input,
660667
{kW, kH, kD}, {dW, dH, dD}, {padW, padH, padD}, {dilationW, dilationH, dilationD});
661668
});
@@ -672,13 +679,13 @@ void max_pool2d_backward_kernel_impl(
672679
const Tensor& indices) {
673680
switch (grad_output.suggest_memory_format()) {
674681
case at::MemoryFormat::Contiguous: {
675-
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool2d_backward", [&] {
682+
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool2d_backward", [&] {
676683
cpu_max_pool_backward<scalar_t, /*is 3d*/ false>(grad_input, grad_output, indices);
677684
});
678685
break;
679686
}
680687
case at::MemoryFormat::ChannelsLast: {
681-
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool2d_backward_channels_last", [&] {
688+
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool2d_backward_channels_last", [&] {
682689
cpu_max_pool_backward_channels_last<scalar_t, /*is 3d*/ false>(grad_input, grad_output, indices);
683690
});
684691
break;
@@ -705,7 +712,7 @@ void max_pool3d_backward_kernel_impl(
705712
sizes.insert(sizes.begin(), 1);
706713
grad_input.resize_(sizes, at::MemoryFormat::ChannelsLast3d);
707714
auto _indices = indices.unsqueeze(0).contiguous(at::MemoryFormat::ChannelsLast3d);
708-
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool3d_backward_channels_last", [&] {
715+
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool3d_backward_channels_last", [&] {
709716
cpu_max_pool_backward_channels_last<scalar_t, /*is_3d*/ true>(grad_input, grad_output_cl_check, _indices);
710717
});
711718
grad_input.squeeze_(0);
@@ -714,13 +721,13 @@ void max_pool3d_backward_kernel_impl(
714721
}
715722
switch (grad_output.suggest_memory_format()) {
716723
case at::MemoryFormat::Contiguous: {
717-
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool3d_backward", [&] {
724+
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool3d_backward", [&] {
718725
cpu_max_pool_backward<scalar_t, /*is_3d*/ true>(grad_input, grad_output, indices);
719726
});
720727
break;
721728
}
722729
case at::MemoryFormat::ChannelsLast3d: {
723-
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool3d_backward_channels_last", [&] {
730+
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "max_pool3d_backward_channels_last", [&] {
724731
cpu_max_pool_backward_channels_last<scalar_t, /*is_3d*/ true>(grad_input, grad_output, indices);
725732
});
726733
break;

aten/src/ATen/native/cpu/MaxPooling.cpp

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2-
#include <ATen/core/Tensor.h>
32
#include <ATen/Dispatch.h>
43
#include <ATen/Parallel.h>
4+
#include <ATen/core/Tensor.h>
55
#include <ATen/cpu/vec/vec.h>
66
#include <ATen/native/MaxPooling.h>
77
#include <c10/util/irange.h>
@@ -31,25 +31,30 @@ void max_pool1d_impl(
3131
Tensor& output,
3232
const Tensor& input,
3333
const PoolingParams1D& p) {
34-
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool1d_impl", [&] {
35-
const Tensor in = input.contiguous();
36-
scalar_t* const OP = output.data_ptr<scalar_t>();
37-
const scalar_t* const IP = in.data_ptr<scalar_t>();
38-
39-
// Value used for padding
40-
scalar_t FILL = std::numeric_limits<scalar_t>::has_infinity
41-
? -std::numeric_limits<scalar_t>::infinity()
42-
: std::numeric_limits<scalar_t>::lowest();
43-
44-
at::parallel_for(0, p.NB * p.NC, 0, [&](int64_t begin, int64_t end) {
45-
for (const auto it : c10::irange(begin, end)) {
46-
scalar_t* op = OP + it * p.OW;
47-
const scalar_t* ip = IP + it * p.IW;
48-
std::fill_n(op, p.OW, FILL);
49-
max_pool1d_kernel(op, ip, p);
50-
}
51-
});
52-
});
34+
AT_DISPATCH_FLOATING_TYPES_AND2(
35+
ScalarType::BFloat16,
36+
ScalarType::Half,
37+
input.scalar_type(),
38+
"max_pool1d_impl",
39+
[&] {
40+
const Tensor in = input.contiguous();
41+
scalar_t* const OP = output.data_ptr<scalar_t>();
42+
const scalar_t* const IP = in.data_ptr<scalar_t>();
43+
44+
// Value used for padding
45+
scalar_t FILL = std::numeric_limits<scalar_t>::has_infinity
46+
? -std::numeric_limits<scalar_t>::infinity()
47+
: std::numeric_limits<scalar_t>::lowest();
48+
49+
at::parallel_for(0, p.NB * p.NC, 0, [&](int64_t begin, int64_t end) {
50+
for (const auto it : c10::irange(begin, end)) {
51+
scalar_t* op = OP + it * p.OW;
52+
const scalar_t* ip = IP + it * p.IW;
53+
std::fill_n(op, p.OW, FILL);
54+
max_pool1d_kernel(op, ip, p);
55+
}
56+
});
57+
});
5358
}
5459

5560
} // namespace

0 commit comments

Comments
 (0)