Skip to content

Commit a711c68

Browse files
committed
EmitC: Add emitc.global and emitc.get_global (#145)
This adds - `emitc.global` and `emitc.get_global` ops to model global variables similar to how `memref.global` and `memref.get_global` work. - translation of those ops to C++ - lowering of `memref.global` and `memref.get_global` into those ops
1 parent f82d018 commit a711c68

File tree

9 files changed

+376
-15
lines changed

9 files changed

+376
-15
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

+69
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,75 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
10161016
let hasVerifier = 1;
10171017
}
10181018

1019+
def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> {
1020+
let summary = "A global variable";
1021+
let description = [{
1022+
The `emitc.global` operation declares or defines a named global variable.
1023+
The backing memory for the variable is allocated statically and is
1024+
described by the type of the variable.
1025+
Optionally, and `initial_value` can be provided.
1026+
Internal linkage can be specified using the `staticSpecifier` unit attribute
1027+
and external linkage can be specified using the `externSpecifier` unit attribute.
1028+
Note that the default linkage without those two keywords depends on whether
1029+
the target is C or C++ and whether the global variable is `const`.
1030+
The global variable can also be marked constant using the `constSpecifier`
1031+
unit attribute. Writing to such constant global variables is
1032+
undefined.
1033+
1034+
The global variable can be accessed by using the `emitc.get_global` to
1035+
retrieve the value for the global variable.
1036+
1037+
Example:
1038+
1039+
```mlir
1040+
// Global variable with an initial value.
1041+
emitc.global @x : emitc.array<2xf32> = dense<0.0, 2.0>
1042+
// External global variable
1043+
emitc.global extern @x : emitc.array<2xf32>
1044+
// Constant global variable with internal linkage
1045+
emitc.global static const @x : i32 = 0
1046+
```
1047+
}];
1048+
1049+
let arguments = (ins SymbolNameAttr:$sym_name,
1050+
TypeAttr:$type,
1051+
OptionalAttr<EmitC_OpaqueOrTypedAttr>:$initial_value,
1052+
UnitAttr:$externSpecifier,
1053+
UnitAttr:$staticSpecifier,
1054+
UnitAttr:$constSpecifier);
1055+
1056+
let assemblyFormat = [{
1057+
(`extern` $externSpecifier^)?
1058+
(`static` $staticSpecifier^)?
1059+
(`const` $constSpecifier^)?
1060+
$sym_name
1061+
`:` custom<EmitCGlobalOpTypeAndInitialValue>($type, $initial_value)
1062+
attr-dict
1063+
}];
1064+
1065+
let hasVerifier = 1;
1066+
}
1067+
1068+
def EmitC_GetGlobalOp : EmitC_Op<"get_global",
1069+
[Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
1070+
let summary = "Obtain access to a global variable";
1071+
let description = [{
1072+
The `emitc.get_global` operation retrieves the lvalue of a
1073+
named global variable. If the global variable is marked constant, assigning
1074+
to that lvalue is undefined.
1075+
1076+
Example:
1077+
1078+
```mlir
1079+
%x = emitc.get_global @foo : !emitc.array<2xf32>
1080+
```
1081+
}];
1082+
1083+
let arguments = (ins FlatSymbolRefAttr:$name);
1084+
let results = (outs AnyType:$result);
1085+
let assemblyFormat = "$name `:` type($result) attr-dict";
1086+
}
1087+
10191088
def EmitC_VerbatimOp : EmitC_Op<"verbatim"> {
10201089
let summary = "Verbatim operation";
10211090
let description = [{

mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp

+64-2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,68 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
5050
}
5151
};
5252

53+
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
54+
using OpConversionPattern::OpConversionPattern;
55+
56+
LogicalResult
57+
matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
58+
ConversionPatternRewriter &rewriter) const override {
59+
60+
if (!op.getType().hasStaticShape()) {
61+
return rewriter.notifyMatchFailure(
62+
op.getLoc(), "cannot transform global with dynamic shape");
63+
}
64+
65+
if (op.getAlignment().value_or(1) > 1) {
66+
// TODO: Extend GlobalOp to specify alignment via the `alignas` specifier.
67+
return rewriter.notifyMatchFailure(
68+
op.getLoc(), "global variable with alignment requirement is "
69+
"currently not supported");
70+
}
71+
auto resultTy = getTypeConverter()->convertType(op.getType());
72+
if (!resultTy) {
73+
return rewriter.notifyMatchFailure(op.getLoc(),
74+
"cannot convert result type");
75+
}
76+
77+
SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(op);
78+
if (visibility != SymbolTable::Visibility::Public &&
79+
visibility != SymbolTable::Visibility::Private) {
80+
return rewriter.notifyMatchFailure(
81+
op.getLoc(),
82+
"only public and private visibility is currently supported");
83+
}
84+
// We are explicit in specifier the linkage because the default linkage
85+
// for constants is different in C and C++.
86+
bool staticSpecifier = visibility == SymbolTable::Visibility::Private;
87+
bool externSpecifier = !staticSpecifier;
88+
89+
rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
90+
op, operands.getSymName(), resultTy, operands.getInitialValueAttr(),
91+
externSpecifier, staticSpecifier, operands.getConstant());
92+
return success();
93+
}
94+
};
95+
96+
struct ConvertGetGlobal final
97+
: public OpConversionPattern<memref::GetGlobalOp> {
98+
using OpConversionPattern::OpConversionPattern;
99+
100+
LogicalResult
101+
matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
102+
ConversionPatternRewriter &rewriter) const override {
103+
104+
auto resultTy = getTypeConverter()->convertType(op.getType());
105+
if (!resultTy) {
106+
return rewriter.notifyMatchFailure(op.getLoc(),
107+
"cannot convert result type");
108+
}
109+
rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
110+
operands.getNameAttr());
111+
return success();
112+
}
113+
};
114+
53115
struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
54116
using OpConversionPattern::OpConversionPattern;
55117

@@ -109,6 +171,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
109171

110172
void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
111173
TypeConverter &converter) {
112-
patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
113-
patterns.getContext());
174+
patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
175+
ConvertStore>(converter, patterns.getContext());
114176
}

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

+108-7
Original file line numberDiff line numberDiff line change
@@ -790,13 +790,6 @@ LogicalResult emitc::SubscriptOp::verify() {
790790
return success();
791791
}
792792

793-
//===----------------------------------------------------------------------===//
794-
// TableGen'd op method definitions
795-
//===----------------------------------------------------------------------===//
796-
797-
#define GET_OP_CLASSES
798-
#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
799-
800793
//===----------------------------------------------------------------------===//
801794
// EmitC Enums
802795
//===----------------------------------------------------------------------===//
@@ -896,3 +889,111 @@ LogicalResult mlir::emitc::OpaqueType::verify(
896889
}
897890
return success();
898891
}
892+
893+
//===----------------------------------------------------------------------===//
894+
// GlobalOp
895+
//===----------------------------------------------------------------------===//
896+
static void printEmitCGlobalOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
897+
TypeAttr type,
898+
Attribute initialValue) {
899+
p << type;
900+
if (initialValue) {
901+
p << " = ";
902+
p.printAttributeWithoutType(initialValue);
903+
}
904+
}
905+
906+
static Type getInitializerTypeForGlobal(Type type) {
907+
if (auto array = llvm::dyn_cast<ArrayType>(type))
908+
return RankedTensorType::get(array.getShape(), array.getElementType());
909+
return type;
910+
}
911+
912+
static ParseResult
913+
parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
914+
Attribute &initialValue) {
915+
Type type;
916+
if (parser.parseType(type))
917+
return failure();
918+
919+
typeAttr = TypeAttr::get(type);
920+
921+
if (parser.parseOptionalEqual())
922+
return success();
923+
924+
if (parser.parseAttribute(initialValue, getInitializerTypeForGlobal(type)))
925+
return failure();
926+
927+
if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr>(initialValue))
928+
return parser.emitError(parser.getNameLoc())
929+
<< "initial value should be a unit, integer, float or elements "
930+
"attribute";
931+
return success();
932+
}
933+
934+
LogicalResult GlobalOp::verify() {
935+
if (getInitialValue().has_value()) {
936+
Attribute initValue = getInitialValue().value();
937+
// Check that the type of the initial value is compatible with the type of
938+
// the global variable.
939+
if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
940+
auto arrayType = llvm::dyn_cast<ArrayType>(getType());
941+
if (!arrayType)
942+
return emitOpError("expected array type, but got ") << getType();
943+
944+
Type initType = elementsAttr.getType();
945+
Type tensorType = getInitializerTypeForGlobal(getType());
946+
if (initType != tensorType) {
947+
return emitOpError("initial value expected to be of type ")
948+
<< getType() << ", but was of type " << initType;
949+
}
950+
} else if (auto intAttr = dyn_cast<IntegerAttr>(initValue)) {
951+
if (intAttr.getType() != getType()) {
952+
return emitOpError("initial value expected to be of type ")
953+
<< getType() << ", but was of type " << intAttr.getType();
954+
}
955+
} else if (auto floatAttr = dyn_cast<FloatAttr>(initValue)) {
956+
if (floatAttr.getType() != getType()) {
957+
return emitOpError("initial value expected to be of type ")
958+
<< getType() << ", but was of type " << floatAttr.getType();
959+
}
960+
} else {
961+
return emitOpError(
962+
"initial value should be a unit, integer, float or elements "
963+
"attribute, but got ")
964+
<< initValue;
965+
}
966+
}
967+
if (getStaticSpecifier() && getExternSpecifier()) {
968+
return emitOpError("cannot have both static and extern specifiers");
969+
}
970+
return success();
971+
}
972+
973+
//===----------------------------------------------------------------------===//
974+
// GetGlobalOp
975+
//===----------------------------------------------------------------------===//
976+
977+
LogicalResult
978+
GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
979+
// Verify that the type matches the type of the global variable.
980+
auto global =
981+
symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
982+
if (!global)
983+
return emitOpError("'")
984+
<< getName() << "' does not reference a valid emitc.global";
985+
986+
Type resultType = getResult().getType();
987+
if (global.getType() != resultType)
988+
return emitOpError("result type ")
989+
<< resultType << " does not match type " << global.getType()
990+
<< " of the global @" << getName();
991+
return success();
992+
}
993+
994+
//===----------------------------------------------------------------------===//
995+
// TableGen'd op method definitions
996+
//===----------------------------------------------------------------------===//
997+
998+
#define GET_OP_CLASSES
999+
#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"

