-
Notifications
You must be signed in to change notification settings - Fork 13.4k
EmitC: Add emitc.global and emitc.get_global (#145) #88701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-emitc @llvm/pr-subscribers-mlir Author: Matthias Gehre (mgehre-amd) ChangesThis adds
Patch is 20.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88701.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index d746222ff37a4b..ee5fc0b09a1611 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1016,6 +1016,75 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
let hasVerifier = 1;
}
+def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> {
+ let summary = "A global variable";
+ let description = [{
+ The `emitc.global` operation declares or defines a named global variable.
+ The backing memory for the variable is allocated statically and is
+ described by the type of the variable.
+ Optionally, and `initial_value` can be provided.
+ Internal linkage can be specified using the `staticSpecifier` unit attribute
+ and external linkage can be specified using the `externSpecifier` unit attribute.
+ Note that the default linkage without those two keywords depends on whether
+ the target is C or C++ and whether the global variable is `const`.
+ The global variable can also be marked constant using the `constSpecifier`
+ unit attribute. Writing to such constant global variables is
+ undefined.
+
+ The global variable can be accessed by using the `emitc.get_global` to
+ retrieve the value for the global variable.
+
+ Example:
+
+ ```mlir
+ // Global variable with an initial value.
+ emitc.global @x : emitc.array<2xf32> = dense<0.0, 2.0>
+ // External global variable
+ emitc.global extern @x : emitc.array<2xf32>
+ // Constant global variable with internal linkage
+ emitc.global static const @x : i32 = 0
+ ```
+ }];
+
+ let arguments = (ins SymbolNameAttr:$sym_name,
+ TypeAttr:$type,
+ OptionalAttr<EmitC_OpaqueOrTypedAttr>:$initial_value,
+ UnitAttr:$externSpecifier,
+ UnitAttr:$staticSpecifier,
+ UnitAttr:$constSpecifier);
+
+ let assemblyFormat = [{
+ (`extern` $externSpecifier^)?
+ (`static` $staticSpecifier^)?
+ (`const` $constSpecifier^)?
+ $sym_name
+ `:` custom<EmitCGlobalOpTypeAndInitialValue>($type, $initial_value)
+ attr-dict
+ }];
+
+ let hasVerifier = 1;
+}
+
+def EmitC_GetGlobalOp : EmitC_Op<"get_global",
+ [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+ let summary = "Obtain access to a global variable";
+ let description = [{
+ The `emitc.get_global` operation retrieves the lvalue of a
+ named global variable. If the global variable is marked constant, assigning
+ to that lvalue is undefined.
+
+ Example:
+
+ ```mlir
+ %x = emitc.get_global @foo : !emitc.array<2xf32>
+ ```
+ }];
+
+ let arguments = (ins FlatSymbolRefAttr:$name);
+ let results = (outs AnyType:$result);
+ let assemblyFormat = "$name `:` type($result) attr-dict";
+}
+
def EmitC_VerbatimOp : EmitC_Op<"verbatim"> {
let summary = "Verbatim operation";
let description = [{
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 0e3b6469212640..d3e7f233c08412 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -50,6 +50,68 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
}
};
+struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+
+ if (!op.getType().hasStaticShape()) {
+ return rewriter.notifyMatchFailure(
+ op.getLoc(), "cannot transform global with dynamic shape");
+ }
+
+ if (op.getAlignment().value_or(1) > 1) {
+ // TODO: Extend GlobalOp to specify alignment via the `alignas` specifier.
+ return rewriter.notifyMatchFailure(
+ op.getLoc(), "global variable with alignment requirement is "
+ "currently not supported");
+ }
+ auto resultTy = getTypeConverter()->convertType(op.getType());
+ if (!resultTy) {
+ return rewriter.notifyMatchFailure(op.getLoc(),
+ "cannot convert result type");
+ }
+
+ SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(op);
+ if (visibility != SymbolTable::Visibility::Public &&
+ visibility != SymbolTable::Visibility::Private) {
+ return rewriter.notifyMatchFailure(
+ op.getLoc(),
+ "only public and private visibility is currently supported");
+ }
+ // We are explicit in specifier the linkage because the default linkage
+ // for constants is different in C and C++.
+ bool staticSpecifier = visibility == SymbolTable::Visibility::Private;
+ bool externSpecifier = !staticSpecifier;
+
+ rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
+ op, operands.getSymName(), resultTy, operands.getInitialValueAttr(),
+ externSpecifier, staticSpecifier, operands.getConstant());
+ return success();
+ }
+};
+
+struct ConvertGetGlobal final
+ : public OpConversionPattern<memref::GetGlobalOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto resultTy = getTypeConverter()->convertType(op.getType());
+ if (!resultTy) {
+ return rewriter.notifyMatchFailure(op.getLoc(),
+ "cannot convert result type");
+ }
+ rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
+ operands.getNameAttr());
+ return success();
+ }
+};
+
struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
using OpConversionPattern::OpConversionPattern;
@@ -109,6 +171,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &converter) {
- patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
- patterns.getContext());
+ patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
+ ConvertStore>(converter, patterns.getContext());
}
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index ab5c418e844fbf..aacc38a4dfca48 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -790,13 +790,6 @@ LogicalResult emitc::SubscriptOp::verify() {
return success();
}
-//===----------------------------------------------------------------------===//
-// TableGen'd op method definitions
-//===----------------------------------------------------------------------===//
-
-#define GET_OP_CLASSES
-#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
-
//===----------------------------------------------------------------------===//
// EmitC Enums
//===----------------------------------------------------------------------===//
@@ -896,3 +889,113 @@ LogicalResult mlir::emitc::OpaqueType::verify(
}
return success();
}
+
+//===----------------------------------------------------------------------===//
+// GlobalOp
+//===----------------------------------------------------------------------===//
+static void printEmitCGlobalOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
+ TypeAttr type,
+ Attribute initialValue) {
+ p << type;
+ if (initialValue) {
+ p << " = ";
+ p.printAttributeWithoutType(initialValue);
+ }
+}
+
+static Type getInitializerTypeForGlobal(Type type) {
+ if (auto array = llvm::dyn_cast<ArrayType>(type))
+ return RankedTensorType::get(array.getShape(), array.getElementType());
+ return type;
+}
+
+static ParseResult
+parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
+ Attribute &initialValue) {
+ Type type;
+ if (parser.parseType(type))
+ return failure();
+
+ typeAttr = TypeAttr::get(type);
+
+ if (parser.parseOptionalEqual())
+ return success();
+
+ if (parser.parseAttribute(initialValue, getInitializerTypeForGlobal(type)))
+ return failure();
+
+ if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr>(initialValue))
+ return parser.emitError(parser.getNameLoc())
+ << "initial value should be a unit, integer, float or elements "
+ "attribute";
+ return success();
+}
+
+LogicalResult GlobalOp::verify() {
+ // Verify that the initial value, if present, is either a unit attribute or
+ // an elements attribute.
+ if (getInitialValue().has_value()) {
+ Attribute initValue = getInitialValue().value();
+ // Check that the type of the initial value is compatible with the type of
+ // the global variable.
+ if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
+ auto arrayType = llvm::dyn_cast<ArrayType>(getType());
+ if (!arrayType)
+ return emitOpError("expected array type, but got ") << getType();
+
+ Type initType = elementsAttr.getType();
+ Type tensorType = getInitializerTypeForGlobal(getType());
+ if (initType != tensorType) {
+ return emitOpError("initial value expected to be of type ")
+ << getType() << ", but was of type " << initType;
+ }
+ } else if (auto intAttr = dyn_cast<IntegerAttr>(initValue)) {
+ if (intAttr.getType() != getType()) {
+ return emitOpError("initial value expected to be of type ")
+ << getType() << ", but was of type " << intAttr.getType();
+ }
+ } else if (auto floatAttr = dyn_cast<FloatAttr>(initValue)) {
+ if (floatAttr.getType() != getType()) {
+ return emitOpError("initial value expected to be of type ")
+ << getType() << ", but was of type " << floatAttr.getType();
+ }
+ } else {
+ return emitOpError(
+ "initial value should be a unit, integer, float or elements "
+ "attribute, but got ")
+ << initValue;
+ }
+ }
+ if (getStaticSpecifier() && getExternSpecifier()) {
+ return emitOpError("cannot have both static and extern specifiers");
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// GetGlobalOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ // Verify that the type matches the type of the global variable.
+ auto global =
+ symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
+ if (!global)
+ return emitOpError("'")
+ << getName() << "' does not reference a valid emitc.global";
+
+ Type resultType = getResult().getType();
+ if (global.getType() != resultType)
+ return emitOpError("result type ")
+ << resultType << " does not match type " << global.getType()
+ << " of the global @" << getName();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 95c7af2f07be46..820bb65dff0ac9 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -154,6 +154,9 @@ struct CppEmitter {
/// any result type could not be converted.
LogicalResult emitAssignPrefix(Operation &op);
+ /// Emits a global variable declaration or definition.
+ LogicalResult emitGlobalVariable(GlobalOp op);
+
/// Emits a label for the block.
LogicalResult emitLabel(Block &block);
@@ -344,6 +347,12 @@ static LogicalResult printOperation(CppEmitter &emitter,
return printConstantOp(emitter, operation, value);
}
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::GlobalOp globalOp) {
+
+ return emitter.emitGlobalVariable(globalOp);
+}
+
static LogicalResult printOperation(CppEmitter &emitter,
emitc::AssignOp assignOp) {
OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
@@ -354,6 +363,13 @@ static LogicalResult printOperation(CppEmitter &emitter,
return emitter.emitOperand(assignOp.getValue());
}
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::GetGlobalOp op) {
+ // Add name to cache so that `hasValueInScope` works.
+ emitter.getOrCreateName(op.getResult());
+ return success();
+}
+
static LogicalResult printOperation(CppEmitter &emitter,
emitc::SubscriptOp subscriptOp) {
// Add name to cache so that `hasValueInScope` works.
@@ -1120,6 +1136,9 @@ StringRef CppEmitter::getOrCreateName(Value val) {
if (auto subscript =
dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
valueMapper.insert(val, getSubscriptName(subscript));
+ } else if (auto getGlobal = dyn_cast_if_present<emitc::GetGlobalOp>(
+ val.getDefiningOp())) {
+ valueMapper.insert(val, getGlobal.getName().str());
} else {
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
}
@@ -1385,6 +1404,30 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
return success();
}
+LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) {
+ if (op.getExternSpecifier())
+ os << "extern ";
+ else if (op.getStaticSpecifier())
+ os << "static ";
+ if (op.getConstSpecifier())
+ os << "const ";
+
+ if (failed(emitVariableDeclaration(op->getLoc(), op.getType(),
+ op.getSymName()))) {
+ return failure();
+ }
+
+ std::optional<Attribute> initialValue = op.getInitialValue();
+ if (initialValue && !isa<UnitAttr>(*initialValue)) {
+ os << " = ";
+ if (failed(emitAttribute(op->getLoc(), *initialValue)))
+ return failure();
+ }
+
+ os << ";";
+ return success();
+}
+
LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
// If op is being emitted as part of an expression, bail out.
if (getEmittedExpression())
@@ -1445,11 +1488,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
- emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp,
- emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
- emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SubscriptOp,
- emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
- emitc::VerbatimOp>(
+ emitc::GlobalOp, emitc::GetGlobalOp, emitc::IfOp,
+ emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
+ emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
+ emitc::SubOp, emitc::SubscriptOp, emitc::UnaryMinusOp,
+ emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::FuncOp, func::ReturnOp>(
@@ -1462,7 +1505,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
if (failed(status))
return failure();
- if (isa<emitc::LiteralOp, emitc::SubscriptOp>(op))
+ if (isa<emitc::LiteralOp, emitc::SubscriptOp, emitc::GetGlobalOp>(op))
return success();
if (getEmittedExpression() ||
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
index 390190d341e5ae..89dafa7529ed53 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
@@ -38,3 +38,8 @@ func.func @zero_rank() {
%0 = memref.alloca() : memref<f32>
return
}
+
+// -----
+
+// expected-error@+1 {{failed to legalize operation 'memref.global'}}
+memref.global "nested" constant @nested_global : memref<3x7xf32>
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index 9793b2d6d7832f..54129f4f6cbc8e 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -11,6 +11,7 @@ func.func @memref_store(%v : f32, %i: index, %j: index) {
memref.store %v, %0[%i, %j] : memref<4x8xf32>
return
}
+
// -----
// CHECK-LABEL: memref_load
@@ -26,3 +27,19 @@ func.func @memref_load(%i: index, %j: index) -> f32 {
// CHECK: return %[[VAR]] : f32
return %1 : f32
}
+
+// -----
+
+// CHECK-LABEL: globals
+module @globals {
+ memref.global "private" constant @internal_global : memref<3x7xf32> = dense<4.0>
+ // CHECK: emitc.global static const @internal_global : !emitc.array<3x7xf32> = dense<4.000000e+00>
+ memref.global @public_global : memref<3x7xf32>
+ // CHECK: emitc.global extern @public_global : !emitc.array<3x7xf32>
+
+ func.func @use_global() {
+ // CHECK: emitc.get_global @public_global : !emitc.array<3x7xf32>
+ %0 = memref.get_global @public_global : memref<3x7xf32>
+ return
+ }
+}
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 22423cf61b5556..82fa459a5c9270 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -395,3 +395,18 @@ func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2:
%0 = emitc.subscript %arg0[%arg2] : <4x8xf32>, index
return
}
+
+// -----
+
+// expected-error @+1 {{'emitc.global' op cannot have both static and extern specifiers}}
+emitc.global extern static @uninit : i32
+
+// -----
+
+emitc.global @myglobal : !emitc.array<2xf32>
+
+func.func @use_global() {
+ // expected-error @+1 {{'emitc.get_global' op result type 'f32' does not match type '!emitc.array<2xf32>' of the global @myglobal}}
+ %0 = emitc.get_global @myglobal : f32
+ return
+}
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 5f00a295ed740e..3c987937f17212 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -224,3 +224,17 @@ emitc.verbatim "#endif // __cplusplus"
emitc.verbatim "typedef int32_t i32;"
emitc.verbatim "typedef float f32;"
+
+
+emitc.global @uninit : i32
+emitc.global @myglobal_int : i32 = 4
+emitc.global extern @external_linkage : i32
+emitc.global static @internal_linkage : i32
+emitc.global @myglobal : !emitc.array<2xf32> = dense<4.000000e+00>
+emitc.global const @myconstant : !emitc.array<2xi16> = dense<2>
+
+func.func @use_global(%i: index) -> f32 {
+ %0 = emitc.get_global @myglobal : !emitc.array<2xf32>
+ %1 = emitc.subscript %0[%i] : <2xf32>, index
+ return %1 : f32
+}
diff --git a/mlir/test/Target/Cpp/global.mlir b/mlir/test/Target/Cpp/global.mlir
new file mode 100644
index 00000000000000..730d5e0337336f
--- /dev/null
+++ b/mlir/test/Target/Cpp/global.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
+// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
+
+emitc.global extern @decl : i8
+// CHECK: extern int8_t decl;
+
+emitc.global @uninit : i32
+// CHECK: int32_t uninit;
+
+emitc.global @myglobal_int : i32 = 4
+// CHECK: int32_t myglobal_int = 4;
+
+emitc.global @myglobal : !emitc.array<2xf32> = dense<4.000000e+00>
+// CHECK: float myglobal[2] = {4.000000000e+00f, 4.000000000e+00f};
+
+emitc.global const @myconstant : !emitc.array<2xi16> = dense<2>
+// CHECK: const int16_t myconstant[2]...
[truncated]
|
This adds - `emitc.global` and `emitc.get_global` ops to model global variables similar to how `memref.global` and `memref.get_global` work. - translation of those ops to C++ - lowering of `memref.global` and `memref.get_global` into those ops
d34f5b2
to
a711c68
Compare
I've got a high level question before reviewing thoroughly. Should the |
I view emitc as a mostly syntactic dialect, so What do you mean by |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have two two questions/suggestions. Otherwise this looks great and adds one of the big missing features, thanks for working on this.
You are right, I thought we were verifying operand type and pointee type in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice to see globals support rising! Some minor remarks after going though part of the newly added code.
Co-authored-by: Simon Camphausen <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks.
This adds - `emitc.global` and `emitc.get_global` ops to model global variables similar to how `memref.global` and `memref.get_global` work. - translation of those ops to C++ - lowering of `memref.global` and `memref.get_global` into those ops --------- Co-authored-by: Simon Camphausen <[email protected]>
This adds
emitc.global
andemitc.get_global
ops to model global variables similar to howmemref.global
andmemref.get_global
work.memref.global
andmemref.get_global
into those ops