diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 4533a522bc..a6cbb18493 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2201,7 +2201,16 @@ def _get_upsample_align_corners_mode(align_corners: bool) -> str: return "align_corners" if align_corners else "pytorch_half_pixel" -@torch_op(("aten::upsample_bicubic2d", "aten::upsample_bilinear2d"), private=True) +@torch_op( + ( + "aten::upsample_bicubic2d", + "aten::upsample_bilinear2d", + "aten::upsample_nearest1d", + "aten::upsample_nearest2d", + "aten::upsample_nearest3d", + ), + private=True, +) def _aten_upsample_output_size( self: TReal, output_size: INT64, @@ -2240,7 +2249,6 @@ def _aten_upsample_scales( None, mode=mode, coordinate_transformation_mode=coordinate_transformation_mode, - nearest_mode="floor", ) @@ -2396,12 +2404,33 @@ def aten_upsample_linear1d_backward( raise NotImplementedError() +@torch_op("aten::upsample_nearest1d", trace_only=True) def aten_upsample_nearest1d( - self: TensorType, output_size: INT64, scales: Optional[float] = None -) -> TensorType: + self: TReal, size: INT64, scale_factor: Optional[float] = None +) -> TReal: """upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor""" + if size is not None: + return _aten_upsample_output_size(self, size, "nearest", "asymmetric") + else: + return _aten_upsample_nearest1d_scales(self, scale_factor) - raise NotImplementedError() + +@torch_op("aten::upsample_nearest1d", private=True) +def _aten_upsample_nearest1d_scales( + self: TReal, + scale_factors: TFloat, +) -> TReal: + scale_factors = op.Cast(scale_factors, to=FLOAT.dtype) + scale_factors = op.Concat(op.Constant(value_floats=[1.0, 1.0]), scale_factors, axis=0) + return op.Resize( + self, + None, + scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w] + None, + mode="nearest", + coordinate_transformation_mode="asymmetric", + nearest_mode="floor", + ) def aten_upsample_nearest1d_backward( @@ -2429,29 +2458,7 @@ def aten_upsample_nearest2d( del scales_h del scales_w - return _aten_upsample_nearest2d_onnx(self, size) - - -@torch_op("aten::upsample_nearest2d", private=True) -def _aten_upsample_nearest2d_onnx( - self: TReal, - size: INT64, -) -> TReal: - self_shape = op.Shape(self) - batch_channel = self_shape[:2] # type: ignore[index] - output_size = op.Concat(batch_channel, size, axis=0) - - return op.Resize( - self, - None, - None, - output_size, - mode="nearest", - # NOTE(justinchuby): Both asymmetric and pytorch_half_pixel pass the test - # I used asymmetric because it aligns with the torch.onnx exporter - coordinate_transformation_mode="asymmetric", - nearest_mode="floor", - ) + return _aten_upsample_output_size(self, size, "nearest", "asymmetric") def aten_upsample_nearest2d_backward( @@ -2466,16 +2473,21 @@ def aten_upsample_nearest2d_backward( raise NotImplementedError() +@torch_op("aten::upsample_nearest3d", trace_only=True) def aten_upsample_nearest3d( - self: TensorType, - output_size: INT64, + self: TReal, + size: INT64, scales_d: Optional[float] = None, scales_h: Optional[float] = None, scales_w: Optional[float] = None, -) -> TensorType: +) -> TReal: """upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor""" - raise NotImplementedError() + del scales_h + del scales_w + del scales_d + + return _aten_upsample_output_size(self, size, "nearest", "asymmetric") def aten_upsample_nearest3d_backward( diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index d401b35898..086264e9bf 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -1579,6 +1579,149 @@ def shape(size, rank, with_batch_channel=True): ) +def sample_inputs_upsample_nearest1d(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + N, C = 2, 3 + D = 4 + SS = 3 + L = 5 + + rank = 1 + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + make_arg = functools.partial( + torch_testing.make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + low=-1, + high=1, + ) + + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True) + + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + shape(S, rank, False), + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + shape(L, rank, False), + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + None, # output_size + (1.7,), # scaler + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + None, # if this is None, the scalar must be list + (0.6,), + ) + + +def sample_inputs_upsample_nearest2d(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + N, C = 2, 3 + D = 4 + SS = 3 + L = 5 + + rank = 2 + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + make_arg = functools.partial( + torch_testing.make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + low=-1, + high=1, + ) + + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True) + + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + shape(S, rank, False), + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + shape(L, rank, False), + ) + # ONNX don't support below cases: both output_size and scaler are not None + # yield opinfo_core.SampleInput( + # make_arg(shape(D, rank)), + # shape(L, rank, False), + # 1.7, # scaler + # ) + # yield opinfo_core.SampleInput( + # make_arg(shape(D, rank)), + # shape(L, rank, False), + # 0.6, + # ) + + +def sample_inputs_upsample_nearest3d(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + N, C = 2, 3 + D = 4 + SS = 3 + L = 5 + + rank = 3 + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + make_arg = functools.partial( + torch_testing.make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + low=-1, + high=1, + ) + + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True) + + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + shape(S, rank, False), + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + shape(L, rank, False), + ) + # ONNX don't support below cases: both output_size and scaler are not None + # yield opinfo_core.SampleInput( + # make_arg(shape(D, rank)), + # shape(L, rank, False), + # 1.7, # scaler + # ) + # yield opinfo_core.SampleInput( + # make_arg(shape(D, rank)), + # shape(L, rank, False), + # 0.6, + # ) + + def sample_inputs_upsample_trilinear3d(op_info, device, dtype, requires_grad, **kwargs): del op_info del kwargs @@ -2117,6 +2260,27 @@ def __init__(self): sample_inputs_func=sample_inputs_upsample_linear1d, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.upsample_nearest1d", + aten_name="upsample_nearest1d", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_nearest1d, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.upsample_nearest2d", + aten_name="upsample_nearest2d", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_nearest2d, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.upsample_nearest3d", + aten_name="upsample_nearest3d", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_nearest3d, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.upsample_trilinear3d", aten_name="upsample_trilinear3d", diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index d2bfab279e..3cfd4a1629 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -415,18 +415,6 @@ def _sum_input_wrangler( return args, kwargs -def _upsample_input_wrangler( - args: list[Any], kwargs: dict[str, Any] -) -> tuple[list[Any], dict[str, Any]]: - if "scale_factor" in kwargs: - kwargs["scales_h"] = kwargs["scale_factor"] - kwargs["scales_w"] = kwargs["scale_factor"] - del kwargs["scale_factor"] - if "size" in kwargs: - kwargs["size"] = np.array(kwargs["size"], dtype=np.int64) - return args, kwargs - - def _unflatten_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -2141,24 +2129,24 @@ def _where_input_wrangler( reason="fixme: align_corners=False output mismatch when scales are provided", ), TorchLibOpInfo( - "ops.aten.upsample_trilinear3d", - nn_ops.aten_upsample_trilinear3d, + "ops.aten.upsample_nearest1d", + nn_ops.aten_upsample_nearest1d, trace_only=True, ), TorchLibOpInfo( - "nn.functional.upsample_nearest2d", + "ops.aten.upsample_nearest2d", nn_ops.aten_upsample_nearest2d, - input_wrangler=_upsample_input_wrangler, trace_only=True, - ) - .skip( - # Shape should be [N, C, H, W] - matcher=lambda sample: len(sample.input.shape) != 2 + 2, - reason="only test on 2d inputs", - ) - .xfail( - matcher=lambda sample: "scale_factor" in sample.kwargs, - reason="fixme: the scale_factor tests", + ), + TorchLibOpInfo( + "ops.aten.upsample_nearest3d", + nn_ops.aten_upsample_nearest3d, + trace_only=True, + ), + TorchLibOpInfo( + "ops.aten.upsample_trilinear3d", + nn_ops.aten_upsample_trilinear3d, + trace_only=True, ), TorchLibOpInfo("ones_like", core_ops.aten_ones_like, trace_only=True), TorchLibOpInfo( @@ -2376,15 +2364,6 @@ def _where_input_wrangler( "nn.functional.celu", ("nn.functional.celu_type_promoted",), ) -ops_test_common.duplicate_opinfo( - OPS_DB, - "nn.functional.upsample_nearest", - ( - "nn.functional.upsample_nearest1d", - "nn.functional.upsample_nearest2d", - "nn.functional.upsample_nearest3d", - ), -) ops_test_common.duplicate_opinfo( OPS_DB, "ops.aten._log_softmax", ("ops.aten._log_softmax_half",) )