From 3b22fb7f36be3a3b32d4d6cb734a24ac757beb73 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Tue, 10 Oct 2023 11:38:12 +0000 Subject: [PATCH 1/6] [mlir][VectorOps] Support string literals in `vector.print` Printing strings within integration tests is currently quite annoyingly verbose, and can't be tucked into shared helpers as the types depend on the length of the string: ``` llvm.mlir.global internal constant @hello_world("Hello, World!\0") func.func @entry() { %0 = llvm.mlir.addressof @hello_world : !llvm.ptr> %1 = llvm.mlir.constant(0 : index) : i64 %2 = llvm.getelementptr %0[%1, %1] : (!llvm.ptr>, i64, i64) -> !llvm.ptr llvm.call @printCString(%2) : (!llvm.ptr) -> () return } `` So this patch adds a simple extension to `vector.print` to simplify this: ``` func.func @entry() { // Print a vector of characters ;) vector.print str "Hello, World!" return } ``` Most of the logic for this is now shared with `cf.assert` which already does something similar. --- .../Conversion/LLVMCommon/PrintCallHelper.h | 36 ++++++++++ .../mlir/Dialect/Vector/IR/VectorOps.td | 37 +++++++++-- .../ControlFlowToLLVM/ControlFlowToLLVM.cpp | 49 +------------- mlir/lib/Conversion/LLVMCommon/CMakeLists.txt | 1 + .../Conversion/LLVMCommon/PrintCallHelper.cpp | 66 +++++++++++++++++++ .../VectorToLLVM/ConvertVectorToLLVM.cpp | 6 +- .../VectorToLLVM/vector-to-llvm.mlir | 14 ++++ .../Dialect/Vector/CPU/test-hello-world.mlir | 10 +++ 8 files changed, 168 insertions(+), 51 deletions(-) create mode 100644 mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h create mode 100644 mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h new file mode 100644 index 0000000000000..7e26858589f27 --- /dev/null +++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h @@ -0,0 +1,36 @@ + +//===- PrintCallHelper.h - LLVM Interfaces ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_ +#define MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_ + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir { + +class Location; +class ModuleOp; +class OpBuilder; +class Operation; +class Type; +class ValueRange; +class LLVMTypeConverter; + +namespace LLVM { + +/// Generate IR that prints the given string to stdout. +void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, + StringRef symbolName, StringRef string, + const LLVMTypeConverter &typeConverter); +} // namespace LLVM + +} // namespace mlir + +#endif diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 917b27a40f26f..0da4ca617a94c 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -26,6 +26,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" +include "mlir/IR/BuiltinAttributes.td" // TODO: Add an attribute to specify a different algebra with operators other // than the current set: {*, +}. @@ -2477,12 +2478,18 @@ def Vector_TransposeOp : } def Vector_PrintOp : - Vector_Op<"print", []>, + Vector_Op<"print", [ + PredOpTrait< + "`source` or `punctuation` are not set printing strings", + CPred<"!getStringLiteral() || (!getSource() && getPunctuation() == PrintPunctuation::NewLine)"> + >, + ]>, Arguments<(ins Optional>>:$source, DefaultValuedAttr:$punctuation) + "::mlir::vector::PrintPunctuation::NewLine">:$punctuation, + OptionalAttr:$stringLiteral) > { let summary = "print operation (for testing and debugging)"; let description = [{ @@ -2521,6 +2528,13 @@ def Vector_PrintOp : ```mlir vector.print punctuation ``` + + Additionally, to aid with debugging and testing `vector.print` can also + print constant strings: + + ```mlir + vector.print str "Hello, World!" + ``` }]; let extraClassDeclaration = [{ Type getPrintType() { @@ -2529,11 +2543,26 @@ def Vector_PrintOp : }]; let builders = [ OpBuilder<(ins "PrintPunctuation":$punctuation), [{ - build($_builder, $_state, {}, punctuation); + build($_builder, $_state, {}, punctuation, {}); + }]>, + OpBuilder<(ins "::mlir::Value":$source), [{ + build($_builder, $_state, source, PrintPunctuation::NewLine); + }]>, + OpBuilder<(ins "::mlir::Value":$source, "PrintPunctuation":$punctuation), [{ + build($_builder, $_state, source, punctuation, {}); + }]>, + OpBuilder<(ins "::llvm::StringRef":$string), [{ + build($_builder, $_state, {}, PrintPunctuation::NewLine, $_builder.getStringAttr(string)); }]>, ]; - let assemblyFormat = "($source^ `:` type($source))? (`punctuation` $punctuation^)? attr-dict"; + let assemblyFormat = [{ + ($source^ `:` type($source))? + oilist( + `str` $stringLiteral + | `punctuation` $punctuation) + attr-dict + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index a4f146bbe475c..6b7647b038f1d 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -16,6 +16,7 @@ #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" @@ -36,51 +37,6 @@ using namespace mlir; #define PASS_NAME "convert-cf-to-llvm" -static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) { - std::string prefix = "assert_msg_"; - int counter = 0; - while (moduleOp.lookupSymbol(prefix + std::to_string(counter))) - ++counter; - return prefix + std::to_string(counter); -} - -/// Generate IR that prints the given string to stderr. -static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp, - StringRef msg, - const LLVMTypeConverter &typeConverter) { - auto ip = builder.saveInsertionPoint(); - builder.setInsertionPointToStart(moduleOp.getBody()); - MLIRContext *ctx = builder.getContext(); - - // Create a zero-terminated byte representation and allocate global symbol. - SmallVector elementVals; - elementVals.append(msg.begin(), msg.end()); - elementVals.push_back(0); - auto dataAttrType = RankedTensorType::get( - {static_cast(elementVals.size())}, builder.getI8Type()); - auto dataAttr = - DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals)); - auto arrayTy = - LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size()); - std::string symbolName = generateGlobalMsgSymbolName(moduleOp); - auto globalOp = builder.create( - loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName, - dataAttr); - - // Emit call to `printStr` in runtime library. - builder.restoreInsertionPoint(ip); - auto msgAddr = builder.create( - loc, typeConverter.getPointerType(arrayTy), globalOp.getName()); - SmallVector indices(1, 0); - Value gep = builder.create( - loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr, - indices); - Operation *printer = LLVM::lookupOrCreatePrintStrFn( - moduleOp, typeConverter.useOpaquePointers()); - builder.create(loc, TypeRange(), SymbolRefAttr::get(printer), - gep); -} - namespace { /// Lower `cf.assert`. The default lowering calls the `abort` function if the /// assertion is violated and has no effect otherwise. The failure message is @@ -105,7 +61,8 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern { // Failed block: Generate IR to print the message and call `abort`. Block *failureBlock = rewriter.createBlock(opBlock->getParent()); - createPrintMsg(rewriter, loc, module, op.getMsg(), *getTypeConverter()); + LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(), + *getTypeConverter()); if (abortOnFailedAssert) { // Insert the `abort` declaration if necessary. auto abortFunc = module.lookupSymbol("abort"); diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt index 091cd539f0ae0..568d9339aaabc 100644 --- a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt +++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion LoweringOptions.cpp MemRefBuilder.cpp Pattern.cpp + PrintCallHelper.cpp StructBuilder.cpp TypeConverter.cpp VectorPattern.cpp diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp new file mode 100644 index 0000000000000..487abb435d10a --- /dev/null +++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp @@ -0,0 +1,66 @@ + +//===- PrintCallHelper.cpp - LLVM Interfaces --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "llvm/ADT/ArrayRef.h" + +using namespace mlir; +using namespace llvm; + +static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp, + StringRef symbolName) { + static int counter = 0; + std::string uniqueName = std::string(symbolName); + while (moduleOp.lookupSymbol(uniqueName)) { + uniqueName = std::string(symbolName) + "_" + std::to_string(counter++); + } + return uniqueName; +} + +void mlir::LLVM::createPrintStrCall(OpBuilder &builder, Location loc, + ModuleOp moduleOp, StringRef symbolName, + StringRef string, + const LLVMTypeConverter &typeConverter) { + auto ip = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(moduleOp.getBody()); + MLIRContext *ctx = builder.getContext(); + + // Create a zero-terminated byte representation and allocate global symbol. + SmallVector elementVals; + elementVals.append(string.begin(), string.end()); + elementVals.push_back(0); + auto dataAttrType = RankedTensorType::get( + {static_cast(elementVals.size())}, builder.getI8Type()); + auto dataAttr = + DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals)); + auto arrayTy = + LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size()); + auto globalOp = builder.create( + loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, + ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr); + + // Emit call to `printStr` in runtime library. + builder.restoreInsertionPoint(ip); + auto msgAddr = builder.create( + loc, typeConverter.getPointerType(arrayTy), globalOp.getName()); + SmallVector indices(1, 0); + Value gep = builder.create( + loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr, + indices); + Operation *printer = LLVM::lookupOrCreatePrintStrFn( + moduleOp, typeConverter.useOpaquePointers()); + builder.create(loc, TypeRange(), SymbolRefAttr::get(printer), + gep); +} diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 8427d60f14c0b..4af58653c8227 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" +#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -1548,7 +1549,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { } auto punct = printOp.getPunctuation(); - if (punct != PrintPunctuation::NoPunctuation) { + if (auto stringLiteral = printOp.getStringLiteral()) { + LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str", + *stringLiteral, *getTypeConverter()); + } else if (punct != PrintPunctuation::NoPunctuation) { emitCall(rewriter, printOp->getLoc(), [&] { switch (punct) { case PrintPunctuation::Close: diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 9aa4d735681f5..65b3a78e295f0 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1068,6 +1068,20 @@ func.func @vector_print_scalar_f64(%arg0: f64) { // ----- +// CHECK-LABEL: module { +// CHECK: llvm.func @puts(!llvm.ptr) +// CHECK: llvm.mlir.global private constant @[[GLOBAL_STR:.*]](dense<[72, 101, 108, 108, 111, 44, 32, 87, 111, 114, 108, 100, 33, 0]> : tensor<14xi8>) {addr_space = 0 : i32} : !llvm.array<14 x i8> +// CHECK: @vector_print_string +// CHECK-NEXT: %[[GLOBAL_ADDR:.*]] = llvm.mlir.addressof @[[GLOBAL_STR]] : !llvm.ptr +// CHECK-NEXT: %[[STR_PTR:.*]] = llvm.getelementptr %[[GLOBAL_ADDR]][0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<14 x i8> +// CHECK-NEXT: llvm.call @puts(%[[STR_PTR]]) : (!llvm.ptr) -> () +func.func @vector_print_string() { + vector.print str "Hello, World!" + return +} + +// ----- + func.func @extract_strided_slice1(%arg0: vector<4xf32>) -> vector<2xf32> { %0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> return %0 : vector<2xf32> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir new file mode 100644 index 0000000000000..c4076e65151ac --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-opt %s -test-lower-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_c_runner_utils | \ +// RUN: FileCheck %s + +func.func @entry() { + // CHECK: Hello, World! + vector.print str "Hello, World!" + return +} From 08bc8890c650b2529b947afde56f286ec52b3041 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Wed, 11 Oct 2023 15:01:35 +0000 Subject: [PATCH 2/6] Fixups --- .../mlir/Conversion/LLVMCommon/PrintCallHelper.h | 8 +------- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +- .../Conversion/LLVMCommon/PrintCallHelper.cpp | 4 +--- mlir/test/Dialect/Vector/invalid.mlir | 16 ++++++++++++++++ ...test-hello-world.mlir => test-print-str.mlir} | 4 ++++ 5 files changed, 23 insertions(+), 11 deletions(-) rename mlir/test/Integration/Dialect/Vector/CPU/{test-hello-world.mlir => test-print-str.mlir} (71%) diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h index 7e26858589f27..457cd98ca3dc2 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h @@ -1,5 +1,4 @@ - -//===- PrintCallHelper.h - LLVM Interfaces ----------------------*- C++ -*-===// +//===- PrintCallHelper.h - Helper to emit runtime print calls ---*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -15,12 +14,7 @@ namespace mlir { -class Location; -class ModuleOp; class OpBuilder; -class Operation; -class Type; -class ValueRange; class LLVMTypeConverter; namespace LLVM { diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 0da4ca617a94c..168ff45ca6154 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2480,7 +2480,7 @@ def Vector_TransposeOp : def Vector_PrintOp : Vector_Op<"print", [ PredOpTrait< - "`source` or `punctuation` are not set printing strings", + "`source` or `punctuation` are not set when printing strings", CPred<"!getStringLiteral() || (!getSource() && getPunctuation() == PrintPunctuation::NewLine)"> >, ]>, diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp index 487abb435d10a..40b9382452fbb 100644 --- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp +++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp @@ -1,5 +1,4 @@ - -//===- PrintCallHelper.cpp - LLVM Interfaces --------------------*- C++ -*-===// +//===- PrintCallHelper.cpp - Helper to emit runtime print calls -----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -8,7 +7,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h" -#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 504ac89659fdb..edb2689364a98 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1016,6 +1016,22 @@ func.func private @print_needs_vector(%arg0: tensor<8xf32>) { // ----- +func.func @cannot_print_string_with_punctuation_set() { + // expected-error@+1 {{`source` or `punctuation` are not set when printing strings}} + vector.print str "Whoops!" punctuation + return +} + +// ----- + +func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) { + // expected-error@+1 {{`source` or `punctuation` are not set when printing strings}} + vector.print %vec: vector<[4]xf32> str "Yay!" + return +} + +// ----- + func.func @reshape_bad_input_shape(%arg0 : vector<3x2x4xf32>) { %c2 = arith.constant 2 : index %c3 = arith.constant 3 : index diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir similarity index 71% rename from mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir rename to mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir index c4076e65151ac..4a11987121b33 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir @@ -3,8 +3,12 @@ // RUN: -shared-libs=%mlir_c_runner_utils | \ // RUN: FileCheck %s +/// This tests printing (multiple) string literals works. + func.func @entry() { // CHECK: Hello, World! vector.print str "Hello, World!" + // CHECK-NEXT: Bye! + vector.print str "Bye!" return } From 5195d2b1a9c9d0d051db618e15786a9206238088 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Thu, 19 Oct 2023 07:59:32 +0000 Subject: [PATCH 3/6] Use `printCStr()` rather than `puts()` PrintCallHelper is the only use of this, so we can safely switch. --- mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h | 3 ++- mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp | 7 +++++-- mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 2 +- .../Integration/Dialect/Vector/CPU/test-print-str.mlir | 2 +- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h index 457cd98ca3dc2..ca30553e5de38 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h @@ -22,7 +22,8 @@ namespace LLVM { /// Generate IR that prints the given string to stdout. void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, - const LLVMTypeConverter &typeConverter); + const LLVMTypeConverter &typeConverter, + bool addNewline = true); } // namespace LLVM } // namespace mlir diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp index 40b9382452fbb..03dbc65240a1a 100644 --- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp +++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp @@ -30,7 +30,8 @@ static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp, void mlir::LLVM::createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, - const LLVMTypeConverter &typeConverter) { + const LLVMTypeConverter &typeConverter, + bool addNewline) { auto ip = builder.saveInsertionPoint(); builder.setInsertionPointToStart(moduleOp.getBody()); MLIRContext *ctx = builder.getContext(); @@ -38,7 +39,9 @@ void mlir::LLVM::createPrintStrCall(OpBuilder &builder, Location loc, // Create a zero-terminated byte representation and allocate global symbol. SmallVector elementVals; elementVals.append(string.begin(), string.end()); - elementVals.push_back(0); + if (addNewline) + elementVals.push_back('\n'); + elementVals.push_back('\0'); auto dataAttrType = RankedTensorType::get( {static_cast(elementVals.size())}, builder.getI8Type()); auto dataAttr = diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index aef3a5a87e9bf..55a644bca3173 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -30,7 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16"; static constexpr llvm::StringRef kPrintBF16 = "printBF16"; static constexpr llvm::StringRef kPrintF32 = "printF32"; static constexpr llvm::StringRef kPrintF64 = "printF64"; -static constexpr llvm::StringRef kPrintStr = "puts"; +static constexpr llvm::StringRef kPrintStr = "printCString"; static constexpr llvm::StringRef kPrintOpen = "printOpen"; static constexpr llvm::StringRef kPrintClose = "printClose"; static constexpr llvm::StringRef kPrintComma = "printComma"; diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir index 4a11987121b33..78d6609ccaf9a 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt %s -test-lower-to-llvm | \ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_c_runner_utils | \ +// RUN: -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils | \ // RUN: FileCheck %s /// This tests printing (multiple) string literals works. From c2f93b7944ad343d3558d80ab214688338a14bbd Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Thu, 19 Oct 2023 08:19:50 +0000 Subject: [PATCH 4/6] Fixup test checks and naming --- mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h | 4 ++-- mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp | 2 +- mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 8 ++++---- mlir/test/Conversion/ControlFlowToLLVM/assert.mlir | 4 ++-- mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 8 ++++---- mlir/test/Integration/Dialect/ControlFlow/assert.mlir | 2 +- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h index 17aa9a3c831c2..4a86edfdf8e1a 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -38,8 +38,8 @@ LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp); -LLVM::LLVMFuncOp lookupOrCreatePrintStrFn(ModuleOp moduleOp, - bool opaquePointers); +LLVM::LLVMFuncOp lookupOrCreatePrintCStringFn(ModuleOp moduleOp, + bool opaquePointers); LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp); diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp index 03dbc65240a1a..4017fd9ad8c01 100644 --- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp +++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp @@ -60,7 +60,7 @@ void mlir::LLVM::createPrintStrCall(OpBuilder &builder, Location loc, Value gep = builder.create( loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr, indices); - Operation *printer = LLVM::lookupOrCreatePrintStrFn( + Operation *printer = LLVM::lookupOrCreatePrintCStringFn( moduleOp, typeConverter.useOpaquePointers()); builder.create(loc, TypeRange(), SymbolRefAttr::get(printer), gep); diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index 55a644bca3173..228d85d96cd4f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -30,7 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16"; static constexpr llvm::StringRef kPrintBF16 = "printBF16"; static constexpr llvm::StringRef kPrintF32 = "printF32"; static constexpr llvm::StringRef kPrintF64 = "printF64"; -static constexpr llvm::StringRef kPrintStr = "printCString"; +static constexpr llvm::StringRef kPrintCString = "printCString"; static constexpr llvm::StringRef kPrintOpen = "printOpen"; static constexpr llvm::StringRef kPrintClose = "printClose"; static constexpr llvm::StringRef kPrintComma = "printComma"; @@ -107,9 +107,9 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context, return getCharPtr(context, opaquePointers); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStrFn(ModuleOp moduleOp, - bool opaquePointers) { - return lookupOrCreateFn(moduleOp, kPrintStr, +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCStringFn(ModuleOp moduleOp, + bool opaquePointers) { + return lookupOrCreateFn(moduleOp, kPrintCString, getCharPtr(moduleOp->getContext(), opaquePointers), LLVM::LLVMVoidType::get(moduleOp->getContext())); } diff --git a/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir b/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir index dc5ba0680acb2..1642a6fb5bb9b 100644 --- a/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir +++ b/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir @@ -10,7 +10,7 @@ func.func @main() { return } -// CHECK: llvm.func @puts(!llvm.ptr) +// CHECK: llvm.func @printCString(!llvm.ptr) // CHECK-LABEL: @main // CHECK: llvm.cond_br %{{.*}}, ^{{.*}}, ^[[FALSE_BRANCH:[[:alnum:]]+]] @@ -18,4 +18,4 @@ func.func @main() { // CHECK: ^[[FALSE_BRANCH]]: // CHECK: %[[ADDRESS_OF:.*]] = llvm.mlir.addressof @{{.*}} : !llvm.ptr{{$}} // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ADDRESS_OF]][0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<{{[0-9]+}} x i8> -// CHECK: llvm.call @puts(%[[GEP]]) : (!llvm.ptr) -> () +// CHECK: llvm.call @printCString(%[[GEP]]) : (!llvm.ptr) -> () diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 65b3a78e295f0..ef7260c5bb57a 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1069,12 +1069,12 @@ func.func @vector_print_scalar_f64(%arg0: f64) { // ----- // CHECK-LABEL: module { -// CHECK: llvm.func @puts(!llvm.ptr) -// CHECK: llvm.mlir.global private constant @[[GLOBAL_STR:.*]](dense<[72, 101, 108, 108, 111, 44, 32, 87, 111, 114, 108, 100, 33, 0]> : tensor<14xi8>) {addr_space = 0 : i32} : !llvm.array<14 x i8> +// CHECK: llvm.func @printCString(!llvm.ptr) +// CHECK: llvm.mlir.global private constant @[[GLOBAL_STR:.*]]({{.*}}) // CHECK: @vector_print_string // CHECK-NEXT: %[[GLOBAL_ADDR:.*]] = llvm.mlir.addressof @[[GLOBAL_STR]] : !llvm.ptr -// CHECK-NEXT: %[[STR_PTR:.*]] = llvm.getelementptr %[[GLOBAL_ADDR]][0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<14 x i8> -// CHECK-NEXT: llvm.call @puts(%[[STR_PTR]]) : (!llvm.ptr) -> () +// CHECK-NEXT: %[[STR_PTR:.*]] = llvm.getelementptr %[[GLOBAL_ADDR]][0] : (!llvm.ptr) -> !llvm.ptr +// CHECK-NEXT: llvm.call @printCString(%[[STR_PTR]]) : (!llvm.ptr) -> () func.func @vector_print_string() { vector.print str "Hello, World!" return diff --git a/mlir/test/Integration/Dialect/ControlFlow/assert.mlir b/mlir/test/Integration/Dialect/ControlFlow/assert.mlir index 42130250daf1b..63ce092818627 100644 --- a/mlir/test/Integration/Dialect/ControlFlow/assert.mlir +++ b/mlir/test/Integration/Dialect/ControlFlow/assert.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt %s -test-cf-assert \ // RUN: -convert-func-to-llvm | \ -// RUN: mlir-cpu-runner -e main -entry-point-result=void | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils | \ // RUN: FileCheck %s func.func @main() { From c1455b49ba4fe256c0e754d11bd93fbbfaaac77a Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Thu, 19 Oct 2023 17:36:12 +0000 Subject: [PATCH 5/6] Add printString to CRunnerUtils and use that instead I've also defined this in RunnerUtils so linking either or both gives you printString, this avoids the need to update a bunch of tests that use cf.assert. --- mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h | 4 ++-- mlir/include/mlir/ExecutionEngine/CRunnerUtils.h | 1 + mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp | 2 +- mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 8 ++++---- mlir/lib/ExecutionEngine/CRunnerUtils.cpp | 1 + mlir/lib/ExecutionEngine/RunnerUtils.cpp | 5 ++++- mlir/test/Conversion/ControlFlowToLLVM/assert.mlir | 4 ++-- mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 4 ++-- 8 files changed, 17 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h index 4a86edfdf8e1a..c0806b64d25f3 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -38,8 +38,8 @@ LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp); -LLVM::LLVMFuncOp lookupOrCreatePrintCStringFn(ModuleOp moduleOp, - bool opaquePointers); +LLVM::LLVMFuncOp lookupOrCreatePrintStringFn(ModuleOp moduleOp, + bool opaquePointers); LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp); diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h index e8f429463cb0b..76b04145b482e 100644 --- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h +++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h @@ -465,6 +465,7 @@ extern "C" MLIR_CRUNNERUTILS_EXPORT void printI64(int64_t i); extern "C" MLIR_CRUNNERUTILS_EXPORT void printU64(uint64_t u); extern "C" MLIR_CRUNNERUTILS_EXPORT void printF32(float f); extern "C" MLIR_CRUNNERUTILS_EXPORT void printF64(double d); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printString(char const *s); extern "C" MLIR_CRUNNERUTILS_EXPORT void printOpen(); extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose(); extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma(); diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp index 4017fd9ad8c01..8fecd4ca6c298 100644 --- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp +++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp @@ -60,7 +60,7 @@ void mlir::LLVM::createPrintStrCall(OpBuilder &builder, Location loc, Value gep = builder.create( loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr, indices); - Operation *printer = LLVM::lookupOrCreatePrintCStringFn( + Operation *printer = LLVM::lookupOrCreatePrintStringFn( moduleOp, typeConverter.useOpaquePointers()); builder.create(loc, TypeRange(), SymbolRefAttr::get(printer), gep); diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index 228d85d96cd4f..83540c83df3d1 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -30,7 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16"; static constexpr llvm::StringRef kPrintBF16 = "printBF16"; static constexpr llvm::StringRef kPrintF32 = "printF32"; static constexpr llvm::StringRef kPrintF64 = "printF64"; -static constexpr llvm::StringRef kPrintCString = "printCString"; +static constexpr llvm::StringRef kPrintString = "printString"; static constexpr llvm::StringRef kPrintOpen = "printOpen"; static constexpr llvm::StringRef kPrintClose = "printClose"; static constexpr llvm::StringRef kPrintComma = "printComma"; @@ -107,9 +107,9 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context, return getCharPtr(context, opaquePointers); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCStringFn(ModuleOp moduleOp, - bool opaquePointers) { - return lookupOrCreateFn(moduleOp, kPrintCString, +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(ModuleOp moduleOp, + bool opaquePointers) { + return lookupOrCreateFn(moduleOp, kPrintString, getCharPtr(moduleOp->getContext(), opaquePointers), LLVM::LLVMVoidType::get(moduleOp->getContext())); } diff --git a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp index c31ae3a1c7ce1..e28e75eb11030 100644 --- a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp +++ b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp @@ -52,6 +52,7 @@ extern "C" void printI64(int64_t i) { fprintf(stdout, "%" PRId64, i); } extern "C" void printU64(uint64_t u) { fprintf(stdout, "%" PRIu64, u); } extern "C" void printF32(float f) { fprintf(stdout, "%g", f); } extern "C" void printF64(double d) { fprintf(stdout, "%lg", d); } +extern "C" void printString(char const *s) { fputs(s, stdout); } extern "C" void printOpen() { fputs("( ", stdout); } extern "C" void printClose() { fputs(" )", stdout); } extern "C" void printComma() { fputs(", ", stdout); } diff --git a/mlir/lib/ExecutionEngine/RunnerUtils.cpp b/mlir/lib/ExecutionEngine/RunnerUtils.cpp index ccf5309487637..72056587a1e09 100644 --- a/mlir/lib/ExecutionEngine/RunnerUtils.cpp +++ b/mlir/lib/ExecutionEngine/RunnerUtils.cpp @@ -158,7 +158,10 @@ extern "C" void printMemrefC64(int64_t rank, void *ptr) { _mlir_ciface_printMemrefC64(&descriptor); } -extern "C" void printCString(char *str) { printf("%s", str); } +extern "C" void printCString(char *str) { fputs(str, stdout); } +extern "C" void printString(char const *str) { + printCString(const_cast(str)); +} extern "C" void _mlir_ciface_printMemref0dF32(StridedMemRefType *M) { impl::printMemRef(*M); diff --git a/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir b/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir index 1642a6fb5bb9b..a432cdfee2e69 100644 --- a/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir +++ b/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir @@ -10,7 +10,7 @@ func.func @main() { return } -// CHECK: llvm.func @printCString(!llvm.ptr) +// CHECK: llvm.func @printString(!llvm.ptr) // CHECK-LABEL: @main // CHECK: llvm.cond_br %{{.*}}, ^{{.*}}, ^[[FALSE_BRANCH:[[:alnum:]]+]] @@ -18,4 +18,4 @@ func.func @main() { // CHECK: ^[[FALSE_BRANCH]]: // CHECK: %[[ADDRESS_OF:.*]] = llvm.mlir.addressof @{{.*}} : !llvm.ptr{{$}} // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ADDRESS_OF]][0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<{{[0-9]+}} x i8> -// CHECK: llvm.call @printCString(%[[GEP]]) : (!llvm.ptr) -> () +// CHECK: llvm.call @printString(%[[GEP]]) : (!llvm.ptr) -> () diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index ef7260c5bb57a..05733214bc3ae 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1069,12 +1069,12 @@ func.func @vector_print_scalar_f64(%arg0: f64) { // ----- // CHECK-LABEL: module { -// CHECK: llvm.func @printCString(!llvm.ptr) +// CHECK: llvm.func @printString(!llvm.ptr) // CHECK: llvm.mlir.global private constant @[[GLOBAL_STR:.*]]({{.*}}) // CHECK: @vector_print_string // CHECK-NEXT: %[[GLOBAL_ADDR:.*]] = llvm.mlir.addressof @[[GLOBAL_STR]] : !llvm.ptr // CHECK-NEXT: %[[STR_PTR:.*]] = llvm.getelementptr %[[GLOBAL_ADDR]][0] : (!llvm.ptr) -> !llvm.ptr -// CHECK-NEXT: llvm.call @printCString(%[[STR_PTR]]) : (!llvm.ptr) -> () +// CHECK-NEXT: llvm.call @printString(%[[STR_PTR]]) : (!llvm.ptr) -> () func.func @vector_print_string() { vector.print str "Hello, World!" return From 2b8b1c9a3515099a83d94cab9d6383c686361e81 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Mon, 23 Oct 2023 10:59:53 +0000 Subject: [PATCH 6/6] Use puts for cf.assert and printString for vector.print --- .../mlir/Conversion/LLVMCommon/PrintCallHelper.h | 6 +++++- mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h | 9 +++++++-- .../ControlFlowToLLVM/ControlFlowToLLVM.cpp | 3 ++- mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp | 11 +++++------ mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 7 ++++--- mlir/lib/ExecutionEngine/RunnerUtils.cpp | 4 +--- mlir/test/Conversion/ControlFlowToLLVM/assert.mlir | 4 ++-- mlir/test/Integration/Dialect/ControlFlow/assert.mlir | 2 +- 8 files changed, 27 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h index ca30553e5de38..c2742b6fc1d73 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h @@ -11,6 +11,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "llvm/ADT/StringRef.h" +#include namespace mlir { @@ -20,10 +21,13 @@ class LLVMTypeConverter; namespace LLVM { /// Generate IR that prints the given string to stdout. +/// If a custom runtime function is defined via `runtimeFunctionName`, it must +/// have the signature void(char const*). The default function is `printString`. void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, - bool addNewline = true); + bool addNewline = true, + std::optional runtimeFunctionName = {}); } // namespace LLVM } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h index c0806b64d25f3..9e69717f471bc 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -16,6 +16,7 @@ #include "mlir/IR/Operation.h" #include "mlir/Support/LLVM.h" +#include namespace mlir { class Location; @@ -38,8 +39,12 @@ LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp); -LLVM::LLVMFuncOp lookupOrCreatePrintStringFn(ModuleOp moduleOp, - bool opaquePointers); +/// Declares a function to print a C-string. +/// If a custom runtime function is defined via `runtimeFunctionName`, it must +/// have the signature void(char const*). The default function is `printString`. +LLVM::LLVMFuncOp +lookupOrCreatePrintStringFn(ModuleOp moduleOp, bool opaquePointers, + std::optional runtimeFunctionName = {}); LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp); diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index 6b7647b038f1d..433d8a01a1ac8 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -62,7 +62,8 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern { // Failed block: Generate IR to print the message and call `abort`. Block *failureBlock = rewriter.createBlock(opBlock->getParent()); LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(), - *getTypeConverter()); + *getTypeConverter(), /*addNewLine=*/false, + /*runtimeFunctionName=*/"puts"); if (abortOnFailedAssert) { // Insert the `abort` declaration if necessary. auto abortFunc = module.lookupSymbol("abort"); diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp index 8fecd4ca6c298..6293643ac6f03 100644 --- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp +++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp @@ -27,11 +27,10 @@ static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp, return uniqueName; } -void mlir::LLVM::createPrintStrCall(OpBuilder &builder, Location loc, - ModuleOp moduleOp, StringRef symbolName, - StringRef string, - const LLVMTypeConverter &typeConverter, - bool addNewline) { +void mlir::LLVM::createPrintStrCall( + OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, + StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline, + std::optional runtimeFunctionName) { auto ip = builder.saveInsertionPoint(); builder.setInsertionPointToStart(moduleOp.getBody()); MLIRContext *ctx = builder.getContext(); @@ -61,7 +60,7 @@ void mlir::LLVM::createPrintStrCall(OpBuilder &builder, Location loc, loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr, indices); Operation *printer = LLVM::lookupOrCreatePrintStringFn( - moduleOp, typeConverter.useOpaquePointers()); + moduleOp, typeConverter.useOpaquePointers(), runtimeFunctionName); builder.create(loc, TypeRange(), SymbolRefAttr::get(printer), gep); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index 83540c83df3d1..7ed8296a22a45 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -107,9 +107,10 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context, return getCharPtr(context, opaquePointers); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(ModuleOp moduleOp, - bool opaquePointers) { - return lookupOrCreateFn(moduleOp, kPrintString, +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn( + ModuleOp moduleOp, bool opaquePointers, + std::optional runtimeFunctionName) { + return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString), getCharPtr(moduleOp->getContext(), opaquePointers), LLVM::LLVMVoidType::get(moduleOp->getContext())); } diff --git a/mlir/lib/ExecutionEngine/RunnerUtils.cpp b/mlir/lib/ExecutionEngine/RunnerUtils.cpp index 72056587a1e09..4618866f68a44 100644 --- a/mlir/lib/ExecutionEngine/RunnerUtils.cpp +++ b/mlir/lib/ExecutionEngine/RunnerUtils.cpp @@ -158,10 +158,8 @@ extern "C" void printMemrefC64(int64_t rank, void *ptr) { _mlir_ciface_printMemrefC64(&descriptor); } +/// Deprecated. This should be unified with printString from CRunnerUtils. extern "C" void printCString(char *str) { fputs(str, stdout); } -extern "C" void printString(char const *str) { - printCString(const_cast(str)); -} extern "C" void _mlir_ciface_printMemref0dF32(StridedMemRefType *M) { impl::printMemRef(*M); diff --git a/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir b/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir index a432cdfee2e69..dc5ba0680acb2 100644 --- a/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir +++ b/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir @@ -10,7 +10,7 @@ func.func @main() { return } -// CHECK: llvm.func @printString(!llvm.ptr) +// CHECK: llvm.func @puts(!llvm.ptr) // CHECK-LABEL: @main // CHECK: llvm.cond_br %{{.*}}, ^{{.*}}, ^[[FALSE_BRANCH:[[:alnum:]]+]] @@ -18,4 +18,4 @@ func.func @main() { // CHECK: ^[[FALSE_BRANCH]]: // CHECK: %[[ADDRESS_OF:.*]] = llvm.mlir.addressof @{{.*}} : !llvm.ptr{{$}} // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ADDRESS_OF]][0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<{{[0-9]+}} x i8> -// CHECK: llvm.call @printString(%[[GEP]]) : (!llvm.ptr) -> () +// CHECK: llvm.call @puts(%[[GEP]]) : (!llvm.ptr) -> () diff --git a/mlir/test/Integration/Dialect/ControlFlow/assert.mlir b/mlir/test/Integration/Dialect/ControlFlow/assert.mlir index 63ce092818627..42130250daf1b 100644 --- a/mlir/test/Integration/Dialect/ControlFlow/assert.mlir +++ b/mlir/test/Integration/Dialect/ControlFlow/assert.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt %s -test-cf-assert \ // RUN: -convert-func-to-llvm | \ -// RUN: mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void | \ // RUN: FileCheck %s func.func @main() {