diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index c69ecb69d08..86bf361fcbb 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -117,7 +117,7 @@ def to_pil_image(pic, mode=None): elif pic.ndimension() == 2: # if 2D image, add channel dimension (CHW) - pic.unsqueeze_(0) + pic = pic.unsqueeze(0) elif isinstance(pic, np.ndarray): if pic.ndim not in {2, 3}: