-
Notifications
You must be signed in to change notification settings - Fork 7.1k
enable get_params alias for transforms v2 #7153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -56,10 +56,19 @@ def extra_repr(self) -> str: | |||||
|
||||||
return ", ".join(extra) | ||||||
|
||||||
# This attribute should be set on all transforms that have a v1 equivalent. Doing so enables the v2 transformation | ||||||
# to be scriptable. See `_extract_params_for_v1_transform()` and `__prepare_scriptable__` for details. | ||||||
# This attribute should be set on all transforms that have a v1 equivalent. Doing so enables two things: | ||||||
# 1. In case the v1 transform has a static `get_params` method, it will also be available under the same name on | ||||||
# the v2 transform. See `__init_subclass__` for details. | ||||||
# 2. The v2 transform will be JIT scriptable. See `_extract_params_for_v1_transform` and `__prepare_scriptable__` | ||||||
# for details. | ||||||
_v1_transform_cls: Optional[Type[nn.Module]] = None | ||||||
|
||||||
def __init_subclass__(cls) -> None: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As an alternative to @staticmethod
def get_params(cls):
if cls._v1_transform_cls is not None and hasattr(cls._v1_transform_cls, "get_params"):
return cls._v1_transform_cls.get_params()
else:
raise AttributeError(
"cls {cls} has no get_params method. You probably don't need one anymore
as the same RNG is applied to all images, bboxes and masks in the same transform call.
If what you need is a way to transform different batches with the same RNG,
please reach out at #1234567 (the feedback issue.
") There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two problems here with JIT:
I agree using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the details. My last try would be to declare There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You guessed right. But even if did, it would wouldn't work for us here. |
||||||
# Since `get_params` is a `@staticmethod`, we have to bind it to the class itself rather than to an instance. | ||||||
# This method is called after subclassing has happened, i.e. `cls` is the subclass, e.g. `Resize`. | ||||||
if cls._v1_transform_cls is not None and hasattr(cls._v1_transform_cls, "get_params"): | ||||||
cls.get_params = cls._v1_transform_cls.get_params # type: ignore[attr-defined] | ||||||
|
||||||
def _extract_params_for_v1_transform(self) -> Dict[str, Any]: | ||||||
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current | ||||||
# v2 transform instance. It does two things: | ||||||
|
Uh oh!
There was an error while loading. Please reload this page.