diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index 3ec89488ff..9e5490f9d6 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -10,7 +10,6 @@ # limitations under the License. -import math from typing import Sequence, Union import torch @@ -19,6 +18,7 @@ from monai.networks.blocks.patchembedding import PatchEmbeddingBlock from monai.networks.blocks.transformerblock import TransformerBlock from monai.networks.layers import Conv +from monai.utils import ensure_tuple_rep __all__ = ["ViTAutoEnc"] @@ -74,6 +74,7 @@ def __init__( super().__init__() + self.patch_size = ensure_tuple_rep(patch_size, spatial_dims) self.spatial_dims = spatial_dims self.patch_embedding = PatchEmbeddingBlock( @@ -105,6 +106,7 @@ def forward(self, x): x: input tensor must have isotropic spatial dimensions, such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``. """ + spatial_size = x.shape[2:] x = self.patch_embedding(x) hidden_states_out = [] for blk in self.blocks: @@ -112,7 +114,7 @@ def forward(self, x): hidden_states_out.append(x) x = self.norm(x) x = x.transpose(1, 2) - d = [round(math.pow(x.shape[2], 1 / self.spatial_dims))] * self.spatial_dims + d = [s // p for s, p in zip(spatial_size, self.patch_size)] x = torch.reshape(x, [x.shape[0], x.shape[1], *d]) x = self.conv3d_transpose(x) x = self.conv3d_transpose_1(x) diff --git a/tests/test_vitautoenc.py b/tests/test_vitautoenc.py index c45cde68c2..8320fef02d 100644 --- a/tests/test_vitautoenc.py +++ b/tests/test_vitautoenc.py @@ -41,6 +41,25 @@ TEST_CASE_Vitautoenc.append(test_case) +TEST_CASE_Vitautoenc.append( + [ + { + "in_channels": 1, + "img_size": (512, 512, 32), + "patch_size": (16, 16, 16), + "hidden_size": 768, + "mlp_dim": 3072, + "num_layers": 4, + "num_heads": 12, + "pos_embed": "conv", + "dropout_rate": 0.6, + "spatial_dims": 3, + }, + (2, 1, 512, 512, 32), + (2, 1, 512, 512, 32), + ] +) + class TestPatchEmbeddingBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_Vitautoenc)