Skip to content

Commit 5e2aaa3

Browse files
committed
rearrange modules, adds back matrix/grid types
Signed-off-by: Wenqi Li <[email protected]>
1 parent c95a62a commit 5e2aaa3

File tree

7 files changed

+141
-122
lines changed

7 files changed

+141
-122
lines changed

monai/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@
228228
from .io.array import SUPPORTED_READERS, LoadImage, SaveImage
229229
from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict
230230
from .lazy.functional import apply
231-
from .meta_matrix import matmul
231+
from .lazy.utils import matmul, resample
232232
from .meta_utility.dictionary import (
233233
FromMetaTensord,
234234
FromMetaTensorD,

monai/transforms/lazy/functional.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,44 +11,15 @@
1111

1212
from typing import Optional, Union
1313

14-
import numpy as np
1514
import torch
1615

1716
from monai.data.meta_tensor import MetaTensor
1817
from monai.data.utils import to_affine_nd
19-
from monai.transforms.meta_matrix import matmul
20-
from monai.transforms.utility.functional import resample
21-
from monai.utils import LazyAttr
18+
from monai.transforms.lazy.utils import is_compatible_kwargs, kwargs_from_pending, mat_from_pending, matmul, resample
2219

2320
__all__ = ["apply"]
2421

2522

26-
def mat_from_pending(pending_item):
27-
if isinstance(pending_item, (torch.Tensor, np.ndarray)):
28-
return pending_item
29-
if isinstance(pending_item, dict):
30-
return pending_item[LazyAttr.AFFINE]
31-
return pending_item
32-
33-
34-
def kwargs_from_pending(pending_item):
35-
if not isinstance(pending_item, dict):
36-
return {}
37-
ret = {
38-
LazyAttr.INTERP_MODE: pending_item.get(LazyAttr.INTERP_MODE, None), # interpolation mode
39-
LazyAttr.PADDING_MODE: pending_item.get(LazyAttr.PADDING_MODE, None), # padding mode
40-
}
41-
if LazyAttr.SHAPE in pending_item:
42-
ret[LazyAttr.SHAPE] = pending_item[LazyAttr.SHAPE]
43-
if LazyAttr.DTYPE in pending_item:
44-
ret[LazyAttr.DTYPE] = pending_item[LazyAttr.DTYPE]
45-
return ret
46-
47-
48-
def is_compatible_kwargs(kwargs_1, kwargs_2):
49-
return True
50-
51-
5223
def apply(data: Union[torch.Tensor, MetaTensor], pending: Optional[list] = None):
5324
"""
5425
This method applies pending transforms to tensors.

monai/transforms/lazy/utils.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from typing import Optional
13+
14+
import numpy as np
15+
import torch
16+
17+
import monai
18+
from monai.config import NdarrayOrTensor
19+
from monai.utils import LazyAttr
20+
21+
__all__ = ["resample", "matmul"]
22+
23+
24+
class Affine:
25+
"""A class to represent an affine transform matrix."""
26+
27+
__slots__ = ("data",)
28+
29+
def __init__(self, data):
30+
self.data = data
31+
32+
@staticmethod
33+
def is_affine_shaped(data):
34+
"""Check if the data is an affine matrix."""
35+
if isinstance(data, Affine):
36+
return True
37+
if isinstance(data, DDF):
38+
return False
39+
if not hasattr(data, "shape") or len(data.shape) < 2:
40+
return False
41+
return data.shape[-1] in (3, 4) and data.shape[-2] in (3, 4) and data.shape[-1] == data.shape[-2]
42+
43+
44+
class DDF:
45+
"""A class to represent a dense displacement field."""
46+
47+
__slots__ = ("data",)
48+
49+
def __init__(self, data):
50+
self.data = data
51+
52+
@staticmethod
53+
def is_ddf_shaped(data):
54+
"""Check if the data is a DDF."""
55+
if isinstance(data, DDF):
56+
return True
57+
if isinstance(data, Affine):
58+
return False
59+
if not hasattr(data, "shape") or len(data.shape) < 3:
60+
return False
61+
return not Affine.is_affine_shaped(data)
62+
63+
64+
def matmul(left: torch.Tensor, right: torch.Tensor):
65+
if Affine.is_affine_shaped(left) and Affine.is_affine_shaped(right): # linear transforms
66+
if isinstance(left, Affine):
67+
left = left.data
68+
if isinstance(right, Affine):
69+
right = right.data
70+
return torch.matmul(left, right)
71+
if DDF.is_ddf_shaped(left) and DDF.is_ddf_shaped(right): # adds DDFs
72+
return left + right
73+
raise NotImplementedError
74+
75+
76+
def mat_from_pending(pending_item):
77+
if isinstance(pending_item, (torch.Tensor, np.ndarray)):
78+
return pending_item
79+
if isinstance(pending_item, dict):
80+
return pending_item[LazyAttr.AFFINE]
81+
return pending_item
82+
83+
84+
def kwargs_from_pending(pending_item):
85+
if not isinstance(pending_item, dict):
86+
return {}
87+
ret = {
88+
LazyAttr.INTERP_MODE: pending_item.get(LazyAttr.INTERP_MODE, None), # interpolation mode
89+
LazyAttr.PADDING_MODE: pending_item.get(LazyAttr.PADDING_MODE, None), # padding mode
90+
}
91+
if LazyAttr.SHAPE in pending_item:
92+
ret[LazyAttr.SHAPE] = pending_item[LazyAttr.SHAPE]
93+
if LazyAttr.DTYPE in pending_item:
94+
ret[LazyAttr.DTYPE] = pending_item[LazyAttr.DTYPE]
95+
return ret
96+
97+
98+
def is_compatible_kwargs(kwargs_1, kwargs_2):
99+
return True
100+
101+
102+
def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: Optional[dict] = None):
103+
"""
104+
This is a minimal implementation of resample that always uses Affine.
105+
"""
106+
if not Affine.is_affine_shaped(matrix):
107+
raise NotImplementedError("calling dense grid resample API not implemented")
108+
kwargs = {} if kwargs is None else kwargs
109+
init_kwargs = {
110+
"spatial_size": kwargs.pop(LazyAttr.SHAPE, data.shape)[1:],
111+
"dtype": kwargs.pop(LazyAttr.DTYPE, data.dtype),
112+
}
113+
call_kwargs = {
114+
"mode": kwargs.pop(LazyAttr.INTERP_MODE, None),
115+
"padding_mode": kwargs.pop(LazyAttr.PADDING_MODE, None),
116+
}
117+
resampler = monai.transforms.Affine(affine=matrix, image_only=True, **init_kwargs)
118+
with resampler.trace_transform(False): # don't track this transform in `data`
119+
return resampler(img=data, **call_kwargs)

monai/transforms/meta_matrix.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

monai/transforms/utility/functional.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

monai/transforms/utils.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -890,53 +890,52 @@ def _create_translate(
890890

891891

892892
def _create_rotate_90(
893-
spatial_dims: int, axis: Tuple[int, int], steps: Optional[int] = 1, eye_func: Callable = np.eye
893+
spatial_dims: int, axes: Tuple[int, int], steps: int = 1, eye_func: Callable = np.eye
894894
) -> NdarrayOrTensor:
895895

896896
values = [(1, 0, 0, 1), (0, -1, 1, 0), (-1, 0, 0, -1), (0, 1, -1, 0)]
897897

898898
if spatial_dims == 2:
899-
if axis != (0, 1):
900-
raise ValueError(f"if 'spatial_dims' is 2, 'axis' must be (0, 1) but is {axis}")
899+
if axes != (0, 1):
900+
raise ValueError(f"if 'spatial_dims' is 2, 'axis' must be (0, 1) but is {axes}")
901901
elif spatial_dims == 3:
902-
if axis not in ((0, 1), (0, 2), (1, 2)):
903-
raise ValueError("if 'spatial_dims' is 3, 'axis' must be (0,1), (0, 2), or (1, 2) " f"but is {axis}")
902+
if axes not in ((0, 1), (0, 2), (1, 2)):
903+
raise ValueError("if 'spatial_dims' is 3, 'axis' must be (0,1), (0, 2), or (1, 2) " f"but is {axes}")
904904
else:
905905
raise ValueError(f"'spatial_dims' must be 2 or 3 but is {spatial_dims}")
906906

907907
affine = eye_func(spatial_dims + 1)
908908

909-
a, b = (0, 1) if spatial_dims == 2 else axis
909+
a, b = (0, 1) if spatial_dims == 2 else axes
910910
affine[a, a], affine[a, b], affine[b, a], affine[b, b] = values[steps % 4]
911-
return affine
911+
return affine # type: ignore
912912

913913

914914
def create_rotate_90(
915915
spatial_dims: int,
916-
axis: int,
917-
steps: Optional[int] = 1,
916+
axes: Tuple[int, int] = (0, 1),
917+
steps: int = 1,
918918
device: Optional[torch.device] = None,
919919
backend: str = TransformBackends.NUMPY,
920920
) -> NdarrayOrTensor:
921921
"""
922-
create a 2D or 3D rotation matrix
922+
create a 2D or 3D rotation90 matrix.
923+
923924
Args:
924925
spatial_dims: {``2``, ``3``} spatial rank
925-
radians: rotation radians
926-
when spatial_dims == 3, the `radians` sequence corresponds to
927-
rotation in the 1st, 2nd, and 3rd dim respectively.
926+
axes: 2 int numbers, defines the plane to rotate with 2 spatial axes.
927+
Default: (0, 1), this is the first two axis in spatial dimensions.
928+
If axis is negative it counts from the last to the first axis.
929+
steps: number of times to rotate by 90 degrees
928930
device: device to compute and store the output (when the backend is "torch").
929931
backend: APIs to use, ``numpy`` or ``torch``.
930-
Raises:
931-
ValueError: When ``radians`` is empty.
932-
ValueError: When ``spatial_dims`` is not one of [2, 3].
933932
"""
934933
_backend = look_up_option(backend, TransformBackends)
935934
if _backend == TransformBackends.NUMPY:
936-
return _create_rotate_90(spatial_dims=spatial_dims, axis=axis, steps=steps, eye_func=np.eye)
935+
return _create_rotate_90(spatial_dims=spatial_dims, axes=axes, steps=steps, eye_func=np.eye)
937936
if _backend == TransformBackends.TORCH:
938937
return _create_rotate_90(
939-
spatial_dims=spatial_dims, axis=axis, steps=steps, eye_func=lambda rank: torch.eye(rank, device=device)
938+
spatial_dims=spatial_dims, axes=axes, steps=steps, eye_func=lambda rank: torch.eye(rank, device=device)
940939
)
941940
raise ValueError(f"backend {backend} is not supported")
942941

@@ -974,9 +973,9 @@ def create_flip(
974973
) -> NdarrayOrTensor:
975974
_backend = look_up_option(backend, TransformBackends)
976975
if _backend == TransformBackends.NUMPY:
977-
return _create_flip(spatial_dims=spatial_dims, spatial_axis=spatial_axis, eye_func=np.eye)
976+
return _create_flip(spatial_dims=spatial_dims, spatial_axis=spatial_axis, eye_func=np.eye) # type: ignore
978977
if _backend == TransformBackends.TORCH:
979-
return _create_flip(
978+
return _create_flip( # type: ignore
980979
spatial_dims=spatial_dims, spatial_axis=spatial_axis, eye_func=lambda rank: torch.eye(rank, device=device)
981980
)
982981
raise ValueError(f"backend {backend} is not supported")

tests/test_resample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import torch
1515

16-
from monai.transforms.utility.functional import resample
16+
from monai.transforms.lazy.functional import resample
1717
from monai.utils import convert_to_tensor
1818
from tests.utils import get_arange_img
1919

0 commit comments

Comments
 (0)