Skip to content

Commit 6548465

Browse files
mgehre-amdSimon Camphausen
and
Simon Camphausen
authored
EmitC: Add emitc.global and emitc.get_global (#145) (llvm#88701)
This adds - `emitc.global` and `emitc.get_global` ops to model global variables similar to how `memref.global` and `memref.get_global` work. - translation of those ops to C++ - lowering of `memref.global` and `memref.get_global` into those ops --------- Co-authored-by: Simon Camphausen <[email protected]>
1 parent 3ea9ed4 commit 6548465

File tree

9 files changed

+388
-15
lines changed

9 files changed

+388
-15
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

+69
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,75 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
10141014
let hasVerifier = 1;
10151015
}
10161016

1017+
def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> {
1018+
let summary = "A global variable";
1019+
let description = [{
1020+
The `emitc.global` operation declares or defines a named global variable.
1021+
The backing memory for the variable is allocated statically and is
1022+
described by the type of the variable.
1023+
Optionally, an `initial_value` can be provided.
1024+
Internal linkage can be specified using the `staticSpecifier` unit attribute
1025+
and external linkage can be specified using the `externSpecifier` unit attribute.
1026+
Note that the default linkage without those two keywords depends on whether
1027+
the target is C or C++ and whether the global variable is `const`.
1028+
The global variable can also be marked constant using the `constSpecifier`
1029+
unit attribute. Writing to such constant global variables is
1030+
undefined.
1031+
1032+
The global variable can be accessed by using the `emitc.get_global` to
1033+
retrieve the value for the global variable.
1034+
1035+
Example:
1036+
1037+
```mlir
1038+
// Global variable with an initial value.
1039+
emitc.global @x : emitc.array<2xf32> = dense<0.0, 2.0>
1040+
// External global variable
1041+
emitc.global extern @x : emitc.array<2xf32>
1042+
// Constant global variable with internal linkage
1043+
emitc.global static const @x : i32 = 0
1044+
```
1045+
}];
1046+
1047+
let arguments = (ins SymbolNameAttr:$sym_name,
1048+
TypeAttr:$type,
1049+
OptionalAttr<EmitC_OpaqueOrTypedAttr>:$initial_value,
1050+
UnitAttr:$extern_specifier,
1051+
UnitAttr:$static_specifier,
1052+
UnitAttr:$const_specifier);
1053+
1054+
let assemblyFormat = [{
1055+
(`extern` $extern_specifier^)?
1056+
(`static` $static_specifier^)?
1057+
(`const` $const_specifier^)?
1058+
$sym_name
1059+
`:` custom<EmitCGlobalOpTypeAndInitialValue>($type, $initial_value)
1060+
attr-dict
1061+
}];
1062+
1063+
let hasVerifier = 1;
1064+
}
1065+
1066+
def EmitC_GetGlobalOp : EmitC_Op<"get_global",
1067+
[Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
1068+
let summary = "Obtain access to a global variable";
1069+
let description = [{
1070+
The `emitc.get_global` operation retrieves the lvalue of a
1071+
named global variable. If the global variable is marked constant, assigning
1072+
to that lvalue is undefined.
1073+
1074+
Example:
1075+
1076+
```mlir
1077+
%x = emitc.get_global @foo : !emitc.array<2xf32>
1078+
```
1079+
}];
1080+
1081+
let arguments = (ins FlatSymbolRefAttr:$name);
1082+
let results = (outs EmitCType:$result);
1083+
let assemblyFormat = "$name `:` type($result) attr-dict";
1084+
}
1085+
10171086
def EmitC_VerbatimOp : EmitC_Op<"verbatim"> {
10181087
let summary = "Verbatim operation";
10191088
let description = [{

mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp

+68-2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,72 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
5050
}
5151
};
5252

53+
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
54+
using OpConversionPattern::OpConversionPattern;
55+
56+
LogicalResult
57+
matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
58+
ConversionPatternRewriter &rewriter) const override {
59+
60+
if (!op.getType().hasStaticShape()) {
61+
return rewriter.notifyMatchFailure(
62+
op.getLoc(), "cannot transform global with dynamic shape");
63+
}
64+
65+
if (op.getAlignment().value_or(1) > 1) {
66+
// TODO: Extend GlobalOp to specify alignment via the `alignas` specifier.
67+
return rewriter.notifyMatchFailure(
68+
op.getLoc(), "global variable with alignment requirement is "
69+
"currently not supported");
70+
}
71+
auto resultTy = getTypeConverter()->convertType(op.getType());
72+
if (!resultTy) {
73+
return rewriter.notifyMatchFailure(op.getLoc(),
74+
"cannot convert result type");
75+
}
76+
77+
SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(op);
78+
if (visibility != SymbolTable::Visibility::Public &&
79+
visibility != SymbolTable::Visibility::Private) {
80+
return rewriter.notifyMatchFailure(
81+
op.getLoc(),
82+
"only public and private visibility is currently supported");
83+
}
84+
// We are explicit in specifing the linkage because the default linkage
85+
// for constants is different in C and C++.
86+
bool staticSpecifier = visibility == SymbolTable::Visibility::Private;
87+
bool externSpecifier = !staticSpecifier;
88+
89+
Attribute initialValue = operands.getInitialValueAttr();
90+
if (isa_and_present<UnitAttr>(initialValue))
91+
initialValue = {};
92+
93+
rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
94+
op, operands.getSymName(), resultTy, initialValue, externSpecifier,
95+
staticSpecifier, operands.getConstant());
96+
return success();
97+
}
98+
};
99+
100+
struct ConvertGetGlobal final
101+
: public OpConversionPattern<memref::GetGlobalOp> {
102+
using OpConversionPattern::OpConversionPattern;
103+
104+
LogicalResult
105+
matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
106+
ConversionPatternRewriter &rewriter) const override {
107+
108+
auto resultTy = getTypeConverter()->convertType(op.getType());
109+
if (!resultTy) {
110+
return rewriter.notifyMatchFailure(op.getLoc(),
111+
"cannot convert result type");
112+
}
113+
rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
114+
operands.getNameAttr());
115+
return success();
116+
}
117+
};
118+
53119
struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
54120
using OpConversionPattern::OpConversionPattern;
55121

