@@ -1205,10 +1205,53 @@ def aten_bilinear(
12051205 # bias shape: (out_features) - optional
12061206 # output shape: (..., out_features)
12071207
1208- # Use Einsum to compute the bilinear transformation
1209- # "...i,oij,...j->...o" means:
1210- # - input1[..., i] * weight[o, i, j] * input2[..., j] -> output[..., o]
1211- result = op .Einsum (input1 , weight , input2 , equation = "...i,oij,...j->...o" )
1208+ # Decompose bilinear into MatMul operations:
1209+ # 1. Create outer product of input1 and input2
1210+ # 2. Reshape to flatten feature dimensions
1211+ # 3. Use MatMul with reshaped weight
1212+
1213+ # Get shapes for reshaping
1214+ input1_shape = op .Shape (input1 )
1215+ weight_shape = op .Shape (weight )
1216+
1217+ # Get dimensions
1218+ out_features = op .Gather (weight_shape , 0 , axis = 0 )
1219+ in1_features = op .Gather (weight_shape , 1 , axis = 0 )
1220+ in2_features = op .Gather (weight_shape , 2 , axis = 0 )
1221+
1222+ # Get batch dimensions (everything except the last dimension)
1223+ input1_rank = Rank (input1 )
1224+ batch_dims = op .Slice (input1_shape , [0 ], [input1_rank - 1 ])
1225+ batch_size = op .ReduceProd (batch_dims , keepdims = False )
1226+
1227+ # Create outer product: input1[..., i] * input2[..., j] -> [..., i, j]
1228+ # Reshape inputs to [batch_size, features] for easier handling
1229+ input1_2d = op .Reshape (input1 , op .Concat ([batch_size ], [in1_features ], axis = 0 ))
1230+ input2_2d = op .Reshape (input2 , op .Concat ([batch_size ], [in2_features ], axis = 0 ))
1231+
1232+ # Create outer product using unsqueeze and broadcasting
1233+ input1_expanded = op .Unsqueeze (input1_2d , axes = [2 ]) # [batch_size, in1_features, 1]
1234+ input2_expanded = op .Unsqueeze (input2_2d , axes = [1 ]) # [batch_size, 1, in2_features]
1235+
1236+ # Outer product via broadcasting multiplication
1237+ outer_product = op .Mul (input1_expanded , input2_expanded ) # [batch_size, in1_features, in2_features]
1238+
1239+ # Flatten the feature dimensions
1240+ features_total = op .Mul (in1_features , in2_features )
1241+ outer_flat = op .Reshape (outer_product , op .Concat ([batch_size ], [features_total ], axis = 0 ))
1242+
1243+ # Reshape weight to 2D: [out_features, in1_features * in2_features]
1244+ weight_2d = op .Reshape (weight , op .Concat ([out_features ], [features_total ], axis = 0 ))
1245+
1246+ # Transpose weight for MatMul: [in1_features * in2_features, out_features]
1247+ weight_t = op .Transpose (weight_2d , perm = [1 , 0 ])
1248+
1249+ # Matrix multiplication: [batch_size, out_features]
1250+ result = op .MatMul (outer_flat , weight_t )
1251+
1252+ # Reshape back to original batch dimensions + out_features
1253+ output_shape = op .Concat (batch_dims , [out_features ], axis = 0 )
1254+ result = op .Reshape (result , output_shape )
12121255
12131256 # Add bias if provided
12141257 if bias is not None :
0 commit comments