mlir/lib/Target/Cpp/TranslateToCpp.cpp

+49-6
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ struct CppEmitter {
154154
/// any result type could not be converted.
155155
LogicalResult emitAssignPrefix(Operation &op);
156156

157+
/// Emits a global variable declaration or definition.
158+
LogicalResult emitGlobalVariable(GlobalOp op);
159+
157160
/// Emits a label for the block.
158161
LogicalResult emitLabel(Block &block);
159162

@@ -344,6 +347,12 @@ static LogicalResult printOperation(CppEmitter &emitter,
344347
return printConstantOp(emitter, operation, value);
345348
}
346349

350+
static LogicalResult printOperation(CppEmitter &emitter,
351+
emitc::GlobalOp globalOp) {
352+
353+
return emitter.emitGlobalVariable(globalOp);
354+
}
355+
347356
static LogicalResult printOperation(CppEmitter &emitter,
348357
emitc::AssignOp assignOp) {
349358
OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
@@ -354,6 +363,13 @@ static LogicalResult printOperation(CppEmitter &emitter,
354363
return emitter.emitOperand(assignOp.getValue());
355364
}
356365

366+
static LogicalResult printOperation(CppEmitter &emitter,
367+
emitc::GetGlobalOp op) {
368+
// Add name to cache so that `hasValueInScope` works.
369+
emitter.getOrCreateName(op.getResult());
370+
return success();
371+
}
372+
357373
static LogicalResult printOperation(CppEmitter &emitter,
358374
emitc::SubscriptOp subscriptOp) {
359375
// Add name to cache so that `hasValueInScope` works.
@@ -1120,6 +1136,9 @@ StringRef CppEmitter::getOrCreateName(Value val) {
11201136
if (auto subscript =
11211137
dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
11221138
valueMapper.insert(val, getSubscriptName(subscript));
1139+
} else if (auto getGlobal = dyn_cast_if_present<emitc::GetGlobalOp>(
1140+
val.getDefiningOp())) {
1141+
valueMapper.insert(val, getGlobal.getName().str());
11231142
} else {
11241143
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
11251144
}
@@ -1385,6 +1404,30 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
13851404
return success();
13861405
}
13871406

1407+
LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) {
1408+
if (op.getExternSpecifier())
1409+
os << "extern ";
1410+
else if (op.getStaticSpecifier())
1411+
os << "static ";
1412+
if (op.getConstSpecifier())
1413+
os << "const ";
1414+
1415+
if (failed(emitVariableDeclaration(op->getLoc(), op.getType(),
1416+
op.getSymName()))) {
1417+
return failure();
1418+
}
1419+
1420+
std::optional<Attribute> initialValue = op.getInitialValue();
1421+
if (initialValue && !isa<UnitAttr>(*initialValue)) {
1422+
os << " = ";
1423+
if (failed(emitAttribute(op->getLoc(), *initialValue)))
1424+
return failure();
1425+
}
1426+
1427+
os << ";";
1428+
return success();
1429+
}
1430+
13881431
LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
13891432
// If op is being emitted as part of an expression, bail out.
13901433
if (getEmittedExpression())
@@ -1445,11 +1488,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
14451488
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
14461489
emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
14471490
emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
1448-
emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp,
1449-
emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
1450-
emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SubscriptOp,
1451-
emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
1452-
emitc::VerbatimOp>(
1491+
emitc::GlobalOp, emitc::GetGlobalOp, emitc::IfOp,
1492+
emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
1493+
emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
1494+
emitc::SubOp, emitc::SubscriptOp, emitc::UnaryMinusOp,
1495+
emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>(
14531496
[&](auto op) { return printOperation(*this, op); })
14541497
// Func ops.
14551498
.Case<func::CallOp, func::FuncOp, func::ReturnOp>(
@@ -1462,7 +1505,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
14621505
if (failed(status))
14631506
return failure();
14641507

1465-
if (isa<emitc::LiteralOp, emitc::SubscriptOp>(op))
1508+
if (isa<emitc::LiteralOp, emitc::SubscriptOp, emitc::GetGlobalOp>(op))
14661509
return success();
14671510

14681511
if (getEmittedExpression() ||

mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir

+5
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,8 @@ func.func @zero_rank() {
3838
%0 = memref.alloca() : memref<f32>
3939
return
4040
}
41+
42+
// -----
43+
44+
// expected-error@+1 {{failed to legalize operation 'memref.global'}}
45+
memref.global "nested" constant @nested_global : memref<3x7xf32>

0 commit comments

Comments
 (0)