From 918250efbb5f07918cbc0ca001d379878389d6c7 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 7 Dec 2023 18:31:17 +0800 Subject: [PATCH 01/17] update --- onnxscript/function_libs/torch_lib/ops/nn.py | 52 ++++++++++++++++++- .../function_libs/torch_lib/extra_opinfo.py | 34 ++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 5 ++ 3 files changed, 90 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index a37ec84c6e..d7068ee020 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2192,6 +2192,7 @@ def aten_unflatten_dense_tensors( raise NotImplementedError() +@torch_op("aten::upsample_bicubic2d", trace_only=True) def aten_upsample_bicubic2d( self: TensorType, output_size: INT64, @@ -2201,7 +2202,56 @@ def aten_upsample_bicubic2d( ) -> TensorType: """upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" - raise NotImplementedError() + if output_size is not None: + result = _aten_upsample_bicubic2d_output_size(self, output_size) + else: + assert scales_h is not None + assert scales_h == scales_w + result = _aten_upsample_bicubic2d_scales(self, scales_h, scales_w) + return result + + +@torch_op("aten::upsample_bicubic2d", private=True) +def _aten_upsample_bicubic2d_output_size( + self: TReal, + output_size: INT64, +) -> TReal: + self_shape = op.Shape(self) + starts = op.Constant(value_ints=[0]) + ends = op.Constant(value_ints=[2]) + batch_channel = op.Slice(self_shape, starts, ends) + output_size = op.Concat(batch_channel, output_size, axis=0) + return op.Resize( + self, + None, + None, + output_size, + mode="cubic", + coordinate_transformation_mode="align_corners", + ) + + +@torch_op("aten::upsample_bicubic2d", private=True) +def _aten_upsample_bicubic2d_scales( + self: TReal, + scales_h: float, + scales_w: float, +) -> TReal: + neg_1 = op.Constant(value_ints=[-1]) + scales = op.Concat( + op.Constant(value_floats=[1.0, 1.0]), + op.Reshape(op.Constant(value_float=scales_h), neg_1), + op.Reshape(op.Constant(value_float=scales_w), neg_1), + axis=0, + ) + return op.Resize( + self, + None, + scales, # format should be: [1.0, 1.0, scale_h, scale_w] + None, + mode="cubic", + coordinate_transformation_mode="align_corners", + ) def aten_upsample_bicubic2d_backward( diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 60b9eb4f84..dc8432815a 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -70,6 +70,33 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra yield opinfo_core.SampleInput(t, kwargs={"p": p}) +def sample_inputs_upsample_bicubic2d(op_info, device, dtype, requires_grad, **kwargs): + N, C = 2, 3 + D = 4 + S = 3 + L = 5 + + align_corners_options = (True, False, None) + rank = 2 + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + make_arg = functools.partial(torch_testing.make_tensor, device=device, dtype=dtype, + requires_grad=requires_grad, low=-1, high=1) + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(S, rank, False), True) + + # for align_corners in align_corners_options: + # yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(S, rank, False), None, align_corners) + # yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(L, rank, False), None, align_corners) + # for recompute_scale_factor in [False, True]: + # yield opinfo_core.SampleInput(make_arg(shape(D, rank)), None, 1.7, align_corners, recompute_scale_factor=recompute_scale_factor) + # yield opinfo_core.SampleInput(make_arg(shape(D, rank)), None, 0.6, align_corners, recompute_scale_factor=recompute_scale_factor) + + + def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs): del op_info # input_shape, output_size, kernal, dilation, padding, stride @@ -1497,6 +1524,13 @@ def __init__(self): sample_inputs_func=sample_inputs_bernoulli_p_deterministic, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.upsample_bicubic2d", + aten_name="upsample_bicubic2d", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_bicubic2d, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.col2im", aten_name="col2im", diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 2d5b1ab720..91d8886619 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -2080,6 +2080,11 @@ def _where_input_wrangler( input_wrangler=_upsample_bilinear2d_input_wrangler, trace_only=True, ), + TorchLibOpInfo( + "ops.aten.upsample_bicubic2d", + nn_ops.aten_upsample_bicubic2d, + trace_only=True, + ), TorchLibOpInfo( "nn.functional.upsample_nearest2d", nn_ops.aten_upsample_nearest2d, From 2950096cd25c5540189085e1cf9dc145bd40a47f Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Fri, 8 Dec 2023 17:24:05 +0800 Subject: [PATCH 02/17] Update extra_opinfo.py --- .../function_libs/torch_lib/extra_opinfo.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index dc8432815a..86bc546b95 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -84,16 +84,22 @@ def shape(size, rank, with_batch_channel=True): return tuple([N, C] + ([size] * rank)) return tuple([size] * rank) - make_arg = functools.partial(torch_testing.make_tensor, device=device, dtype=dtype, - requires_grad=requires_grad, low=-1, high=1) - yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(S, rank, False), True) - - # for align_corners in align_corners_options: - # yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(S, rank, False), None, align_corners) - # yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(L, rank, False), None, align_corners) - # for recompute_scale_factor in [False, True]: - # yield opinfo_core.SampleInput(make_arg(shape(D, rank)), None, 1.7, align_corners, recompute_scale_factor=recompute_scale_factor) - # yield opinfo_core.SampleInput(make_arg(shape(D, rank)), None, 0.6, align_corners, recompute_scale_factor=recompute_scale_factor) + make_arg = functools.partial( + torch_testing.make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + low=-1, + high=1 + ) + #yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(S, rank, False), True) + + for align_corners in align_corners_options: + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(S, rank, False), align_corners) + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(L, rank, False), align_corners) + for recompute_scale_factor in [False, True]: + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), None, 1.7, align_corners, recompute_scale_factor=recompute_scale_factor) + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), None, 0.6, align_corners, recompute_scale_factor=recompute_scale_factor) From cf8ecba0c657684d79a5e344749b052f390eee7a Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Fri, 8 Dec 2023 17:27:16 +0800 Subject: [PATCH 03/17] Update nn.py --- onnxscript/function_libs/torch_lib/ops/nn.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index d7068ee020..b39151bdb9 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2203,18 +2203,19 @@ def aten_upsample_bicubic2d( """upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" if output_size is not None: - result = _aten_upsample_bicubic2d_output_size(self, output_size) + result = _aten_upsample_output_size(self, output_size, "cubic") else: assert scales_h is not None assert scales_h == scales_w - result = _aten_upsample_bicubic2d_scales(self, scales_h, scales_w) + result = _aten_upsample_scales(self, scales_h, scales_w, "cubic") return result @torch_op("aten::upsample_bicubic2d", private=True) -def _aten_upsample_bicubic2d_output_size( +def _aten_upsample_output_size( self: TReal, output_size: INT64, + str_mode: str, ) -> TReal: self_shape = op.Shape(self) starts = op.Constant(value_ints=[0]) @@ -2226,16 +2227,17 @@ def _aten_upsample_bicubic2d_output_size( None, None, output_size, - mode="cubic", + mode=str_mode, coordinate_transformation_mode="align_corners", ) @torch_op("aten::upsample_bicubic2d", private=True) -def _aten_upsample_bicubic2d_scales( +def _aten_upsample_scales( self: TReal, scales_h: float, scales_w: float, + str_mode: str, ) -> TReal: neg_1 = op.Constant(value_ints=[-1]) scales = op.Concat( @@ -2249,7 +2251,7 @@ def _aten_upsample_bicubic2d_scales( None, scales, # format should be: [1.0, 1.0, scale_h, scale_w] None, - mode="cubic", + mode=str_mode, coordinate_transformation_mode="align_corners", ) From 4d2b17e6636003e41f5b8f80092a3d3d2eb9161e Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Fri, 8 Dec 2023 17:34:24 +0800 Subject: [PATCH 04/17] Update extra_opinfo.py --- .../function_libs/torch_lib/extra_opinfo.py | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 86bc546b95..60cf12b6c6 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -71,6 +71,9 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra def sample_inputs_upsample_bicubic2d(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + N, C = 2, 3 D = 4 S = 3 @@ -92,14 +95,31 @@ def shape(size, rank, with_batch_channel=True): low=-1, high=1 ) - #yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(S, rank, False), True) + + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(S, rank, False), True) for align_corners in align_corners_options: - yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(S, rank, False), align_corners) - yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(L, rank, False), align_corners) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), shape(S, rank, False), align_corners + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), shape(L, rank, False), align_corners + ) for recompute_scale_factor in [False, True]: - yield opinfo_core.SampleInput(make_arg(shape(D, rank)), None, 1.7, align_corners, recompute_scale_factor=recompute_scale_factor) - yield opinfo_core.SampleInput(make_arg(shape(D, rank)), None, 0.6, align_corners, recompute_scale_factor=recompute_scale_factor) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + None, + 1.7, + align_corners, + recompute_scale_factor=recompute_scale_factor + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + None, + 0.6, + align_corners, + recompute_scale_factor=recompute_scale_factor + ) From 0831914ca7c647feffd8b111ed8f54a744a718c0 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Fri, 8 Dec 2023 17:42:02 +0800 Subject: [PATCH 05/17] Update extra_opinfo.py --- .../tests/function_libs/torch_lib/extra_opinfo.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 60cf12b6c6..099e079ee9 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -76,7 +76,7 @@ def sample_inputs_upsample_bicubic2d(op_info, device, dtype, requires_grad, **kw N, C = 2, 3 D = 4 - S = 3 + SS = 3 L = 5 align_corners_options = (True, False, None) @@ -93,10 +93,10 @@ def shape(size, rank, with_batch_channel=True): dtype=dtype, requires_grad=requires_grad, low=-1, - high=1 + high=1, ) - yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(S, rank, False), True) + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True) for align_corners in align_corners_options: yield opinfo_core.SampleInput( @@ -111,14 +111,14 @@ def shape(size, rank, with_batch_channel=True): None, 1.7, align_corners, - recompute_scale_factor=recompute_scale_factor + recompute_scale_factor=recompute_scale_factor, ) yield opinfo_core.SampleInput( make_arg(shape(D, rank)), None, 0.6, align_corners, - recompute_scale_factor=recompute_scale_factor + recompute_scale_factor=recompute_scale_factor, ) From 2d6f55c455ada25e1d76432d94048c04cc836fc3 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Fri, 8 Dec 2023 17:55:00 +0800 Subject: [PATCH 06/17] Update extra_opinfo.py --- onnxscript/tests/function_libs/torch_lib/extra_opinfo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 099e079ee9..60d2dbc597 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -122,7 +122,6 @@ def shape(size, rank, with_batch_channel=True): ) - def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs): del op_info # input_shape, output_size, kernal, dilation, padding, stride From 9083b65513a04e48d57d492eb270330413691c48 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Wed, 20 Dec 2023 19:07:45 +0800 Subject: [PATCH 07/17] update --- onnxscript/function_libs/torch_lib/ops/nn.py | 61 ++++++++++++++----- .../function_libs/torch_lib/extra_opinfo.py | 31 +++++----- .../tests/function_libs/torch_lib/ops_test.py | 1 + 3 files changed, 61 insertions(+), 32 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index b39151bdb9..5b66f05311 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2197,17 +2197,17 @@ def aten_upsample_bicubic2d( self: TensorType, output_size: INT64, align_corners: bool, - scales_h: Optional[float] = None, - scales_w: Optional[float] = None, + scales: FLOAT = None, ) -> TensorType: + + """upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor""" """upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" + """upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)""" if output_size is not None: - result = _aten_upsample_output_size(self, output_size, "cubic") + result = _aten_upsample_output_size(self, output_size, align_corners, "cubic") else: - assert scales_h is not None - assert scales_h == scales_w - result = _aten_upsample_scales(self, scales_h, scales_w, "cubic") + result = _aten_upsample_scales(self, scales[0], scales[1], align_corners, "cubic") return result @@ -2215,6 +2215,7 @@ def aten_upsample_bicubic2d( def _aten_upsample_output_size( self: TReal, output_size: INT64, + align_corners: bool, str_mode: str, ) -> TReal: self_shape = op.Shape(self) @@ -2222,7 +2223,24 @@ def _aten_upsample_output_size( ends = op.Constant(value_ints=[2]) batch_channel = op.Slice(self_shape, starts, ends) output_size = op.Concat(batch_channel, output_size, axis=0) - return op.Resize( + # if align_corners: + # result = op.Resize( + # self, + # None, + # None, + # output_size, + # mode=str_mode, + # coordinate_transformation_mode="align_corners", + # ) + # else: + # result = op.Resize( + # self, + # None, + # None, + # output_size, + # mode=str_mode, + # ) + result = op.Resize( self, None, None, @@ -2231,12 +2249,15 @@ def _aten_upsample_output_size( coordinate_transformation_mode="align_corners", ) + return result + @torch_op("aten::upsample_bicubic2d", private=True) def _aten_upsample_scales( self: TReal, scales_h: float, scales_w: float, + align_corners: bool, str_mode: str, ) -> TReal: neg_1 = op.Constant(value_ints=[-1]) @@ -2246,14 +2267,24 @@ def _aten_upsample_scales( op.Reshape(op.Constant(value_float=scales_w), neg_1), axis=0, ) - return op.Resize( - self, - None, - scales, # format should be: [1.0, 1.0, scale_h, scale_w] - None, - mode=str_mode, - coordinate_transformation_mode="align_corners", - ) + if align_corners: + result = op.Resize( + self, + None, + scales, # format should be: [1.0, 1.0, scale_h, scale_w] + None, + mode=str_mode, + coordinate_transformation_mode="align_corners", + ) + else: + result = op.Resize( + self, + None, + scales, # format should be: [1.0, 1.0, scale_h, scale_w] + None, + mode=str_mode, + ) + return result def aten_upsample_bicubic2d_backward( diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 60d2dbc597..421e6f67ae 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -79,7 +79,7 @@ def sample_inputs_upsample_bicubic2d(op_info, device, dtype, requires_grad, **kw SS = 3 L = 5 - align_corners_options = (True, False, None) + align_corners_options = (True, False) rank = 2 def shape(size, rank, with_batch_channel=True): @@ -103,23 +103,20 @@ def shape(size, rank, with_batch_channel=True): make_arg(shape(D, rank)), shape(S, rank, False), align_corners ) yield opinfo_core.SampleInput( - make_arg(shape(D, rank)), shape(L, rank, False), align_corners + make_arg(shape(D, rank)), shape(L, rank, False), align_corners, + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + None, # output_size + align_corners, + [1.7, 1.7], # scaler + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + None, # if this is None, the scalar must be list + align_corners, + [0.6, 0.6], ) - for recompute_scale_factor in [False, True]: - yield opinfo_core.SampleInput( - make_arg(shape(D, rank)), - None, - 1.7, - align_corners, - recompute_scale_factor=recompute_scale_factor, - ) - yield opinfo_core.SampleInput( - make_arg(shape(D, rank)), - None, - 0.6, - align_corners, - recompute_scale_factor=recompute_scale_factor, - ) def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs): diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test.py b/onnxscript/tests/function_libs/torch_lib/ops_test.py index 9cae237c80..712f628a43 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test.py @@ -199,6 +199,7 @@ def run_test_output_match( ), kwargs=repr(cpu_sample.kwargs), ): + if i != 0: continue test_behavior, reason = _should_skip_xfail_test_sample(op.name, cpu_sample, dtype) with ops_test_common.normal_xfail_skip_test_behaviors(test_behavior, reason): From 06ae926bef22dd7b1e3ffb6d5ae630ebf9244f71 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Wed, 20 Dec 2023 19:31:17 +0800 Subject: [PATCH 08/17] update --- onnxscript/function_libs/torch_lib/ops/nn.py | 44 ++++++++----------- .../tests/function_libs/torch_lib/ops_test.py | 1 - 2 files changed, 19 insertions(+), 26 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 30006fe250..4fd6af7263 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2228,31 +2228,24 @@ def _aten_upsample_output_size( ends = op.Constant(value_ints=[2]) batch_channel = op.Slice(self_shape, starts, ends) output_size = op.Concat(batch_channel, output_size, axis=0) - # if align_corners: - # result = op.Resize( - # self, - # None, - # None, - # output_size, - # mode=str_mode, - # coordinate_transformation_mode="align_corners", - # ) - # else: - # result = op.Resize( - # self, - # None, - # None, - # output_size, - # mode=str_mode, - # ) - result = op.Resize( - self, - None, - None, - output_size, - mode=str_mode, - coordinate_transformation_mode="align_corners", - ) + if align_corners: + result = op.Resize( + self, + None, + None, + output_size, + mode=str_mode, + coordinate_transformation_mode="align_corners", + ) + else: + result = op.Resize( + self, + None, + None, + output_size, + mode=str_mode, + coordinate_transformation_mode="pytorch_half_pixel", + ) return result @@ -2288,6 +2281,7 @@ def _aten_upsample_scales( scales, # format should be: [1.0, 1.0, scale_h, scale_w] None, mode=str_mode, + coordinate_transformation_mode="pytorch_half_pixel", ) return result diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test.py b/onnxscript/tests/function_libs/torch_lib/ops_test.py index 712f628a43..9cae237c80 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test.py @@ -199,7 +199,6 @@ def run_test_output_match( ), kwargs=repr(cpu_sample.kwargs), ): - if i != 0: continue test_behavior, reason = _should_skip_xfail_test_sample(op.name, cpu_sample, dtype) with ops_test_common.normal_xfail_skip_test_behaviors(test_behavior, reason): From 0d2eb44988524490bf6ce0902191fc8ee28afa5e Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 21 Dec 2023 17:15:10 +0800 Subject: [PATCH 09/17] update --- onnxscript/function_libs/torch_lib/ops/nn.py | 1 - onnxscript/tests/function_libs/torch_lib/extra_opinfo.py | 8 ++++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 4fd6af7263..bd033c49fe 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2204,7 +2204,6 @@ def aten_upsample_bicubic2d( align_corners: bool, scales: FLOAT = None, ) -> TensorType: - """upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor""" """upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" """upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)""" diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 421e6f67ae..5941a9b984 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -100,10 +100,14 @@ def shape(size, rank, with_batch_channel=True): for align_corners in align_corners_options: yield opinfo_core.SampleInput( - make_arg(shape(D, rank)), shape(S, rank, False), align_corners + make_arg(shape(D, rank)), + shape(S, rank, False), + align_corners ) yield opinfo_core.SampleInput( - make_arg(shape(D, rank)), shape(L, rank, False), align_corners, + make_arg(shape(D, rank)), + shape(L, rank, False), + align_corners, ) yield opinfo_core.SampleInput( make_arg(shape(D, rank)), From 67fe7a9bf8d855959ae1edd2450e583765590e93 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 21 Dec 2023 17:33:11 +0800 Subject: [PATCH 10/17] fix lint --- onnxscript/function_libs/torch_lib/ops/nn.py | 4 ++-- onnxscript/tests/function_libs/torch_lib/extra_opinfo.py | 5 +---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index bd033c49fe..c3c7389592 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2205,8 +2205,8 @@ def aten_upsample_bicubic2d( scales: FLOAT = None, ) -> TensorType: """upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor""" - """upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" - """upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)""" + """upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" # pylint: disable=pointless-string-statement + """upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)""" # pylint: disable=pointless-string-statement if output_size is not None: result = _aten_upsample_output_size(self, output_size, align_corners, "cubic") diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 5941a9b984..f7d2272be3 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -100,10 +100,7 @@ def shape(size, rank, with_batch_channel=True): for align_corners in align_corners_options: yield opinfo_core.SampleInput( - make_arg(shape(D, rank)), - shape(S, rank, False), - align_corners - ) + make_arg(shape(D, rank)), shape(S, rank, False), align_corners) yield opinfo_core.SampleInput( make_arg(shape(D, rank)), shape(L, rank, False), From 30052735890e50914ed3fd7cc6b2fbd083b2e126 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 21 Dec 2023 17:37:46 +0800 Subject: [PATCH 11/17] Update extra_opinfo.py --- onnxscript/tests/function_libs/torch_lib/extra_opinfo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index f7d2272be3..a6b4610bad 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -100,7 +100,8 @@ def shape(size, rank, with_batch_channel=True): for align_corners in align_corners_options: yield opinfo_core.SampleInput( - make_arg(shape(D, rank)), shape(S, rank, False), align_corners) + make_arg(shape(D, rank)), shape(S, rank, False), align_corners + ) yield opinfo_core.SampleInput( make_arg(shape(D, rank)), shape(L, rank, False), From ed0c2efc23a86bf2d8c6cfc7a3cecb74797c2674 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Thu, 21 Dec 2023 18:47:14 +0800 Subject: [PATCH 12/17] update --- onnxscript/function_libs/torch_lib/ops/nn.py | 20 +++++++------------ .../function_libs/torch_lib/extra_opinfo.py | 4 ++-- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index c3c7389592..bacb43378f 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2199,11 +2199,11 @@ def aten_unflatten_dense_tensors( @torch_op("aten::upsample_bicubic2d", trace_only=True) def aten_upsample_bicubic2d( - self: TensorType, + self: TReal, output_size: INT64, align_corners: bool, - scales: FLOAT = None, -) -> TensorType: + scales: TFloat = None, +) -> TReal: """upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor""" """upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" # pylint: disable=pointless-string-statement """upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)""" # pylint: disable=pointless-string-statement @@ -2211,7 +2211,7 @@ def aten_upsample_bicubic2d( if output_size is not None: result = _aten_upsample_output_size(self, output_size, align_corners, "cubic") else: - result = _aten_upsample_scales(self, scales[0], scales[1], align_corners, "cubic") + result = _aten_upsample_scales(self, scales, align_corners, "cubic") return result @@ -2252,18 +2252,12 @@ def _aten_upsample_output_size( @torch_op("aten::upsample_bicubic2d", private=True) def _aten_upsample_scales( self: TReal, - scales_h: float, - scales_w: float, + scales: TFloat, align_corners: bool, str_mode: str, ) -> TReal: - neg_1 = op.Constant(value_ints=[-1]) - scales = op.Concat( - op.Constant(value_floats=[1.0, 1.0]), - op.Reshape(op.Constant(value_float=scales_h), neg_1), - op.Reshape(op.Constant(value_float=scales_w), neg_1), - axis=0, - ) + scales = op.Cast(scales, to=FLOAT.dtype) + scales = op.Concat(op.Constant(value_floats=[1.0, 1.0]), scales, axis=0) if align_corners: result = op.Resize( self, diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index a6b4610bad..1d698c8ef9 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -111,13 +111,13 @@ def shape(size, rank, with_batch_channel=True): make_arg(shape(D, rank)), None, # output_size align_corners, - [1.7, 1.7], # scaler + (1.7, 1.7), # scaler ) yield opinfo_core.SampleInput( make_arg(shape(D, rank)), None, # if this is None, the scalar must be list align_corners, - [0.6, 0.6], + (0.6, 0.6), ) From e547331db19c72e1cb474c15f199596e3f0f2591 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Fri, 22 Dec 2023 14:56:20 +0800 Subject: [PATCH 13/17] Update nn.py --- onnxscript/function_libs/torch_lib/ops/nn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index bacb43378f..274e3ad43e 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2206,7 +2206,6 @@ def aten_upsample_bicubic2d( ) -> TReal: """upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor""" """upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" # pylint: disable=pointless-string-statement - """upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)""" # pylint: disable=pointless-string-statement if output_size is not None: result = _aten_upsample_output_size(self, output_size, align_corners, "cubic") From 6e231e5e6d57c1d7dfa0d11dca09e3285c984038 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Fri, 22 Dec 2023 16:40:15 +0800 Subject: [PATCH 14/17] Update extra_opinfo.py --- .../function_libs/torch_lib/extra_opinfo.py | 116 +++++++++--------- 1 file changed, 58 insertions(+), 58 deletions(-) diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 1d698c8ef9..fc808a5ecc 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -70,57 +70,6 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra yield opinfo_core.SampleInput(t, kwargs={"p": p}) -def sample_inputs_upsample_bicubic2d(op_info, device, dtype, requires_grad, **kwargs): - del op_info - del kwargs - - N, C = 2, 3 - D = 4 - SS = 3 - L = 5 - - align_corners_options = (True, False) - rank = 2 - - def shape(size, rank, with_batch_channel=True): - if with_batch_channel: - return tuple([N, C] + ([size] * rank)) - return tuple([size] * rank) - - make_arg = functools.partial( - torch_testing.make_tensor, - device=device, - dtype=dtype, - requires_grad=requires_grad, - low=-1, - high=1, - ) - - yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True) - - for align_corners in align_corners_options: - yield opinfo_core.SampleInput( - make_arg(shape(D, rank)), shape(S, rank, False), align_corners - ) - yield opinfo_core.SampleInput( - make_arg(shape(D, rank)), - shape(L, rank, False), - align_corners, - ) - yield opinfo_core.SampleInput( - make_arg(shape(D, rank)), - None, # output_size - align_corners, - (1.7, 1.7), # scaler - ) - yield opinfo_core.SampleInput( - make_arg(shape(D, rank)), - None, # if this is None, the scalar must be list - align_corners, - (0.6, 0.6), - ) - - def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs): del op_info # input_shape, output_size, kernal, dilation, padding, stride @@ -1446,6 +1395,57 @@ def sample_inputs_unfold(op_info, device, dtype, requires_grad, **kwargs): yield opinfo_core.SampleInput(t, args=(dimension, size, step)) +def sample_inputs_upsample_bicubic2d(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + N, C = 2, 3 + D = 4 + SS = 3 + L = 5 + + align_corners_options = (True, False) + rank = 2 + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + make_arg = functools.partial( + torch_testing.make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + low=-1, + high=1, + ) + + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True) + + for align_corners in align_corners_options: + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), shape(S, rank, False), align_corners + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + shape(L, rank, False), + align_corners, + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + None, # output_size + align_corners, + (1.7, 1.7), # scaler + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + None, # if this is None, the scalar must be list + align_corners, + (0.6, 0.6), + ) + + class _TestParamsMaxPoolEmptyStrideBase: # Adapted from https://github.com/pytorch/pytorch/blob/d6d55f8590eab05d2536756fb4efcfb2d07eb81a/torch/testing/_internal/common_methods_invocations.py#L3203 def __init__(self): @@ -1548,13 +1548,6 @@ def __init__(self): sample_inputs_func=sample_inputs_bernoulli_p_deterministic, supports_out=False, ), - opinfo_core.OpInfo( - "ops.aten.upsample_bicubic2d", - aten_name="upsample_bicubic2d", - dtypes=common_dtype.floating_types_and(torch.bfloat16), - sample_inputs_func=sample_inputs_upsample_bicubic2d, - supports_out=False, - ), opinfo_core.OpInfo( "ops.aten.col2im", aten_name="col2im", @@ -1918,6 +1911,13 @@ def __init__(self): sample_inputs_func=sample_inputs_unfold, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.upsample_bicubic2d", + aten_name="upsample_bicubic2d", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_bicubic2d, + supports_out=False, + ), opinfo_core.OpInfo( "nn.functional.max_pool1d_with_indices", aten_name="max_pool1d_with_indices", From 4858062389acbcc21ff358514cb81d19bb270737 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Tue, 2 Jan 2024 14:08:25 +0800 Subject: [PATCH 15/17] Update nn.py --- onnxscript/function_libs/torch_lib/ops/nn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 274e3ad43e..e183180827 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2197,15 +2197,15 @@ def aten_unflatten_dense_tensors( raise NotImplementedError() -@torch_op("aten::upsample_bicubic2d", trace_only=True) +@torch_op(("aten::upsample_bicubic2d", "aten::upsample_bicubic2d.vec"), trace_only=True) def aten_upsample_bicubic2d( self: TReal, output_size: INT64, align_corners: bool, - scales: TFloat = None, + scales_factors: Optional[TFloat] = None, ) -> TReal: - """upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor""" - """upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" # pylint: disable=pointless-string-statement + """upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" if output_size is not None: result = _aten_upsample_output_size(self, output_size, align_corners, "cubic") From ca6931ef952809d05827ac45e6995fa04e7f3670 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Tue, 2 Jan 2024 14:16:17 +0800 Subject: [PATCH 16/17] Update nn.py --- onnxscript/function_libs/torch_lib/ops/nn.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index e183180827..5ebb6101e7 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2202,7 +2202,7 @@ def aten_upsample_bicubic2d( self: TReal, output_size: INT64, align_corners: bool, - scales_factors: Optional[TFloat] = None, + scale_factors: Optional[TFloat] = None, ) -> TReal: """upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" @@ -2210,7 +2210,7 @@ def aten_upsample_bicubic2d( if output_size is not None: result = _aten_upsample_output_size(self, output_size, align_corners, "cubic") else: - result = _aten_upsample_scales(self, scales, align_corners, "cubic") + result = _aten_upsample_scales(self, scale_factors, align_corners, "cubic") return result @@ -2251,17 +2251,17 @@ def _aten_upsample_output_size( @torch_op("aten::upsample_bicubic2d", private=True) def _aten_upsample_scales( self: TReal, - scales: TFloat, + scale_factors: TFloat, align_corners: bool, str_mode: str, ) -> TReal: - scales = op.Cast(scales, to=FLOAT.dtype) - scales = op.Concat(op.Constant(value_floats=[1.0, 1.0]), scales, axis=0) + scale_factors = op.Cast(scale_factors, to=FLOAT.dtype) + scale_factors = op.Concat(op.Constant(value_floats=[1.0, 1.0]), scale_factors, axis=0) if align_corners: result = op.Resize( self, None, - scales, # format should be: [1.0, 1.0, scale_h, scale_w] + scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w] None, mode=str_mode, coordinate_transformation_mode="align_corners", @@ -2270,7 +2270,7 @@ def _aten_upsample_scales( result = op.Resize( self, None, - scales, # format should be: [1.0, 1.0, scale_h, scale_w] + scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w] None, mode=str_mode, coordinate_transformation_mode="pytorch_half_pixel", From 82220496aef667aeb0cc892c958a26dd717f82e1 Mon Sep 17 00:00:00 2001 From: xiaowuhu Date: Tue, 2 Jan 2024 14:34:10 +0800 Subject: [PATCH 17/17] Update nn.py --- onnxscript/function_libs/torch_lib/ops/nn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 5ebb6101e7..bb767071e7 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2205,7 +2205,8 @@ def aten_upsample_bicubic2d( scale_factors: Optional[TFloat] = None, ) -> TReal: """upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor - upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" + upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + """ if output_size is not None: result = _aten_upsample_output_size(self, output_size, align_corners, "cubic")