Skip to content

Commit 2473d22

Browse files
committed
Remove redundant decollate condition for torch/numpy scalars
Signed-off-by: Arthur Dujardin <[email protected]>
1 parent 1e0a554 commit 2473d22

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

monai/data/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -625,13 +625,12 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None):
625625
type(batch).__module__ == "numpy" and not isinstance(batch, Iterable)
626626
):
627627
return batch
628+
# if scalar tensor/array, return the item itself.
628629
if getattr(batch, "ndim", -1) == 0 and hasattr(batch, "item"):
629630
return batch.item() if detach else batch
630631
if isinstance(batch, torch.Tensor):
631632
if detach:
632633
batch = batch.detach()
633-
if batch.ndim == 0:
634-
return batch.item() if detach else batch
635634
out_list = torch.unbind(batch, dim=0)
636635
# if of type MetaObj, decollate the metadata
637636
if isinstance(batch, MetaObj):

0 commit comments

Comments
 (0)