-
Notifications
You must be signed in to change notification settings - Fork 7.1k
add support for single channel images #174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, sorry for the delay in reviewing.
This looks good for the most part, I made some small comments (mostly cosmetics). Could you address them?
torchvision/transforms.py
Outdated
@@ -43,7 +43,10 @@ class ToTensor(object): | |||
def __call__(self, pic): | |||
if isinstance(pic, np.ndarray): | |||
# handle numpy array | |||
img = torch.from_numpy(pic.transpose((2, 0, 1))) | |||
if len(pic.shape) >= 3: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/transforms.py
Outdated
if len(pic.shape) >= 3: | ||
img = torch.from_numpy(pic.transpose((2, 0, 1))) | ||
else: | ||
img = torch.from_numpy(pic.reshape((1,) + pic.shape)) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -89,6 +92,8 @@ def __call__(self, pic): | |||
if torch.is_tensor(pic): | |||
npimg = np.transpose(pic.numpy(), (1, 2, 0)) | |||
assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray' | |||
if len(npimg.shape) < 3: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/transforms.py
Outdated
@@ -89,6 +92,8 @@ def __call__(self, pic): | |||
if torch.is_tensor(pic): | |||
npimg = np.transpose(pic.numpy(), (1, 2, 0)) | |||
assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray' | |||
if len(npimg.shape) < 3: | |||
npimg = np.reshape(npimg, npimg.shape + (1,)) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@edgarriba are you able to make the changes requested by @fmassa? |
@alykhantejani @fmassa sure, I'll allocate some time during this weekend |
@@ -1,6 +1,7 @@ | |||
import torch | |||
import torchvision.transforms as transforms | |||
import unittest | |||
from parameterized import parameterized |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
||
if accimage is not None and isinstance(pic, accimage.Image): | ||
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) | ||
nppic=np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -1,6 +1,7 @@ | |||
import torch | |||
import torchvision.transforms as transforms | |||
import unittest | |||
from parameterized import parameterized |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
('3channel', 3, 4, 4), | ||
('1channel', 1, 4, 4), | ||
]) | ||
def test_pil_to_tensor(self, _, channels, height, width): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@edgarriba Thanks for this - in general it looks good just a few small things to fix. I don't think we want to add another dependency to the library if we can avoid it. Would you be able to make this change and the linting ones? |
Checks for image/numpy channels are implemented with refactoring of transforms to functional and transforms in #311 vision/torchvision/transforms/functional.py Lines 43 to 44 in 50b2f91
Although single channel images (HxW) still produce following error: ValueError: axes don't match array due to vision/torchvision/transforms/functional.py Lines 46 to 48 in 50b2f91
Wouldn't it be easier to just expand these arrays with a simple if, like: ...
if isinstance(pic, np.ndarray): ## Existing
# handle numpy array
if pic.ndim == 2: ## Addition
# expand by one dimension as the last channel ## Addition
img = torch.from_numpy(pic.transpose((2, 0, 1))) ## Existing
... Either way, this PR might as well be closed as a clean up due to massive conflicts? |
I'm planning on modifying the underlying data structure that we use internally (instead of using a PIL image, have a custom class), so that those corner cases can be fixed. |
Sounds like a plan, I would like to work on that if you will need a pair of hands. |
* adding tag to handle evaluations by iterations * better naming * better naming round 2 * update the python package version
No description provided.