Skip to content

Commit 45d9a30

Browse files
fatcat-zdatumbox
andauthored
[ONNX] Fix ShuffleNetV2 model export issue. (#3158)
* Fix an issue that ShuffleNetV2 model is exported to a wrong ONNX file if dynamic_axes field was provided. * Add a ut for the bug fix. * Fix flake8 issue. * Don't access each element in x.shape, use x.size() instead. Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 2e5e058 commit 45d9a30

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

test/test_onnx.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,17 @@ def test_keypoint_rcnn(self):
482482
dynamic_axes={"images_tensors": [0, 1, 2]},
483483
tolerate_small_mismatch=True)
484484

485+
def test_shufflenet_v2_dynamic_axes(self):
486+
model = models.shufflenet_v2_x0_5(pretrained=True)
487+
dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)
488+
test_inputs = torch.cat([dummy_input, dummy_input, dummy_input], 0)
489+
490+
self.run_model(model, [(dummy_input,), (test_inputs,)],
491+
input_names=["input_images"],
492+
output_names=["output"],
493+
dynamic_axes={"input_images": {0: 'batch_size'}, "output": {0: 'batch_size'}},
494+
tolerate_small_mismatch=True)
495+
485496

486497
if __name__ == '__main__':
487498
unittest.main()

torchvision/models/shufflenetv2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020

2121
def channel_shuffle(x: Tensor, groups: int) -> Tensor:
22-
batchsize, num_channels, height, width = x.data.size()
22+
batchsize, num_channels, height, width = x.size()
2323
channels_per_group = num_channels // groups
2424

2525
# reshape

0 commit comments

Comments
 (0)