Skip to content

Conversation

lpy
Copy link
Contributor

@lpy lpy commented May 7, 2025

Now that MLIR accepts nuw and nusw in getelementptr, this patch emits
the inbounds and nuw attributes when lower memref to LLVM in load and
store operators.

This patch also strengthens the memref.load and memref.store spec about
undefined behaviour during lowering.

This patch also lifts the |rewriter| parameter in getStridedElementPtr
ahead so that LLVM::GEPNoWrapFlags can be added at the end with a
default value and grouped together with other operators' parameters.

@llvmbot
Copy link
Member

llvmbot commented May 7, 2025

@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir-amx
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Peiyong Lin (lpy)

Changes

Now that MLIR accepts nuw and nusw in getelementptr, this patch emits the inbounds and nuw attributes when lower memref to LLVM in load and store operators. It is guaranteed that memref.load and memref.store must be inbounds: 0 <= idx < dim_size.

This patch also lifts the |rewriter| parameter in getStridedElementPtr ahead so that LLVM::GEPNoWrapFlags can be added at the end with a default value and grouped together with other operators' parameters.


Patch is 30.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138984.diff

13 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/Pattern.h (+4-3)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+6-4)
  • (modified) mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp (+8-8)
  • (modified) mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp (+5-4)
  • (modified) mlir/lib/Conversion/LLVMCommon/Pattern.cpp (+4-3)
  • (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+15-13)
  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+14-13)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+10-10)
  • (modified) mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp (+4-4)
  • (modified) mlir/test/Conversion/FuncToLLVM/calling-convention.mlir (+3-3)
  • (modified) mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir (+4-4)
  • (modified) mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir (+5-5)
  • (modified) mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir (+1-1)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 7a58e4fc2f984..66d0fc624e8f1 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -83,9 +83,10 @@ class ConvertToLLVMPattern : public ConversionPattern {
 
   // This is a strided getElementPtr variant that linearizes subscripts as:
   //   `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
-  Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
-                             ValueRange indices,
-                             ConversionPatternRewriter &rewriter) const;
+  Value getStridedElementPtr(
+      ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
+      Value memRefDesc, ValueRange indices,
+      LLVM::GEPNoWrapFlags noWrapFlags = LLVM::GEPNoWrapFlags::none) const;
 
   /// Returns if the given memref type is convertible to LLVM and has an
   /// identity layout map.
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 6e596485cbb58..ff462033462b2 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1118,10 +1118,12 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
     if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
       return op.emitOpError("chipset unsupported element size");
 
-    Value srcPtr = getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
-                                        (adaptor.getSrcIndices()), rewriter);
-    Value dstPtr = getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
-                                        (adaptor.getDstIndices()), rewriter);
+    Value srcPtr =
+        getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
+                             (adaptor.getSrcIndices()));
+    Value dstPtr =
+        getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
+                             (adaptor.getDstIndices()));
 
     rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
         op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 417555792b44f..0c3f942b5cbd9 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -299,9 +299,9 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
     auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
         loc, rewriter.getI64Type(), sliceIndex);
     return getStridedElementPtr(
-        loc, llvm::cast<MemRefType>(tileMemory.getType()),
-        descriptor.getResult(0), {sliceIndexI64, zero},
-        static_cast<ConversionPatternRewriter &>(rewriter));
+        static_cast<ConversionPatternRewriter &>(rewriter), loc,
+        llvm::cast<MemRefType>(tileMemory.getType()), descriptor.getResult(0),
+        {sliceIndexI64, zero});
   }
 
   /// Emits an in-place swap of a slice of a tile in ZA and a slice of a
@@ -507,9 +507,9 @@ struct LoadTileSliceConversion
     if (!tileId)
       return failure();
 
-    Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
-                                           adaptor.getBase(),
-                                           adaptor.getIndices(), rewriter);
+    Value ptr = this->getStridedElementPtr(
+        rewriter, loc, loadTileSliceOp.getMemRefType(), adaptor.getBase(),
+        adaptor.getIndices());
 
     auto tileSlice = loadTileSliceOp.getTileSliceIndex();
 
