@@ -888,6 +888,7 @@ def conv_2d_nchw_fchw_q(
888888 - TypeFn .cast_signed (U , IZp )
889889 ) * (TypeFn .cast_signed (U , K [D .f , D .c , D .kh , D .kw ]) - TypeFn .cast_signed (U , KZp ))
890890
891+
891892@linalg_structured_op
892893def conv_2d_nchw_fchw (
893894 I = TensorDef (T1 , S .N , S .C , S .OH * S .SH + S .KH * S .DH , S .OW * S .SW + S .KW * S .DW ),
@@ -1082,16 +1083,19 @@ def conv_3d_ndhwc_dhwcf(
10821083 """
10831084 implements (ConvolutionOpInterface )
10841085 domain (D .n , D .od , D .oh , D .ow , D .f , D .kd , D .kh , D .kw , D .c )
1085- O [D .n , D .od , D .oh , D .ow , D .f ] += TypeFn .cast_signed (
1086- U ,
1087- I [
1088- D .n ,
1089- D .od * S .SD + D .kd * S .DD ,
1090- D .oh * S .SH + D .kh * S .DH ,
1091- D .ow * S .SW + D .kw * S .DW ,
1092- D .c ,
1093- ],
1094- ) * TypeFn .cast_signed (U , K [D .kd , D .kh , D .kw , D .c , D .f ])
1086+ O [D .n , D .od , D .oh , D .ow , D .f ] += (
1087+ TypeFn .cast_signed (
1088+ U ,
1089+ I [
1090+ D .n ,
1091+ D .od * S .SD + D .kd * S .DD ,
1092+ D .oh * S .SH + D .kh * S .DH ,
1093+ D .ow * S .SW + D .kw * S .DW ,
1094+ D .c ,
1095+ ],
1096+ )
1097+ * TypeFn .cast_signed (U , K [D .kd , D .kh , D .kw , D .c , D .f ])
1098+ )
10951099
10961100
10971101@linalg_structured_op
@@ -1159,16 +1163,19 @@ def conv_3d_ncdhw_fcdhw(
11591163 """
11601164 implements (ConvolutionOpInterface )
11611165 domain (D .n , D .od , D .oh , D .ow , D .f , D .kd , D .kh , D .kw , D .c )
1162- O [D .n , D .f , D .od , D .oh , D .ow ] += TypeFn .cast_signed (
1163- U ,
1164- I [
1165- D .n ,
1166- D .c ,
1167- D .od * S .SD + D .kd * S .DD ,
1168- D .oh * S .SH + D .kh * S .DH ,
1169- D .ow * S .SW + D .kw * S .DW ,
1170- ],
1171- ) * TypeFn .cast_signed (U , K [D .f , D .c , D .kd , D .kh , D .kw ])
1166+ O [D .n , D .f , D .od , D .oh , D .ow ] += (
1167+ TypeFn .cast_signed (
1168+ U ,
1169+ I [
1170+ D .n ,
1171+ D .c ,
1172+ D .od * S .SD + D .kd * S .DD ,
1173+ D .oh * S .SH + D .kh * S .DH ,
1174+ D .ow * S .SW + D .kw * S .DW ,
1175+ ],
1176+ )
1177+ * TypeFn .cast_signed (U , K [D .f , D .c , D .kd , D .kh , D .kw ])
1178+ )
11721179
11731180
11741181@linalg_structured_op
@@ -1368,16 +1375,19 @@ def depthwise_conv_3d_ndhwc_dhwc(
13681375 """
13691376 implements (ConvolutionOpInterface )
13701377 domain (D .n , D .od , D .oh , D .ow , D .kd , D .kh , D .kw , D .ic )
1371- O [D .n , D .od , D .oh , D .ow , D .ic ] += TypeFn .cast_signed (
1372- U ,
1373- I [
1374- D .n ,
1375- D .od * S .SD + D .kd * S .DD ,
1376- D .oh * S .SH + D .kh * S .DH ,
1377- D .ow * S .SW + D .kw * S .DW ,
1378- D .ic ,
1379- ],
1380- ) * TypeFn .cast_signed (U , K [D .kd , D .kh , D .kw , D .ic ])
1378+ O [D .n , D .od , D .oh , D .ow , D .ic ] += (
1379+ TypeFn .cast_signed (
1380+ U ,
1381+ I [
1382+ D .n ,
1383+ D .od * S .SD + D .kd * S .DD ,
1384+ D .oh * S .SH + D .kh * S .DH ,
1385+ D .ow * S .SW + D .kw * S .DW ,
1386+ D .ic ,
1387+ ],
1388+ )
1389+ * TypeFn .cast_signed (U , K [D .kd , D .kh , D .kw , D .ic ])
1390+ )
13811391
13821392
13831393@linalg_structured_op
@@ -1403,16 +1413,19 @@ def depthwise_conv_3d_ncdhw_cdhw(
14031413 """
14041414 implements (ConvolutionOpInterface )
14051415 domain (D .n , D .od , D .oh , D .ow , D .kd , D .kh , D .kw , D .ic )
1406- O [D .n , D .ic , D .od , D .oh , D .ow ] += TypeFn .cast_signed (
1407- U ,
1408- I [
1409- D .n ,
1410- D .ic ,
1411- D .od * S .SD + D .kd * S .DD ,
1412- D .oh * S .SH + D .kh * S .DH ,
1413- D .ow * S .SW + D .kw * S .DW ,
1414- ],
1415- ) * TypeFn .cast_signed (U , K [D .ic , D .kd , D .kh , D .kw ])
1416+ O [D .n , D .ic , D .od , D .oh , D .ow ] += (
1417+ TypeFn .cast_signed (
1418+ U ,
1419+ I [
1420+ D .n ,
1421+ D .ic ,
1422+ D .od * S .SD + D .kd * S .DD ,
1423+ D .oh * S .SH + D .kh * S .DH ,
1424+ D .ow * S .SW + D .kw * S .DW ,
1425+ ],
1426+ )
1427+ * TypeFn .cast_signed (U , K [D .ic , D .kd , D .kh , D .kw ])
1428+ )
14161429
14171430
14181431@linalg_structured_op
@@ -1437,16 +1450,19 @@ def depthwise_conv_3d_ndhwc_dhwcm(
14371450 """
14381451 implements (ConvolutionOpInterface )
14391452 domain (D .n , D .od , D .oh , D .ow , D .cm , D .kd , D .kh , D .kw , D .ic )
1440- O [D .n , D .od , D .oh , D .ow , D .ic , D .cm ] += TypeFn .cast_signed (
1441- U ,
1442- I [
1443- D .n ,
1444- D .od * S .SD + D .kd * S .DD ,
1445- D .oh * S .SH + D .kh * S .DH ,
1446- D .ow * S .SW + D .kw * S .DW ,
1447- D .ic ,
1448- ],
1449- ) * TypeFn .cast_signed (U , K [D .kd , D .kh , D .kw , D .ic , D .cm ])
1453+ O [D .n , D .od , D .oh , D .ow , D .ic , D .cm ] += (
1454+ TypeFn .cast_signed (
1455+ U ,
1456+ I [
1457+ D .n ,
1458+ D .od * S .SD + D .kd * S .DD ,
1459+ D .oh * S .SH + D .kh * S .DH ,
1460+ D .ow * S .SW + D .kw * S .DW ,
1461+ D .ic ,
1462+ ],
1463+ )
1464+ * TypeFn .cast_signed (U , K [D .kd , D .kh , D .kw , D .ic , D .cm ])
1465+ )
14501466
14511467
14521468@linalg_structured_op
0 commit comments