Skip to content

EmitC: Add emitc.global and emitc.get_global (#145) #88701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,75 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
let hasVerifier = 1;
}

def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> {
let summary = "A global variable";
let description = [{
The `emitc.global` operation declares or defines a named global variable.
The backing memory for the variable is allocated statically and is
described by the type of the variable.
Optionally, an `initial_value` can be provided.
Internal linkage can be specified using the `staticSpecifier` unit attribute
and external linkage can be specified using the `externSpecifier` unit attribute.
Note that the default linkage without those two keywords depends on whether
the target is C or C++ and whether the global variable is `const`.
The global variable can also be marked constant using the `constSpecifier`
unit attribute. Writing to such constant global variables is
undefined.

The global variable can be accessed by using the `emitc.get_global` to
retrieve the value for the global variable.

Example:

```mlir
// Global variable with an initial value.
emitc.global @x : emitc.array<2xf32> = dense<0.0, 2.0>
// External global variable
emitc.global extern @x : emitc.array<2xf32>
// Constant global variable with internal linkage
emitc.global static const @x : i32 = 0
```
}];

let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttr:$type,
OptionalAttr<EmitC_OpaqueOrTypedAttr>:$initial_value,
UnitAttr:$extern_specifier,
UnitAttr:$static_specifier,
UnitAttr:$const_specifier);

let assemblyFormat = [{
(`extern` $extern_specifier^)?
(`static` $static_specifier^)?
(`const` $const_specifier^)?
$sym_name
`:` custom<EmitCGlobalOpTypeAndInitialValue>($type, $initial_value)
attr-dict
}];

let hasVerifier = 1;
}

def EmitC_GetGlobalOp : EmitC_Op<"get_global",
[Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Obtain access to a global variable";
let description = [{
The `emitc.get_global` operation retrieves the lvalue of a
named global variable. If the global variable is marked constant, assigning
to that lvalue is undefined.

Example:

```mlir
%x = emitc.get_global @foo : !emitc.array<2xf32>
```
}];

let arguments = (ins FlatSymbolRefAttr:$name);
let results = (outs EmitCType:$result);
let assemblyFormat = "$name `:` type($result) attr-dict";
}

def EmitC_VerbatimOp : EmitC_Op<"verbatim"> {
let summary = "Verbatim operation";
let description = [{
Expand Down
70 changes: 68 additions & 2 deletions mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,72 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
}
};

struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {

if (!op.getType().hasStaticShape()) {
return rewriter.notifyMatchFailure(
op.getLoc(), "cannot transform global with dynamic shape");
}

if (op.getAlignment().value_or(1) > 1) {
// TODO: Extend GlobalOp to specify alignment via the `alignas` specifier.
return rewriter.notifyMatchFailure(
op.getLoc(), "global variable with alignment requirement is "
"currently not supported");
}
auto resultTy = getTypeConverter()->convertType(op.getType());
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
}

SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(op);
if (visibility != SymbolTable::Visibility::Public &&
visibility != SymbolTable::Visibility::Private) {
return rewriter.notifyMatchFailure(
op.getLoc(),
"only public and private visibility is currently supported");
}
// We are explicit in specifing the linkage because the default linkage
// for constants is different in C and C++.
bool staticSpecifier = visibility == SymbolTable::Visibility::Private;
bool externSpecifier = !staticSpecifier;

Attribute initialValue = operands.getInitialValueAttr();
if (isa_and_present<UnitAttr>(initialValue))
initialValue = {};

rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
op, operands.getSymName(), resultTy, initialValue, externSpecifier,
staticSpecifier, operands.getConstant());
return success();
}
};

struct ConvertGetGlobal final
: public OpConversionPattern<memref::GetGlobalOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {

auto resultTy = getTypeConverter()->convertType(op.getType());
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
}
rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
operands.getNameAttr());
return success();
}
};

struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
using OpConversionPattern::OpConversionPattern;

Expand Down Expand Up @@ -120,6 +186,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {

void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &converter) {
patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
patterns.getContext());
patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
ConvertStore>(converter, patterns.getContext());
}
118 changes: 111 additions & 7 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -881,13 +881,6 @@ LogicalResult emitc::SubscriptOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"

//===----------------------------------------------------------------------===//
// EmitC Enums
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -987,3 +980,114 @@ LogicalResult mlir::emitc::OpaqueType::verify(
}
return success();
}

//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
static void printEmitCGlobalOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
TypeAttr type,
Attribute initialValue) {
p << type;
if (initialValue) {
p << " = ";
p.printAttributeWithoutType(initialValue);
}
}

static Type getInitializerTypeForGlobal(Type type) {
if (auto array = llvm::dyn_cast<ArrayType>(type))
return RankedTensorType::get(array.getShape(), array.getElementType());
return type;
}

static ParseResult
parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
Attribute &initialValue) {
Type type;
if (parser.parseType(type))
return failure();

typeAttr = TypeAttr::get(type);

if (parser.parseOptionalEqual())
return success();

if (parser.parseAttribute(initialValue, getInitializerTypeForGlobal(type)))
return failure();

if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
initialValue))
return parser.emitError(parser.getNameLoc())
<< "initial value should be a integer, float, elements or opaque "
"attribute";
return success();
}

