Skip to content

Commit 22c4c0f

Browse files
Copilotjustinchuby
andcommitted
Re-implement aten_bilinear using MatMul and Transpose operations instead of Einsum
Co-authored-by: justinchuby <[email protected]>
1 parent 38bd90b commit 22c4c0f

File tree

1 file changed

+47
-4
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+47
-4
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,10 +1205,53 @@ def aten_bilinear(
12051205
# bias shape: (out_features) - optional
12061206
# output shape: (..., out_features)
12071207

1208-
# Use Einsum to compute the bilinear transformation
1209-
# "...i,oij,...j->...o" means:
1210-
# - input1[..., i] * weight[o, i, j] * input2[..., j] -> output[..., o]
1211-
result = op.Einsum(input1, weight, input2, equation="...i,oij,...j->...o")
1208+
# Decompose bilinear into MatMul operations:
1209+
# 1. Create outer product of input1 and input2
1210+
# 2. Reshape to flatten feature dimensions
1211+
# 3. Use MatMul with reshaped weight
1212+
1213+
# Get shapes for reshaping
1214+
input1_shape = op.Shape(input1)
1215+
weight_shape = op.Shape(weight)
1216+
1217+
# Get dimensions
1218+
out_features = op.Gather(weight_shape, 0, axis=0)
1219+
in1_features = op.Gather(weight_shape, 1, axis=0)
1220+
in2_features = op.Gather(weight_shape, 2, axis=0)
1221+
1222+
# Get batch dimensions (everything except the last dimension)
1223+
input1_rank = Rank(input1)
1224+
batch_dims = op.Slice(input1_shape, [0], [input1_rank - 1])
1225+
batch_size = op.ReduceProd(batch_dims, keepdims=False)
1226+
1227+
# Create outer product: input1[..., i] * input2[..., j] -> [..., i, j]
1228+
# Reshape inputs to [batch_size, features] for easier handling
1229+
input1_2d = op.Reshape(input1, op.Concat([batch_size], [in1_features], axis=0))
1230+
input2_2d = op.Reshape(input2, op.Concat([batch_size], [in2_features], axis=0))
1231+
1232+
# Create outer product using unsqueeze and broadcasting
1233+
input1_expanded = op.Unsqueeze(input1_2d, axes=[2]) # [batch_size, in1_features, 1]
1234+
input2_expanded = op.Unsqueeze(input2_2d, axes=[1]) # [batch_size, 1, in2_features]
1235+
1236+
# Outer product via broadcasting multiplication
1237+
outer_product = op.Mul(input1_expanded, input2_expanded) # [batch_size, in1_features, in2_features]
1238+
1239+
# Flatten the feature dimensions
1240+
features_total = op.Mul(in1_features, in2_features)
1241+
outer_flat = op.Reshape(outer_product, op.Concat([batch_size], [features_total], axis=0))
1242+
1243+
# Reshape weight to 2D: [out_features, in1_features * in2_features]
1244+
weight_2d = op.Reshape(weight, op.Concat([out_features], [features_total], axis=0))
1245+
1246+
# Transpose weight for MatMul: [in1_features * in2_features, out_features]
1247+
weight_t = op.Transpose(weight_2d, perm=[1, 0])
1248+
1249+
# Matrix multiplication: [batch_size, out_features]
1250+
result = op.MatMul(outer_flat, weight_t)
1251+
1252+
# Reshape back to original batch dimensions + out_features
1253+
output_shape = op.Concat(batch_dims, [out_features], axis=0)
1254+
result = op.Reshape(result, output_shape)
12121255

12131256
# Add bias if provided
12141257
if bias is not None:

0 commit comments

Comments
 (0)