diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 50f62701dc..4533a522bc 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2491,17 +2491,28 @@ def aten_upsample_nearest3d_backward( raise NotImplementedError() +@torch_op("aten::upsample_trilinear3d", trace_only=True) def aten_upsample_trilinear3d( - self: TensorType, + self: TReal, output_size: INT64, align_corners: bool, scales_d: Optional[float] = None, scales_h: Optional[float] = None, scales_w: Optional[float] = None, -) -> TensorType: +) -> TReal: """upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor""" - raise NotImplementedError() + del scales_d + del scales_h + del scales_w + + coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners) + return _aten_upsample_output_size( + self, + output_size, + mode="linear", + coordinate_transformation_mode=coordinate_transformation_mode, + ) def aten_upsample_trilinear3d_backward( diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index c274df2beb..d401b35898 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -1579,6 +1579,44 @@ def shape(size, rank, with_batch_channel=True): ) +def sample_inputs_upsample_trilinear3d(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + N, C = 2, 3 + D = 4 + SS = 3 + L = 5 + + align_corners_options = (True, False) + rank = 3 + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + make_arg = functools.partial( + torch_testing.make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + low=-1, + high=1, + ) + + for align_corners in align_corners_options: + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), shape(SS, rank, False), align_corners + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), shape(S, rank, False), align_corners + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), shape(L, rank, False), align_corners + ) + + class _TestParamsMaxPoolEmptyStrideBase: # Adapted from https://github.com/pytorch/pytorch/blob/d6d55f8590eab05d2536756fb4efcfb2d07eb81a/torch/testing/_internal/common_methods_invocations.py#L3203 def __init__(self): @@ -2079,6 +2117,13 @@ def __init__(self): sample_inputs_func=sample_inputs_upsample_linear1d, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.upsample_trilinear3d", + aten_name="upsample_trilinear3d", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_trilinear3d, + supports_out=False, + ), opinfo_core.OpInfo( "nn.functional.max_pool1d_with_indices", aten_name="max_pool1d_with_indices", diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 58ade8bf1b..d2bfab279e 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -2140,6 +2140,11 @@ def _where_input_wrangler( and sample.kwargs.get("scales") is not None, reason="fixme: align_corners=False output mismatch when scales are provided", ), + TorchLibOpInfo( + "ops.aten.upsample_trilinear3d", + nn_ops.aten_upsample_trilinear3d, + trace_only=True, + ), TorchLibOpInfo( "nn.functional.upsample_nearest2d", nn_ops.aten_upsample_nearest2d,