LogicalResult GlobalOp::verify() {
if (!isSupportedEmitCType(getType())) {
return emitOpError("expected valid emitc type");
}
if (getInitialValue().has_value()) {
Attribute initValue = getInitialValue().value();
// Check that the type of the initial value is compatible with the type of
// the global variable.
if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
auto arrayType = llvm::dyn_cast<ArrayType>(getType());
if (!arrayType)
return emitOpError("expected array type, but got ") << getType();

Type initType = elementsAttr.getType();
Type tensorType = getInitializerTypeForGlobal(getType());
if (initType != tensorType) {
return emitOpError("initial value expected to be of type ")
<< getType() << ", but was of type " << initType;
}
} else if (auto intAttr = dyn_cast<IntegerAttr>(initValue)) {
if (intAttr.getType() != getType()) {
return emitOpError("initial value expected to be of type ")
<< getType() << ", but was of type " << intAttr.getType();
}
} else if (auto floatAttr = dyn_cast<FloatAttr>(initValue)) {
if (floatAttr.getType() != getType()) {
return emitOpError("initial value expected to be of type ")
<< getType() << ", but was of type " << floatAttr.getType();
}
} else if (!isa<emitc::OpaqueAttr>(initValue)) {
return emitOpError("initial value should be a integer, float, elements "
"or opaque attribute, but got ")
<< initValue;
}
}
if (getStaticSpecifier() && getExternSpecifier()) {
return emitOpError("cannot have both static and extern specifiers");
}
return success();
}

//===----------------------------------------------------------------------===//
// GetGlobalOp
//===----------------------------------------------------------------------===//

LogicalResult
GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Verify that the type matches the type of the global variable.
auto global =
symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
if (!global)
return emitOpError("'")
<< getName() << "' does not reference a valid emitc.global";

Type resultType = getResult().getType();
if (global.getType() != resultType)
return emitOpError("result type ")
<< resultType << " does not match type " << global.getType()
<< " of the global @" << getName();
return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
55 changes: 49 additions & 6 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ struct CppEmitter {
/// any result type could not be converted.
LogicalResult emitAssignPrefix(Operation &op);

/// Emits a global variable declaration or definition.
LogicalResult emitGlobalVariable(GlobalOp op);

/// Emits a label for the block.
LogicalResult emitLabel(Block &block);

Expand Down Expand Up @@ -344,6 +347,12 @@ static LogicalResult printOperation(CppEmitter &emitter,
return printConstantOp(emitter, operation, value);
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::GlobalOp globalOp) {

return emitter.emitGlobalVariable(globalOp);
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::AssignOp assignOp) {
OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
Expand All @@ -354,6 +363,13 @@ static LogicalResult printOperation(CppEmitter &emitter,
return emitter.emitOperand(assignOp.getValue());
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::GetGlobalOp op) {
// Add name to cache so that `hasValueInScope` works.
emitter.getOrCreateName(op.getResult());
return success();
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::SubscriptOp subscriptOp) {
// Add name to cache so that `hasValueInScope` works.
Expand Down Expand Up @@ -1119,6 +1135,9 @@ StringRef CppEmitter::getOrCreateName(Value val) {
if (auto subscript =
dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
valueMapper.insert(val, getSubscriptName(subscript));
} else if (auto getGlobal = dyn_cast_if_present<emitc::GetGlobalOp>(
val.getDefiningOp())) {
valueMapper.insert(val, getGlobal.getName().str());
} else {
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
}
Expand Down Expand Up @@ -1384,6 +1403,30 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
return success();
}

LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) {
if (op.getExternSpecifier())
os << "extern ";
else if (op.getStaticSpecifier())
os << "static ";
if (op.getConstSpecifier())
os << "const ";

if (failed(emitVariableDeclaration(op->getLoc(), op.getType(),
op.getSymName()))) {
return failure();
}

std::optional<Attribute> initialValue = op.getInitialValue();
if (initialValue) {
os << " = ";
if (failed(emitAttribute(op->getLoc(), *initialValue)))
return failure();
}

os << ";";
return success();
}

LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
// If op is being emitted as part of an expression, bail out.
if (getEmittedExpression())
Expand Down Expand Up @@ -1444,11 +1487,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp,
emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SubscriptOp,
emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
emitc::VerbatimOp>(
emitc::GlobalOp, emitc::GetGlobalOp, emitc::IfOp,
emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
emitc::SubOp, emitc::SubscriptOp, emitc::UnaryMinusOp,
emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::FuncOp, func::ReturnOp>(
Expand All @@ -1461,7 +1504,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
if (failed(status))
return failure();

if (isa<emitc::LiteralOp, emitc::SubscriptOp>(op))
if (isa<emitc::LiteralOp, emitc::SubscriptOp, emitc::GetGlobalOp>(op))
return success();

if (getEmittedExpression() ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,8 @@ func.func @zero_rank() {
%0 = memref.alloca() : memref<f32>
return
}

// -----

// expected-error@+1 {{failed to legalize operation 'memref.global'}}
memref.global "nested" constant @nested_global : memref<3x7xf32>
Loading
Loading