diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 99fc6fb44f..91cf03499d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1189,6 +1189,7 @@ def aten_bernoulli_p(self: TTensor, p: float) -> TTensor: return op.CastLike(sampled, self) +@torch_op("aten::bilinear", trace_only=True) def aten_bilinear( input1: TensorType, input2: TensorType, @@ -1197,7 +1198,23 @@ def aten_bilinear( ) -> TensorType: """bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor""" - raise NotImplementedError() + # Bilinear transformation: y = x1^T A x2 + b + # input1 shape: (..., in1_features) + # input2 shape: (..., in2_features) + # weight shape: (out_features, in1_features, in2_features) + # bias shape: (out_features) - optional + # output shape: (..., out_features) + + # Use Einsum to compute the bilinear transformation + # "...i,oij,...j->...o" means: + # - input1[..., i] * weight[o, i, j] * input2[..., j] -> output[..., o] + result = op.Einsum(input1, weight, input2, equation="...i,oij,...j->...o") + + # Add bias if provided + if bias is not None: + result = op.Add(result, bias) + + return result def aten_binary_cross_entropy_with_logits( diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 4f4a3872e1..9455364ea2 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -37,6 +37,37 @@ def sample_inputs_scalar_tensor(op_info, device, dtype, requires_grad, **kwargs) yield opinfo_core.SampleInput(item, dtype=dtype) +def sample_inputs_bilinear(op_info, device, dtype, requires_grad, **kwargs): + """Sample inputs for bilinear operation.""" + del op_info + del kwargs + + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + + # Test cases: (batch_size, in1_features, in2_features, out_features) + cases = [ + (2, 3, 4, 5), # Basic case + (1, 2, 2, 1), # Minimal case + (3, 5, 7, 4), # Different dimensions + (2, 1, 1, 3), # Single input features + ] + + for batch_size, in1_features, in2_features, out_features in cases: + input1 = make_arg((batch_size, in1_features)) + input2 = make_arg((batch_size, in2_features)) + weight = make_arg((out_features, in1_features, in2_features)) + bias = make_arg((out_features,)) + + # Test with bias + yield opinfo_core.SampleInput(input1, args=(input2, weight, bias)) + + # Test without bias (only for first case to avoid too many tests) + if batch_size == 2: + yield opinfo_core.SampleInput(input1, args=(input2, weight, None)) + + def sample_inputs_bernoulli_p(op_info, device, dtype, requires_grad, **kwargs): del op_info @@ -2180,6 +2211,13 @@ def __init__(self): # To avoid name duplication, it is possible to rename the OpInfo and specify # the `op` field explicitly. OP_DB: List[opinfo_core.OpInfo] = [ + opinfo_core.OpInfo( + "bilinear", + op=torch.nn.functional.bilinear, + dtypes=common_dtype.floating_types(), + sample_inputs_func=sample_inputs_bilinear, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.bernoulli.p", aten_name="bernoulli.p", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index b1e0c529ec..6f21bd518b 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -657,6 +657,9 @@ def _where_input_wrangler( ), TorchLibOpInfo("baddbmm", core_ops.aten_baddbmm, tolerance={torch.float16: (1e-3, 1e-2)}), TorchLibOpInfo("bernoulli", core_ops.aten_bernoulli, nondeterministic=True), + TorchLibOpInfo( + "bilinear", core_ops.aten_bilinear, tolerance={torch.float32: (2e-5, 2e-5)} + ), TorchLibOpInfo( # This string is a unique ID. In extra_opinfo.py, we # also define test data for this ID with