@@ -888,6 +888,7 @@ def conv_2d_nchw_fchw_q(
888
888
- TypeFn .cast_signed (U , IZp )
889
889
) * (TypeFn .cast_signed (U , K [D .f , D .c , D .kh , D .kw ]) - TypeFn .cast_signed (U , KZp ))
890
890
891
+
891
892
@linalg_structured_op
892
893
def conv_2d_nchw_fchw (
893
894
I = TensorDef (T1 , S .N , S .C , S .OH * S .SH + S .KH * S .DH , S .OW * S .SW + S .KW * S .DW ),
@@ -1082,16 +1083,19 @@ def conv_3d_ndhwc_dhwcf(
1082
1083
"""
1083
1084
implements (ConvolutionOpInterface )
1084
1085
domain (D .n , D .od , D .oh , D .ow , D .f , D .kd , D .kh , D .kw , D .c )
1085
- O [D .n , D .od , D .oh , D .ow , D .f ] += TypeFn .cast_signed (
1086
- U ,
1087
- I [
1088
- D .n ,
1089
- D .od * S .SD + D .kd * S .DD ,
1090
- D .oh * S .SH + D .kh * S .DH ,
1091
- D .ow * S .SW + D .kw * S .DW ,
1092
- D .c ,
1093
- ],
1094
- ) * TypeFn .cast_signed (U , K [D .kd , D .kh , D .kw , D .c , D .f ])
1086
+ O [D .n , D .od , D .oh , D .ow , D .f ] += (
1087
+ TypeFn .cast_signed (
1088
+ U ,
1089
+ I [
1090
+ D .n ,
1091
+ D .od * S .SD + D .kd * S .DD ,
1092
+ D .oh * S .SH + D .kh * S .DH ,
1093
+ D .ow * S .SW + D .kw * S .DW ,
1094
+ D .c ,
1095
+ ],
1096
+ )
1097
+ * TypeFn .cast_signed (U , K [D .kd , D .kh , D .kw , D .c , D .f ])
1098
+ )
1095
1099
1096
1100
1097
1101
@linalg_structured_op
@@ -1159,16 +1163,19 @@ def conv_3d_ncdhw_fcdhw(
1159
1163
"""
1160
1164
implements (ConvolutionOpInterface )
1161
1165
domain (D .n , D .od , D .oh , D .ow , D .f , D .kd , D .kh , D .kw , D .c )
1162
- O [D .n , D .f , D .od , D .oh , D .ow ] += TypeFn .cast_signed (
1163
- U ,
1164
- I [
1165
- D .n ,
1166
- D .c ,
1167
- D .od * S .SD + D .kd * S .DD ,
1168
- D .oh * S .SH + D .kh * S .DH ,
1169
- D .ow * S .SW + D .kw * S .DW ,
1170
- ],
1171
- ) * TypeFn .cast_signed (U , K [D .f , D .c , D .kd , D .kh , D .kw ])
1166
+ O [D .n , D .f , D .od , D .oh , D .ow ] += (
1167
+ TypeFn .cast_signed (
1168
+ U ,
1169
+ I [
1170
+ D .n ,
1171
+ D .c ,
1172
+ D .od * S .SD + D .kd * S .DD ,
1173
+ D .oh * S .SH + D .kh * S .DH ,
1174
+ D .ow * S .SW + D .kw * S .DW ,
1175
+ ],
1176
+ )
1177
+ * TypeFn .cast_signed (U , K [D .f , D .c , D .kd , D .kh , D .kw ])
1178
+ )
1172
1179
1173
1180
1174
1181
@linalg_structured_op
@@ -1368,16 +1375,19 @@ def depthwise_conv_3d_ndhwc_dhwc(
1368
1375
"""
1369
1376
implements (ConvolutionOpInterface )
1370
1377
domain (D .n , D .od , D .oh , D .ow , D .kd , D .kh , D .kw , D .ic )
1371
- O [D .n , D .od , D .oh , D .ow , D .ic ] += TypeFn .cast_signed (
1372
- U ,
1373
- I [
1374
- D .n ,
1375
- D .od * S .SD + D .kd * S .DD ,
1376
- D .oh * S .SH + D .kh * S .DH ,
1377
- D .ow * S .SW + D .kw * S .DW ,
1378
- D .ic ,
1379
- ],
1380
- ) * TypeFn .cast_signed (U , K [D .kd , D .kh , D .kw , D .ic ])
1378
+ O [D .n , D .od , D .oh , D .ow , D .ic ] += (
1379
+ TypeFn .cast_signed (
1380
+ U ,
1381
+ I [
1382
+ D .n ,
1383
+ D .od * S .SD + D .kd * S .DD ,
1384
+ D .oh * S .SH + D .kh * S .DH ,
1385
+ D .ow * S .SW + D .kw * S .DW ,
1386
+ D .ic ,
1387
+ ],
1388
+ )
1389
+ * TypeFn .cast_signed (U , K [D .kd , D .kh , D .kw , D .ic ])
1390
+ )
1381
1391
1382
1392
1383
1393
@linalg_structured_op
@@ -1403,16 +1413,19 @@ def depthwise_conv_3d_ncdhw_cdhw(
1403
1413
"""
1404
1414
implements (ConvolutionOpInterface )
1405
1415
domain (D .n , D .od , D .oh , D .ow , D .kd , D .kh , D .kw , D .ic )
1406
- O [D .n , D .ic , D .od , D .oh , D .ow ] += TypeFn .cast_signed (
1407
- U ,
1408
- I [
1409
- D .n ,
1410
- D .ic ,
1411
- D .od * S .SD + D .kd * S .DD ,
1412
- D .oh * S .SH + D .kh * S .DH ,
1413
- D .ow * S .SW + D .kw * S .DW ,
1414
- ],
1415
- ) * TypeFn .cast_signed (U , K [D .ic , D .kd , D .kh , D .kw ])
1416
+ O [D .n , D .ic , D .od , D .oh , D .ow ] += (
1417
+ TypeFn .cast_signed (
1418
+ U ,
1419
+ I [
1420
+ D .n ,
1421
+ D .ic ,
1422
+ D .od * S .SD + D .kd * S .DD ,
1423
+ D .oh * S .SH + D .kh * S .DH ,
1424
+ D .ow * S .SW + D .kw * S .DW ,
1425
+ ],
1426
+ )
1427
+ * TypeFn .cast_signed (U , K [D .ic , D .kd , D .kh , D .kw ])
1428
+ )
1416
1429
1417
1430
1418
1431
@linalg_structured_op
@@ -1437,16 +1450,19 @@ def depthwise_conv_3d_ndhwc_dhwcm(
1437
1450
"""
1438
1451
implements (ConvolutionOpInterface )
1439
1452
domain (D .n , D .od , D .oh , D .ow , D .cm , D .kd , D .kh , D .kw , D .ic )
1440
- O [D .n , D .od , D .oh , D .ow , D .ic , D .cm ] += TypeFn .cast_signed (
1441
- U ,
1442
- I [
1443
- D .n ,
1444
- D .od * S .SD + D .kd * S .DD ,
1445
- D .oh * S .SH + D .kh * S .DH ,
1446
- D .ow * S .SW + D .kw * S .DW ,
1447
- D .ic ,
1448
- ],
1449
- ) * TypeFn .cast_signed (U , K [D .kd , D .kh , D .kw , D .ic , D .cm ])
1453
+ O [D .n , D .od , D .oh , D .ow , D .ic , D .cm ] += (
1454
+ TypeFn .cast_signed (
1455
+ U ,
1456
+ I [
1457
+ D .n ,
1458
+ D .od * S .SD + D .kd * S .DD ,
1459
+ D .oh * S .SH + D .kh * S .DH ,
1460
+ D .ow * S .SW + D .kw * S .DW ,
1461
+ D .ic ,
1462
+ ],
1463
+ )
1464
+ * TypeFn .cast_signed (U , K [D .kd , D .kh , D .kw , D .ic , D .cm ])
1465
+ )
1450
1466
1451
1467
1452
1468
@linalg_structured_op
0 commit comments