diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index d746222ff37a4..ee5fc0b09a161 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -1016,6 +1016,75 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> { let hasVerifier = 1; } +def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> { + let summary = "A global variable"; + let description = [{ + The `emitc.global` operation declares or defines a named global variable. + The backing memory for the variable is allocated statically and is + described by the type of the variable. + Optionally, and `initial_value` can be provided. + Internal linkage can be specified using the `staticSpecifier` unit attribute + and external linkage can be specified using the `externSpecifier` unit attribute. + Note that the default linkage without those two keywords depends on whether + the target is C or C++ and whether the global variable is `const`. + The global variable can also be marked constant using the `constSpecifier` + unit attribute. Writing to such constant global variables is + undefined. + + The global variable can be accessed by using the `emitc.get_global` to + retrieve the value for the global variable. + + Example: + + ```mlir + // Global variable with an initial value. + emitc.global @x : emitc.array<2xf32> = dense<0.0, 2.0> + // External global variable + emitc.global extern @x : emitc.array<2xf32> + // Constant global variable with internal linkage + emitc.global static const @x : i32 = 0 + ``` + }]; + + let arguments = (ins SymbolNameAttr:$sym_name, + TypeAttr:$type, + OptionalAttr:$initial_value, + UnitAttr:$externSpecifier, + UnitAttr:$staticSpecifier, + UnitAttr:$constSpecifier); + + let assemblyFormat = [{ + (`extern` $externSpecifier^)? + (`static` $staticSpecifier^)? + (`const` $constSpecifier^)? + $sym_name + `:` custom($type, $initial_value) + attr-dict + }]; + + let hasVerifier = 1; +} + +def EmitC_GetGlobalOp : EmitC_Op<"get_global", + [Pure, DeclareOpInterfaceMethods]> { + let summary = "Obtain access to a global variable"; + let description = [{ + The `emitc.get_global` operation retrieves the lvalue of a + named global variable. If the global variable is marked constant, assigning + to that lvalue is undefined. + + Example: + + ```mlir + %x = emitc.get_global @foo : !emitc.array<2xf32> + ``` + }]; + + let arguments = (ins FlatSymbolRefAttr:$name); + let results = (outs AnyType:$result); + let assemblyFormat = "$name `:` type($result) attr-dict"; +} + def EmitC_VerbatimOp : EmitC_Op<"verbatim"> { let summary = "Verbatim operation"; let description = [{ diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 0e3b646921264..d3e7f233c0841 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -50,6 +50,68 @@ struct ConvertAlloca final : public OpConversionPattern { } }; +struct ConvertGlobal final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::GlobalOp op, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + + if (!op.getType().hasStaticShape()) { + return rewriter.notifyMatchFailure( + op.getLoc(), "cannot transform global with dynamic shape"); + } + + if (op.getAlignment().value_or(1) > 1) { + // TODO: Extend GlobalOp to specify alignment via the `alignas` specifier. + return rewriter.notifyMatchFailure( + op.getLoc(), "global variable with alignment requirement is " + "currently not supported"); + } + auto resultTy = getTypeConverter()->convertType(op.getType()); + if (!resultTy) { + return rewriter.notifyMatchFailure(op.getLoc(), + "cannot convert result type"); + } + + SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(op); + if (visibility != SymbolTable::Visibility::Public && + visibility != SymbolTable::Visibility::Private) { + return rewriter.notifyMatchFailure( + op.getLoc(), + "only public and private visibility is currently supported"); + } + // We are explicit in specifier the linkage because the default linkage + // for constants is different in C and C++. + bool staticSpecifier = visibility == SymbolTable::Visibility::Private; + bool externSpecifier = !staticSpecifier; + + rewriter.replaceOpWithNewOp( + op, operands.getSymName(), resultTy, operands.getInitialValueAttr(), + externSpecifier, staticSpecifier, operands.getConstant()); + return success(); + } +}; + +struct ConvertGetGlobal final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + + auto resultTy = getTypeConverter()->convertType(op.getType()); + if (!resultTy) { + return rewriter.notifyMatchFailure(op.getLoc(), + "cannot convert result type"); + } + rewriter.replaceOpWithNewOp(op, resultTy, + operands.getNameAttr()); + return success(); + } +}; + struct ConvertLoad final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -109,6 +171,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, TypeConverter &converter) { - patterns.add(converter, - patterns.getContext()); + patterns.add(converter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 09ac30b1bf807..41e290397e3cf 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -790,13 +790,6 @@ LogicalResult emitc::SubscriptOp::verify() { return success(); } -//===----------------------------------------------------------------------===// -// TableGen'd op method definitions -//===----------------------------------------------------------------------===// - -#define GET_OP_CLASSES -#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc" - //===----------------------------------------------------------------------===// // EmitC Enums //===----------------------------------------------------------------------===// @@ -896,3 +889,113 @@ LogicalResult mlir::emitc::OpaqueType::verify( } return success(); } + +//===----------------------------------------------------------------------===// +// GlobalOp +//===----------------------------------------------------------------------===// +static void printEmitCGlobalOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, + TypeAttr type, + Attribute initialValue) { + p << type; + if (initialValue) { + p << " = "; + p.printAttributeWithoutType(initialValue); + } +} + +static Type getInitializerTypeForGlobal(Type type) { + if (auto array = llvm::dyn_cast(type)) + return RankedTensorType::get(array.getShape(), array.getElementType()); + return type; +} + +static ParseResult +parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, + Attribute &initialValue) { + Type type; + if (parser.parseType(type)) + return failure(); + + typeAttr = TypeAttr::get(type); + + if (parser.parseOptionalEqual()) + return success(); + + if (parser.parseAttribute(initialValue, getInitializerTypeForGlobal(type))) + return failure(); + + if (!llvm::isa(initialValue)) + return parser.emitError(parser.getNameLoc()) + << "initial value should be a unit, integer, float or elements " + "attribute"; + return success(); +} + +LogicalResult GlobalOp::verify() { + // Verify that the initial value, if present, is either a unit attribute or + // an elements attribute. + if (getInitialValue().has_value()) { + Attribute initValue = getInitialValue().value(); + // Check that the type of the initial value is compatible with the type of + // the global variable. + if (auto elementsAttr = llvm::dyn_cast(initValue)) { + auto arrayType = llvm::dyn_cast(getType()); + if (!arrayType) + return emitOpError("expected array type, but got ") << getType(); + + Type initType = elementsAttr.getType(); + Type tensorType = getInitializerTypeForGlobal(getType()); + if (initType != tensorType) { + return emitOpError("initial value expected to be of type ") + << getType() << ", but was of type " << initType; + } + } else if (auto intAttr = dyn_cast(initValue)) { + if (intAttr.getType() != getType()) { + return emitOpError("initial value expected to be of type ") + << getType() << ", but was of type " << intAttr.getType(); + } + } else if (auto floatAttr = dyn_cast(initValue)) { + if (floatAttr.getType() != getType()) { + return emitOpError("initial value expected to be of type ") + << getType() << ", but was of type " << floatAttr.getType(); + } + } else { + return emitOpError( + "initial value should be a unit, integer, float or elements " + "attribute, but got ") + << initValue; + } + } + if (getStaticSpecifier() && getExternSpecifier()) { + return emitOpError("cannot have both static and extern specifiers"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// GetGlobalOp +//===----------------------------------------------------------------------===// + +LogicalResult +GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Verify that the type matches the type of the global variable. + auto global = + symbolTable.lookupNearestSymbolFrom(*this, getNameAttr()); + if (!global) + return emitOpError("'") + << getName() << "' does not reference a valid emitc.global"; + + Type resultType = getResult().getType(); + if (global.getType() != resultType) + return emitOpError("result type ") + << resultType << " does not match type " << global.getType() + << " of the global @" << getName(); + return success(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc" diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 95c7af2f07be4..820bb65dff0ac 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -154,6 +154,9 @@ struct CppEmitter { /// any result type could not be converted. LogicalResult emitAssignPrefix(Operation &op); + /// Emits a global variable declaration or definition. + LogicalResult emitGlobalVariable(GlobalOp op); + /// Emits a label for the block. LogicalResult emitLabel(Block &block); @@ -344,6 +347,12 @@ static LogicalResult printOperation(CppEmitter &emitter, return printConstantOp(emitter, operation, value); } +static LogicalResult printOperation(CppEmitter &emitter, + emitc::GlobalOp globalOp) { + + return emitter.emitGlobalVariable(globalOp); +} + static LogicalResult printOperation(CppEmitter &emitter, emitc::AssignOp assignOp) { OpResult result = assignOp.getVar().getDefiningOp()->getResult(0); @@ -354,6 +363,13 @@ static LogicalResult printOperation(CppEmitter &emitter, return emitter.emitOperand(assignOp.getValue()); } +static LogicalResult printOperation(CppEmitter &emitter, + emitc::GetGlobalOp op) { + // Add name to cache so that `hasValueInScope` works. + emitter.getOrCreateName(op.getResult()); + return success(); +} + static LogicalResult printOperation(CppEmitter &emitter, emitc::SubscriptOp subscriptOp) { // Add name to cache so that `hasValueInScope` works. @@ -1120,6 +1136,9 @@ StringRef CppEmitter::getOrCreateName(Value val) { if (auto subscript = dyn_cast_if_present(val.getDefiningOp())) { valueMapper.insert(val, getSubscriptName(subscript)); + } else if (auto getGlobal = dyn_cast_if_present( + val.getDefiningOp())) { + valueMapper.insert(val, getGlobal.getName().str()); } else { valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top())); } @@ -1385,6 +1404,30 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, return success(); } +LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) { + if (op.getExternSpecifier()) + os << "extern "; + else if (op.getStaticSpecifier()) + os << "static "; + if (op.getConstSpecifier()) + os << "const "; + + if (failed(emitVariableDeclaration(op->getLoc(), op.getType(), + op.getSymName()))) { + return failure(); + } + + std::optional initialValue = op.getInitialValue(); + if (initialValue && !isa(*initialValue)) { + os << " = "; + if (failed(emitAttribute(op->getLoc(), *initialValue))) + return failure(); + } + + os << ";"; + return success(); +} + LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { // If op is being emitted as part of an expression, bail out. if (getEmittedExpression()) @@ -1445,11 +1488,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp, emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, - emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp, - emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp, - emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SubscriptOp, - emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp, - emitc::VerbatimOp>( + emitc::GlobalOp, emitc::GetGlobalOp, emitc::IfOp, + emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp, + emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp, + emitc::SubOp, emitc::SubscriptOp, emitc::UnaryMinusOp, + emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. .Case( @@ -1462,7 +1505,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { if (failed(status)) return failure(); - if (isa(op)) + if (isa(op)) return success(); if (getEmittedExpression() || diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir index 390190d341e5a..89dafa7529ed5 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir @@ -38,3 +38,8 @@ func.func @zero_rank() { %0 = memref.alloca() : memref return } + +// ----- + +// expected-error@+1 {{failed to legalize operation 'memref.global'}} +memref.global "nested" constant @nested_global : memref<3x7xf32> diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir index 9793b2d6d7832..54129f4f6cbc8 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir @@ -11,6 +11,7 @@ func.func @memref_store(%v : f32, %i: index, %j: index) { memref.store %v, %0[%i, %j] : memref<4x8xf32> return } + // ----- // CHECK-LABEL: memref_load @@ -26,3 +27,19 @@ func.func @memref_load(%i: index, %j: index) -> f32 { // CHECK: return %[[VAR]] : f32 return %1 : f32 } + +// ----- + +// CHECK-LABEL: globals +module @globals { + memref.global "private" constant @internal_global : memref<3x7xf32> = dense<4.0> + // CHECK: emitc.global static const @internal_global : !emitc.array<3x7xf32> = dense<4.000000e+00> + memref.global @public_global : memref<3x7xf32> + // CHECK: emitc.global extern @public_global : !emitc.array<3x7xf32> + + func.func @use_global() { + // CHECK: emitc.get_global @public_global : !emitc.array<3x7xf32> + %0 = memref.get_global @public_global : memref<3x7xf32> + return + } +} diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index 22423cf61b555..82fa459a5c927 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -395,3 +395,18 @@ func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: %0 = emitc.subscript %arg0[%arg2] : <4x8xf32>, index return } + +// ----- + +// expected-error @+1 {{'emitc.global' op cannot have both static and extern specifiers}} +emitc.global extern static @uninit : i32 + +// ----- + +emitc.global @myglobal : !emitc.array<2xf32> + +func.func @use_global() { + // expected-error @+1 {{'emitc.get_global' op result type 'f32' does not match type '!emitc.array<2xf32>' of the global @myglobal}} + %0 = emitc.get_global @myglobal : f32 + return +} diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index 5f00a295ed740..3c987937f1721 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -224,3 +224,17 @@ emitc.verbatim "#endif // __cplusplus" emitc.verbatim "typedef int32_t i32;" emitc.verbatim "typedef float f32;" + + +emitc.global @uninit : i32 +emitc.global @myglobal_int : i32 = 4 +emitc.global extern @external_linkage : i32 +emitc.global static @internal_linkage : i32 +emitc.global @myglobal : !emitc.array<2xf32> = dense<4.000000e+00> +emitc.global const @myconstant : !emitc.array<2xi16> = dense<2> + +func.func @use_global(%i: index) -> f32 { + %0 = emitc.get_global @myglobal : !emitc.array<2xf32> + %1 = emitc.subscript %0[%i] : <2xf32>, index + return %1 : f32 +} diff --git a/mlir/test/Target/Cpp/global.mlir b/mlir/test/Target/Cpp/global.mlir new file mode 100644 index 0000000000000..730d5e0337336 --- /dev/null +++ b/mlir/test/Target/Cpp/global.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s +// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s + +emitc.global extern @decl : i8 +// CHECK: extern int8_t decl; + +emitc.global @uninit : i32 +// CHECK: int32_t uninit; + +emitc.global @myglobal_int : i32 = 4 +// CHECK: int32_t myglobal_int = 4; + +emitc.global @myglobal : !emitc.array<2xf32> = dense<4.000000e+00> +// CHECK: float myglobal[2] = {4.000000000e+00f, 4.000000000e+00f}; + +emitc.global const @myconstant : !emitc.array<2xi16> = dense<2> +// CHECK: const int16_t myconstant[2] = {2, 2}; + +emitc.global extern const @extern_constant : !emitc.array<2xi16> +// CHECK: extern const int16_t extern_constant[2]; + +emitc.global static @static_var : f32 +// CHECK: static float static_var; + +emitc.global static @static_const : f32 = 3.0 +// CHECK: static float static_const = 3.000000000e+00f; + +func.func @use_global(%i: index) -> f32 { + %0 = emitc.get_global @myglobal : !emitc.array<2xf32> + %1 = emitc.subscript %0[%i] : <2xf32>, index + return %1 : f32 + // CHECK-LABEL: use_global + // CHECK-SAME: (size_t [[V1:.*]]) + // CHECK: return myglobal[[[V1]]]; +}