Skip to content

Commit 02b29e7

Browse files
Episkey0109pytorchmergebot
authored andcommitted
Add meta function for channel_shuffle operation (#123033)
This commit introduces a meta function for the `channel_shuffle` operation, enabling PyTorch to perform shape inference and optimizations related to this operation without actual computation. The meta function assumes input shape (*, C, H, W) and validates that the number of channels (C) is divisible by the specified number of groups. Fixes #122771 Pull Request resolved: #123033 Approved by: https://github.com/ezyang, https://github.com/mikaylagawarecki
1 parent 84580f7 commit 02b29e7

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

torch/_meta_registrations.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6161,6 +6161,22 @@ def meta_polygamma(n: int, self: Tensor) -> Tensor:
61616161
return torch.empty_like(self, dtype=result_dtype)
61626162

61636163

6164+
@register_meta(aten.channel_shuffle.default)
6165+
def meta_channel_shuffle(input, groups):
6166+
# Assume the input shape is (*, C, H, W), where * represents any number of leading dimensions
6167+
*leading_dims, C, H, W = input.size()
6168+
# The output shape is the same as the input
6169+
return torch.empty(
6170+
*leading_dims,
6171+
C,
6172+
H,
6173+
W,
6174+
dtype=input.dtype,
6175+
layout=input.layout,
6176+
device=input.device,
6177+
)
6178+
6179+
61646180
def _create_unary_float_meta_func(func):
61656181
@register_meta(func)
61666182
@out_wrapper()

torch/testing/_internal/common_methods_invocations.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8828,6 +8828,20 @@ def sample_inputs_pixel_unshuffle(op_info, device, dtype, requires_grad, **kwarg
88288828
]
88298829
)
88308830

8831+
def sample_inputs_channel_shuffle(op_info, device, dtype, requires_grad, **kwargs):
8832+
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
8833+
8834+
shapes_groups = [
8835+
((1, 4, 10, 10), 2),
8836+
((2, 6, 8, 8), 3),
8837+
((2, 8, 5, 5), 4),
8838+
]
8839+
8840+
yield from (
8841+
SampleInput(make_arg(shape), args=(groups,))
8842+
for shape, groups in shapes_groups
8843+
)
8844+
88318845
def sample_inputs_binary_cross_entropy(op_info, device, dtype, requires_grad, logits=False, **kwargs):
88328846
make = partial(make_tensor, device=device, dtype=dtype)
88338847
# Lower bounds must be greater than 'eps' defined in gradcheck.py::gradgradcheck() -> eps
@@ -19610,6 +19624,25 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
1961019624
),
1961119625
),
1961219626
),
19627+
OpInfo(
19628+
"nn.functional.channel_shuffle",
19629+
sample_inputs_func=sample_inputs_channel_shuffle,
19630+
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
19631+
backward_dtypes=integral_types_and(torch.bool),
19632+
supports_out=False,
19633+
supports_autograd=False,
19634+
allow_cow_input_materialize_forward=[0],
19635+
skips=(
19636+
# Skip due to NotImplementedError for MPS device.
19637+
DecorateInfo(unittest.expectedFailure, 'TestConsistency'),
19638+
# vmap: calling random operator not supported
19639+
DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
19640+
DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
19641+
DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'),
19642+
DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
19643+
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
19644+
),
19645+
),
1961319646
OpInfo(
1961419647
"nn.functional.kl_div",
1961519648
sample_inputs_func=sample_inputs_kl_div,

0 commit comments

Comments
 (0)