diff --git a/torch2trt/converters/interpolate_custom.py b/torch2trt/converters/interpolate_custom.py index 913794c..207da54 100644 --- a/torch2trt/converters/interpolate_custom.py +++ b/torch2trt/converters/interpolate_custom.py @@ -67,7 +67,7 @@ def convert_interpolate(ctx): scale_factor = (1,)*2 + tuple(scale_factor) layer.scales = scale_factor else: - layer.shape = tuple(output.shape) + layer.scales = tuple([float(output_shape / input_shape) for input_shape, output_shape in zip(input.shape, output.shape)]) layer.align_corners = align_corners if mode=="nearest": @@ -80,7 +80,6 @@ def convert_interpolate(ctx): output._trt = layer.get_output(0) - class InterpolateTest(torch.nn.Module): def __init__(self, size=None, scale_factor=None, mode='nearest'):