diff --git a/captum/optim/_param/image/images.py b/captum/optim/_param/image/images.py index aea46d9e19..af0bf2486b 100644 --- a/captum/optim/_param/image/images.py +++ b/captum/optim/_param/image/images.py @@ -416,6 +416,38 @@ def forward(self) -> torch.Tensor: return torch.stack(A).refine_names("B", "C", "H", "W") +class SimpleTensorParameterization(ImageParameterization): + """ + Parameterize a simple tensor with or without it requiring grad. + Compared to PixelImage, this parameterization has no specific shape requirements + and does not wrap inputs in nn.Parameter. + + This parameterization can for example be combined with StackImage for batch + dimensions that both require and don't require gradients. + + This parameterization can also be combined with nn.ModuleList as workaround for + TorchScript / JIT not supporting nn.ParameterList. SharedImage uses this module + internally for this purpose. + """ + + def __init__(self, tensor: torch.Tensor = None) -> None: + """ + Args: + + tensor (torch.tensor): The tensor to return everytime this module is called. + """ + super().__init__() + assert isinstance(tensor, torch.Tensor) + self.tensor = tensor + + def forward(self) -> torch.Tensor: + """ + Returns: + tensor (torch.Tensor): The tensor stored during initialization. + """ + return self.tensor + + class SharedImage(ImageParameterization): """ Share some image parameters across the batch to increase spatial alignment, @@ -429,6 +461,8 @@ class SharedImage(ImageParameterization): https://distill.pub/2018/differentiable-parameterizations/ """ + __constants__ = ["offset"] + def __init__( self, shapes: Union[Tuple[Tuple[int]], Tuple[int]] = None, @@ -454,8 +488,11 @@ def __init__( assert len(shape) >= 2 and len(shape) <= 4 shape = ([1] * (4 - len(shape))) + list(shape) batch, channels, height, width = shape - A.append(torch.nn.Parameter(torch.randn([batch, channels, height, width]))) - self.shared_init = torch.nn.ParameterList(A) + shape_param = torch.nn.Parameter( + torch.randn([batch, channels, height, width]) + ) + A.append(SimpleTensorParameterization(shape_param)) + self.shared_init = torch.nn.ModuleList(A) self.parameterization = parameterization self.offset = self._get_offset(offset, len(A)) if offset is not None else None @@ -484,6 +521,7 @@ def _get_offset(self, offset: Union[int, Tuple[int]], n: int) -> List[List[int]] assert all([all([type(o) is int for o in v]) for v in offset]) return offset + @torch.jit.ignore def _apply_offset(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]: """ Apply list of offsets to list of tensors. @@ -517,6 +555,63 @@ def _apply_offset(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]: A.append(x) return A + def _interpolate_bilinear( + self, + x: torch.Tensor, + size: Tuple[int, int], + ) -> torch.Tensor: + """ + Perform interpolation without any warnings. + + Args: + + x (torch.Tensor): The NCHW tensor to resize. + size (tuple of int): The desired output size to resize the input + to, with a format of: [height, width]. + + Returns: + x (torch.Tensor): A resized NCHW tensor. + """ + assert x.dim() == 4 + assert len(size) == 2 + + x = F.interpolate( + x, + size=size, + mode="bilinear", + align_corners=False, + recompute_scale_factor=False, + ) + return x + + def _interpolate_trilinear( + self, + x: torch.Tensor, + size: Tuple[int, int, int], + ) -> torch.Tensor: + """ + Perform interpolation without any warnings. + + Args: + + x (torch.Tensor): The NCHW tensor to resize. + size (tuple of int): The desired output size to resize the input + to, with a format of: [channels, height, width]. + + Returns: + x (torch.Tensor): A resized NCHW tensor. + """ + x = x.unsqueeze(0) + assert x.dim() == 5 + x = F.interpolate( + x, + size=size, + mode="trilinear", + align_corners=False, + recompute_scale_factor=False, + ) + return x.squeeze(0) + def _interpolate_tensor( self, x: torch.Tensor, batch: int, channels: int, height: int, width: int ) -> torch.Tensor: @@ -537,29 +632,26 @@ def _interpolate_tensor( """ if x.size(1) == channels: - mode = "bilinear" size = (height, width) + x = self._interpolate_bilinear(x, size=size) else: - mode = "trilinear" - x = x.unsqueeze(0) size = (channels, height, width) - x = F.interpolate(x, size=size, mode=mode) - x = x.squeeze(0) if len(size) == 3 else x + x = self._interpolate_trilinear(x, size=size) if x.size(0) != batch: x = x.permute(1, 0, 2, 3) - x = F.interpolate( - x.unsqueeze(0), - size=(batch, x.size(2), x.size(3)), - mode="trilinear", - ).squeeze(0) + x = self._interpolate_trilinear(x, size=(batch, x.size(2), x.size(3))) x = x.permute(1, 0, 2, 3) return x def forward(self) -> torch.Tensor: + """ + Returns: + output (torch.Tensor): An NCHW image parameterization output. + """ image = self.parameterization() x = [ self._interpolate_tensor( - shared_tensor, + shared_tensor(), image.size(0), image.size(1), image.size(2), @@ -569,7 +661,78 @@ def forward(self) -> torch.Tensor: ] if self.offset is not None: x = self._apply_offset(x) - return (image + sum(x)).refine_names("B", "C", "H", "W") + output = image + torch.cat(x, 0).sum(0, keepdim=True) + + if torch.jit.is_scripting(): + return output + return output.refine_names("B", "C", "H", "W") + + +class StackImage(ImageParameterization): + """ + Stack multiple NCHW image parameterizations along their batch dimensions. + """ + + __constants__ = ["dim", "output_device"] + + def __init__( + self, + parameterizations: List[Union[ImageParameterization, torch.Tensor]], + dim: int = 0, + output_device: Optional[torch.device] = None, + ) -> None: + """ + Args: + + parameterizations (list of ImageParameterization and torch.Tensor): A list + of image parameterizations to stack across their batch dimensions. + dim (int, optional): Optionally specify the dim to concatinate + parameterization outputs on. Default is set to the batch dimension. + Default: 0 + output_device (torch.device, optional): If the parameterizations are on + different devices, then their outputs will be moved to the device + specified by this variable. Default is set to None with the expectation + that all parameterizations are on the same device. + Default: None + """ + super().__init__() + assert len(parameterizations) > 0 + assert isinstance(parameterizations, (list, tuple)) + assert all( + [ + isinstance(param, (ImageParameterization, torch.Tensor)) + for param in parameterizations + ] + ) + parameterizations = [ + SimpleTensorParameterization(p) if isinstance(p, torch.Tensor) else p + for p in parameterizations + ] + self.parameterizations = torch.nn.ModuleList(parameterizations) + self.dim = dim + self.output_device = output_device + + def forward(self) -> torch.Tensor: + """ + Returns: + image (torch.Tensor): A set of NCHW image parameterization outputs stacked + along the batch dimension. + """ + P = [] + for image_param in self.parameterizations: + img = image_param() + if self.output_device is not None: + img = img.to(self.output_device, dtype=img.dtype) + P.append(img) + + assert P[0].dim() == 4 + assert all([im.shape == P[0].shape for im in P]) + assert all([im.device == P[0].device for im in P]) + + image = torch.cat(P, dim=self.dim) + if torch.jit.is_scripting(): + return image + return image.refine_names("B", "C", "H", "W") class NaturalImage(ImageParameterization): @@ -683,5 +846,6 @@ def forward(self) -> torch.Tensor: "PixelImage", "LaplacianImage", "SharedImage", + "StackImage", "NaturalImage", ] diff --git a/tests/optim/param/test_images.py b/tests/optim/param/test_images.py index d574c8c756..617d34a3a3 100644 --- a/tests/optim/param/test_images.py +++ b/tests/optim/param/test_images.py @@ -75,7 +75,7 @@ def test_export_and_open_local_image(self) -> None: self.assertTrue(torch.is_tensor(new_tensor)) assertTensorAlmostEqual(self, image_tensor, new_tensor) - def test_natural_image_cuda(self) -> None: + def test_image_tensor_cuda(self) -> None: if not torch.cuda.is_available(): raise unittest.SkipTest( "Skipping ImageTensor CUDA test due to not supporting CUDA." @@ -84,7 +84,22 @@ def test_natural_image_cuda(self) -> None: self.assertTrue(image_t.is_cuda) +class TestInputParameterization(BaseTest): + def test_subclass(self) -> None: + self.assertTrue(issubclass(images.InputParameterization, torch.nn.Module)) + + +class TestImageParameterization(BaseTest): + def test_subclass(self) -> None: + self.assertTrue( + issubclass(images.ImageParameterization, images.InputParameterization) + ) + + class TestFFTImage(BaseTest): + def test_subclass(self) -> None: + self.assertTrue(issubclass(images.FFTImage, images.ImageParameterization)) + def test_pytorch_fftfreq(self) -> None: image = images.FFTImage((1, 1)) _, _, fftfreq = image.get_fft_funcs() @@ -295,6 +310,9 @@ def test_fftimage_forward_init_batch(self) -> None: class TestPixelImage(BaseTest): + def test_subclass(self) -> None: + self.assertTrue(issubclass(images.PixelImage, images.ImageParameterization)) + def test_pixelimage_random(self) -> None: size = (224, 224) channels = 3 @@ -360,6 +378,9 @@ def test_pixelimage_init_forward(self) -> None: class TestLaplacianImage(BaseTest): + def test_subclass(self) -> None: + self.assertTrue(issubclass(images.LaplacianImage, images.ImageParameterization)) + def test_laplacianimage_random_forward(self) -> None: size = (224, 224) channels = 3 @@ -379,7 +400,156 @@ def test_laplacianimage_init(self) -> None: assertTensorAlmostEqual(self, torch.ones_like(output) * 0.5, output, mode="max") +class TestSimpleTensorParameterization(BaseTest): + def test_subclass(self) -> None: + self.assertTrue( + issubclass( + images.SimpleTensorParameterization, images.ImageParameterization + ) + ) + + def test_simple_tensor_parameterization_no_grad(self) -> None: + test_input = torch.randn(1, 3, 4, 4) + image_param = images.SimpleTensorParameterization(test_input) + assertTensorAlmostEqual(self, image_param.tensor, test_input, 0.0) + self.assertFalse(image_param.tensor.requires_grad) + + test_output = image_param() + assertTensorAlmostEqual(self, test_output, test_input, 0.0) + self.assertFalse(image_param.tensor.requires_grad) + + def test_simple_tensor_parameterization_jit_module_no_grad(self) -> None: + if version.parse(torch.__version__) <= version.parse("1.8.0"): + raise unittest.SkipTest( + "Skipping SimpleTensorParameterization JIT module test due to" + + " insufficient Torch version." + ) + test_input = torch.randn(1, 3, 4, 4) + image_param = images.SimpleTensorParameterization(test_input) + jit_image_param = torch.jit.script(image_param) + + test_output = jit_image_param() + assertTensorAlmostEqual(self, test_output, test_input, 0.0) + self.assertFalse(image_param.tensor.requires_grad) + + def test_simple_tensor_parameterization_with_grad(self) -> None: + test_input = torch.nn.Parameter(torch.randn(1, 3, 4, 4)) + image_param = images.SimpleTensorParameterization(test_input) + assertTensorAlmostEqual(self, image_param.tensor, test_input, 0.0) + self.assertTrue(image_param.tensor.requires_grad) + + test_output = image_param() + assertTensorAlmostEqual(self, test_output, test_input, 0.0) + self.assertTrue(image_param.tensor.requires_grad) + + def test_simple_tensor_parameterization_jit_module_with_grad(self) -> None: + if torch.__version__ <= "1.8.0": + raise unittest.SkipTest( + "Skipping SimpleTensorParameterization JIT module test due to" + + " insufficient Torch version." + ) + test_input = torch.nn.Parameter(torch.randn(1, 3, 4, 4)) + image_param = images.SimpleTensorParameterization(test_input) + jit_image_param = torch.jit.script(image_param) + + test_output = jit_image_param() + assertTensorAlmostEqual(self, test_output, test_input, 0.0) + self.assertTrue(image_param.tensor.requires_grad) + + def test_simple_tensor_parameterization_cuda(self) -> None: + if not torch.cuda.is_available(): + raise unittest.SkipTest( + "Skipping SimpleTensorParameterization CUDA test due to not supporting" + + " CUDA." + ) + test_input = torch.randn(1, 3, 4, 4).cuda() + image_param = images.SimpleTensorParameterization(test_input) + self.assertTrue(image_param.tensor.is_cuda) + assertTensorAlmostEqual(self, image_param.tensor, test_input, 0.0) + self.assertFalse(image_param.tensor.requires_grad) + + test_output = image_param() + self.assertTrue(test_output.is_cuda) + assertTensorAlmostEqual(self, test_output, test_input, 0.0) + self.assertFalse(image_param.tensor.requires_grad) + + class TestSharedImage(BaseTest): + def test_subclass(self) -> None: + self.assertTrue(issubclass(images.SharedImage, images.ImageParameterization)) + + def test_sharedimage_init(self) -> None: + shared_shapes = ( + (1, 3, 128 // 2, 128 // 2), + (1, 3, 128 // 4, 128 // 4), + (1, 3, 128 // 8, 128 // 8), + ) + test_param = images.SimpleTensorParameterization(torch.ones(4, 3, 4, 4)) + shared_param = images.SharedImage( + shapes=shared_shapes, parameterization=test_param + ) + + self.assertIsInstance(shared_param.shared_init, torch.nn.ModuleList) + self.assertEqual(len(shared_param.shared_init), len(shared_shapes)) + for shared_init, shape in zip(shared_param.shared_init, shared_shapes): + self.assertIsInstance(shared_init, images.SimpleTensorParameterization) + self.assertEqual(list(shared_init().shape), list(shape)) + + self.assertIsInstance( + shared_param.parameterization, images.SimpleTensorParameterization + ) + self.assertIsNone(shared_param.offset) + + def test_sharedimage_interpolate_bilinear(self) -> None: + shared_shapes = (128 // 2, 128 // 2) + test_param = lambda: torch.ones(3, 3, 224, 224) # noqa: E731 + image_param = images.SharedImage( + shapes=shared_shapes, parameterization=test_param + ) + + size = (224, 128) + test_input = torch.randn(1, 3, 128, 128) + + test_output = image_param._interpolate_bilinear(test_input.clone(), size=size) + expected_output = torch.nn.functional.interpolate( + test_input.clone(), size=size, mode="bilinear" + ) + assertTensorAlmostEqual(self, test_output, expected_output, 0.0) + + size = (128, 128) + test_input = torch.randn(1, 3, 224, 224) + + test_output = image_param._interpolate_bilinear(test_input.clone(), size=size) + expected_output = torch.nn.functional.interpolate( + test_input.clone(), size=size, mode="bilinear" + ) + assertTensorAlmostEqual(self, test_output, expected_output, 0.0) + + def test_sharedimage_interpolate_trilinear(self) -> None: + shared_shapes = (128 // 2, 128 // 2) + test_param = lambda: torch.ones(3, 3, 224, 224) # noqa: E731 + image_param = images.SharedImage( + shapes=shared_shapes, parameterization=test_param + ) + + size = (3, 224, 128) + test_input = torch.randn(1, 1, 128, 128) + + test_output = image_param._interpolate_trilinear(test_input.clone(), size=size) + expected_output = torch.nn.functional.interpolate( + test_input.clone().unsqueeze(0), size=size, mode="trilinear" + ).squeeze(0) + assertTensorAlmostEqual(self, test_output, expected_output, 0.0) + + size = (2, 128, 128) + test_input = torch.randn(1, 4, 224, 224) + + test_output = image_param._interpolate_trilinear(test_input.clone(), size=size) + expected_output = torch.nn.functional.interpolate( + test_input.clone().unsqueeze(0), size=size, mode="trilinear" + ).squeeze(0) + assertTensorAlmostEqual(self, test_output, expected_output, 0.0) + def test_sharedimage_get_offset_single_number(self) -> None: shared_shapes = (128 // 2, 128 // 2) test_param = lambda: torch.ones(3, 3, 224, 224) # noqa: E731 @@ -523,9 +693,9 @@ def test_sharedimage_single_shape_hw_forward(self) -> None: test_tensor = image_param.forward() self.assertIsNone(image_param.offset) - self.assertEqual(image_param.shared_init[0].dim(), 4) + self.assertEqual(image_param.shared_init[0]().dim(), 4) self.assertEqual( - list(image_param.shared_init[0].shape), [1, 1] + list(shared_shapes) + list(image_param.shared_init[0]().shape), [1, 1] + list(shared_shapes) ) self.assertEqual(test_tensor.dim(), 4) self.assertEqual(test_tensor.size(0), batch) @@ -545,9 +715,9 @@ def test_sharedimage_single_shape_chw_forward(self) -> None: test_tensor = image_param.forward() self.assertIsNone(image_param.offset) - self.assertEqual(image_param.shared_init[0].dim(), 4) + self.assertEqual(image_param.shared_init[0]().dim(), 4) self.assertEqual( - list(image_param.shared_init[0].shape), [1] + list(shared_shapes) + list(image_param.shared_init[0]().shape), [1] + list(shared_shapes) ) self.assertEqual(test_tensor.dim(), 4) self.assertEqual(test_tensor.size(0), batch) @@ -567,8 +737,8 @@ def test_sharedimage_single_shape_bchw_forward(self) -> None: test_tensor = image_param.forward() self.assertIsNone(image_param.offset) - self.assertEqual(image_param.shared_init[0].dim(), 4) - self.assertEqual(list(image_param.shared_init[0].shape), list(shared_shapes)) + self.assertEqual(image_param.shared_init[0]().dim(), 4) + self.assertEqual(list(image_param.shared_init[0]().shape), list(shared_shapes)) self.assertEqual(test_tensor.dim(), 4) self.assertEqual(test_tensor.size(0), batch) self.assertEqual(test_tensor.size(1), channels) @@ -595,9 +765,9 @@ def test_sharedimage_multiple_shapes_forward(self) -> None: self.assertIsNone(image_param.offset) for i in range(len(shared_shapes)): - self.assertEqual(image_param.shared_init[i].dim(), 4) + self.assertEqual(image_param.shared_init[i]().dim(), 4) self.assertEqual( - list(image_param.shared_init[i].shape), list(shared_shapes[i]) + list(image_param.shared_init[i]().shape), list(shared_shapes[i]) ) self.assertEqual(test_tensor.dim(), 4) self.assertEqual(test_tensor.size(0), batch) @@ -625,10 +795,10 @@ def test_sharedimage_multiple_shapes_diff_len_forward(self) -> None: self.assertIsNone(image_param.offset) for i in range(len(shared_shapes)): - self.assertEqual(image_param.shared_init[i].dim(), 4) + self.assertEqual(image_param.shared_init[i]().dim(), 4) s_shape = list(shared_shapes[i]) s_shape = ([1] * (4 - len(s_shape))) + list(s_shape) - self.assertEqual(list(image_param.shared_init[i].shape), s_shape) + self.assertEqual(list(image_param.shared_init[i]().shape), s_shape) self.assertEqual(test_tensor.dim(), 4) self.assertEqual(test_tensor.size(0), batch) @@ -636,8 +806,244 @@ def test_sharedimage_multiple_shapes_diff_len_forward(self) -> None: self.assertEqual(test_tensor.size(2), size[0]) self.assertEqual(test_tensor.size(3), size[1]) + def test_sharedimage_multiple_shapes_diff_len_forward_jit_module(self) -> None: + if version.parse(torch.__version__) <= version.parse("1.8.0"): + raise unittest.SkipTest( + "Skipping SharedImage JIT module test due to insufficient Torch" + + " version." + ) + + shared_shapes = ( + (128 // 2, 128 // 2), + (7, 3, 128 // 4, 128 // 4), + (3, 128 // 8, 128 // 8), + (2, 4, 128 // 8, 128 // 8), + (1, 3, 128 // 16, 128 // 16), + (2, 2, 128 // 16, 128 // 16), + ) + batch = 6 + channels = 3 + size = (224, 224) + test_input = torch.ones(batch, channels, size[0], size[1]) # noqa: E731 + test_param = images.SimpleTensorParameterization(test_input) + image_param = images.SharedImage( + shapes=shared_shapes, parameterization=test_param + ) + jit_image_param = torch.jit.script(image_param) + test_tensor = jit_image_param() + + self.assertEqual(test_tensor.dim(), 4) + self.assertEqual(test_tensor.size(0), batch) + self.assertEqual(test_tensor.size(1), channels) + self.assertEqual(test_tensor.size(2), size[0]) + self.assertEqual(test_tensor.size(3), size[1]) + + +class TestStackImage(BaseTest): + def test_subclass(self) -> None: + self.assertTrue(issubclass(images.StackImage, images.ImageParameterization)) + + def test_stackimage_init(self) -> None: + size = (4, 4) + fft_param_1 = images.FFTImage(size=size) + fft_param_2 = images.FFTImage(size=size) + param_list = [fft_param_1, fft_param_2] + stack_param = images.StackImage(parameterizations=param_list) + + self.assertIsInstance(stack_param.parameterizations, torch.nn.ModuleList) + self.assertEqual(len(stack_param.parameterizations), 2) + self.assertEqual(stack_param.dim, 0) + + for image_param in stack_param.parameterizations: + self.assertIsInstance(image_param, images.FFTImage) + self.assertEqual(list(image_param().shape), [1, 3] + list(size)) + self.assertTrue(image_param().requires_grad) + + def test_stackimage_dim(self) -> None: + img_param_r = images.SimpleTensorParameterization(torch.ones(1, 1, 4, 4)) + img_param_g = images.SimpleTensorParameterization(torch.ones(1, 1, 4, 4)) + img_param_b = images.SimpleTensorParameterization(torch.ones(1, 1, 4, 4)) + param_list = [img_param_r, img_param_g, img_param_b] + stack_param = images.StackImage(parameterizations=param_list, dim=1) + + self.assertEqual(stack_param.dim, 1) + + test_output = stack_param() + self.assertEqual(list(test_output.shape), [1, 3, 4, 4]) + + def test_stackimage_forward(self) -> None: + size = (4, 4) + fft_param_1 = images.FFTImage(size=size) + fft_param_2 = images.FFTImage(size=size) + param_list = [fft_param_1, fft_param_2] + stack_param = images.StackImage(parameterizations=param_list) + for image_param in stack_param.parameterizations: + self.assertIsInstance(image_param, images.FFTImage) + self.assertEqual(list(image_param().shape), [1, 3] + list(size)) + self.assertTrue(image_param().requires_grad) + + output_tensor = stack_param() + self.assertEqual(list(output_tensor.shape), [2, 3] + list(size)) + self.assertTrue(output_tensor.requires_grad) + self.assertIsNone(stack_param.output_device) + + def test_stackimage_forward_diff_image_params(self) -> None: + size = (4, 4) + fft_param = images.FFTImage(size=size) + pixel_param = images.PixelImage(size=size) + param_list = [fft_param, pixel_param] + + stack_param = images.StackImage(parameterizations=param_list) + + type_list = [images.FFTImage, images.PixelImage] + for image_param, expected_type in zip(stack_param.parameterizations, type_list): + self.assertIsInstance(image_param, expected_type) + self.assertEqual(list(image_param().shape), [1, 3] + list(size)) + self.assertTrue(image_param().requires_grad) + + output_tensor = stack_param() + self.assertEqual(list(output_tensor.shape), [2, 3] + list(size)) + self.assertTrue(output_tensor.requires_grad) + self.assertIsNone(stack_param.output_device) + + def test_stackimage_forward_diff_image_params_and_tensor_with_grad(self) -> None: + size = (4, 4) + fft_param = images.FFTImage(size=size) + pixel_param = images.PixelImage(size=size) + test_tensor = torch.nn.Parameter(torch.ones(1, 3, size[0], size[1])) + param_list = [fft_param, pixel_param, test_tensor] + + stack_param = images.StackImage(parameterizations=param_list) + + type_list = [ + images.FFTImage, + images.PixelImage, + images.SimpleTensorParameterization, + ] + for image_param, expected_type in zip(stack_param.parameterizations, type_list): + self.assertIsInstance(image_param, expected_type) + self.assertEqual(list(image_param().shape), [1, 3] + list(size)) + self.assertTrue(image_param().requires_grad) + + output_tensor = stack_param() + self.assertEqual(list(output_tensor.shape), [3, 3] + list(size)) + self.assertTrue(output_tensor.requires_grad) + self.assertIsNone(stack_param.output_device) + + def test_stackimage_forward_diff_image_params_and_tensor_no_grad(self) -> None: + size = (4, 4) + fft_param = images.FFTImage(size=size) + pixel_param = images.PixelImage(size=size) + test_tensor = torch.ones(1, 3, size[0], size[1]) + param_list = [fft_param, pixel_param, test_tensor] + + stack_param = images.StackImage(parameterizations=param_list) + + type_list = [ + images.FFTImage, + images.PixelImage, + images.SimpleTensorParameterization, + ] + for image_param, expected_type in zip(stack_param.parameterizations, type_list): + self.assertIsInstance(image_param, expected_type) + self.assertEqual(list(image_param().shape), [1, 3] + list(size)) + + self.assertTrue(stack_param.parameterizations[0]().requires_grad) + self.assertTrue(stack_param.parameterizations[1]().requires_grad) + self.assertFalse(stack_param.parameterizations[2]().requires_grad) + + output_tensor = stack_param() + self.assertEqual(list(output_tensor.shape), [3, 3] + list(size)) + self.assertTrue(output_tensor.requires_grad) + self.assertIsNone(stack_param.output_device) + + def test_stackimage_forward_multi_gpu(self) -> None: + if not torch.cuda.is_available(): + raise unittest.SkipTest( + "Skipping StackImage multi GPU test due to not supporting CUDA." + ) + if torch.cuda.device_count() == 1: + raise unittest.SkipTest( + "Skipping StackImage multi GPU device test due to not having enough" + + " GPUs available." + ) + size = (4, 4) + + num_cuda_devices = torch.cuda.device_count() + param_list, device_list = [], [] + + fft_param = images.FFTImage(size=size).cpu() + param_list.append(fft_param) + device_list.append(torch.device("cpu")) + + for i in range(num_cuda_devices - 1): + device = torch.device("cuda:" + str(i)) + device_list.append(device) + fft_param = images.FFTImage(size=size).to(device) + param_list.append(fft_param) + + output_device = torch.device("cuda:" + str(num_cuda_devices - 1)) + stack_param = images.StackImage( + parameterizations=param_list, output_device=output_device + ) + + for image_param, torch_device in zip( + stack_param.parameterizations, device_list + ): + self.assertIsInstance(image_param, images.FFTImage) + self.assertEqual(list(image_param().shape), [1, 3] + list(size)) + self.assertEqual(image_param().device, torch_device) + self.assertTrue(image_param().requires_grad) + + output_tensor = stack_param() + self.assertEqual( + list(output_tensor.shape), [len(param_list)] + [3] + list(size) + ) + self.assertTrue(output_tensor.requires_grad) + self.assertEqual(stack_param().device, output_device) + + def test_stackimage_forward_multi_device_cpu_gpu(self) -> None: + if not torch.cuda.is_available(): + raise unittest.SkipTest( + "Skipping StackImage multi device test due to not supporting CUDA." + ) + size = (4, 4) + param_list, device_list = [], [] + + fft_param = images.FFTImage(size=size).cpu() + param_list.append(fft_param) + device_list.append(torch.device("cpu")) + + device = torch.device("cuda:0") + device_list.append(device) + fft_param = images.FFTImage(size=size).to(device) + param_list.append(fft_param) + + output_device = torch.device("cuda:0") + stack_param = images.StackImage( + parameterizations=param_list, output_device=output_device + ) + + for image_param, torch_device in zip( + stack_param.parameterizations, device_list + ): + self.assertIsInstance(image_param, images.FFTImage) + self.assertEqual(list(image_param().shape), [1, 3] + list(size)) + self.assertEqual(image_param().device, torch_device) + self.assertTrue(image_param().requires_grad) + + output_tensor = stack_param() + self.assertEqual( + list(output_tensor.shape), [len(param_list)] + [3] + list(size) + ) + self.assertTrue(output_tensor.requires_grad) + self.assertEqual(stack_param().device, output_device) + class TestNaturalImage(BaseTest): + def test_subclass(self) -> None: + self.assertTrue(issubclass(images.NaturalImage, images.ImageParameterization)) + def test_natural_image_init_func_default(self) -> None: image_param = images.NaturalImage(size=(4, 4)) self.assertIsInstance(image_param.parameterization, images.FFTImage)