@@ -85,7 +85,7 @@ class Volumes:
85
85
are linearly interpolated over the spatial dimensions of the volume.
86
86
- Note that the convention is the same as for the 5D version of the
87
87
`torch.nn.functional.grid_sample` function called with
88
- `align_corners==True` .
88
+ the same value of `align_corners` argument .
89
89
- Note that the local coordinate convention of `Volumes`
90
90
(+X = left to right, +Y = top to bottom, +Z = away from the user)
91
91
is *different* from the world coordinate convention of the
@@ -143,7 +143,7 @@ class Volumes:
143
143
torch.nn.functional.grid_sample(
144
144
v.densities(),
145
145
v.get_coord_grid(world_coordinates=False),
146
- align_corners=True ,
146
+ align_corners=align_corners ,
147
147
) == v.densities(),
148
148
149
149
i.e. sampling the volume at trivial local coordinates
@@ -157,6 +157,7 @@ def __init__(
157
157
features : Optional [_TensorBatch ] = None ,
158
158
voxel_size : _VoxelSize = 1.0 ,
159
159
volume_translation : _Translation = (0.0 , 0.0 , 0.0 ),
160
+ align_corners : bool = True ,
160
161
) -> None :
161
162
"""
162
163
Args:
@@ -186,6 +187,10 @@ def __init__(
186
187
b) a Tensor of shape (3,)
187
188
c) a Tensor of shape (minibatch, 3)
188
189
d) a Tensor of shape (1,) (square voxels)
190
+ **align_corners**: If set (default), the coordinates of the corner voxels are
191
+ exactly −1 or +1 in the local coordinate system. Otherwise, the coordinates
192
+ correspond to the centers of the corner voxels. Cf. the namesake argument to
193
+ `torch.nn.functional.grid_sample`.
189
194
"""
190
195
191
196
# handle densities
@@ -206,6 +211,7 @@ def __init__(
206
211
voxel_size = voxel_size ,
207
212
volume_translation = volume_translation ,
208
213
device = self .device ,
214
+ align_corners = align_corners ,
209
215
)
210
216
211
217
# handle features
@@ -336,6 +342,13 @@ def features_list(self) -> List[torch.Tensor]:
336
342
return None
337
343
return self ._features_densities_list (features_ )
338
344
345
+ def get_align_corners (self ) -> bool :
346
+ """
347
+ Return whether the corners of the voxels should be aligned with the
348
+ image pixels.
349
+ """
350
+ return self .locator ._align_corners
351
+
339
352
def _features_densities_list (self , x : torch .Tensor ) -> List [torch .Tensor ]:
340
353
"""
341
354
Retrieve the list representation of features/densities.
@@ -576,7 +589,7 @@ class VolumeLocator:
576
589
are linearly interpolated over the spatial dimensions of the volume.
577
590
- Note that the convention is the same as for the 5D version of the
578
591
`torch.nn.functional.grid_sample` function called with
579
- `align_corners==True` .
592
+ the same value of `align_corners` argument .
580
593
- Note that the local coordinate convention of `VolumeLocator`
581
594
(+X = left to right, +Y = top to bottom, +Z = away from the user)
582
595
is *different* from the world coordinate convention of the
@@ -634,7 +647,7 @@ class VolumeLocator:
634
647
torch.nn.functional.grid_sample(
635
648
v.densities(),
636
649
v.get_coord_grid(world_coordinates=False),
637
- align_corners=True ,
650
+ align_corners=align_corners ,
638
651
) == v.densities(),
639
652
640
653
i.e. sampling the volume at trivial local coordinates
@@ -651,6 +664,7 @@ def __init__(
651
664
device : torch .device ,
652
665
voxel_size : _VoxelSize = 1.0 ,
653
666
volume_translation : _Translation = (0.0 , 0.0 , 0.0 ),
667
+ align_corners : bool = True ,
654
668
):
655
669
"""
656
670
**batch_size** : Batch size of the underlying grids
@@ -674,15 +688,21 @@ def __init__(
674
688
b) a Tensor of shape (3,)
675
689
c) a Tensor of shape (minibatch, 3)
676
690
d) a Tensor of shape (1,) (square voxels)
691
+ **align_corners**: If set (default), the coordinates of the corner voxels are
692
+ exactly −1 or +1 in the local coordinate system. Otherwise, the coordinates
693
+ correspond to the centers of the corner voxels. Cf. the namesake argument to
694
+ `torch.nn.functional.grid_sample`.
677
695
"""
678
696
self .device = device
679
697
self ._batch_size = batch_size
680
698
self ._grid_sizes = self ._convert_grid_sizes2tensor (grid_sizes )
681
699
self ._resolution = tuple (torch .max (self ._grid_sizes .cpu (), dim = 0 ).values )
700
+ self ._align_corners = align_corners
682
701
683
702
# set the local_to_world transform
684
703
self ._set_local_to_world_transform (
685
- voxel_size = voxel_size , volume_translation = volume_translation
704
+ voxel_size = voxel_size ,
705
+ volume_translation = volume_translation ,
686
706
)
687
707
688
708
def _convert_grid_sizes2tensor (
@@ -806,8 +826,17 @@ def _calculate_coordinate_grid(
806
826
grid_sizes = self .get_grid_sizes ()
807
827
808
828
# generate coordinate axes
829
+ def corner_coord_adjustment (r ):
830
+ return 0.0 if self ._align_corners else 1.0 / r
831
+
809
832
vol_axes = [
810
- torch .linspace (- 1.0 , 1.0 , r , dtype = torch .float32 , device = self .device )
833
+ torch .linspace (
834
+ - 1.0 + corner_coord_adjustment (r ),
835
+ 1.0 - corner_coord_adjustment (r ),
836
+ r ,
837
+ dtype = torch .float32 ,
838
+ device = self .device ,
839
+ )
811
840
for r in (de , he , wi )
812
841
]
813
842
0 commit comments