Skip to content

[MLIR][LLVM] Change addressof builders to use opaque pointers #69215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ class PrintOpLowering : public ConversionPattern {
/// * `i32 (i8*, ...)`
static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) {
auto llvmI32Ty = IntegerType::get(context, 32);
auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
auto llvmPtrTy = LLVM::LLVMPointerType::get(context);
auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy,
/*isVarArg=*/true);
return llvmFnType;
}
Expand Down Expand Up @@ -162,9 +162,9 @@ class PrintOpLowering : public ConversionPattern {
Value cst0 = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
builder.getIndexAttr(0));
return builder.create<LLVM::GEPOp>(
loc,
LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)),
globalPtr, ArrayRef<Value>({cst0, cst0}));
loc, LLVM::LLVMPointerType::get(builder.getContext()),
IntegerType::get(builder.getContext(), 8), globalPtr,
ArrayRef<Value>({cst0, cst0}));
}
};
} // namespace
Expand Down
10 changes: 5 additions & 5 deletions mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ class PrintOpLowering : public ConversionPattern {
/// * `i32 (i8*, ...)`
static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) {
auto llvmI32Ty = IntegerType::get(context, 32);
auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
auto llvmPtrTy = LLVM::LLVMPointerType::get(context);
auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy,
/*isVarArg=*/true);
return llvmFnType;
}
Expand Down Expand Up @@ -162,9 +162,9 @@ class PrintOpLowering : public ConversionPattern {
Value cst0 = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
builder.getIndexAttr(0));
return builder.create<LLVM::GEPOp>(
loc,
LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)),
globalPtr, ArrayRef<Value>({cst0, cst0}));
loc, LLVM::LLVMPointerType::get(builder.getContext()),
IntegerType::get(builder.getContext(), 8), globalPtr,
ArrayRef<Value>({cst0, cst0}));
}
};
} // namespace
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1071,15 +1071,15 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof",
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
build($_builder, $_state,
LLVM::LLVMPointerType::get(global.getType(), global.getAddrSpace()),
LLVM::LLVMPointerType::get($_builder.getContext(), global.getAddrSpace()),
global.getSymName());
$_state.addAttributes(attrs);
}]>,
OpBuilder<(ins "LLVMFuncOp":$func,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
build($_builder, $_state,
LLVM::LLVMPointerType::get(func.getFunctionType()), func.getName());
LLVM::LLVMPointerType::get($_builder.getContext()), func.getName());
$_state.addAttributes(attrs);
}]>
];
Expand Down
19 changes: 9 additions & 10 deletions mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,15 +441,15 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
Location loc = gpuPrintfOp->getLoc();

mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());

// Note: this is the GPUModule op, not the ModuleOp that surrounds it
// This ensures that global constants and declarations are placed within
// the device code, not the host code
auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();

auto vprintfType =
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr, i8Ptr});
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
LLVM::LLVMFuncOp vprintfDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);

Expand All @@ -473,7 +473,7 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
Value stringStart = rewriter.create<LLVM::GEPOp>(
loc, i8Ptr, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
loc, ptrType, ptrType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
SmallVector<Type> types;
SmallVector<Value> args;
// Promote and pack the arguments into a stack allocation.
Expand All @@ -490,18 +490,17 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
}
Type structType =
LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
Type structPtrType = LLVM::LLVMPointerType::get(structType);
Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
rewriter.getIndexAttr(1));
Value tempAlloc = rewriter.create<LLVM::AllocaOp>(loc, structPtrType, one,
/*alignment=*/0);
Value tempAlloc =
rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one,
/*alignment=*/0);
for (auto [index, arg] : llvm::enumerate(args)) {
Value ptr = rewriter.create<LLVM::GEPOp>(
loc, LLVM::LLVMPointerType::get(arg.getType()), tempAlloc,
ArrayRef<LLVM::GEPArg>{0, index});
Value ptr =
rewriter.create<LLVM::GEPOp>(loc, ptrType, arg.getType(), tempAlloc,
ArrayRef<LLVM::GEPArg>{0, index});
rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
}
tempAlloc = rewriter.create<LLVM::BitcastOp>(loc, i8Ptr, tempAlloc);
std::array<Value, 2> printfArgs = {stringStart, tempAlloc};

rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);
Expand Down
28 changes: 13 additions & 15 deletions mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -542,34 +542,32 @@ gpu.module @test_module_28 {
gpu.module @test_module_29 {
// CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL0:[A-Za-z0-9_]+]]("Hello, world\0A\00")
// CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL1:[A-Za-z0-9_]+]]("Hello: %d\0A\00")
// CHECK-DAG: llvm.func @vprintf(!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
// CHECK-DAG: llvm.func @vprintf(!llvm.ptr, !llvm.ptr) -> i32

// CHECK-LABEL: func @test_const_printf
gpu.func @test_const_printf() {
// CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr<array<14 x i8>>
// CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr<array<14 x i8>>) -> !llvm.ptr<i8>
// CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr
// CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr) -> !llvm.ptr
// CHECK-NEXT: %[[O:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr<struct<()>>
// CHECK-NEXT: %[[ARGPTR:.*]] = llvm.bitcast %[[ALLOC]] : !llvm.ptr<struct<()>> to !llvm.ptr<i8>
// CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ARGPTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
// CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr
// CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32
gpu.printf "Hello, world\n"
gpu.return
}

// CHECK-LABEL: func @test_printf
// CHECK: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: f32)
gpu.func @test_printf(%arg0: i32, %arg1: f32) {
// CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr<array<11 x i8>>
// CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr<array<11 x i8>>) -> !llvm.ptr<i8>
// CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr
// CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr) -> !llvm.ptr
// CHECK-NEXT: %[[EXT:.+]] = llvm.fpext %[[ARG1]] : f32 to f64
// CHECK-NEXT: %[[O:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<(i32, f64)> : (i64) -> !llvm.ptr<struct<(i32, f64)>>
// CHECK-NEXT: %[[EL0:.*]] = llvm.getelementptr %[[ALLOC]][0, 0] : (!llvm.ptr<struct<(i32, f64)>>) -> !llvm.ptr<i32>
// CHECK-NEXT: llvm.store %[[ARG0]], %[[EL0]] : !llvm.ptr<i32>
// CHECK-NEXT: %[[EL1:.*]] = llvm.getelementptr %[[ALLOC]][0, 1] : (!llvm.ptr<struct<(i32, f64)>>) -> !llvm.ptr<f64>
// CHECK-NEXT: llvm.store %[[EXT]], %[[EL1]] : !llvm.ptr<f64>
// CHECK-NEXT: %[[ARGPTR:.*]] = llvm.bitcast %[[ALLOC]] : !llvm.ptr<struct<(i32, f64)>> to !llvm.ptr<i8>
// CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ARGPTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
// CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<(i32, f64)> : (i64) -> !llvm.ptr
// CHECK-NEXT: %[[EL0:.*]] = llvm.getelementptr %[[ALLOC]][0, 0] : (!llvm.ptr) -> !llvm.ptr
// CHECK-NEXT: llvm.store %[[ARG0]], %[[EL0]] : i32, !llvm.ptr
// CHECK-NEXT: %[[EL1:.*]] = llvm.getelementptr %[[ALLOC]][0, 1] : (!llvm.ptr) -> !llvm.ptr
// CHECK-NEXT: llvm.store %[[EXT]], %[[EL1]] : f64, !llvm.ptr
// CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32
gpu.printf "Hello: %d\n" %arg0, %arg1 : i32, f32
gpu.return
}
Expand Down