Skip to content
Open
35 changes: 26 additions & 9 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3101,27 +3101,44 @@
sparse: bool = False,
per_sample_weights: Optional[TFloat] = None,
include_last_offset: bool = False,
padding_idx: int = -1,
padding_idx: Optional[int] = None,
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
"""embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)
We add default values for the attributes to accommodate _embedding_bag as well:
_embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1)
"""
assert padding_idx is not None, (
"padding_idx must not be None. This is likely a dispatcher error"
)

if per_sample_weights is None:
per_sample_weights = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(indices))
per_sample_weights = op.CastLike(per_sample_weights, weight)

# Change padding_idx to positive value, -1 means the last index
if padding_idx < 0:
padding_idx = weight.shape[0] + padding_idx
if padding_idx is not None:
# Call the existing function for handling padding_idx
result, offset2bag, bag_size, max_indices =_aten_embedding_bag_1d_padding_idx_onnx(
weight,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset,
padding_idx,
)

result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx(
weight, indices, offsets, mode, per_sample_weights, include_last_offset, padding_idx
return result, offset2bag, bag_size, max_indices

# When padding_idx is None, use the standard embedding_bag implementation
result, offset2bag, bag_size, max_indices = _aten_embedding_bag_onnx(
weight,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset,
)

return result, offset2bag, bag_size, max_indices
Expand Down
38 changes: 38 additions & 0 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2210,6 +2210,44 @@ def __init__(self):
sample_inputs_func=sample_inputs_embedding_bag_padding_idx,
supports_out=False,
),
opinfo_core.OpInfo(
"test_embedding_bag_with_padding_idx_none",
op=torch.nn.functional.embedding_bag,
dtypes=(torch.float32,),
sample_inputs_func=lambda op_info, device, dtype, requires_grad: [
opinfo_core.SampleInput(
torch.tensor(
[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0]],
dtype=dtype,
device=device,
),
args=(
torch.tensor([0, 1, 2, 3], dtype=torch.int64, device=device),
torch.tensor([0, 2], dtype=torch.int64, device=device),
),
kwargs={"padding_idx": None},
)
],
),
opinfo_core.OpInfo(
"test_embedding_bag_with_padding_idx_int",
op=torch.nn.functional.embedding_bag,
dtypes=(torch.float32,),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it’s this line: you need the all_float_types() etc. construct for specifying supported dtypes. See other existing op infos for reference.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed it to common_dtype.floating_types_and_half(), similar to other test cases

sample_inputs_func=lambda op_info, device, dtype, requires_grad: [
opinfo_core.SampleInput(
torch.tensor(
[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]],
dtype=dtype,
device=device,
),
args=(
torch.tensor([0, 1, 2], dtype=torch.int64, device=device),
torch.tensor([0, 2], dtype=torch.int64, device=device),
),
kwargs={"padding_idx": 0},
)
],
),
opinfo_core.OpInfo(
"ops.aten.embedding_renorm",
aten_name="embedding_renorm",
Expand Down
41 changes: 41 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,24 @@
# Modify this section ##########################################################


def _embedding_bag_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# ONNX attributes cannot be None; omit padding_idx if it’s None.
padding_idx = kwargs.pop("padding_idx", "___MISSING___")
if padding_idx is not "___MISSING___":
if padding_idx is not None:
kwargs["padding_idx"] = int(padding_idx)

# Ensure indices/offsets are int64 (positional: weight, indices, offsets, ...)
if len(args) >= 3:
if isinstance(args[1], torch.Tensor):
args[1] = args[1].to(torch.long)
if isinstance(args[2], torch.Tensor):
args[2] = args[2].to(torch.long)

return args, kwargs

def _amin_amax_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
Expand Down Expand Up @@ -1035,15 +1053,38 @@
core_ops.aten_embedding_bag,
tolerance={torch.float32: (1e-4, 5e-4)},
compare_shape_only_for_output=(1, 2, 3),
input_wrangler=_embedding_bag_input_wrangler,
).skip(
dtypes=(torch.float16,),
reason="fixme: results mismatch in torch nightly.",
),
TorchLibOpInfo(
"test_embedding_bag_with_padding_idx_none",
core_ops.aten_embedding_bag,
input_wrangler=_embedding_bag_input_wrangler,
),
TorchLibOpInfo(
"test_embedding_bag_with_padding_idx_int",
core_ops.aten_embedding_bag,
input_wrangler=_embedding_bag_input_wrangler,
),

TorchLibOpInfo(
"ops.aten.embedding_bag.padding_idx",
core_ops.aten_embedding_bag_padding_idx,
tolerance={torch.float16: (1e-2, 1e-2)},
compare_shape_only_for_output=(1, 2, 3),
input_wrangler=_embedding_bag_input_wrangler,
),
TorchLibOpInfo(
"test_embedding_bag_with_padding_idx_none",
core_ops.aten_embedding_bag,
input_wrangler=_embedding_bag_input_wrangler,
),
TorchLibOpInfo(
"test_embedding_bag_with_padding_idx_int",
core_ops.aten_embedding_bag,
input_wrangler=_embedding_bag_input_wrangler,
),
TorchLibOpInfo(
"ops.aten.embedding_renorm",
Expand Down
Loading