diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 05e2cd9258..8506debede 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3089,28 +3089,30 @@ def aten_embedding_bag_padding_idx( sparse: bool = False, per_sample_weights: Optional[TFloat] = None, include_last_offset: bool = False, - padding_idx: int = -1, + padding_idx: Optional[int] = None, ) -> Tuple[TFloat, TFloat, TFloat, TFloat]: """embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor) We add default values for the attributes to accommodate _embedding_bag as well: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) """ - assert padding_idx is not None, ( - "padding_idx must not be None. This is likely a dispatcher error" - ) - if per_sample_weights is None: per_sample_weights = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(indices)) per_sample_weights = op.CastLike(per_sample_weights, weight) - # Change padding_idx to positive value, -1 means the last index - if padding_idx < 0: - padding_idx = weight.shape[0] + padding_idx + # If padding_idx is None, behave like regular embedding_bag (no padding index to ignore) + if padding_idx is None: + result, offset2bag, bag_size, max_indices = _aten_embedding_bag_onnx( + weight, indices, offsets, mode, per_sample_weights, include_last_offset + ) + else: + # Change padding_idx to positive value, -1 means the last index + if padding_idx < 0: + padding_idx = weight.shape[0] + padding_idx - result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx( - weight, indices, offsets, mode, per_sample_weights, include_last_offset, padding_idx - ) + result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx( + weight, indices, offsets, mode, per_sample_weights, include_last_offset, padding_idx + ) return result, offset2bag, bag_size, max_indices