From 6b232d903407d91d111180e6069694705f789565 Mon Sep 17 00:00:00 2001 From: Ali Rahbar Date: Sun, 7 Sep 2025 20:12:28 -0400 Subject: [PATCH 01/10] added fixed function logic --- .../function_libs/torch_lib/ops/core.py | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ab992e0580..e98e3c3f12 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3101,27 +3101,44 @@ def aten_embedding_bag_padding_idx( sparse: bool = False, per_sample_weights: Optional[TFloat] = None, include_last_offset: bool = False, - padding_idx: int = -1, + padding_idx: Optional[int] = None, ) -> Tuple[TFloat, TFloat, TFloat, TFloat]: """embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor) We add default values for the attributes to accommodate _embedding_bag as well: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) """ - assert padding_idx is not None, ( - "padding_idx must not be None. This is likely a dispatcher error" - ) if per_sample_weights is None: per_sample_weights = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(indices)) per_sample_weights = op.CastLike(per_sample_weights, weight) - # Change padding_idx to positive value, -1 means the last index - if padding_idx < 0: - padding_idx = weight.shape[0] + padding_idx + if padding_idx is not None: + # Call the existing function for handling padding_idx + result, offset2bag, bag_size, max_indices =_aten_embedding_bag_1d_padding_idx_onnx( + weight, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, + ) - result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx( - weight, indices, offsets, mode, per_sample_weights, include_last_offset, padding_idx + return result, offset2bag, bag_size, max_indices + + # When padding_idx is None, use the standard embedding_bag implementation + result, offset2bag, bag_size, max_indices = _aten_embedding_bag_onnx( + weight, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, ) return result, offset2bag, bag_size, max_indices From 40f487bcb5ce0a7f158119623b1b64ea8729ec26 Mon Sep 17 00:00:00 2001 From: Ali Rahbar Date: Mon, 8 Sep 2025 14:12:28 -0400 Subject: [PATCH 02/10] added test cases for aten_embedding_bag_padding_idx --- tests/function_libs/torch_lib/extra_opinfo.py | 38 +++++++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 41 +++++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index ca80cf5172..4e607ff36c 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -2210,6 +2210,44 @@ def __init__(self): sample_inputs_func=sample_inputs_embedding_bag_padding_idx, supports_out=False, ), + opinfo_core.OpInfo( + "test_embedding_bag_with_padding_idx_none", + op=torch.nn.functional.embedding_bag, + dtypes=(torch.float32,), + sample_inputs_func=lambda op_info, device, dtype, requires_grad: [ + opinfo_core.SampleInput( + torch.tensor( + [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]], + dtype=dtype, + device=device, + ), + args=( + torch.tensor([0, 1, 2, 3], dtype=torch.int64, device=device), + torch.tensor([0, 2], dtype=torch.int64, device=device), + ), + kwargs={"padding_idx": None}, + ) + ], + ), + opinfo_core.OpInfo( + "test_embedding_bag_with_padding_idx_int", + op=torch.nn.functional.embedding_bag, + dtypes=(torch.float32,), + sample_inputs_func=lambda op_info, device, dtype, requires_grad: [ + opinfo_core.SampleInput( + torch.tensor( + [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]], + dtype=dtype, + device=device, + ), + args=( + torch.tensor([0, 1, 2], dtype=torch.int64, device=device), + torch.tensor([0, 2], dtype=torch.int64, device=device), + ), + kwargs={"padding_idx": 0}, + ) + ], + ), opinfo_core.OpInfo( "ops.aten.embedding_renorm", aten_name="embedding_renorm", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 7af7413185..2b06003162 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -185,6 +185,24 @@ def xfail( # Modify this section ########################################################## +def _embedding_bag_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + # ONNX attributes cannot be None; omit padding_idx if it’s None. + padding_idx = kwargs.pop("padding_idx", "___MISSING___") + if padding_idx is not "___MISSING___": + if padding_idx is not None: + kwargs["padding_idx"] = int(padding_idx) + + # Ensure indices/offsets are int64 (positional: weight, indices, offsets, ...) + if len(args) >= 3: + if isinstance(args[1], torch.Tensor): + args[1] = args[1].to(torch.long) + if isinstance(args[2], torch.Tensor): + args[2] = args[2].to(torch.long) + + return args, kwargs + def _amin_amax_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -1035,15 +1053,38 @@ def _where_input_wrangler( core_ops.aten_embedding_bag, tolerance={torch.float32: (1e-4, 5e-4)}, compare_shape_only_for_output=(1, 2, 3), + input_wrangler=_embedding_bag_input_wrangler, ).skip( dtypes=(torch.float16,), reason="fixme: results mismatch in torch nightly.", ), + TorchLibOpInfo( + "test_embedding_bag_with_padding_idx_none", + core_ops.aten_embedding_bag, + input_wrangler=_embedding_bag_input_wrangler, + ), + TorchLibOpInfo( + "test_embedding_bag_with_padding_idx_int", + core_ops.aten_embedding_bag, + input_wrangler=_embedding_bag_input_wrangler, + ), + TorchLibOpInfo( "ops.aten.embedding_bag.padding_idx", core_ops.aten_embedding_bag_padding_idx, tolerance={torch.float16: (1e-2, 1e-2)}, compare_shape_only_for_output=(1, 2, 3), + input_wrangler=_embedding_bag_input_wrangler, + ), + TorchLibOpInfo( + "test_embedding_bag_with_padding_idx_none", + core_ops.aten_embedding_bag, + input_wrangler=_embedding_bag_input_wrangler, + ), + TorchLibOpInfo( + "test_embedding_bag_with_padding_idx_int", + core_ops.aten_embedding_bag, + input_wrangler=_embedding_bag_input_wrangler, ), TorchLibOpInfo( "ops.aten.embedding_renorm", From 294eca3e2d4691c552ff14c7dcabe812bdb8139d Mon Sep 17 00:00:00 2001 From: Ali Rahbar Date: Mon, 8 Sep 2025 23:28:50 -0400 Subject: [PATCH 03/10] fix: resolve lint warnings and comparison issue --- onnxscript/function_libs/torch_lib/ops/core.py | 4 ---- tests/function_libs/torch_lib/ops_test_data.py | 4 ++-- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index e98e3c3f12..589a5f4ba1 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3119,9 +3119,7 @@ def aten_embedding_bag_padding_idx( weight, indices, offsets, - scale_grad_by_freq, mode, - sparse, per_sample_weights, include_last_offset, padding_idx, @@ -3134,9 +3132,7 @@ def aten_embedding_bag_padding_idx( weight, indices, offsets, - scale_grad_by_freq, mode, - sparse, per_sample_weights, include_last_offset, ) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 2b06003162..6612c7c72a 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -188,9 +188,9 @@ def xfail( def _embedding_bag_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: - # ONNX attributes cannot be None; omit padding_idx if it’s None. + # ONNX attributes cannot be None; omit padding_idx if it's None. padding_idx = kwargs.pop("padding_idx", "___MISSING___") - if padding_idx is not "___MISSING___": + if padding_idx != "___MISSING___": if padding_idx is not None: kwargs["padding_idx"] = int(padding_idx) From 71920357875e3904fe372452572a34bc2cde9201 Mon Sep 17 00:00:00 2001 From: Ali Rahbar Date: Tue, 9 Sep 2025 14:04:10 -0400 Subject: [PATCH 04/10] fix: clean up _embedding_bag_input_wrangler padding_idx check --- tests/function_libs/torch_lib/ops_test_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 6612c7c72a..7ec639fad0 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -189,8 +189,8 @@ def _embedding_bag_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: # ONNX attributes cannot be None; omit padding_idx if it's None. - padding_idx = kwargs.pop("padding_idx", "___MISSING___") - if padding_idx != "___MISSING___": + if "padding_idx" in kwargs: + padding_idx = kwargs.pop("padding_idx") if padding_idx is not None: kwargs["padding_idx"] = int(padding_idx) From 6e41cfe796809770469b27c8269f287baadd3c9d Mon Sep 17 00:00:00 2001 From: Ali Rahbar Date: Sat, 13 Sep 2025 20:14:16 -0400 Subject: [PATCH 05/10] fix: fixed bugs and issues with test cases --- tests/function_libs/torch_lib/extra_opinfo.py | 8 ++++---- tests/function_libs/torch_lib/ops_test_data.py | 14 ++------------ 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 4e607ff36c..3d81896187 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -2211,9 +2211,9 @@ def __init__(self): supports_out=False, ), opinfo_core.OpInfo( - "test_embedding_bag_with_padding_idx_none", + "ops.aten.embedding_bag.padding_idx_none", op=torch.nn.functional.embedding_bag, - dtypes=(torch.float32,), + dtypes=common_dtype.floating_types_and_half(), sample_inputs_func=lambda op_info, device, dtype, requires_grad: [ opinfo_core.SampleInput( torch.tensor( @@ -2230,9 +2230,9 @@ def __init__(self): ], ), opinfo_core.OpInfo( - "test_embedding_bag_with_padding_idx_int", + "ops.aten.embedding_bag.padding_idx_int", op=torch.nn.functional.embedding_bag, - dtypes=(torch.float32,), + dtypes=common_dtype.floating_types_and_half(), sample_inputs_func=lambda op_info, device, dtype, requires_grad: [ opinfo_core.SampleInput( torch.tensor( diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 7ec639fad0..6c560ee126 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1059,12 +1059,12 @@ def _where_input_wrangler( reason="fixme: results mismatch in torch nightly.", ), TorchLibOpInfo( - "test_embedding_bag_with_padding_idx_none", + "ops.aten.embedding_bag.padding_idx_none", core_ops.aten_embedding_bag, input_wrangler=_embedding_bag_input_wrangler, ), TorchLibOpInfo( - "test_embedding_bag_with_padding_idx_int", + "ops.aten.embedding_bag.padding_idx_int", core_ops.aten_embedding_bag, input_wrangler=_embedding_bag_input_wrangler, ), @@ -1076,16 +1076,6 @@ def _where_input_wrangler( compare_shape_only_for_output=(1, 2, 3), input_wrangler=_embedding_bag_input_wrangler, ), - TorchLibOpInfo( - "test_embedding_bag_with_padding_idx_none", - core_ops.aten_embedding_bag, - input_wrangler=_embedding_bag_input_wrangler, - ), - TorchLibOpInfo( - "test_embedding_bag_with_padding_idx_int", - core_ops.aten_embedding_bag, - input_wrangler=_embedding_bag_input_wrangler, - ), TorchLibOpInfo( "ops.aten.embedding_renorm", core_ops.aten_embedding_renorm, From d2da96ecf5c7639d8162612d2f07c89d646084a4 Mon Sep 17 00:00:00 2001 From: Ali Rahbar Date: Sat, 13 Sep 2025 20:24:50 -0400 Subject: [PATCH 06/10] style: fixed linting issues in code --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- tests/function_libs/torch_lib/ops_test_data.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 589a5f4ba1..6eb9fb4cbb 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3115,7 +3115,7 @@ def aten_embedding_bag_padding_idx( if padding_idx is not None: # Call the existing function for handling padding_idx - result, offset2bag, bag_size, max_indices =_aten_embedding_bag_1d_padding_idx_onnx( + result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx( weight, indices, offsets, diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 6c560ee126..183b23cc4c 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -203,6 +203,7 @@ def _embedding_bag_input_wrangler( return args, kwargs + def _amin_amax_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -1068,7 +1069,6 @@ def _where_input_wrangler( core_ops.aten_embedding_bag, input_wrangler=_embedding_bag_input_wrangler, ), - TorchLibOpInfo( "ops.aten.embedding_bag.padding_idx", core_ops.aten_embedding_bag_padding_idx, From 5f6156bcf06988487cad26f9dcd58c31261a4705 Mon Sep 17 00:00:00 2001 From: Ali Rahbar Date: Sun, 14 Dec 2025 17:10:21 -0500 Subject: [PATCH 07/10] fix: resolve embedding_bag untyped ONNX outputs with explicit casts --- .../function_libs/torch_lib/ops/core.py | 62 +++++++++++-------- 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 64905496ec..f6cfe8ec57 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2861,7 +2861,8 @@ def aten_embedding_bag( sparse: bool = False, per_sample_weights: Optional[TFloat] = None, include_last_offset: bool = False, -) -> Tuple[TFloat, TFloat, TFloat, TFloat]: + padding_idx: Optional[int] = None, +) -> Tuple[TFloat, INT64, INT64, INT64]: """embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)""" # assert(rank(indices) in [1,2]) @@ -2889,7 +2890,7 @@ def _aten_embedding_bag_onnx( mode: int, per_sample_weights: TFloat, include_last_offset: bool, -) -> Tuple[TFloat, TFloat, TFloat, TFloat]: +) -> Tuple[TFloat, INT64, INT64, INT64]: neg_1 = op.Constant(value_ints=[-1]) # Assume indices is shape(5,2), indices_1d is shape(10,) indices_1d = op.Reshape(indices, neg_1) @@ -2957,23 +2958,24 @@ def _aten_embedding_bag_onnx( # Only compute the shape of other 3 outputs, we don't care the value if mode == 0: # sum - offset2bag = op.Shape(indices, start=0, end=0) # Generate empty tensor + offset2bag = op.Cast(op.Shape(indices, start=0, end=0), to=INT64.dtype) if op.Equal(include_last_offset, True): - bag_size = op.Expand(0, op.Shape(offsets)) + bag_size = op.Cast(op.Expand(0, op.Shape(offsets)), to=INT64.dtype) + max_indices = op.Cast(op.Expand(0, op.Shape(offsets)), to=INT64.dtype) else: - bag_size = op.Expand(0, op.Shape(offsets) - 1) - max_indices = op.Expand(0, op.Shape(bag_size)) + bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype) + max_indices = op.Cast(op.Expand(0, op.Shape(offsets) -1), to=INT64.dtype) elif mode == 1: # mean - offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1)) - bag_size = op.Expand(0, op.Shape(offsets) - 1) - max_indices = op.Expand(0, op.Shape(bag_size)) + offset2bag = op.Cast(op.Expand(0, op.Shape(indices, start=0, end=1)), to=INT64.dtype) + bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype) + max_indices = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype) else: # max - offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1)) - bag_size = op.Expand(0, op.Shape(offsets) - 1) + offset2bag = op.Cast(op.Expand(0, op.Shape(indices, start=0, end=1)), to=INT64.dtype) + bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype) # shape = (bag_size.dim[0], weight.dim[1]) dim_0 = op.Shape(bag_size, start=0, end=1) dim_1 = op.Shape(weight, start=1, end=2) - max_indices = op.Expand(0, op.Concat(dim_0, dim_1, axis=0)) + max_indices = op.Cast(op.Expand(0, op.Concat(dim_0, dim_1, axis=0)), to=INT64.dtype) return result, offset2bag, bag_size, max_indices @@ -2996,7 +2998,7 @@ def aten_embedding_bag_padding_idx( per_sample_weights: Optional[TFloat] = None, include_last_offset: bool = False, padding_idx: Optional[int] = None, -) -> Tuple[TFloat, TFloat, TFloat, TFloat]: +) -> Tuple[TFloat, INT64, INT64, INT64]: """embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor) We add default values for the attributes to accommodate _embedding_bag as well: @@ -3043,8 +3045,14 @@ def _aten_embedding_bag_1d_padding_idx_onnx( per_sample_weights: TFloat, include_last_offset: bool, padding_idx: int, -) -> Tuple[TFloat, TFloat, TFloat, TFloat]: +) -> Tuple[TFloat, INT64, INT64, INT64]: neg_1 = op.Constant(value_ints=[-1]) + + num_embeddings = op.Shape(weight, start=0, end=1) # Get number of rows in weight + num_embeddings_scalar = op.Squeeze(num_embeddings) + if padding_idx < 0: + padding_idx = padding_idx + num_embeddings_scalar + # Get weight out according to indices, # e.g. indices=[3,1,4,5,3] means get weight[[3,1,4,5,3]] indices_weight = op.Gather(weight, indices) @@ -3080,7 +3088,10 @@ def _aten_embedding_bag_1d_padding_idx_onnx( cond_2 = j < end_pos while cond_2: index = op.Gather(indices, j) - if not op.Equal(index, padding_idx): + normalized_index = index + if index < 0: + normalized_index = index + num_embeddings_scalar + if not op.Equal(normalized_index, padding_idx): # Something like the 'append' operation curr_offsets = op.Concat(curr_offsets, op.Reshape(j, neg_1), axis=0) j = j + 1 @@ -3109,23 +3120,24 @@ def _aten_embedding_bag_1d_padding_idx_onnx( result = op.CastLike(result, weight) if mode == 0: # sum - offset2bag = op.Expand(0, op.Shape(indices)) + offset2bag = op.Cast(op.Expand(0, op.Shape(indices)), to=INT64.dtype) if op.Equal(include_last_offset, True): - bag_size = op.Expand(0, op.Shape(offsets)) + bag_size = op.Cast(op.Expand(0, op.Shape(offsets)), to=INT64.dtype) + max_indices = op.Cast(op.Expand(0, op.Shape(offsets)), to=INT64.dtype) else: - bag_size = op.Expand(0, op.Shape(offsets) - 1) - max_indices = op.Expand(0, op.Shape(bag_size)) + bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype) + max_indices = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype) elif mode == 1: # mean - offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1)) - bag_size = op.Expand(0, op.Shape(offsets) - 1) - max_indices = op.Expand(0, op.Shape(bag_size)) + offset2bag = op.Cast(op.Expand(0, op.Shape(indices, start=0, end=1)), to=INT64.dtype) + bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype) + max_indices = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype) else: # mode == 2, max - offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1)) - bag_size = op.Expand(0, op.Shape(offsets) - 1) + offset2bag = op.Cast(op.Expand(0, op.Shape(indices, start=0, end=1)), to=INT64.dtype) + bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype) # shape = (bag_size.dim[0], weight.dim[1]) dim_0 = op.Shape(bag_size, start=0, end=1) dim_1 = op.Shape(weight, start=1, end=2) - max_indices = op.Expand(0, op.Concat(dim_0, dim_1, axis=0)) + max_indices = op.Cast(op.Expand(0, op.Concat(dim_0, dim_1, axis=0)), to=INT64.dtype) return result, offset2bag, bag_size, max_indices From 4789d56f14783e57bf3a63b6b2ec465d257f0a37 Mon Sep 17 00:00:00 2001 From: Ali Rahbar Date: Mon, 15 Dec 2025 18:24:57 -0500 Subject: [PATCH 08/10] fix: removed redundant tests --- tests/function_libs/torch_lib/extra_opinfo.py | 38 ------------------- .../function_libs/torch_lib/ops_test_data.py | 10 ----- 2 files changed, 48 deletions(-) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 6cc7ab6d35..2ce015b363 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -2483,44 +2483,6 @@ def __init__(self): sample_inputs_func=sample_inputs_embedding_bag_padding_idx, supports_out=False, ), - opinfo_core.OpInfo( - "ops.aten.embedding_bag.padding_idx_none", - op=torch.nn.functional.embedding_bag, - dtypes=common_dtype.floating_types_and_half(), - sample_inputs_func=lambda op_info, device, dtype, requires_grad: [ - opinfo_core.SampleInput( - torch.tensor( - [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]], - dtype=dtype, - device=device, - ), - args=( - torch.tensor([0, 1, 2, 3], dtype=torch.int64, device=device), - torch.tensor([0, 2], dtype=torch.int64, device=device), - ), - kwargs={"padding_idx": None}, - ) - ], - ), - opinfo_core.OpInfo( - "ops.aten.embedding_bag.padding_idx_int", - op=torch.nn.functional.embedding_bag, - dtypes=common_dtype.floating_types_and_half(), - sample_inputs_func=lambda op_info, device, dtype, requires_grad: [ - opinfo_core.SampleInput( - torch.tensor( - [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]], - dtype=dtype, - device=device, - ), - args=( - torch.tensor([0, 1, 2], dtype=torch.int64, device=device), - torch.tensor([0, 2], dtype=torch.int64, device=device), - ), - kwargs={"padding_idx": 0}, - ) - ], - ), opinfo_core.OpInfo( "ops.aten.embedding_renorm", aten_name="embedding_renorm", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 01653f74fe..6fdde2086b 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -932,16 +932,6 @@ def _where_input_wrangler( dtypes=(torch.float16,), reason="fixme: results mismatch in torch nightly.", ), - TorchLibOpInfo( - "ops.aten.embedding_bag.padding_idx_none", - core_ops.aten_embedding_bag, - input_wrangler=_embedding_bag_input_wrangler, - ), - TorchLibOpInfo( - "ops.aten.embedding_bag.padding_idx_int", - core_ops.aten_embedding_bag, - input_wrangler=_embedding_bag_input_wrangler, - ), TorchLibOpInfo( "ops.aten.embedding_bag.padding_idx", core_ops.aten_embedding_bag_padding_idx, From 9ba946efec6c65fb1f6da00c1820f20142467157 Mon Sep 17 00:00:00 2001 From: Ali Rahbar Date: Mon, 15 Dec 2025 18:27:45 -0500 Subject: [PATCH 09/10] rollback return types --- onnxscript/function_libs/torch_lib/ops/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index f6cfe8ec57..3aed73705b 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2862,7 +2862,7 @@ def aten_embedding_bag( per_sample_weights: Optional[TFloat] = None, include_last_offset: bool = False, padding_idx: Optional[int] = None, -) -> Tuple[TFloat, INT64, INT64, INT64]: +) -> Tuple[TFloat, TFloat, TFloat, TFloat]: """embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)""" # assert(rank(indices) in [1,2]) @@ -2890,7 +2890,7 @@ def _aten_embedding_bag_onnx( mode: int, per_sample_weights: TFloat, include_last_offset: bool, -) -> Tuple[TFloat, INT64, INT64, INT64]: +) -> Tuple[TFloat, TFloat, TFloat, TFloat]: neg_1 = op.Constant(value_ints=[-1]) # Assume indices is shape(5,2), indices_1d is shape(10,) indices_1d = op.Reshape(indices, neg_1) @@ -2998,7 +2998,7 @@ def aten_embedding_bag_padding_idx( per_sample_weights: Optional[TFloat] = None, include_last_offset: bool = False, padding_idx: Optional[int] = None, -) -> Tuple[TFloat, INT64, INT64, INT64]: +) -> Tuple[TFloat, TFloat, TFloat, TFloat]: """embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor) We add default values for the attributes to accommodate _embedding_bag as well: @@ -3045,7 +3045,7 @@ def _aten_embedding_bag_1d_padding_idx_onnx( per_sample_weights: TFloat, include_last_offset: bool, padding_idx: int, -) -> Tuple[TFloat, INT64, INT64, INT64]: +) -> Tuple[TFloat, TFloat, TFloat, TFloat]: neg_1 = op.Constant(value_ints=[-1]) num_embeddings = op.Shape(weight, start=0, end=1) # Get number of rows in weight From ca6db1aed2ff037b40f3d67d7bc3a5d9caac1952 Mon Sep 17 00:00:00 2001 From: Ali Rahbar Date: Mon, 15 Dec 2025 18:41:58 -0500 Subject: [PATCH 10/10] fixed style issue --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 3aed73705b..48cca8e3b5 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2964,7 +2964,7 @@ def _aten_embedding_bag_onnx( max_indices = op.Cast(op.Expand(0, op.Shape(offsets)), to=INT64.dtype) else: bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype) - max_indices = op.Cast(op.Expand(0, op.Shape(offsets) -1), to=INT64.dtype) + max_indices = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype) elif mode == 1: # mean offset2bag = op.Cast(op.Expand(0, op.Shape(indices, start=0, end=1)), to=INT64.dtype) bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype)