Skip to content

Commit a370057

Browse files
vfdev-5pytorchmergebot
authored andcommitted
Fixed a bug in interpolate uint8 AVX2 on non-contig input (pytorch#101136)
Description: - Fixed a bug in interpolate uint8 AVX2 on non-contig input - Added tests Pull Request resolved: pytorch#101136 Approved by: https://github.com/NicolasHug
1 parent 4a7ee79 commit a370057

File tree

2 files changed

+43
-13
lines changed

2 files changed

+43
-13
lines changed

aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ void ImagingResampleVertical(
266266
auto xout = unpacked_output.size(2);
267267
auto yout = unpacked_output.size(1);
268268
const auto num_channels = unpacked_input.size(0);
269+
TORCH_INTERNAL_ASSERT(num_channels == unpacked_output.size(0));
269270

270271
auto xout_stride = xout * num_channels;
271272
for (const auto yy : c10::irange(yout)) {
@@ -301,23 +302,34 @@ void ImagingResampleVertical(
301302
// weights, but when aa=False they could be optimized further.
302303
template <typename scale_type, class F>
303304
void upsample_avx_bilinear_uint8(
304-
const at::Tensor& input,
305+
const at::Tensor& input_,
305306
const at::Tensor& output,
306307
bool align_corners,
307308
const scale_type& scales,
308309
bool antialias) {
309-
auto batch_size = input.size(0);
310-
auto num_channels = input.size(1);
311-
auto xin = input.size(3);
312-
auto yin = input.size(2);
310+
auto batch_size = input_.size(0);
311+
auto num_channels = input_.size(1);
312+
auto xin = input_.size(3);
313+
auto yin = input_.size(2);
313314
auto xout = output.size(3);
314315
auto yout = output.size(2);
315316

316317
if (xin == xout && yin == yout) {
317-
output.copy_(input);
318+
output.copy_(input_);
318319
return;
319320
}
320321

322+
at::Tensor input = input_;
323+
if (!(input.is_contiguous() || input.is_contiguous(at::MemoryFormat::ChannelsLast))) {
324+
// If input is not contiguous with memory format channels first or channels last,
325+
// we explicitly convert the input to contiguous channels last memory format.
326+
// This simplifies the rest of the code and let us assume that the format is only contiguous channels first or channels last,
327+
// Most tensors going through this `if` block won't need to go through unpacking, but those having C < 3 may
328+
// have to (this means 2 copies are made). We could avoid the extra copy by handling non-contiguous input
329+
// directly within unpack_rgb() and pack_rgb(), but initial attempts showed that this is fairly complex.
330+
input = input.contiguous(at::MemoryFormat::ChannelsLast);
331+
}
332+
321333
auto need_horizontal = xout != xin;
322334
auto need_vertical = yout != yin;
323335

test/test_nn.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9674,18 +9674,35 @@ def test_upsamplingBilinear2d_aa_correctness(self, device, memory_format):
96749674
@parametrize_test("num_channels", [3, 5])
96759675
@parametrize_test("output_size", [32, 600])
96769676
@parametrize_test("check_as_unsqueezed_3d_tensor", [True, False])
9677+
@parametrize_test("non_contig", [False, "sliced", "restrided"])
9678+
@parametrize_test("batch_size", [1, 5])
96779679
def test_upsamplingBiLinear2d_consistency(
9678-
self, device, memory_format, antialias, align_corners, num_channels, output_size, check_as_unsqueezed_3d_tensor
9680+
self,
9681+
device,
9682+
memory_format,
9683+
antialias,
9684+
align_corners,
9685+
num_channels,
9686+
output_size,
9687+
check_as_unsqueezed_3d_tensor,
9688+
non_contig,
9689+
batch_size,
96799690
):
96809691
if torch.device(device).type == "cuda":
96819692
raise SkipTest("CUDA implementation is not yet supporting uint8")
96829693

96839694
mode = "bilinear"
9684-
# Check if Max Abs Error between resized input_uint8 and resized input_float is smaller than a tolerated value, e.g. 1.0
9685-
input_ui8 = torch.randint(0, 256, size=(1, num_channels, 400, 400), dtype=torch.uint8, device=device)
9695+
# Check if Max Abs Error between resized input_uint8 and resized input_float is
9696+
# smaller than a tolerated value, e.g. 1.0
9697+
input_ui8 = torch.randint(0, 256, size=(batch_size, num_channels, 400, 400), dtype=torch.uint8, device=device)
96869698
input_ui8 = input_ui8.contiguous(memory_format=memory_format)
96879699

9688-
if check_as_unsqueezed_3d_tensor:
9700+
if non_contig == "sliced":
9701+
input_ui8 = input_ui8[:, :, 10:-10, 10:-10]
9702+
elif non_contig == "restrided":
9703+
input_ui8 = input_ui8[:, :, ::2, ::2]
9704+
9705+
if batch_size == 1 and check_as_unsqueezed_3d_tensor:
96899706
input_ui8 = input_ui8[0, ...]
96909707
input_ui8 = input_ui8[None, ...]
96919708

@@ -9698,15 +9715,16 @@ def test_upsamplingBiLinear2d_consistency(
96989715
input_ui8, size=(output_size, output_size), mode=mode, align_corners=align_corners, antialias=antialias
96999716
)
97009717

9718+
if non_contig is False:
9719+
self.assertTrue(input_ui8.is_contiguous(memory_format=memory_format))
9720+
97019721
# FIXME if-clause shows the current behaviour which is definitely unexpected.
97029722
# Ideally we want to fix it such that both the ui8 and f32 outputs are also channels_last
97039723
# See for more details: https://github.com/pytorch/pytorch/pull/100373
9704-
if check_as_unsqueezed_3d_tensor and memory_format == torch.channels_last:
9705-
self.assertTrue(input_ui8.is_contiguous(memory_format=torch.channels_last))
9724+
if batch_size == 1 and check_as_unsqueezed_3d_tensor and memory_format == torch.channels_last:
97069725
self.assertTrue(output_ui8.is_contiguous())
97079726
self.assertTrue(output_f32.is_contiguous())
97089727
else:
9709-
self.assertTrue(input_ui8.is_contiguous(memory_format=memory_format))
97109728
self.assertTrue(output_ui8.is_contiguous(memory_format=memory_format))
97119729
self.assertTrue(output_f32.is_contiguous(memory_format=memory_format))
97129730

0 commit comments

Comments
 (0)