-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][VectorOps] Support string literals in vector.print
#68695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3b22fb7
08bc889
5195d2b
c2f93b7
c1455b4
2b8b1c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
//===- PrintCallHelper.h - Helper to emit runtime print calls ---*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_ | ||
#define MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_ | ||
|
||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||
#include "llvm/ADT/StringRef.h" | ||
#include <optional> | ||
|
||
namespace mlir { | ||
|
||
class OpBuilder; | ||
class LLVMTypeConverter; | ||
|
||
namespace LLVM { | ||
|
||
/// Generate IR that prints the given string to stdout. | ||
/// If a custom runtime function is defined via `runtimeFunctionName`, it must | ||
/// have the signature void(char const*). The default function is `printString`. | ||
void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, | ||
StringRef symbolName, StringRef string, | ||
const LLVMTypeConverter &typeConverter, | ||
bool addNewline = true, | ||
std::optional<StringRef> runtimeFunctionName = {}); | ||
} // namespace LLVM | ||
|
||
} // namespace mlir | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -26,6 +26,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td" | |||||||||
include "mlir/Interfaces/SideEffectInterfaces.td" | ||||||||||
include "mlir/Interfaces/VectorInterfaces.td" | ||||||||||
include "mlir/Interfaces/ViewLikeInterface.td" | ||||||||||
include "mlir/IR/BuiltinAttributes.td" | ||||||||||
|
||||||||||
// TODO: Add an attribute to specify a different algebra with operators other | ||||||||||
// than the current set: {*, +}. | ||||||||||
|
@@ -2477,12 +2478,18 @@ def Vector_TransposeOp : | |||||||||
} | ||||||||||
|
||||||||||
def Vector_PrintOp : | ||||||||||
Vector_Op<"print", []>, | ||||||||||
Vector_Op<"print", [ | ||||||||||
PredOpTrait< | ||||||||||
"`source` or `punctuation` are not set when printing strings", | ||||||||||
CPred<"!getStringLiteral() || (!getSource() && getPunctuation() == PrintPunctuation::NewLine)"> | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why isn't other punctuation allowed? It would be very useful to be able to print strings without the trailing newline. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Strings don't have a trailing newline. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The lowering for them in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll post a complete patch for this in a bit, but I was thinking something along the lines of:
so that:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hm that's odd, I verified no newline is printed for strings yesterday and that seems to be the case, are we seeing different behaviour? The llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp Lines 1520 to 1523 in 4014e2e
although for empty string it will print newline as that wont be entered. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, my branch is a bit behind. This has been fixed already: c1b8c6c |
||||||||||
>, | ||||||||||
]>, | ||||||||||
Arguments<(ins Optional<Type<Or<[ | ||||||||||
AnyVectorOfAnyRank.predicate, | ||||||||||
AnyInteger.predicate, Index.predicate, AnyFloat.predicate | ||||||||||
]>>>:$source, DefaultValuedAttr<Vector_PrintPunctuation, | ||||||||||
"::mlir::vector::PrintPunctuation::NewLine">:$punctuation) | ||||||||||
"::mlir::vector::PrintPunctuation::NewLine">:$punctuation, | ||||||||||
OptionalAttr<Builtin_StringAttr>:$stringLiteral) | ||||||||||
> { | ||||||||||
let summary = "print operation (for testing and debugging)"; | ||||||||||
let description = [{ | ||||||||||
|
@@ -2521,6 +2528,13 @@ def Vector_PrintOp : | |||||||||
```mlir | ||||||||||
vector.print punctuation <newline> | ||||||||||
``` | ||||||||||
|
||||||||||
Additionally, to aid with debugging and testing `vector.print` can also | ||||||||||
print constant strings: | ||||||||||
|
||||||||||
```mlir | ||||||||||
vector.print str "Hello, World!" | ||||||||||
``` | ||||||||||
}]; | ||||||||||
let extraClassDeclaration = [{ | ||||||||||
Type getPrintType() { | ||||||||||
|
@@ -2529,11 +2543,26 @@ def Vector_PrintOp : | |||||||||
}]; | ||||||||||
let builders = [ | ||||||||||
OpBuilder<(ins "PrintPunctuation":$punctuation), [{ | ||||||||||
build($_builder, $_state, {}, punctuation); | ||||||||||
build($_builder, $_state, {}, punctuation, {}); | ||||||||||
}]>, | ||||||||||
OpBuilder<(ins "::mlir::Value":$source), [{ | ||||||||||
build($_builder, $_state, source, PrintPunctuation::NewLine); | ||||||||||
}]>, | ||||||||||
OpBuilder<(ins "::mlir::Value":$source, "PrintPunctuation":$punctuation), [{ | ||||||||||
build($_builder, $_state, source, punctuation, {}); | ||||||||||
}]>, | ||||||||||
OpBuilder<(ins "::llvm::StringRef":$string), [{ | ||||||||||
build($_builder, $_state, {}, PrintPunctuation::NewLine, $_builder.getStringAttr(string)); | ||||||||||
}]>, | ||||||||||
]; | ||||||||||
|
||||||||||
let assemblyFormat = "($source^ `:` type($source))? (`punctuation` $punctuation^)? attr-dict"; | ||||||||||
let assemblyFormat = [{ | ||||||||||
($source^ `:` type($source))? | ||||||||||
oilist( | ||||||||||
`str` $stringLiteral | ||||||||||
| `punctuation` $punctuation) | ||||||||||
attr-dict | ||||||||||
}]; | ||||||||||
} | ||||||||||
|
||||||||||
//===----------------------------------------------------------------------===// | ||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
//===- PrintCallHelper.cpp - Helper to emit runtime print calls -----------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h" | ||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h" | ||
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" | ||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||
#include "mlir/IR/Builders.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "llvm/ADT/ArrayRef.h" | ||
|
||
using namespace mlir; | ||
using namespace llvm; | ||
|
||
static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp, | ||
StringRef symbolName) { | ||
static int counter = 0; | ||
std::string uniqueName = std::string(symbolName); | ||
while (moduleOp.lookupSymbol(uniqueName)) { | ||
uniqueName = std::string(symbolName) + "_" + std::to_string(counter++); | ||
} | ||
return uniqueName; | ||
} | ||
|
||
void mlir::LLVM::createPrintStrCall( | ||
OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, | ||
StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline, | ||
std::optional<StringRef> runtimeFunctionName) { | ||
auto ip = builder.saveInsertionPoint(); | ||
builder.setInsertionPointToStart(moduleOp.getBody()); | ||
MLIRContext *ctx = builder.getContext(); | ||
|
||
// Create a zero-terminated byte representation and allocate global symbol. | ||
SmallVector<uint8_t> elementVals; | ||
elementVals.append(string.begin(), string.end()); | ||
if (addNewline) | ||
elementVals.push_back('\n'); | ||
elementVals.push_back('\0'); | ||
auto dataAttrType = RankedTensorType::get( | ||
{static_cast<int64_t>(elementVals.size())}, builder.getI8Type()); | ||
auto dataAttr = | ||
DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals)); | ||
auto arrayTy = | ||
LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size()); | ||
auto globalOp = builder.create<LLVM::GlobalOp>( | ||
loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, | ||
ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr); | ||
|
||
// Emit call to `printStr` in runtime library. | ||
builder.restoreInsertionPoint(ip); | ||
auto msgAddr = builder.create<LLVM::AddressOfOp>( | ||
loc, typeConverter.getPointerType(arrayTy), globalOp.getName()); | ||
SmallVector<LLVM::GEPArg> indices(1, 0); | ||
Value gep = builder.create<LLVM::GEPOp>( | ||
loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr, | ||
indices); | ||
Operation *printer = LLVM::lookupOrCreatePrintStringFn( | ||
moduleOp, typeConverter.useOpaquePointers(), runtimeFunctionName); | ||
builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer), | ||
gep); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -158,7 +158,8 @@ extern "C" void printMemrefC64(int64_t rank, void *ptr) { | |
_mlir_ciface_printMemrefC64(&descriptor); | ||
} | ||
|
||
extern "C" void printCString(char *str) { printf("%s", str); } | ||
/// Deprecated. This should be unified with printString from CRunnerUtils. | ||
extern "C" void printCString(char *str) { fputs(str, stdout); } | ||
Comment on lines
-161
to
+162
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nit] I would just skip these changes altoghether. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It follows on for #68973 (which removes all uses of |
||
|
||
extern "C" void _mlir_ciface_printMemref0dF32(StridedMemRefType<float, 0> *M) { | ||
impl::printMemRef(*M); | ||
|
Uh oh!
There was an error while loading. Please reload this page.