@@ -554,8 +554,8 @@ struct StoreTileSliceConversion
 
     // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
     Value ptr = this->getStridedElementPtr(
-        loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
-        adaptor.getIndices(), rewriter);
+        rewriter, loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
+        adaptor.getIndices());
 
     auto tileSlice = storeTileSliceOp.getTileSliceIndex();
 
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..45fd933d58857 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -122,8 +122,9 @@ struct WmmaLoadOpToNVVMLowering
 
     // Create nvvm.mma_load op according to the operand types.
     Value dataPtr = getStridedElementPtr(
-        loc, cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
-        adaptor.getSrcMemref(), adaptor.getIndices(), rewriter);
+        rewriter, loc,
+        cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
+        adaptor.getSrcMemref(), adaptor.getIndices());
 
     Value leadingDim = rewriter.create<LLVM::ConstantOp>(
         loc, rewriter.getI32Type(),
@@ -177,9 +178,9 @@ struct WmmaStoreOpToNVVMLowering
     }
 
     Value dataPtr = getStridedElementPtr(
-        loc,
+        rewriter, loc,
         cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
-        adaptor.getDstMemref(), adaptor.getIndices(), rewriter);
+        adaptor.getDstMemref(), adaptor.getIndices());
     Value leadingDim = rewriter.create<LLVM::ConstantOp>(
         loc, rewriter.getI32Type(),
         subgroupMmaStoreMatrixOp.getLeadDimensionAttr());
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 0505214de2015..6942a64048722 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -59,8 +59,9 @@ Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
 }
 
 Value ConvertToLLVMPattern::getStridedElementPtr(
-    Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
-    ConversionPatternRewriter &rewriter) const {
+    ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
+    Value memRefDesc, ValueRange indices,
+    LLVM::GEPNoWrapFlags noWrapFlags) const {
 
   auto [strides, offset] = type.getStridesAndOffset();
 
@@ -91,7 +92,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
   return index ? rewriter.create<LLVM::GEPOp>(
                      loc, elementPtrType,
                      getTypeConverter()->convertType(type.getElementType()),
-                     base, index)
+                     base, index, noWrapFlags)
                : base;
 }
 
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index c8b2c0bdc6c20..8753505b6db46 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -420,8 +420,8 @@ struct AssumeAlignmentOpLowering
     auto loc = op.getLoc();
 
     auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
-    Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
-                                     rewriter);
+    Value ptr = getStridedElementPtr(rewriter, loc, srcMemRefType, memref,
+                                     /*indices=*/{});
 
     // Emit llvm.assume(true) ["align"(memref, alignment)].
     // This is more direct than ptrtoint-based checks, is explicitly supported,
@@ -644,8 +644,8 @@ struct GenericAtomicRMWOpLowering
     // Compute the loaded value and branch to the loop block.
     rewriter.setInsertionPointToEnd(initBlock);
     auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
-    auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
-                                        adaptor.getIndices(), rewriter);
+    auto dataPtr = getStridedElementPtr(
+        rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
     Value init = rewriter.create<LLVM::LoadOp>(
         loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
@@ -829,9 +829,10 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto type = loadOp.getMemRefType();
 
-    Value dataPtr =
-        getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
-                             adaptor.getIndices(), rewriter);
+    Value dataPtr = getStridedElementPtr(
+        rewriter, loadOp.getLoc(), type, adaptor.getMemref(),
+        adaptor.getIndices(),
+        LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw);
     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
         loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
         false, loadOp.getNontemporal());
@@ -849,8 +850,9 @@ struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto type = op.getMemRefType();
 
-    Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
-                                         adaptor.getIndices(), rewriter);
+    Value dataPtr = getStridedElementPtr(
+        rewriter, op.getLoc(), type, adaptor.getMemref(), adaptor.getIndices(),
+        LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw);
     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
                                                0, false, op.getNontemporal());
     return success();
@@ -868,8 +870,8 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
     auto type = prefetchOp.getMemRefType();
     auto loc = prefetchOp.getLoc();
 
-    Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
-                                         adaptor.getIndices(), rewriter);
+    Value dataPtr = getStridedElementPtr(
+        rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices());
 
     // Replace with llvm.prefetch.
     IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
@@ -1809,8 +1811,8 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
     if (failed(memRefType.getStridesAndOffset(strides, offset)))
       return failure();
     auto dataPtr =
-        getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
-                             adaptor.getIndices(), rewriter);
+        getStridedElementPtr(rewriter, atomicOp.getLoc(), memRefType,
+                             adaptor.getMemref(), adaptor.getIndices());
     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
         atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
         LLVM::AtomicOrdering::acq_rel);
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 69fa62c8196e4..eb3558d2460e4 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -283,8 +283,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
 
     auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
     Value srcPtr =
-        getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
-                             adaptor.getIndices(), rewriter);
+        getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
+                             adaptor.getSrcMemref(), adaptor.getIndices());
     Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
         ldMatrixResultType, srcPtr,
         /*num=*/op.getNumTiles(),
@@ -661,8 +661,8 @@ struct NVGPUAsyncCopyLowering
     Location loc = op.getLoc();
     auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
     Value dstPtr =
-        getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
-                             adaptor.getDstIndices(), rewriter);
+        getStridedElementPtr(rewriter, b.getLoc(), dstMemrefType,
+                             adaptor.getDst(), adaptor.getDstIndices());
     FailureOr<unsigned> dstAddressSpace =
         getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
     if (failed(dstAddressSpace))
@@ -676,8 +676,9 @@ struct NVGPUAsyncCopyLowering
       return rewriter.notifyMatchFailure(
           loc, "source memref address space not convertible to integer");
 
-    Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
-                                        adaptor.getSrcIndices(), rewriter);
+    Value scrPtr =
+        getStridedElementPtr(rewriter, loc, srcMemrefType, adaptor.getSrc(),
+                             adaptor.getSrcIndices());
     // Intrinsics takes a global pointer so we need an address space cast.
     auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
         op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace);
@@ -814,7 +815,7 @@ struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
     MemRefType mbarrierMemrefType =
         nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
     return ConvertToLLVMPattern::getStridedElementPtr(
-        b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
+        rewriter, b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId});
   }
 };
 
@@ -995,8 +996,8 @@ struct NVGPUTmaAsyncLoadOpLowering
                   ConversionPatternRewriter &rewriter) const override {
     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
     auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
-    Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
-                                      adaptor.getDst(), {}, rewriter);
+    Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
+                                      adaptor.getDst(), {});
     Value barrier =
         getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
                        adaptor.getMbarId(), rewriter);
@@ -1021,8 +1022,8 @@ struct NVGPUTmaAsyncStoreOpLowering
                   ConversionPatternRewriter &rewriter) const override {
     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
     auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
-    Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
-                                      adaptor.getSrc(), {}, rewriter);
+    Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
+                                      adaptor.getSrc(), {});
     SmallVector<Value> coords = adaptor.getCoordinates();
     for (auto [index, value] : llvm::enumerate(coords)) {
       coords[index] = truncToI32(b, value);
@@ -1083,8 +1084,8 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
     Value leadDim = makeConst(leadDimVal);
 
     Value baseAddr = getStridedElementPtr(
-        op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
-        adaptor.getTensor(), {}, rewriter);
+        rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
+        adaptor.getTensor(), {});
     Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
     // Just use 14 bits for base address
     Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 5296013189b9e..154b989ae5a12 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -289,8 +289,8 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
     // Resolve address.
     auto vtype = cast<VectorType>(
         this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
-    Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
-                                               adaptor.getIndices(), rewriter);
+    Value dataPtr = this->getStridedElementPtr(
+        rewriter, loc, memRefTy, adaptor.getBase(), adaptor.getIndices());
     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
                          rewriter);
     return success();
@@ -337,8 +337,8 @@ class VectorGatherOpConversion
       return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
 
     // Resolve address.
-    Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
-                                     adaptor.getIndices(), rewriter);
+    Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
+                                     adaptor.getBase(), adaptor.getIndices());
     Value base = adaptor.getBase();
     Value ptrs =
         getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
@@ -393,8 +393,8 @@ class VectorScatterOpConversion
                                          "could not resolve alignment");
 
     // Resolve address.
-    Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
-                                     adaptor.getIndices(), rewriter);
+    Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
+                                     adaptor.getBase(), adaptor.getIndices());
     Value ptrs =
         getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
                        adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
@@ -428,8 +428,8 @@ class VectorExpandLoadOpConversion
 
     // Resolve address.
     auto vtype = typeConverter->convertType(expand.getVectorType());
-    Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
-                                     adaptor.getIndices(), rewriter);
+    Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
+                                     adaptor.getBase(), adaptor.getIndices());
 
     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
         expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
@@ -450,8 +450,8 @@ class VectorCompressStoreOpConversion
     MemRefType memRefType = compress.getMemRefType();
 
     // Resolve address.
-    Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
-                                     adaptor.getIndices(), rewriter);
+    Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
+                                     adaptor.getBase(), adaptor.getIndices());
 
     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
         compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 4cb777b03b196..2168409184549 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -105,8 +105,8 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
     if (failed(stride))
       return failure();
     // Replace operation with intrinsic.
-    Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
-                                     adaptor.getIndices(), rewriter);
+    Value ptr = getStridedElementPtr(rewriter, op.getLoc(), mType,
+                                     adaptor.getBase(), adaptor.getIndices());
     Type resType = typeConverter->convertType(tType);
     rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
         op, resType, tsz.first, tsz.second, ptr, stride.value());
@@ -131,8 +131,8 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
     if (failed(stride))
       return failure();
     // Replace operation with intrinsic.
-    Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
-                                     adaptor.getIndices(), rewriter);
+    Value ptr = getStridedElementPtr(rewriter, op.getLoc(), mType,
+                                     adaptor.getBase(), adaptor.getIndices());
     rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
         op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal());
     return success();
diff --git a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
index 058b69b8e3596..3b52d8fd76464 100644
--- a/mlir/test/Conversion/FuncTo...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 7, 2025

@llvm/pr-subscribers-backend-amdgpu

Author: Peiyong Lin (lpy)

Changes

Now that MLIR accepts nuw and nusw in getelementptr, this patch emits the inbounds and nuw attributes when lower memref to LLVM in load and store operators. It is guaranteed that memref.load and memref.store must be inbounds: 0 &lt;= idx &lt; dim_size.

This patch also lifts the |rewriter| parameter in getStridedElementPtr ahead so that LLVM::GEPNoWrapFlags can be added at the end with a default value and grouped together with other operators' parameters.


Patch is 30.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138984.diff

13 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/Pattern.h (+4-3)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+6-4)
  • (modified) mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp (+8-8)
  • (modified) mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp (+5-4)
  • (modified) mlir/lib/Conversion/LLVMCommon/Pattern.cpp (+4-3)
  • (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+15-13)
  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+14-13)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+10-10)
  • (modified) mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp (+4-4)
  • (modified) mlir/test/Conversion/FuncToLLVM/calling-convention.mlir (+3-3)
  • (modified) mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir (+4-4)
  • (modified) mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir (+5-5)
  • (modified) mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir (+1-1)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 7a58e4fc2f984..66d0fc624e8f1 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -83,9 +83,10 @@ class ConvertToLLVMPattern : public ConversionPattern {
 
   // This is a strided getElementPtr variant that linearizes subscripts as:
   //   `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
-  Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
-                             ValueRange indices,
-                             ConversionPatternRewriter &rewriter) const;
+  Value getStridedElementPtr(
+      ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
+      Value memRefDesc, ValueRange indices,
+      LLVM::GEPNoWrapFlags noWrapFlags = LLVM::GEPNoWrapFlags::none) const;
 
   /// Returns if the given memref type is convertible to LLVM and has an
   /// identity layout map.
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 6e596485cbb58..ff462033462b2 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1118,10 +1118,12 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
     if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
       return op.emitOpError("chipset unsupported element size");
 
-    Value srcPtr = getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
-                                        (adaptor.getSrcIndices()), rewriter);
-    Value dstPtr = getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
-                                        (adaptor.getDstIndices()), rewriter);
+    Value srcPtr =
+        getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
+                             (adaptor.getSrcIndices()));
+    Value dstPtr =
+        getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
+                             (adaptor.getDstIndices()));
 
     rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
         op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 417555792b44f..0c3f942b5cbd9 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -299,9 +299,9 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
     auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
         loc, rewriter.getI64Type(), sliceIndex);
     return getStridedElementPtr(
-        loc, llvm::cast<MemRefType>(tileMemory.getType()),
-        descriptor.getResult(0), {sliceIndexI64, zero},
-        static_cast<ConversionPatternRewriter &>(rewriter));
+        static_cast<ConversionPatternRewriter &>(rewriter), loc,
+        llvm::cast<MemRefType>(tileMemory.getType()), descriptor.getResult(0),
+        {sliceIndexI64, zero});
   }
 
   /// Emits an in-place swap of a slice of a tile in ZA and a slice of a
@@ -507,9 +507,9 @@ struct LoadTileSliceConversion
     if (!tileId)
       return failure();
 
-    Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
-                                           adaptor.getBase(),
-                                           adaptor.getIndices(), rewriter);
+    Value ptr = this->getStridedElementPtr(
+        rewriter, loc, loadTileSliceOp.getMemRefType(), adaptor.getBase(),
+        adaptor.getIndices());
 
     auto tileSlice = loadTileSliceOp.getTileSliceIndex();
 
@@ -554,8 +554,8 @@ struct StoreTileSliceConversion
 
     // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
     Value ptr = this->getStridedElementPtr(
-        loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
-        adaptor.getIndices(), rewriter);
+        rewriter, loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
+        adaptor.getIndices());
 
     auto tileSlice = storeTileSliceOp.getTileSliceIndex();
 
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..45fd933d58857 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -122,8 +122,9 @@ struct WmmaLoadOpToNVVMLowering
 
     // Create nvvm.mma_load op according to the operand types.
     Value dataPtr = getStridedElementPtr(
-        loc, cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
-        adaptor.getSrcMemref(), adaptor.getIndices(), rewriter);
+        rewriter, loc,
+        cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
+        adaptor.getSrcMemref(), adaptor.getIndices());
 
     Value leadingDim = rewriter.create<LLVM::ConstantOp>(
         loc, rewriter.getI32Type(),
@@ -177,9 +178,9 @@ struct WmmaStoreOpToNVVMLowering
     }
 
     Value dataPtr = getStridedElementPtr(
-        loc,
+        rewriter, loc,
         cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
-        adaptor.getDstMemref(), adaptor.getIndices(), rewriter);
+        adaptor.getDstMemref(), adaptor.getIndices());
     Value leadingDim = rewriter.create<LLVM::ConstantOp>(
         loc, rewriter.getI32Type(),
         subgroupMmaStoreMatrixOp.getLeadDimensionAttr());
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 0505214de2015..6942a64048722 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -59,8 +59,9 @@ Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
 }
 
 Value ConvertToLLVMPattern::getStridedElementPtr(
-    Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
-    ConversionPatternRewriter &rewriter) const {
+    ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
+    Value memRefDesc, ValueRange indices,
+    LLVM::GEPNoWrapFlags noWrapFlags) const {
 
   auto [strides, offset] = type.getStridesAndOffset();
 
@@ -91,7 +92,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
   return index ? rewriter.create<LLVM::GEPOp>(
                      loc, elementPtrType,
                      getTypeConverter()->convertType(type.getElementType()),
-                     base, index)
+                     base, index, noWrapFlags)
                : base;
 }
 
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index c8b2c0bdc6c20..8753505b6db46 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -420,8 +420,8 @@ struct AssumeAlignmentOpLowering
     auto loc = op.getLoc();
 
     auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
-    Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
-                                     rewriter);
+    Value ptr = getStridedElementPtr(rewriter, loc, srcMemRefType, memref,
+                                     /*indices=*/{});
 
     // Emit llvm.assume(true) ["align"(memref, alignment)].
     // This is more direct than ptrtoint-based checks, is explicitly supported,
@@ -644,8 +644,8 @@ struct GenericAtomicRMWOpLowering
     // Compute the loaded value and branch to the loop block.
     rewriter.setInsertionPointToEnd(initBlock);
     auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
-    auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
-                                        adaptor.getIndices(), rewriter);
+    auto dataPtr = getStridedElementPtr(
+        rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
     Value init = rewriter.create<LLVM::LoadOp>(
         loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
@@ -829,9 +829,10 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto type = loadOp.getMemRefType();
 
-    Value dataPtr =
-        getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
-                             adaptor.getIndices(), rewriter);
+    Value dataPtr = getStridedElementPtr(
+        rewriter, loadOp.getLoc(), type, adaptor.getMemref(),
+        adaptor.getIndices(),
+        LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw);
     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
         loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
         false, loadOp.getNontemporal());
@@ -849,8 +850,9 @@ struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto type = op.getMemRefType();
 
-    Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
-                                         adaptor.getIndices(), rewriter);
+    Value dataPtr = getStridedElementPtr(
+        rewriter, op.getLoc(), type, adaptor.getMemref(), adaptor.getIndices(),
+        LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw);
     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
                                                0, false, op.getNontemporal());
     return success();
@@ -868,8 +870,8 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
     auto type = prefetchOp.getMemRefType();
     auto loc = prefetchOp.getLoc();
 
-    Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
-                                         adaptor.getIndices(), rewriter);
+    Value dataPtr = getStridedElementPtr(
+        rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices());
 
     // Replace with llvm.prefetch.
     IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
@@ -1809,8 +1811,8 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
     if (failed(memRefType.getStridesAndOffset(strides, offset)))
       return failure();
     auto dataPtr =
-        getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
-                             adaptor.getIndices(), rewriter);
+        getStridedElementPtr(rewriter, atomicOp.getLoc(), memRefType,
+                             adaptor.getMemref(), adaptor.getIndices());
     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
         atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
         LLVM::AtomicOrdering::acq_rel);
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 69fa62c8196e4..eb3558d2460e4 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -283,8 +283,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
 
     auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
     Value srcPtr =
-        getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
-                             adaptor.getIndices(), rewriter);
+        getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
+                             adaptor.getSrcMemref(), adaptor.getIndices());
     Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
         ldMatrixResultType, srcPtr,
         /*num=*/op.getNumTiles(),
@@ -661,8 +661,8 @@ struct NVGPUAsyncCopyLowering
     Location loc = op.getLoc();
     auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
     Value dstPtr =
-        getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
-                             adaptor.getDstIndices(), rewriter);
+        getStridedElementPtr(rewriter, b.getLoc(), dstMemrefType,
+                             adaptor.getDst(), adaptor.getDstIndices());
     FailureOr<unsigned> dstAddressSpace =
         getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
     if (failed(dstAddressSpace))
@@ -676,8 +676,9 @@ struct NVGPUAsyncCopyLowering
       return rewriter.notifyMatchFailure(
           loc, "source memref address space not convertible to integer");
 
-    Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
-                                        adaptor.getSrcIndices(), rewriter);
+    Value scrPtr =
+        getStridedElementPtr(rewriter, loc, srcMemrefType, adaptor.getSrc(),
+                             adaptor.getSrcIndices());
     // Intrinsics takes a global pointer so we need an address space cast.
     auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
         op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace);
@@ -814,7 +815,7 @@ struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
     MemRefType mbarrierMemrefType =
         nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
     return ConvertToLLVMPattern::getStridedElementPtr(
-        b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
+        rewriter, b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId});
   }
 };
 
@@ -995,8 +996,8 @@ struct NVGPUTmaAsyncLoadOpLowering
                   ConversionPatternRewriter &rewriter) const override {
     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
     auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
-    Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
-                                      adaptor.getDst(), {}, rewriter);
+    Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
+                                      adaptor.getDst(), {});
     Value barrier =
         getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
                        adaptor.getMbarId(), rewriter);
@@ -1021,8 +1022,8 @@ struct NVGPUTmaAsyncStoreOpLowering
                   ConversionPatternRewriter &rewriter) const override {
     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
     auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
-    Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
-                                      adaptor.getSrc(), {}, rewriter);
+    Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
+                                      adaptor.getSrc(), {});
     SmallVector<Value> coords = adaptor.getCoordinates();
     for (auto [index, value] : llvm::enumerate(coords)) {
       coords[index] = truncToI32(b, value);
@@ -1083,8 +1084,8 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
     Value leadDim = makeConst(leadDimVal);
 
     Value baseAddr = getStridedElementPtr(
-        op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
-        adaptor.getTensor(), {}, rewriter);
+        rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
+        adaptor.getTensor(), {});
     Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
     // Just use 14 bits for base address
     Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 5296013189b9e..154b989ae5a12 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -289,8 +289,8 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
     // Resolve address.
     auto vtype = cast<VectorType>(
         this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
-    Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
-                                               adaptor.getIndices(), rewriter);
+    Value dataPtr = this->getStridedElementPtr(
+        rewriter, loc, memRefTy, adaptor.getBase(), adaptor.getIndices());
     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
                          rewriter);
     return success();
@@ -337,8 +337,8 @@ class VectorGatherOpConversion
       return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
 
     // Resolve address.
-    Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
-                                     adaptor.getIndices(), rewriter);
+    Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
+                                     adaptor.getBase(), adaptor.getIndices());
     Value base = adaptor.getBase();
     Value ptrs =
         getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
@@ -393,8 +393,8 @@ class VectorScatterOpConversion
                                          "could not resolve alignment");
 
     // Resolve address.
-    Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
-                                     adaptor.getIndices(), rewriter);
+    Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
+                                     adaptor.getBase(), adaptor.getIndices());
     Value ptrs =
         getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
                        adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
@@ -428,8 +428,8 @@ class VectorExpandLoadOpConversion
 
     // Resolve address.
     auto vtype = typeConverter->convertType(expand.getVectorType());
-    Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
-                                     adaptor.getIndices(), rewriter);
+    Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
+                                     adaptor.getBase(), adaptor.getIndices());
 
     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
         expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
@@ -450,8 +450,8 @@ class VectorCompressStoreOpConversion
     MemRefType memRefType = compress.getMemRefType();
 
     // Resolve address.
-    Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
-                                     adaptor.getIndices(), rewriter);
+    Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
+                                     adaptor.getBase(), adaptor.getIndices());
 
     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
         compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 4cb777b03b196..2168409184549 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -105,8 +105,8 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
     if (failed(stride))
       return failure();
     // Replace operation with intrinsic.
-    Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
-                                     adaptor.getIndices(), rewriter);
+    Value ptr = getStridedElementPtr(rewriter, op.getLoc(), mType,
+                                     adaptor.getBase(), adaptor.getIndices());
     Type resType = typeConverter->convertType(tType);
     rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
         op, resType, tsz.first, tsz.second, ptr, stride.value());
@@ -131,8 +131,8 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
     if (failed(stride))
       return failure();
     // Replace operation with intrinsic.
-    Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
-                                     adaptor.getIndices(), rewriter);
+    Value ptr = getStridedElementPtr(rewriter, op.getLoc(), mType,
+                                     adaptor.getBase(), adaptor.getIndices());
     rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
         op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal());
     return success();
diff --git a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
index 058b69b8e3596..3b52d8fd76464 100644
--- a/mlir/test/Conversion/FuncTo...
[truncated]

@lpy
Copy link
Contributor Author

lpy commented May 8, 2025

PTAL :)

@krzysz00 krzysz00 self-requested a review May 12, 2025 16:32
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, going to wait a bit for other comments before landing

@banach-space
Copy link
Contributor

It is guaranteed that memref.load and memref.store must be inbounds: 0 <= idx < dim_size.

I don't quite follow this statement - is it really guaranteed? Looking at one of the tests that has been updated:

func.func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f32) {
  memref.store %val, %static[%i, %j] : memref<10x42xf32>
  return
}

How can we be sure that i and j are in bounds? I appreciate that the docs ( memref.store) state that:

The indices must be in-bounds: 0 <= idx < dim_size

but there's just nothing to enforce that, is there? It would be good to somehow document that this assumption is used when lowering to llvm.getelemptr.

Also, why nuw instead of nsuw?

@lpy
Copy link
Contributor Author

lpy commented May 15, 2025

It is guaranteed that memref.load and memref.store must be inbounds: 0 <= idx < dim_size.

I don't quite follow this statement - is it really guaranteed? Looking at one of the tests that has been updated:

func.func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f32) {
  memref.store %val, %static[%i, %j] : memref<10x42xf32>
  return
}

How can we be sure that i and j are in bounds? I appreciate that the docs ( memref.store) state that:

The indices must be in-bounds: 0 <= idx < dim_size

but there's just nothing to enforce that, is there? It would be good to somehow document that this assumption is used when lowering to llvm.getelemptr.

Thanks for pointing this out. I checked a bit and if I understand correctly this is not enforced in the code. I can update the memref.load and memref.store doc to capture this information. @krzysz00 wdyt?

Also, why nuw instead of nsuw?

If I understand correctly inbounds already implies nusw.

@krzysz00
Copy link
Contributor

  1. Yeah, there's no static enforecement of the 0 <= [index] < [dimension length] requirement, but it's documented on memref.load and memref.store. It might be worth strengthening the language in the op description about undefined behavior
  2. inbounds implies nusw, and we additionally know nuw because all of the offsets we're adding are non-negative

@lpy lpy force-pushed the 20483 branch 2 times, most recently from ffe1f3b to e92c1f5 Compare May 15, 2025 22:26
@lpy
Copy link
Contributor Author

lpy commented May 15, 2025

I updated the spec language to call out the undefined behaviour, please take a look :)

@banach-space
Copy link
Contributor

