Skip to content

[mlir][vector] Add extra check on distribute types to avoid crashes #102952

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 14, 2024

Conversation

bangtianliu
Copy link
Contributor

This PR addresses the issue detailed in iree-org/iree#17948.

The problem occurs when distributed types are set to NULL, leading to compilation crashes.

@llvmbot
Copy link
Member

llvmbot commented Aug 12, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Bangtian Liu (bangtianliu)

Changes

This PR addresses the issue detailed in iree-org/iree#17948.

The problem occurs when distributed types are set to NULL, leading to compilation crashes.


Full diff: https://github.com/llvm/llvm-project/pull/102952.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+3)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 7285ad65fb549e..29899f44eb2e22 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1689,6 +1689,9 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
           }
         });
 
+    if(llvm::any_of(distTypes, [](Type type){return !type;}))
+      return failure();
+
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, escapingValues.getArrayRef(), distTypes,

Copy link

github-actions bot commented Aug 12, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Signed-off-by: Bangtian Liu <[email protected]>
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a test that captures this?

@bangtianliu
Copy link
Contributor Author

Could we add a test that captures this?

Not sure how to do it

@hanhanW hanhanW changed the title add extra check on distribute types to avoid crashes [mlir][vector] Add extra check on distribute types to avoid crashes Aug 12, 2024
@MaheshRavishankar
Copy link
Contributor

The way this is written, it is really hard to write a test here. Maybe if you have a repro of the pass failure before/after, reducing it manually to the smallest size that preserves the failure might help.

@hanhanW
Copy link
Contributor

hanhanW commented Aug 12, 2024

(wait, where is my comment! I swear I wrote something down...)

According to Ian's log, it failed in VectorReductionToGPUPass. Can we trim IREE specifics from the IR and add a test to vector-warp-distribute.mlir?

// -----// IR Dump Before VectorReductionToGPUPass (iree-codegen-vector-reduction-to-gpu) //----- //
func.func @main$async_dispatch_6_generic_32x262144_f16xf32xf32xf32() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUWarpReduction workgroup_size = [1024, 1, 1] subgroup_size = 64>} {
  %cst = arith.constant dense<0.000000e+00> : vector<4096xf32>
  %cst_0 = arith.constant dense<9.99999997E-7> : vector<1xf32>
  %cst_1 = arith.constant dense<2.621440e+05> : vector<1xf32>
  %cst_2 = arith.constant dense<2.621440e+05> : vector<4096xf32>
  %cst_3 = arith.constant 0.000000e+00 : f16
  %cst_4 = arith.constant dense<0.000000e+00> : vector<1xf32>
  %c262144 = arith.constant 262144 : index
  %c16384 = arith.constant 16384 : index
  %c16 = arith.constant 16 : index
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c4096 = arith.constant 4096 : index
  %0 = hal.interface.constant.load[0] : i32
  %1 = hal.interface.constant.load[1] : i32
  %2 = arith.index_castui %0 : i32 to index
  %3 = arith.index_castui %1 : i32 to index
  %4 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%2) flags(ReadOnly) : memref<32x16x16384xf16, strided<[262144, 16384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %4, 1 : memref<32x16x16384xf16, strided<[262144, 16384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  %5 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%2) flags(ReadOnly) : memref<32x262144xf16, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %5, 1 : memref<32x262144xf16, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  %6 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%3) : memref<32x262144xf32, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %6, 1 : memref<32x262144xf32, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %7 = scf.for %arg0 = %c0 to %c16 step %c1 iter_args(%arg1 = %cst) -> (vector<4096xf32>) {
    %19 = scf.for %arg2 = %c0 to %c16384 step %c4096 iter_args(%arg3 = %arg1) -> (vector<4096xf32>) {
      %20 = vector.transfer_read %4[%workgroup_id_x, %arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x16x16384xf16, strided<[262144, 16384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4096xf16>
      %21 = arith.extf %20 : vector<4096xf16> to vector<4096xf32>
      %22 = arith.addf %21, %arg3 : vector<4096xf32>
      scf.yield %22 : vector<4096xf32>
    }
    scf.yield %19 : vector<4096xf32>
  }
  %8 = vector.broadcast %7 : vector<4096xf32> to vector<1x1x4096xf32>
  %9 = vector.multi_reduction <add>, %8, %cst_4 [1, 2] : vector<1x1x4096xf32> to vector<1xf32>
  %10 = vector.broadcast %9 : vector<1xf32> to vector<4096xf32>
  %11 = arith.divf %10, %cst_2 : vector<4096xf32>
  %12 = scf.for %arg0 = %c0 to %c16 step %c1 iter_args(%arg1 = %cst) -> (vector<4096xf32>) {
    %19 = scf.for %arg2 = %c0 to %c16384 step %c4096 iter_args(%arg3 = %arg1) -> (vector<4096xf32>) {
      %20 = vector.transfer_read %4[%workgroup_id_x, %arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x16x16384xf16, strided<[262144, 16384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4096xf16>
      %21 = arith.extf %20 : vector<4096xf16> to vector<4096xf32>
      %22 = arith.subf %21, %11 : vector<4096xf32>
      %23 = arith.mulf %22, %22 : vector<4096xf32>
      %24 = arith.addf %23, %arg3 : vector<4096xf32>
      scf.yield %24 : vector<4096xf32>
    }
    scf.yield %19 : vector<4096xf32>
  }
  %13 = vector.broadcast %12 : vector<4096xf32> to vector<1x1x4096xf32>
  %14 = vector.multi_reduction <add>, %13, %cst_4 [1, 2] : vector<1x1x4096xf32> to vector<1xf32>
  %15 = arith.divf %9, %cst_1 : vector<1xf32>
  %16 = arith.divf %14, %cst_1 : vector<1xf32>
  %17 = arith.addf %16, %cst_0 : vector<1xf32>
  %18 = math.rsqrt %17 : vector<1xf32>
  scf.for %arg0 = %c0 to %c262144 step %c1 {
    %19 = vector.transfer_read %5[%workgroup_id_x, %arg0], %cst_3 {in_bounds = [true]} : memref<32x262144xf16, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
    %20 = arith.extf %19 : vector<1xf16> to vector<1xf32>
    %21 = arith.subf %20, %15 : vector<1xf32>
    %22 = arith.mulf %21, %18 : vector<1xf32>
    vector.transfer_write %22, %6[%workgroup_id_x, %arg0] {in_bounds = [true]} : vector<1xf32>, memref<32x262144xf32, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  }
  return
}

@bangtianliu
Copy link
Contributor Author

Thanks for the suggestions @kuhar @MaheshRavishankar @hanhanW, I will see how to add a test for this PR.

Signed-off-by: Bangtian Liu <[email protected]>
@bangtianliu
Copy link
Contributor Author

Added a test that can capture the failure. Without the fixes from this PR, the newly added function causes crashes.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just few nits about the lit test.

Signed-off-by: Bangtian Liu <[email protected]>
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@hanhanW hanhanW merged commit b5e47d2 into llvm:main Aug 14, 2024
8 checks passed
@bangtianliu bangtianliu deleted the warpreduction-llvm-fixes branch August 14, 2024 22:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants