Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions monai/data/meta_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ def push_pending_operation(self, t: Any) -> None:
def pop_pending_operation(self) -> Any:
return self._pending_operations.pop()

def clear_pending_operations(self) -> Any:
self._pending_operations = MetaObj.get_default_applied_operations()

@property
def is_batch(self) -> bool:
"""Return whether object is part of batch or not."""
Expand Down
2 changes: 2 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@
from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict
from .io.array import SUPPORTED_READERS, LoadImage, SaveImage
from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict
from .lazy.functional import apply
from .lazy.utils import combine_transforms, resample
from .meta_utility.dictionary import (
FromMetaTensord,
FromMetaTensorD,
Expand Down
10 changes: 10 additions & 0 deletions monai/transforms/lazy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
62 changes: 62 additions & 0 deletions monai/transforms/lazy/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Union

import torch

from monai.data.meta_tensor import MetaTensor
from monai.data.utils import to_affine_nd
from monai.transforms.lazy.utils import (
affine_from_pending,
combine_transforms,
is_compatible_apply_kwargs,
kwargs_from_pending,
resample,
)

__all__ = ["apply"]


def apply(data: Union[torch.Tensor, MetaTensor], pending: Optional[list] = None):
"""
This method applies pending transforms to `data` tensors.

Args:
data: A torch Tensor or a monai MetaTensor.
pending: pending transforms. This must be set if data is a Tensor, but is optional if data is a MetaTensor.
"""
if isinstance(data, MetaTensor) and pending is None:
pending = data.pending_operations
pending = [] if pending is None else pending

if not pending:
return data

cumulative_xform = affine_from_pending(pending[0])
cur_kwargs = kwargs_from_pending(pending[0])

for p in pending[1:]:
new_kwargs = kwargs_from_pending(p)
if not is_compatible_apply_kwargs(cur_kwargs, new_kwargs):
# carry out an intermediate resample here due to incompatibility between arguments
data = resample(data, cumulative_xform, cur_kwargs)
next_matrix = affine_from_pending(p)
cumulative_xform = combine_transforms(cumulative_xform, next_matrix)
cur_kwargs.update(new_kwargs)
data = resample(data, cumulative_xform, cur_kwargs)
if isinstance(data, MetaTensor):
data.clear_pending_operations()
data.affine = data.affine @ to_affine_nd(3, cumulative_xform)
for p in pending:
data.push_applied_operation(p)

return data, pending
125 changes: 125 additions & 0 deletions monai/transforms/lazy/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import numpy as np
import torch

import monai
from monai.config import NdarrayOrTensor
from monai.utils import LazyAttr, convert_to_tensor

__all__ = ["resample", "combine_transforms"]


class Affine:
"""A class to represent an affine transform matrix."""

__slots__ = ("data",)

def __init__(self, data):
self.data = data

@staticmethod
def is_affine_shaped(data):
"""Check if the data is an affine matrix."""
if isinstance(data, Affine):
return True
if isinstance(data, DisplacementField):
return False
if not hasattr(data, "shape") or len(data.shape) < 2:
return False
return data.shape[-1] in (3, 4) and data.shape[-2] in (3, 4) and data.shape[-1] == data.shape[-2]


class DisplacementField:
"""A class to represent a dense displacement field."""

__slots__ = ("data",)

def __init__(self, data):
self.data = data

@staticmethod
def is_ddf_shaped(data):
"""Check if the data is a DDF."""
if isinstance(data, DisplacementField):
return True
if isinstance(data, Affine):
return False
if not hasattr(data, "shape") or len(data.shape) < 3:
return False
return not Affine.is_affine_shaped(data)


def combine_transforms(left: torch.Tensor, right: torch.Tensor) -> torch.Tensor:
"""Given transforms A and B to be applied to x, return the combined transform (AB), so that A(B(x)) becomes AB(x)"""
if Affine.is_affine_shaped(left) and Affine.is_affine_shaped(right): # linear transforms
left = convert_to_tensor(left.data if isinstance(left, Affine) else left, wrap_sequence=True)
right = convert_to_tensor(right.data if isinstance(right, Affine) else right, wrap_sequence=True)
return torch.matmul(left, right)
if DisplacementField.is_ddf_shaped(left) and DisplacementField.is_ddf_shaped(
right
): # adds DDFs, do we need metadata if metatensor input?
left = convert_to_tensor(left.data if isinstance(left, DisplacementField) else left, wrap_sequence=True)
right = convert_to_tensor(right.data if isinstance(right, DisplacementField) else right, wrap_sequence=True)
return left + right
raise NotImplementedError


def affine_from_pending(pending_item):
"""Extract the affine matrix from a pending transform item."""
if isinstance(pending_item, (torch.Tensor, np.ndarray)):
return pending_item
if isinstance(pending_item, dict):
return pending_item[LazyAttr.AFFINE]
return pending_item


def kwargs_from_pending(pending_item):
"""Extract kwargs from a pending transform item."""
if not isinstance(pending_item, dict):
return {}
ret = {
LazyAttr.INTERP_MODE: pending_item.get(LazyAttr.INTERP_MODE, None), # interpolation mode
LazyAttr.PADDING_MODE: pending_item.get(LazyAttr.PADDING_MODE, None), # padding mode
}
if LazyAttr.SHAPE in pending_item:
ret[LazyAttr.SHAPE] = pending_item[LazyAttr.SHAPE]
if LazyAttr.DTYPE in pending_item:
ret[LazyAttr.DTYPE] = pending_item[LazyAttr.DTYPE]
return ret


def is_compatible_apply_kwargs(kwargs_1, kwargs_2):
"""Check if two sets of kwargs are compatible (to be combined in `apply`)."""
return True


def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: Optional[dict] = None):
"""
This is a minimal implementation of resample that always uses Affine.
"""
if not Affine.is_affine_shaped(matrix):
raise NotImplementedError("calling dense grid resample API not implemented")
kwargs = {} if kwargs is None else kwargs
init_kwargs = {
"spatial_size": kwargs.pop(LazyAttr.SHAPE, data.shape)[1:],
"dtype": kwargs.pop(LazyAttr.DTYPE, data.dtype),
}
call_kwargs = {
"mode": kwargs.pop(LazyAttr.INTERP_MODE, None),
"padding_mode": kwargs.pop(LazyAttr.PADDING_MODE, None),
}
resampler = monai.transforms.Affine(affine=matrix, image_only=True, **init_kwargs)
with resampler.trace_transform(False): # don't track this transform in `data`
return resampler(img=data, **call_kwargs)
1 change: 1 addition & 0 deletions monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,3 +630,4 @@ class LazyAttr(StrEnum):
AFFINE = "lazy_affine"
PADDING_MODE = "lazy_padding_mode"
INTERP_MODE = "lazy_interpolation_mode"
DTYPE = "lazy_dtype"
71 changes: 71 additions & 0 deletions tests/test_apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import torch

