-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][vector] Add extra check on distribute types to avoid crashes #102952
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Bangtian Liu <[email protected]>
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Bangtian Liu (bangtianliu) ChangesThis PR addresses the issue detailed in iree-org/iree#17948. The problem occurs when distributed types are set to NULL, leading to compilation crashes. Full diff: https://github.com/llvm/llvm-project/pull/102952.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 7285ad65fb549e..29899f44eb2e22 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1689,6 +1689,9 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
});
+ if(llvm::any_of(distTypes, [](Type type){return !type;}))
+ return failure();
+
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Signed-off-by: Bangtian Liu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a test that captures this?
Not sure how to do it |
The way this is written, it is really hard to write a test here. Maybe if you have a repro of the pass failure before/after, reducing it manually to the smallest size that preserves the failure might help. |
(wait, where is my comment! I swear I wrote something down...) According to Ian's log, it failed in VectorReductionToGPUPass. Can we trim IREE specifics from the IR and add a test to vector-warp-distribute.mlir? // -----// IR Dump Before VectorReductionToGPUPass (iree-codegen-vector-reduction-to-gpu) //----- //
func.func @main$async_dispatch_6_generic_32x262144_f16xf32xf32xf32() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUWarpReduction workgroup_size = [1024, 1, 1] subgroup_size = 64>} {
%cst = arith.constant dense<0.000000e+00> : vector<4096xf32>
%cst_0 = arith.constant dense<9.99999997E-7> : vector<1xf32>
%cst_1 = arith.constant dense<2.621440e+05> : vector<1xf32>
%cst_2 = arith.constant dense<2.621440e+05> : vector<4096xf32>
%cst_3 = arith.constant 0.000000e+00 : f16
%cst_4 = arith.constant dense<0.000000e+00> : vector<1xf32>
%c262144 = arith.constant 262144 : index
%c16384 = arith.constant 16384 : index
%c16 = arith.constant 16 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4096 = arith.constant 4096 : index
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = arith.index_castui %0 : i32 to index
%3 = arith.index_castui %1 : i32 to index
%4 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%2) flags(ReadOnly) : memref<32x16x16384xf16, strided<[262144, 16384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %4, 1 : memref<32x16x16384xf16, strided<[262144, 16384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%5 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%2) flags(ReadOnly) : memref<32x262144xf16, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %5, 1 : memref<32x262144xf16, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%6 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%3) : memref<32x262144xf32, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %6, 1 : memref<32x262144xf32, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%7 = scf.for %arg0 = %c0 to %c16 step %c1 iter_args(%arg1 = %cst) -> (vector<4096xf32>) {
%19 = scf.for %arg2 = %c0 to %c16384 step %c4096 iter_args(%arg3 = %arg1) -> (vector<4096xf32>) {
%20 = vector.transfer_read %4[%workgroup_id_x, %arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x16x16384xf16, strided<[262144, 16384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4096xf16>
%21 = arith.extf %20 : vector<4096xf16> to vector<4096xf32>
%22 = arith.addf %21, %arg3 : vector<4096xf32>
scf.yield %22 : vector<4096xf32>
}
scf.yield %19 : vector<4096xf32>
}
%8 = vector.broadcast %7 : vector<4096xf32> to vector<1x1x4096xf32>
%9 = vector.multi_reduction <add>, %8, %cst_4 [1, 2] : vector<1x1x4096xf32> to vector<1xf32>
%10 = vector.broadcast %9 : vector<1xf32> to vector<4096xf32>
%11 = arith.divf %10, %cst_2 : vector<4096xf32>
%12 = scf.for %arg0 = %c0 to %c16 step %c1 iter_args(%arg1 = %cst) -> (vector<4096xf32>) {
%19 = scf.for %arg2 = %c0 to %c16384 step %c4096 iter_args(%arg3 = %arg1) -> (vector<4096xf32>) {
%20 = vector.transfer_read %4[%workgroup_id_x, %arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x16x16384xf16, strided<[262144, 16384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4096xf16>
%21 = arith.extf %20 : vector<4096xf16> to vector<4096xf32>
%22 = arith.subf %21, %11 : vector<4096xf32>
%23 = arith.mulf %22, %22 : vector<4096xf32>
%24 = arith.addf %23, %arg3 : vector<4096xf32>
scf.yield %24 : vector<4096xf32>
}
scf.yield %19 : vector<4096xf32>
}
%13 = vector.broadcast %12 : vector<4096xf32> to vector<1x1x4096xf32>
%14 = vector.multi_reduction <add>, %13, %cst_4 [1, 2] : vector<1x1x4096xf32> to vector<1xf32>
%15 = arith.divf %9, %cst_1 : vector<1xf32>
%16 = arith.divf %14, %cst_1 : vector<1xf32>
%17 = arith.addf %16, %cst_0 : vector<1xf32>
%18 = math.rsqrt %17 : vector<1xf32>
scf.for %arg0 = %c0 to %c262144 step %c1 {
%19 = vector.transfer_read %5[%workgroup_id_x, %arg0], %cst_3 {in_bounds = [true]} : memref<32x262144xf16, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%20 = arith.extf %19 : vector<1xf16> to vector<1xf32>
%21 = arith.subf %20, %15 : vector<1xf32>
%22 = arith.mulf %21, %18 : vector<1xf32>
vector.transfer_write %22, %6[%workgroup_id_x, %arg0] {in_bounds = [true]} : vector<1xf32>, memref<32x262144xf32, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
} |
Thanks for the suggestions @kuhar @MaheshRavishankar @hanhanW, I will see how to add a test for this PR. |
Signed-off-by: Bangtian Liu <[email protected]>
Added a test that can capture the failure. Without the fixes from this PR, the newly added function causes crashes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, just few nits about the lit test.
Signed-off-by: Bangtian Liu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR addresses the issue detailed in iree-org/iree#17948.
The problem occurs when distributed types are set to NULL, leading to compilation crashes.