@@ -301,6 +301,10 @@ class LayoutInfoPropagation
301
301
ArrayRef<LayoutInfoLattice *> operands,
302
302
ArrayRef<const LayoutInfoLattice *> results);
303
303
304
+ void visitPrefetchNdOp (xegpu::PrefetchNdOp prefetch,
305
+ ArrayRef<LayoutInfoLattice *> operands,
306
+ ArrayRef<const LayoutInfoLattice *> results);
307
+
304
308
void visitVectorMultiReductionOp (vector::MultiDimReductionOp reduction,
305
309
ArrayRef<LayoutInfoLattice *> operands,
306
310
ArrayRef<const LayoutInfoLattice *> results);
@@ -352,6 +356,9 @@ LogicalResult LayoutInfoPropagation::visitOperation(
352
356
.Case <xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
353
357
visitUpdateNdOffsetOp (updateNdOffsetOp, operands, results);
354
358
})
359
+ .Case <xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
360
+ visitPrefetchNdOp (prefetchNdOp, operands, results);
361
+ })
355
362
// No need to propagate the layout to operands in CreateNdDescOp because
356
363
// they are scalars (offsets, sizes, etc.).
357
364
.Case <xegpu::CreateNdDescOp>([&](auto createNdDescOp) {})
@@ -381,6 +388,18 @@ LogicalResult LayoutInfoPropagation::visitOperation(
381
388
return success ();
382
389
}
383
390
391
+ void LayoutInfoPropagation::visitPrefetchNdOp (
392
+ xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
393
+ ArrayRef<const LayoutInfoLattice *> results) {
394
+ // Here we assign the default layout to the tensor descriptor operand of
395
+ // prefetch.
396
+ auto tdescTy = prefetch.getTensorDescType ();
397
+ auto prefetchLayout = getDefaultLayoutInfo (
398
+ VectorType::get (tdescTy.getShape (), tdescTy.getElementType ()));
399
+ // Propagate the layout to the source tensor descriptor.
400
+ propagateIfChanged (operands[0 ], operands[0 ]->meet (prefetchLayout));
401
+ }
402
+
384
403
void LayoutInfoPropagation::visitVectorMultiReductionOp (
385
404
vector::MultiDimReductionOp reduction,
386
405
ArrayRef<LayoutInfoLattice *> operands,
@@ -865,18 +884,6 @@ getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
865
884
return VectorType::get (distributedShape, originalType.getElementType ());
866
885
}
867
886
868
- // Drop the layout attribute from the tensor descriptor type if layout is
869
- // present.
870
- static xegpu::TensorDescType dropLayouts (xegpu::TensorDescType tensorDesc) {
871
- if (tensorDesc.getLayoutAttr () == xegpu::LayoutAttr ())
872
- return tensorDesc;
873
-
874
- return xegpu::TensorDescType::get (
875
- tensorDesc.getContext (), tensorDesc.getShape (),
876
- tensorDesc.getElementType (), tensorDesc.getEncoding (),
877
- xegpu::LayoutAttr ());
878
- }
879
-
880
887
// / Helper function to resolve types if the distributed type out of
881
888
// / gpu.warp_execute_on_lane0 is different from the expected xegpu SIMT type.
882
889
// / Example 1:
@@ -1023,12 +1030,12 @@ struct MoveFuncBodyToWarpExecuteOnLane0
1023
1030
// / Example:
1024
1031
// /
1025
1032
// / ```
1026
- // / #lo0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1033
+ // / #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1027
1034
// / %r = gpu.warp_execute_on_lane_0(%laneid) ->
1028
- // / (!xegpu.tensor_desc<4x8xf32, #lo0 >) {
1035
+ // / (!xegpu.tensor_desc<4x8xf32, #layout0 >) {
1029
1036
// / ...
1030
1037
// / %td = xegpu.create_nd_tdesc %arg0[0, 0]
1031
- // / : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #lo0 >
1038
+ // / : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0 >
1032
1039
// / vector.yield %td
1033
1040
// / }
1034
1041
// / ```
@@ -1037,7 +1044,7 @@ struct MoveFuncBodyToWarpExecuteOnLane0
1037
1044
// / %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (...) {
1038
1045
// / ...
1039
1046
// / %dead = xegpu.create_nd_tdesc %arg0[0, 0]
1040
- // / : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #lo0 >
1047
+ // / : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0 >
1041
1048
// / vector.yield %arg0, %dead
1042
1049
// / }
1043
1050
// / %td = xegpu.create_nd_tdesc %r#0[0, 0]: memref<4x8xf32>
@@ -1080,8 +1087,8 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
1080
1087
}
1081
1088
rewriter.setInsertionPointAfter (newWarpOp);
1082
1089
xegpu::TensorDescType distributedTensorDescTy =
1083
- dropLayouts ( descOp.getType ()); // Distributed tensor descriptor type
1084
- // does not contain layout info.
1090
+ descOp.getType (). dropLayouts ( ); // Distributed tensor descriptor type
1091
+ // does not contain layout info.
1085
1092
auto newDescOp = rewriter.create <xegpu::CreateNdDescOp>(
1086
1093
newWarpOp.getLoc (), distributedTensorDescTy, newDescOperands,
1087
1094
descOp->getAttrs ());
@@ -1101,23 +1108,23 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
1101
1108
// / Example:
1102
1109
// /
1103
1110
// / ```
1104
- // / #lo0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1111
+ // / #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1105
1112
// / gpu.warp_execute_on_lane_0(%laneid) -> () {
1106
1113
// / ...
1107
1114
// / xegpu.store_nd %arg0, %arg1: vector<4x8xf32>,
1108
- // / !xegpu.tensor_desc<4x8xf32, #lo0 >
1115
+ // / !xegpu.tensor_desc<4x8xf32, #layout0 >
1109
1116
// / }
1110
1117
// / ```
1111
1118
// / To
1112
1119
// / ```
1113
1120
// / %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
1114
- // / !xegpu.tensor_desc<4x8xf32, #lo0 >) {
1121
+ // / !xegpu.tensor_desc<4x8xf32, #layout0 >) {
1115
1122
// / gpu.yield %arg0, %arg1: vector<4x8xf32>, !xegpu.tensor_desc<4x8xf32,
1116
- // / #lo0 >
1123
+ // / #layout0 >
1117
1124
// / }
1118
1125
// / %0 = vector.shape_cast %r#0: vector<4x1xf32> to vector<4xf32>
1119
1126
// / %1 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
1120
- // / #lo0 >
1127
+ // / #layout0 >
1121
1128
// / -> !xegpu.tensor_desc<4x8xf32>
1122
1129
// / xegpu.store_nd %0, %1: vector<4xf32>,
1123
1130
// / !xegpu.tensor_desc<4x8xf32>
@@ -1173,10 +1180,10 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
1173
1180
newStoreOperands.push_back (resolveDistributedTy (
1174
1181
newWarpOp.getResult (newRetIndices[0 ]),
1175
1182
storeNdDistributedValueTyOrFailure.value (), rewriter));
1176
- // For the tensor descriptor operand, the layout attibute is dropped after
1183
+ // For the tensor descriptor operand, the layout attribute is dropped after
1177
1184
// distribution. Types needs to be resolved in this case also.
1178
1185
xegpu::TensorDescType distributedTensorDescTy =
1179
- dropLayouts ( storeOp.getTensorDescType ());
1186
+ storeOp.getTensorDescType (). dropLayouts ( );
1180
1187
newStoreOperands.push_back (
1181
1188
resolveDistributedTy (newWarpOp.getResult (newRetIndices[1 ]),
1182
1189
distributedTensorDescTy, rewriter));
@@ -1201,25 +1208,26 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
1201
1208
// / Example:
1202
1209
// /
1203
1210
// / ```
1204
- // / #lo0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1211
+ // / #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1205
1212
// / %r = gpu.warp_execute_on_lane_0(%laneid) ->
1206
1213
// / (vector<4x1xf32>) {
1207
1214
// / ...
1208
- // / %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #lo0> ->
1215
+ // / %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #layout0>
1216
+ // / ->
1209
1217
// / vector<4x8xf32>
1210
1218
// / gpu.yield %ld
1211
1219
// / }
1212
1220
// / ```
1213
1221
// / To
1214
1222
// / ```
1215
1223
// / %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
1216
- // / !xegpu.tensor_desc<4x8xf32, #lo0 >) {
1224
+ // / !xegpu.tensor_desc<4x8xf32, #layout0 >) {
1217
1225
// / ...
1218
- // / %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #lo0 > ->
1226
+ // / %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #layout0 > ->
1219
1227
// / vector<4x8xf32> gpu.yield %dead, %arg0
1220
1228
// / }
1221
1229
// / %0 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
1222
- // / #lo0 > -> !xegpu.tensor_desc<4x8xf32>
1230
+ // / #layout0 > -> !xegpu.tensor_desc<4x8xf32>
1223
1231
// / %1 = xegpu.load_nd %0: !xegpu.tensor_desc<4x8xf32> -> vector<4xf32>
1224
1232
// / %2 = vector.shape_cast %r#0: vector<4xf32> to vector<4x1xf32>
1225
1233
// /
@@ -1260,9 +1268,9 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
1260
1268
return rewriter.notifyMatchFailure (
1261
1269
loadOp, " Failed to get distributed vector type for the load op" );
1262
1270
xegpu::TensorDescType distributedTensorDescTy =
1263
- dropLayouts ( loadOp.getTensorDescType ()); // Distributed tensor
1264
- // descriptor type does not
1265
- // contain layout info.
1271
+ loadOp.getTensorDescType (). dropLayouts ( ); // Distributed tensor
1272
+ // descriptor type does not
1273
+ // contain layout info.
1266
1274
auto newLoadOp = rewriter.create <xegpu::LoadNdOp>(
1267
1275
newWarpOp.getLoc (), loadNdDistValueTyOrFailure.value (),
1268
1276
resolveDistributedTy (newWarpOp->getResult (newRetIndices[0 ]),
@@ -1412,6 +1420,152 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
1412
1420
}
1413
1421
};
1414
1422
1423
+ // / Sink an update_nd_offset op feeding into yield op of an enclosing
1424
+ // / `gpu.warp_execute_on_lane_0` region. The warp op will still contain the
1425
+ // / original op that will not be used by the yield op (and should be cleaned
1426
+ // / up later). The yield op will bypass the updateOp's arguments. The tensor
1427
+ // / descriptor type is not distributed. Appropriate cast ops are inserted if
1428
+ // / the distributed types does not match expected xegpu SIMT types.
1429
+ // / Example:
1430
+ // / ```
1431
+ // / #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1432
+ // / %r = gpu.warp_execute_on_lane_0(%laneid) ->
1433
+ // / (!xegpu.tensor_desc<4x8xf32, #layout0>) {
1434
+ // / ...
1435
+ // / %update = xegpu.update_nd_offset %arg0, [%c32, %c16]:
1436
+ // / !xegpu.tensor_desc<4x8xf32, #layout0>
1437
+ // / gpu.yield %update
1438
+ // / }
1439
+ // / ...
1440
+ // / ```
1441
+ // / To
1442
+ // / ```
1443
+ // / %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (
1444
+ // / !xegpu.tensor_desc<4x8xf32, #layout0>,
1445
+ // / !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
1446
+ // / ...
1447
+ // / %dead = xegpu.update_nd_offset %arg0, [%c32, %c16]:
1448
+ // / !xegpu.tensor_desc<4x8xf32, #layout0> gpu.yield %dead, %arg0
1449
+ // / gpu.yield %dead, %arg0, %c32, %c16
1450
+ // / }
1451
+ // / %0 = xegpu.unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
1452
+ // / #layout0> -> !xegpu.tensor_desc<4x8xf32>
1453
+ // / %1 = xegpu.update_nd_offset %0, [%r#2, %r#3]:
1454
+ // / !xegpu.tensor_desc<4x8xf32>
1455
+ // / ...
1456
+ // / ```
1457
+ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
1458
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
1459
+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp,
1460
+ PatternRewriter &rewriter) const override {
1461
+ OpOperand *operand =
1462
+ getWarpResult (subgroupOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
1463
+ if (!operand)
1464
+ return rewriter.notifyMatchFailure (
1465
+ subgroupOp, " warp result is not a xegpu::UpdateNdOffset op" );
1466
+ auto updateOp = operand->get ().getDefiningOp <xegpu::UpdateNdOffsetOp>();
1467
+ unsigned operandIdx = operand->getOperandNumber ();
1468
+ // new update op does not have layout attribute.
1469
+ xegpu::TensorDescType newTensorDescTy =
1470
+ updateOp.getTensorDescType ().dropLayouts ();
1471
+
1472
+ SmallVector<Value, 3 > newYieldValues;
1473
+ SmallVector<Type, 3 > newYieldTypes;
1474
+ for (Value operand : updateOp->getOperands ()) {
1475
+ newYieldValues.push_back (operand);
1476
+ if (isa<xegpu::TensorDescType>(operand.getType ())) {
1477
+ newYieldTypes.push_back (newTensorDescTy);
1478
+ } else {
1479
+ newYieldTypes.push_back (operand.getType ());
1480
+ }
1481
+ }
1482
+ SmallVector<size_t > newRetIndices;
1483
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1484
+ rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1485
+ rewriter.setInsertionPointAfter (newWarpOp);
1486
+ SmallVector<Value> newUpdateOperands;
1487
+ for (size_t i : newRetIndices) {
1488
+ // For the tensor descriptor operand, the layout attribute is dropped
1489
+ // after distribution. Types needs to be resolved in this case.
1490
+ if (isa<xegpu::TensorDescType>(newWarpOp.getResult (i).getType ())) {
1491
+ newUpdateOperands.push_back (resolveDistributedTy (
1492
+ newWarpOp.getResult (i), newTensorDescTy, rewriter));
1493
+ } else {
1494
+ newUpdateOperands.push_back (newWarpOp.getResult (i));
1495
+ }
1496
+ }
1497
+ // Create a new update op outside the warp op.
1498
+ auto newUpdateOp = rewriter.create <xegpu::UpdateNdOffsetOp>(
1499
+ newWarpOp.getLoc (), newTensorDescTy, newUpdateOperands,
1500
+ removeTemporaryLayoutAttributes (updateOp->getAttrs ()));
1501
+ Value distributedVal = newWarpOp.getResult (operandIdx);
1502
+ rewriter.replaceAllUsesWith (distributedVal, newUpdateOp);
1503
+ return success ();
1504
+ }
1505
+ };
1506
+
1507
+ // / Distribute a prefetch_nd op at the end of enclosing
1508
+ // / `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed
1509
+ // / through the warp op interface they would be propagated as returned values.
1510
+ // / Tensor descriptor shape is not distributed because it is a uniform value
1511
+ // / across all work items within the subgroup. Appropriate cast ops are inserted
1512
+ // / if the distributed types does not match expected xegpu SIMT types.
1513
+ // /
1514
+ // / Example:
1515
+ // /
1516
+ // / ```
1517
+ // / #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1518
+ // / gpu.warp_execute_on_lane_0(%laneid) -> () {
1519
+ // / ...
1520
+ // / xegpu.prefetch_nd %arg0 : !xegpu.tensor_desc<4x8xf32, #layout0>
1521
+ // / }
1522
+ // / ```
1523
+ // / To
1524
+ // / ```
1525
+ // / %r:1 = gpu.warp_execute_on_lane_0(%laneid) -> (
1526
+ // / !xegpu.tensor_desc<4x8xf32, #layout0>) {
1527
+ // / gpu.yield %arg0: !xegpu.tensor_desc<4x8xf32, #layout0>
1528
+ // / }
1529
+ // / %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
1530
+ // / #layout0> -> !xegpu.tensor_desc<4x8xf32>
1531
+ // / xegpu.prefetch_nd %1 : !xegpu.tensor_desc<4x8xf32>
1532
+ // /
1533
+ // / ```
1534
+ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
1535
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
1536
+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp,
1537
+ PatternRewriter &rewriter) const override {
1538
+ auto yield = cast<gpu::YieldOp>(
1539
+ subgroupOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
1540
+ Operation *lastNode = yield->getPrevNode ();
1541
+ auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
1542
+ if (!prefetchOp)
1543
+ return failure ();
1544
+ xegpu::LayoutAttr layout = prefetchOp.getTensorDescType ().getLayoutAttr ();
1545
+ if (!layout)
1546
+ return rewriter.notifyMatchFailure (
1547
+ prefetchOp, " the source tensor descriptor lacks layout attribute" );
1548
+
1549
+ SmallVector<Value, 1 > newYieldValues = {prefetchOp.getTensorDesc ()};
1550
+ SmallVector<Type, 1 > newYieldTypes = {prefetchOp.getTensorDescType ()};
1551
+ SmallVector<size_t > newRetIndices;
1552
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1553
+ rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1554
+ // Create a new prefetch op outside the warp op with updated tensor
1555
+ // descriptor type. Source tensor descriptor require type resolution.
1556
+ xegpu::TensorDescType newTensorDescTy =
1557
+ prefetchOp.getTensorDescType ().dropLayouts ();
1558
+ rewriter.setInsertionPointAfter (newWarpOp);
1559
+ SmallVector<Value> newPrefetchOperands = {resolveDistributedTy (
1560
+ newWarpOp.getResult (newRetIndices[0 ]), newTensorDescTy, rewriter)};
1561
+ rewriter.create <xegpu::PrefetchNdOp>(
1562
+ newWarpOp.getLoc (), TypeRange{}, newPrefetchOperands,
1563
+ removeTemporaryLayoutAttributes (prefetchOp->getAttrs ()));
1564
+ rewriter.eraseOp (prefetchOp);
1565
+ return success ();
1566
+ }
1567
+ };
1568
+
1415
1569
} // namespace
1416
1570
1417
1571
namespace {
@@ -1430,7 +1584,8 @@ struct XeGPUSubgroupDistributePass final
1430
1584
void xegpu::populateXeGPUSubgroupDistributePatterns (
1431
1585
RewritePatternSet &patterns) {
1432
1586
patterns.add <CreateNdDescDistribution, StoreNdDistribution,
1433
- LoadNdDistribution, DpasDistribution>(patterns.getContext ());
1587
+ LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
1588
+ UpdateNdOffsetDistribution>(patterns.getContext ());
1434
1589
}
1435
1590
1436
1591
void XeGPUSubgroupDistributePass::runOnOperation () {
0 commit comments