Skip to content

Commit 2b8b1c9

Browse files
committed
Use puts for cf.assert and printString for vector.print
1 parent c1455b4 commit 2b8b1c9

File tree

8 files changed

+27
-19
lines changed

8 files changed

+27
-19
lines changed

mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1313
#include "llvm/ADT/StringRef.h"
14+
#include <optional>
1415

1516
namespace mlir {
1617

@@ -20,10 +21,13 @@ class LLVMTypeConverter;
2021
namespace LLVM {
2122

2223
/// Generate IR that prints the given string to stdout.
24+
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
25+
/// have the signature void(char const*). The default function is `printString`.
2326
void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
2427
StringRef symbolName, StringRef string,
2528
const LLVMTypeConverter &typeConverter,
26-
bool addNewline = true);
29+
bool addNewline = true,
30+
std::optional<StringRef> runtimeFunctionName = {});
2731
} // namespace LLVM
2832

2933
} // namespace mlir

mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "mlir/IR/Operation.h"
1818
#include "mlir/Support/LLVM.h"
19+
#include <optional>
1920

2021
namespace mlir {
2122
class Location;
@@ -38,8 +39,12 @@ LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp);
3839
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp);
3940
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
4041
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
41-
LLVM::LLVMFuncOp lookupOrCreatePrintStringFn(ModuleOp moduleOp,
42-
bool opaquePointers);
42+
/// Declares a function to print a C-string.
43+
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
44+
/// have the signature void(char const*). The default function is `printString`.
45+
LLVM::LLVMFuncOp
46+
lookupOrCreatePrintStringFn(ModuleOp moduleOp, bool opaquePointers,
47+
std::optional<StringRef> runtimeFunctionName = {});
4348
LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp);
4449
LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp);
4550
LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp);

mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
6262
// Failed block: Generate IR to print the message and call `abort`.
6363
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
6464
LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
65-
*getTypeConverter());
65+
*getTypeConverter(), /*addNewLine=*/false,
66+
/*runtimeFunctionName=*/"puts");
6667
if (abortOnFailedAssert) {
6768
// Insert the `abort` declaration if necessary.
6869
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");

mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,10 @@ static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
2727
return uniqueName;
2828
}
2929

30-
void mlir::LLVM::createPrintStrCall(OpBuilder &builder, Location loc,
31-
ModuleOp moduleOp, StringRef symbolName,
32-
StringRef string,
33-
const LLVMTypeConverter &typeConverter,
34-
bool addNewline) {
30+
void mlir::LLVM::createPrintStrCall(
31+
OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
32+
StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline,
33+
std::optional<StringRef> runtimeFunctionName) {
3534
auto ip = builder.saveInsertionPoint();
3635
builder.setInsertionPointToStart(moduleOp.getBody());
3736
MLIRContext *ctx = builder.getContext();
@@ -61,7 +60,7 @@ void mlir::LLVM::createPrintStrCall(OpBuilder &builder, Location loc,
6160
loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
6261
indices);
6362
Operation *printer = LLVM::lookupOrCreatePrintStringFn(
64-
moduleOp, typeConverter.useOpaquePointers());
63+
moduleOp, typeConverter.useOpaquePointers(), runtimeFunctionName);
6564
builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
6665
gep);
6766
}

mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,10 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context,
107107
return getCharPtr(context, opaquePointers);
108108
}
109109

110-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(ModuleOp moduleOp,
111-
bool opaquePointers) {
112-
return lookupOrCreateFn(moduleOp, kPrintString,
110+
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
111+
ModuleOp moduleOp, bool opaquePointers,
112+
std::optional<StringRef> runtimeFunctionName) {
113+
return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
113114
getCharPtr(moduleOp->getContext(), opaquePointers),
114115
LLVM::LLVMVoidType::get(moduleOp->getContext()));
115116
}

mlir/lib/ExecutionEngine/RunnerUtils.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,8 @@ extern "C" void printMemrefC64(int64_t rank, void *ptr) {
158158
_mlir_ciface_printMemrefC64(&descriptor);
159159
}
160160

161+
/// Deprecated. This should be unified with printString from CRunnerUtils.
161162
extern "C" void printCString(char *str) { fputs(str, stdout); }
162-
extern "C" void printString(char const *str) {
163-
printCString(const_cast<char *>(str));
164-
}
165163

166164
extern "C" void _mlir_ciface_printMemref0dF32(StridedMemRefType<float, 0> *M) {
167165
impl::printMemRef(*M);

mlir/test/Conversion/ControlFlowToLLVM/assert.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@ func.func @main() {
1010
return
1111
}
1212

13-
// CHECK: llvm.func @printString(!llvm.ptr)
13+
// CHECK: llvm.func @puts(!llvm.ptr)
1414

1515
// CHECK-LABEL: @main
1616
// CHECK: llvm.cond_br %{{.*}}, ^{{.*}}, ^[[FALSE_BRANCH:[[:alnum:]]+]]
1717

1818
// CHECK: ^[[FALSE_BRANCH]]:
1919
// CHECK: %[[ADDRESS_OF:.*]] = llvm.mlir.addressof @{{.*}} : !llvm.ptr{{$}}
2020
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ADDRESS_OF]][0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<{{[0-9]+}} x i8>
21-
// CHECK: llvm.call @printString(%[[GEP]]) : (!llvm.ptr) -> ()
21+
// CHECK: llvm.call @puts(%[[GEP]]) : (!llvm.ptr) -> ()

mlir/test/Integration/Dialect/ControlFlow/assert.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: mlir-opt %s -test-cf-assert \
22
// RUN: -convert-func-to-llvm | \
3-
// RUN: mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils | \
3+
// RUN: mlir-cpu-runner -e main -entry-point-result=void | \
44
// RUN: FileCheck %s
55

66
func.func @main() {

0 commit comments

Comments
 (0)