@@ -9674,18 +9674,35 @@ def test_upsamplingBilinear2d_aa_correctness(self, device, memory_format):
96749674 @parametrize_test("num_channels", [3, 5])
96759675 @parametrize_test("output_size", [32, 600])
96769676 @parametrize_test("check_as_unsqueezed_3d_tensor", [True, False])
9677+ @parametrize_test("non_contig", [False, "sliced", "restrided"])
9678+ @parametrize_test("batch_size", [1, 5])
96779679 def test_upsamplingBiLinear2d_consistency(
9678- self, device, memory_format, antialias, align_corners, num_channels, output_size, check_as_unsqueezed_3d_tensor
9680+ self,
9681+ device,
9682+ memory_format,
9683+ antialias,
9684+ align_corners,
9685+ num_channels,
9686+ output_size,
9687+ check_as_unsqueezed_3d_tensor,
9688+ non_contig,
9689+ batch_size,
96799690 ):
96809691 if torch.device(device).type == "cuda":
96819692 raise SkipTest("CUDA implementation is not yet supporting uint8")
96829693
96839694 mode = "bilinear"
9684- # Check if Max Abs Error between resized input_uint8 and resized input_float is smaller than a tolerated value, e.g. 1.0
9685- input_ui8 = torch.randint(0, 256, size=(1, num_channels, 400, 400), dtype=torch.uint8, device=device)
9695+ # Check if Max Abs Error between resized input_uint8 and resized input_float is
9696+ # smaller than a tolerated value, e.g. 1.0
9697+ input_ui8 = torch.randint(0, 256, size=(batch_size, num_channels, 400, 400), dtype=torch.uint8, device=device)
96869698 input_ui8 = input_ui8.contiguous(memory_format=memory_format)
96879699
9688- if check_as_unsqueezed_3d_tensor:
9700+ if non_contig == "sliced":
9701+ input_ui8 = input_ui8[:, :, 10:-10, 10:-10]
9702+ elif non_contig == "restrided":
9703+ input_ui8 = input_ui8[:, :, ::2, ::2]
9704+
9705+ if batch_size == 1 and check_as_unsqueezed_3d_tensor:
96899706 input_ui8 = input_ui8[0, ...]
96909707 input_ui8 = input_ui8[None, ...]
96919708
@@ -9698,15 +9715,16 @@ def test_upsamplingBiLinear2d_consistency(
96989715 input_ui8, size=(output_size, output_size), mode=mode, align_corners=align_corners, antialias=antialias
96999716 )
97009717
9718+ if non_contig is False:
9719+ self.assertTrue(input_ui8.is_contiguous(memory_format=memory_format))
9720+
97019721 # FIXME if-clause shows the current behaviour which is definitely unexpected.
97029722 # Ideally we want to fix it such that both the ui8 and f32 outputs are also channels_last
97039723 # See for more details: https://github.com/pytorch/pytorch/pull/100373
9704- if check_as_unsqueezed_3d_tensor and memory_format == torch.channels_last:
9705- self.assertTrue(input_ui8.is_contiguous(memory_format=torch.channels_last))
9724+ if batch_size == 1 and check_as_unsqueezed_3d_tensor and memory_format == torch.channels_last:
97069725 self.assertTrue(output_ui8.is_contiguous())
97079726 self.assertTrue(output_f32.is_contiguous())
97089727 else:
9709- self.assertTrue(input_ui8.is_contiguous(memory_format=memory_format))
97109728 self.assertTrue(output_ui8.is_contiguous(memory_format=memory_format))
97119729 self.assertTrue(output_f32.is_contiguous(memory_format=memory_format))
97129730
0 commit comments