@@ -2371,9 +2371,9 @@ Value BroadcastOp::createOrFoldBroadcastOp(
2371
2371
return res;
2372
2372
}
2373
2373
2374
- BroadcastableToResult
2375
- mlir::vector::isBroadcastableTo ( Type srcType, VectorType dstVectorType,
2376
- std::pair<int , int > *mismatchingDims) {
2374
+ BroadcastableToResult mlir::vector::isBroadcastableTo (
2375
+ Type srcType, VectorType dstVectorType,
2376
+ std::pair<VectorDim, VectorDim > *mismatchingDims) {
2377
2377
// Broadcast scalar to vector of the same element type.
2378
2378
if (srcType.isIntOrIndexOrFloat () && dstVectorType &&
2379
2379
getElementTypeOrSelf (srcType) == getElementTypeOrSelf (dstVectorType))
@@ -2390,13 +2390,31 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
2390
2390
// Source has an exact match or singleton value for all trailing dimensions
2391
2391
// (all leading dimensions are simply duplicated).
2392
2392
int64_t lead = dstRank - srcRank;
2393
- for (int64_t r = 0 ; r < srcRank; ++r) {
2394
- int64_t srcDim = srcVectorType.getDimSize (r);
2395
- int64_t dstDim = dstVectorType.getDimSize (lead + r);
2396
- if (srcDim != 1 && srcDim != dstDim) {
2397
- if (mismatchingDims) {
2398
- mismatchingDims->first = srcDim;
2399
- mismatchingDims->second = dstDim;
2393
+ for (int64_t dimIdx = 0 ; dimIdx < srcRank; ++dimIdx) {
2394
+ // Have mismatching dims (in the sense of vector.broadcast semantics) been
2395
+ // encountered?
2396
+ bool foundMismatchingDims = false ;
2397
+
2398
+ // Check fixed-width dims.
2399
+ int64_t srcDim = srcVectorType.getDimSize (dimIdx);
2400
+ int64_t dstDim = dstVectorType.getDimSize (lead + dimIdx);
2401
+ if (srcDim != 1 && srcDim != dstDim)
2402
+ foundMismatchingDims = true ;
2403
+
2404
+ // Check scalable flags.
2405
+ bool srcDimScalableFlag = srcVectorType.getScalableDims ()[dimIdx];
2406
+ bool dstDimScalableFlag = dstVectorType.getScalableDims ()[lead + dimIdx];
2407
+ if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1 ) ||
2408
+ (srcDimScalableFlag != dstDimScalableFlag))
2409
+ foundMismatchingDims = true ;
2410
+
2411
+ if (foundMismatchingDims) {
2412
+ if (mismatchingDims != nullptr ) {
2413
+ mismatchingDims->first .dim = srcDim;
2414
+ mismatchingDims->first .isScalable = srcDimScalableFlag;
2415
+
2416
+ mismatchingDims->second .dim = dstDim;
2417
+ mismatchingDims->second .isScalable = dstDimScalableFlag;
2400
2418
}
2401
2419
return BroadcastableToResult::DimensionMismatch;
2402
2420
}
@@ -2406,16 +2424,22 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
2406
2424
}
2407
2425
2408
2426
LogicalResult BroadcastOp::verify () {
2409
- std::pair<int , int > mismatchingDims;
2427
+ std::pair<VectorDim, VectorDim > mismatchingDims;
2410
2428
BroadcastableToResult res = isBroadcastableTo (
2411
2429
getSourceType (), getResultVectorType (), &mismatchingDims);
2412
2430
if (res == BroadcastableToResult::Success)
2413
2431
return success ();
2414
2432
if (res == BroadcastableToResult::SourceRankHigher)
2415
2433
return emitOpError (" source rank higher than destination rank" );
2416
- if (res == BroadcastableToResult::DimensionMismatch)
2434
+ if (res == BroadcastableToResult::DimensionMismatch) {
2417
2435
return emitOpError (" dimension mismatch (" )
2418
- << mismatchingDims.first << " vs. " << mismatchingDims.second << " )" ;
2436
+ << (mismatchingDims.first .isScalable ? " [" : " " )
2437
+ << mismatchingDims.first .dim
2438
+ << (mismatchingDims.first .isScalable ? " ]" : " " ) << " vs. "
2439
+ << (mismatchingDims.second .isScalable ? " [" : " " )
2440
+ << mismatchingDims.second .dim
2441
+ << (mismatchingDims.second .isScalable ? " ]" : " " ) << " )" ;
2442
+ }
2419
2443
if (res == BroadcastableToResult::SourceTypeNotAVector)
2420
2444
return emitOpError (" source type is not a vector" );
2421
2445
llvm_unreachable (" unexpected vector.broadcast op error" );
0 commit comments