@@ -120,6 +186,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
120186

121187
void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
122188
TypeConverter &converter) {
123-
patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
124-
patterns.getContext());
189+
patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
190+
ConvertStore>(converter, patterns.getContext());
125191
}

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

+111-7
Original file line numberDiff line numberDiff line change
@@ -881,13 +881,6 @@ LogicalResult emitc::SubscriptOp::verify() {
881881
return success();
882882
}
883883

884-
//===----------------------------------------------------------------------===//
885-
// TableGen'd op method definitions
886-
//===----------------------------------------------------------------------===//
887-
888-
#define GET_OP_CLASSES
889-
#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
890-
891884
//===----------------------------------------------------------------------===//
892885
// EmitC Enums
893886
//===----------------------------------------------------------------------===//
@@ -987,3 +980,114 @@ LogicalResult mlir::emitc::OpaqueType::verify(
987980
}
988981
return success();
989982
}
983+
984+
//===----------------------------------------------------------------------===//
985+
// GlobalOp
986+
//===----------------------------------------------------------------------===//
987+
static void printEmitCGlobalOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
988+
TypeAttr type,
989+
Attribute initialValue) {
990+
p << type;
991+
if (initialValue) {
992+
p << " = ";
993+
p.printAttributeWithoutType(initialValue);
994+
}
995+
}
996+
997+
static Type getInitializerTypeForGlobal(Type type) {
998+
if (auto array = llvm::dyn_cast<ArrayType>(type))
999+
return RankedTensorType::get(array.getShape(), array.getElementType());
1000+
return type;
1001+
}
1002+
1003+
static ParseResult
1004+
parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1005+
Attribute &initialValue) {
1006+
Type type;
1007+
if (parser.parseType(type))
1008+
return failure();
1009+
1010+
typeAttr = TypeAttr::get(type);
1011+
1012+
if (parser.parseOptionalEqual())
1013+
return success();
1014+
1015+
if (parser.parseAttribute(initialValue, getInitializerTypeForGlobal(type)))
1016+
return failure();
1017+
1018+
if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
1019+
initialValue))
1020+
return parser.emitError(parser.getNameLoc())
1021+
<< "initial value should be a integer, float, elements or opaque "
1022+
"attribute";
1023+
return success();
1024+
}
1025+
1026+
LogicalResult GlobalOp::verify() {
1027+
if (!isSupportedEmitCType(getType())) {
1028+
return emitOpError("expected valid emitc type");
1029+
}
1030+
if (getInitialValue().has_value()) {
1031+
Attribute initValue = getInitialValue().value();
1032+
// Check that the type of the initial value is compatible with the type of
1033+
// the global variable.
1034+
if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1035+
auto arrayType = llvm::dyn_cast<ArrayType>(getType());
1036+
if (!arrayType)
1037+
return emitOpError("expected array type, but got ") << getType();
1038+
1039+
Type initType = elementsAttr.getType();
1040+
Type tensorType = getInitializerTypeForGlobal(getType());
1041+
if (initType != tensorType) {
1042+
return emitOpError("initial value expected to be of type ")
1043+
<< getType() << ", but was of type " << initType;
1044+
}
1045+
} else if (auto intAttr = dyn_cast<IntegerAttr>(initValue)) {
1046+
if (intAttr.getType() != getType()) {
1047+
return emitOpError("initial value expected to be of type ")
1048+
<< getType() << ", but was of type " << intAttr.getType();
1049+
}
1050+
} else if (auto floatAttr = dyn_cast<FloatAttr>(initValue)) {
1051+
if (floatAttr.getType() != getType()) {
1052+
return emitOpError("initial value expected to be of type ")
1053+
<< getType() << ", but was of type " << floatAttr.getType();
1054+
}
1055+
} else if (!isa<emitc::OpaqueAttr>(initValue)) {
1056+
return emitOpError("initial value should be a integer, float, elements "
1057+
"or opaque attribute, but got ")
1058+
<< initValue;
1059+
}
1060+
}
1061+
if (getStaticSpecifier() && getExternSpecifier()) {
1062+
return emitOpError("cannot have both static and extern specifiers");
1063+
}
1064+
return success();
1065+
}
1066+
1067+
//===----------------------------------------------------------------------===//
1068+
// GetGlobalOp
1069+
//===----------------------------------------------------------------------===//
1070+
1071+
LogicalResult
1072+
GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1073+
// Verify that the type matches the type of the global variable.
1074+
auto global =
1075+
symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1076+
if (!global)
1077+
return emitOpError("'")
1078+
<< getName() << "' does not reference a valid emitc.global";
1079+
1080+
Type resultType = getResult().getType();
1081+
if (global.getType() != resultType)
1082+
return emitOpError("result type ")
1083+
<< resultType << " does not match type " << global.getType()
1084+
<< " of the global @" << getName();
1085+
return success();
1086+
}
1087+
1088+
//===----------------------------------------------------------------------===//
1089+
// TableGen'd op method definitions
1090+
//===----------------------------------------------------------------------===//
1091+
1092+
#define GET_OP_CLASSES
1093+
#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"