from monai.transforms.lazy.functional import apply
from monai.transforms.utils import create_rotate
from monai.utils import LazyAttr, convert_to_tensor
from tests.utils import get_arange_img


def single_2d_transform_cases():
return [
(
torch.as_tensor(get_arange_img((32, 32))),
[{LazyAttr.AFFINE: create_rotate(2, np.pi / 4)}, {LazyAttr.AFFINE: create_rotate(2, -np.pi / 4)}],
(1, 32, 32),
),
(torch.as_tensor(get_arange_img((32, 32))), [create_rotate(2, np.pi / 2)], (1, 32, 32)),
(
torch.as_tensor(get_arange_img((16, 16))),
[{LazyAttr.AFFINE: create_rotate(2, np.pi / 2), LazyAttr.SHAPE: (1, 45, 45)}],
(1, 45, 45),
),
]


class TestApply(unittest.TestCase):
def _test_apply_impl(self, tensor, pending_transforms, expected_shape):
result = apply(tensor, pending_transforms)
self.assertListEqual(result[1], pending_transforms)
self.assertEqual(result[0].shape, expected_shape)

def _test_apply_metatensor_impl(self, tensor, pending_transforms, expected_shape, pending_as_parameter):
tensor_ = convert_to_tensor(tensor, track_meta=True)
if pending_as_parameter:
result, transforms = apply(tensor_, pending_transforms)
else:
for p in pending_transforms:
tensor_.push_pending_operation(p)
result, transforms = apply(tensor_)
self.assertEqual(result.shape, expected_shape)

SINGLE_TRANSFORM_CASES = single_2d_transform_cases()

def test_apply_single_transform(self):
for case in self.SINGLE_TRANSFORM_CASES:
self._test_apply_impl(*case)

def test_apply_single_transform_metatensor(self):
for case in self.SINGLE_TRANSFORM_CASES:
self._test_apply_metatensor_impl(*case, False)

def test_apply_single_transform_metatensor_override(self):
for case in self.SINGLE_TRANSFORM_CASES:
self._test_apply_metatensor_impl(*case, True)


if __name__ == "__main__":
unittest.main()
40 changes: 40 additions & 0 deletions tests/test_resample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.transforms.lazy.functional import resample
from monai.utils import convert_to_tensor
from tests.utils import assert_allclose, get_arange_img


def rotate_90_2d():
t = torch.eye(3)
t[:, 0] = torch.FloatTensor([0, -1, 0])
t[:, 1] = torch.FloatTensor([1, 0, 0])
return t


RESAMPLE_FUNCTION_CASES = [(get_arange_img((3, 3)), rotate_90_2d(), [[2, 5, 8], [1, 4, 7], [0, 3, 6]])]


class TestResampleFunction(unittest.TestCase):
@parameterized.expand(RESAMPLE_FUNCTION_CASES)
def test_resample_function_impl(self, img, matrix, expected):
out = resample(convert_to_tensor(img), matrix)
assert_allclose(out[0], expected, type_test=False)


if __name__ == "__main__":
unittest.main()
10 changes: 10 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,16 @@ def make_rand_affine(ndim: int = 3, random_state: Optional[np.random.RandomState
return af


def get_arange_img(size, dtype=np.float32, offset=0):
"""
Returns an image as a numpy array (complete with channel as dim 0)
with contents that iterate like an arange.
"""
n_elem = np.prod(size)
img = np.arange(offset, offset + n_elem, dtype=dtype).reshape(size)
return np.expand_dims(img, 0)


class DistTestCase(unittest.TestCase):
"""
testcase without _outcome, so that it's picklable.
Expand Down