diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 254378bf09..48cca8e3b5 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2861,6 +2861,7 @@ def aten_embedding_bag( sparse: bool = False, per_sample_weights: Optional[TFloat] = None, include_last_offset: bool = False, + padding_idx: Optional[int] = None, ) -> Tuple[TFloat, TFloat, TFloat, TFloat]: """embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)""" @@ -2957,23 +2958,24 @@ def _aten_embedding_bag_onnx( # Only compute the shape of other 3 outputs, we don't care the value if mode == 0: # sum - offset2bag = op.Shape(indices, start=0, end=0) # Generate empty tensor + offset2bag = op.Cast(op.Shape(indices, start=0, end=0), to=INT64.dtype) if op.Equal(include_last_offset, True): - bag_size = op.Expand(0, op.Shape(offsets)) + bag_size = op.Cast(op.Expand(0, op.Shape(offsets)), to=INT64.dtype) + max_indices = op.Cast(op.Expand(0, op.Shape(offsets)), to=INT64.dtype) else: - bag_size = op.Expand(0, op.Shape(offsets) - 1) - max_indices = op.Expand(0, op.Shape(bag_size)) + bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype) + max_indices = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype) elif mode == 1: # mean - offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1)) - bag_size = op.Expand(0, op.Shape(offsets) - 1) - max_indices = op.Expand(0, op.Shape(bag_size)) + offset2bag = op.Cast(op.Expand(0, op.Shape(indices, start=0, end=1)), to=INT64.dtype) + bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype) + max_indices = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype) else: # max - offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1)) - bag_size = op.Expand(0, op.Shape(offsets) - 1) + offset2bag = op.Cast(op.Expand(0, op.Shape(indices, start=0, end=1)), to=INT64.dtype) + bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype) # shape = (bag_size.dim[0], weight.dim[1]) dim_0 = op.Shape(bag_size, start=0, end=1) dim_1 = op.Shape(weight, start=1, end=2) - max_indices = op.Expand(0, op.Concat(dim_0, dim_1, axis=0)) + max_indices = op.Cast(op.Expand(0, op.Concat(dim_0, dim_1, axis=0)), to=INT64.dtype) return result, offset2bag, bag_size, max_indices @@ -2995,27 +2997,40 @@ def aten_embedding_bag_padding_idx( sparse: bool = False, per_sample_weights: Optional[TFloat] = None, include_last_offset: bool = False, - padding_idx: int = -1, + padding_idx: Optional[int] = None, ) -> Tuple[TFloat, TFloat, TFloat, TFloat]: """embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor) We add default values for the attributes to accommodate _embedding_bag as well: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) """ - assert padding_idx is not None, ( - "padding_idx must not be None. This is likely a dispatcher error" - ) if per_sample_weights is None: per_sample_weights = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(indices)) per_sample_weights = op.CastLike(per_sample_weights, weight) - # Change padding_idx to positive value, -1 means the last index - if padding_idx < 0: - padding_idx = weight.shape[0] + padding_idx + if padding_idx is not None: + # Call the existing function for handling padding_idx + result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx( + weight, + indices, + offsets, + mode, + per_sample_weights, + include_last_offset, + padding_idx, + ) + + return result, offset2bag, bag_size, max_indices - result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx( - weight, indices, offsets, mode, per_sample_weights, include_last_offset, padding_idx + # When padding_idx is None, use the standard embedding_bag implementation + result, offset2bag, bag_size, max_indices = _aten_embedding_bag_onnx( + weight, + indices, + offsets, + mode, + per_sample_weights, + include_last_offset, ) return result, offset2bag, bag_size, max_indices @@ -3032,6 +3047,12 @@ def _aten_embedding_bag_1d_padding_idx_onnx( padding_idx: int, ) -> Tuple[TFloat, TFloat, TFloat, TFloat]: neg_1 = op.Constant(value_ints=[-1]) + + num_embeddings = op.Shape(weight, start=0, end=1) # Get number of rows in weight + num_embeddings_scalar = op.Squeeze(num_embeddings) + if padding_idx < 0: + padding_idx = padding_idx + num_embeddings_scalar + # Get weight out according to indices, # e.g. indices=[3,1,4,5,3] means get weight[[3,1,4,5,3]] indices_weight = op.Gather(weight, indices) @@ -3067,7 +3088,10 @@ def _aten_embedding_bag_1d_padding_idx_onnx( cond_2 = j < end_pos while cond_2: index = op.Gather(indices, j) - if not op.Equal(index, padding_idx): + normalized_index = index + if index < 0: + normalized_index = index + num_embeddings_scalar + if not op.Equal(normalized_index, padding_idx): # Something like the 'append' operation curr_offsets = op.Concat(curr_offsets, op.Reshape(j, neg_1), axis=0) j = j + 1 @@ -3096,23 +3120,24 @@ def _aten_embedding_bag_1d_padding_idx_onnx( result = op.CastLike(result, weight) if mode == 0: # sum - offset2bag = op.Expand(0, op.Shape(indices)) + offset2bag = op.Cast(op.Expand(0, op.Shape(indices)), to=INT64.dtype) if op.Equal(include_last_offset, True): - bag_size = op.Expand(0, op.Shape(offsets)) + bag_size = op.Cast(op.Expand(0, op.Shape(offsets)), to=INT64.dtype) + max_indices = op.Cast(op.Expand(0, op.Shape(offsets)), to=INT64.dtype) else: - bag_size = op.Expand(0, op.Shape(offsets) - 1) - max_indices = op.Expand(0, op.Shape(bag_size)) + bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype) + max_indices = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype) elif mode == 1: # mean - offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1)) - bag_size = op.Expand(0, op.Shape(offsets) - 1) - max_indices = op.Expand(0, op.Shape(bag_size)) + offset2bag = op.Cast(op.Expand(0, op.Shape(indices, start=0, end=1)), to=INT64.dtype) + bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype) + max_indices = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype) else: # mode == 2, max - offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1)) - bag_size = op.Expand(0, op.Shape(offsets) - 1) + offset2bag = op.Cast(op.Expand(0, op.Shape(indices, start=0, end=1)), to=INT64.dtype) + bag_size = op.Cast(op.Expand(0, op.Shape(offsets) - 1), to=INT64.dtype) # shape = (bag_size.dim[0], weight.dim[1]) dim_0 = op.Shape(bag_size, start=0, end=1) dim_1 = op.Shape(weight, start=1, end=2) - max_indices = op.Expand(0, op.Concat(dim_0, dim_1, axis=0)) + max_indices = op.Cast(op.Expand(0, op.Concat(dim_0, dim_1, axis=0)), to=INT64.dtype) return result, offset2bag, bag_size, max_indices diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index e87a0cc232..6fdde2086b 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -184,6 +184,25 @@ def xfail( # Modify this section ########################################################## +def _embedding_bag_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + # ONNX attributes cannot be None; omit padding_idx if it's None. + if "padding_idx" in kwargs: + padding_idx = kwargs.pop("padding_idx") + if padding_idx is not None: + kwargs["padding_idx"] = int(padding_idx) + + # Ensure indices/offsets are int64 (positional: weight, indices, offsets, ...) + if len(args) >= 3: + if isinstance(args[1], torch.Tensor): + args[1] = args[1].to(torch.long) + if isinstance(args[2], torch.Tensor): + args[2] = args[2].to(torch.long) + + return args, kwargs + + def _amin_amax_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -908,12 +927,17 @@ def _where_input_wrangler( core_ops.aten_embedding_bag, tolerance={torch.float32: (1e-4, 5e-4)}, compare_shape_only_for_output=(1, 2, 3), - ).skip(dtypes=(torch.float16,), reason="fixme: results mismatch in torch nightly."), + input_wrangler=_embedding_bag_input_wrangler, + ).skip( + dtypes=(torch.float16,), + reason="fixme: results mismatch in torch nightly.", + ), TorchLibOpInfo( "ops.aten.embedding_bag.padding_idx", core_ops.aten_embedding_bag_padding_idx, tolerance={torch.float16: (1e-2, 1e-2)}, compare_shape_only_for_output=(1, 2, 3), + input_wrangler=_embedding_bag_input_wrangler, ), TorchLibOpInfo( "ops.aten.embedding_renorm",