mlir/lib/Target/Cpp/TranslateToCpp.cpp

+49-6
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ struct CppEmitter {
154154
/// any result type could not be converted.
155155
LogicalResult emitAssignPrefix(Operation &op);
156156

157+
/// Emits a global variable declaration or definition.
158+
LogicalResult emitGlobalVariable(GlobalOp op);
159+
157160
/// Emits a label for the block.
158161
LogicalResult emitLabel(Block &block);
159162

@@ -344,6 +347,12 @@ static LogicalResult printOperation(CppEmitter &emitter,
344347
return printConstantOp(emitter, operation, value);
345348
}
346349

350+
static LogicalResult printOperation(CppEmitter &emitter,
351+
emitc::GlobalOp globalOp) {
352+
353+
return emitter.emitGlobalVariable(globalOp);
354+
}
355+
347356
static LogicalResult printOperation(CppEmitter &emitter,
348357
emitc::AssignOp assignOp) {
349358
OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
@@ -354,6 +363,13 @@ static LogicalResult printOperation(CppEmitter &emitter,
354363
return emitter.emitOperand(assignOp.getValue());
355364
}
356365

366+
static LogicalResult printOperation(CppEmitter &emitter,
367+
emitc::GetGlobalOp op) {
368+
// Add name to cache so that `hasValueInScope` works.
369+
emitter.getOrCreateName(op.getResult());
370+
return success();
371+
}
372+
357373
static LogicalResult printOperation(CppEmitter &emitter,
358374
emitc::SubscriptOp subscriptOp) {
359375
// Add name to cache so that `hasValueInScope` works.
@@ -1119,6 +1135,9 @@ StringRef CppEmitter::getOrCreateName(Value val) {
11191135
if (auto subscript =
11201136
dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
11211137
valueMapper.insert(val, getSubscriptName(subscript));
1138+
} else if (auto getGlobal = dyn_cast_if_present<emitc::GetGlobalOp>(
1139+
val.getDefiningOp())) {
1140+
valueMapper.insert(val, getGlobal.getName().str());
11221141
} else {
11231142
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
11241143
}
@@ -1384,6 +1403,30 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
13841403
return success();
13851404
}
13861405

1406+
LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) {
1407+
if (op.getExternSpecifier())
1408+
os << "extern ";
1409+
else if (op.getStaticSpecifier())
1410+
os << "static ";
1411+
if (op.getConstSpecifier())
1412+
os << "const ";
1413+
1414+
if (failed(emitVariableDeclaration(op->getLoc(), op.getType(),
1415+
op.getSymName()))) {
1416+
return failure();
1417+
}
1418+
1419+
std::optional<Attribute> initialValue = op.getInitialValue();
1420+
if (initialValue) {
1421+
os << " = ";
1422+
if (failed(emitAttribute(op->getLoc(), *initialValue)))
1423+
return failure();
1424+
}
1425+
1426+
os << ";";
1427+
return success();
1428+
}
1429+
13871430
LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
13881431
// If op is being emitted as part of an expression, bail out.
13891432
if (getEmittedExpression())
@@ -1444,11 +1487,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
14441487
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
14451488
emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
14461489
emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
1447-
emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp,
1448-
emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
1449-
emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SubscriptOp,
1450-
emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
1451-
emitc::VerbatimOp>(
1490+
emitc::GlobalOp, emitc::GetGlobalOp, emitc::IfOp,
1491+
emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
1492+
emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
1493+
emitc::SubOp, emitc::SubscriptOp, emitc::UnaryMinusOp,
1494+
emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>(
14521495
[&](auto op) { return printOperation(*this, op); })
14531496
// Func ops.
14541497
.Case<func::CallOp, func::FuncOp, func::ReturnOp>(
@@ -1461,7 +1504,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
14611504
if (failed(status))
14621505
return failure();
14631506

1464-
if (isa<emitc::LiteralOp, emitc::SubscriptOp>(op))
1507+
if (isa<emitc::LiteralOp, emitc::SubscriptOp, emitc::GetGlobalOp>(op))
14651508
return success();
14661509

14671510
if (getEmittedExpression() ||

mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir

+5
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,8 @@ func.func @zero_rank() {
3838
%0 = memref.alloca() : memref<f32>
3939
return
4040
}
41+
42+
// -----
43+
44+
// expected-error@+1 {{failed to legalize operation 'memref.global'}}
45+
memref.global "nested" constant @nested_global : memref<3x7xf32>

0 commit comments

Comments
 (0)