diff --git a/MxNet/Classification/RN50v1.5/models.py b/MxNet/Classification/RN50v1.5/models.py index 43b5ae099..5de06b99f 100644 --- a/MxNet/Classification/RN50v1.5/models.py +++ b/MxNet/Classification/RN50v1.5/models.py @@ -467,7 +467,7 @@ def create_resnet(builder, version, num_layers=50, resnext=False, classes=1000): block_class, layers, channels = resnet_spec[num_layers] assert not resnext or num_layers >= 50, \ "Cannot create resnext with less then 50 layers" - net = ResNet(builder, block_class, layers, channels, version=version, + net = ResNet(builder, block_class, layers, channels, classes, version=version, resnext_groups=args.num_groups if resnext else None) return net