-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[prototype] Minor speed and nit optimizations on Transform Classes #6837
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
49f7e5a
5ea065c
99b1685
17d8184
5e0be6e
08ae56f
7b8be17
5f7e1ee
b0b9b55
ee31969
88328f5
e15f536
53f12bb
843bcc9
8e6af8d
11af094
4f90ce1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,7 +51,7 @@ def _check_input( | |
|
||
@staticmethod | ||
def _generate_value(left: float, right: float) -> float: | ||
return float(torch.distributions.Uniform(left, right).sample()) | ||
return torch.empty(1).uniform_(left, right).item() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Switching to this random generator we get a performance boost on GPU. Moreover this option is JIT-scriptable (if on the future we decide to add support) and doesn't require to constantly initialize a distribution object as before:
|
||
|
||
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: | ||
fn_idx = torch.randperm(4) | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,4 +1,4 @@ | ||||||
from typing import Any, cast, Dict, Optional, Union | ||||||
from typing import Any, Dict, Optional, Union | ||||||
|
||||||
import numpy as np | ||||||
import PIL.Image | ||||||
|
@@ -13,7 +13,7 @@ class DecodeImage(Transform): | |||||
_transformed_types = (features.EncodedImage,) | ||||||
|
||||||
def _transform(self, inpt: torch.Tensor, params: Dict[str, Any]) -> features.Image: | ||||||
return cast(features.Image, F.decode_image_with_pil(inpt)) | ||||||
return F.decode_image_with_pil(inpt) # type: ignore[no-any-return] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This has to be here, because it seems
doesn't "forward" the type annotations 🙄 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In all other places we took the decision to silence with ignore rather than cast, do we really need the cast here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nono, I was just explaining why we need the ignore for future me that is looking confused at the blame why we introduced it in the first place. |
||||||
|
||||||
|
||||||
class LabelToOneHot(Transform): | ||||||
|
@@ -27,7 +27,7 @@ def _transform(self, inpt: features.Label, params: Dict[str, Any]) -> features.O | |||||
num_categories = self.num_categories | ||||||
if num_categories == -1 and inpt.categories is not None: | ||||||
num_categories = len(inpt.categories) | ||||||
output = one_hot(inpt, num_classes=num_categories) | ||||||
output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories) | ||||||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
return features.OneHotLabel(output, categories=inpt.categories) | ||||||
|
||||||
def extra_repr(self) -> str: | ||||||
|
@@ -50,7 +50,7 @@ class ToImageTensor(Transform): | |||||
def _transform( | ||||||
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] | ||||||
) -> features.Image: | ||||||
return cast(features.Image, F.to_image_tensor(inpt)) | ||||||
return F.to_image_tensor(inpt) # type: ignore[no-any-return] | ||||||
|
||||||
|
||||||
class ToImagePIL(Transform): | ||||||
|
Uh oh!
There was an error while loading. Please reload this page.