inbounds implies nusw, and we additionally know nuw because all of the offsets we're adding are non-negative

OK, IIUC, we make this assumption based on the docs, but that's not something that we can enforce, right? Just making sure I follow the rationale :)

Similarly, inbounds is added based on what the docs for memref.load + memref.store say?

Could you document this rationale in MemrefToLLVM.cpp? Otherwise, LLVM::GEPNoWrapFlags::nuw is no different to a magic number. Thanks!

@lpy
Copy link
Contributor Author

lpy commented May 16, 2025

I added comments to talk about why inbounds and nuw are added. Please take another look :)

@krzysz00
Copy link
Contributor

OK, IIUC, we make this assumption based on the docs, but that's not something that we can enforce, right? Just making sure I follow the rationale :)

We can actually enforce it if you're willing to run https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing my comments, LGTM

I've left a couple of small suggestion, could you incorporate them before landing? Thanks!

Now that MLIR accepts nuw and nusw in getelementptr, this patch emits
the inbounds and nuw attributes when lower memref to LLVM in load and
store operators.

This patch also strengthens the memref.load and memref.store spec about
undefined behaviour during lowering.

This patch also lifts the |rewriter| parameter in getStridedElementPtr
ahead so that LLVM::GEPNoWrapFlags can be added at the end with a
default value and grouped together with other operators' parameters.

fixes: iree-org/iree#20483

Signed-off-by: Lin, Peiyong <[email protected]>
@krzysz00 krzysz00 merged commit 04ad8d4 into llvm:main May 20, 2025
9 of 10 checks passed
@lpy lpy deleted the 20483 branch May 20, 2025 21:19
paul0403 added a commit to PennyLaneAI/catalyst that referenced this pull request Jul 28, 2025
**Context:**
Update llvm, mhlo and enzyme, 2025 Q3.
The latest pair of good versions, indicated by mhlo, is
tensorflow/mlir-hlo@1dd2e71
```
mhlo=1dd2e71331014ae0373f6bf900ce6be393357190
llvm=f8cb7987c64dcffb72414a40560055cb717dbf74
```

For Enzyme, we go to the latest release
https://github.com/EnzymeAD/Enzyme/releases/tag/v0.0.186
```
enzyme=v0.0.186
```
with commit `8c1a596158f6194f10e8ffd56a1660a61c54337e`

**Description of the Change:**
Miscellaneous:
1. `GreedyRewriteConfig.stuff = blah` ->
`GreedyRewriteConfig.setStuff(blah)`
llvm/llvm-project#137122
2. llvm gep op `inbounds` attribute is subsumed under a gep sign wrap
enum flag llvm/llvm-project#137272
3. `arith::Constant[Int, Float]Op` builders now have the same argument
order as other ops (output type first, then arguments)
llvm/llvm-project#144636 (note that Enzyme also
noticed this EnzymeAD/Enzyme#2379 😆 )
4. The `lookupOrCreateFn` functions now take in a builder instead of
instantiating a new one llvm/llvm-project#136421
5. `getStridedElementPtr` now takes in `rewriter` as the first argument
(instead of the last), like all the other utils
llvm/llvm-project#138984
6. The following functions now return a `LogicalResult`, and will be
caught by warnings as errors as `-Wunused-result`:
- `func::FuncOp.[insert, erase]Argument(s)`
llvm/llvm-project#137130
- `getBackwardSlice()` llvm/llvm-project#140961

Things related to `transform.apply_registered_pass` op:
1. It now takes in a `dynamic_options`
llvm/llvm-project#142683. We don't need to use
this as all our pass options are static.
2. The options it takes in are now dictionaries instead of strings
llvm/llvm-project#143159

Bufferization:
1. `bufferization.to_memref` op is renamed to `bufferization.to_buffer`
llvm/llvm-project#137180
3. `bufferization.to_tensor` op's builder now needs the result type to
be explicit llvm/llvm-project#142986. This is
also needed by a patched mhlo pass.
4. The `getBuffer()` methods take in a new arg for `BufferizationState`
llvm/llvm-project#141019,
llvm/llvm-project#141466
5. `UnknownTypeConverterFn` in bufferization options now takes in just a
type instead of a full value
llvm/llvm-project#144658

**Related GitHub Issues:** 
[sc-95176]
[sc-95664]

---------

Co-authored-by: Mehrdad Malek <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants