-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][emitc] Support scalar MemRef types #92684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Extend MemRefToEmitC to map zero-ranked memrefs into scalar C variables and replace the use of emitc.{variable, assign} in SCFToEmitC with memref.{alloca, load, store}, leaving it to the memref dialect lowering to handle memory allocation and accesses.
@llvm/pr-subscribers-mlir-emitc @llvm/pr-subscribers-mlir Author: Gil Rapaport (aniragil) ChangesExtend MemRefToEmitC to map zero-ranked memrefs into scalar C variables Patch is 27.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92684.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index e6d678dc1b12b..18d7abe2d707c 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -976,7 +976,7 @@ def ConvertParallelLoopToGpu : Pass<"convert-parallel-loops-to-gpu"> {
def SCFToEmitC : Pass<"convert-scf-to-emitc"> {
let summary = "Convert SCF dialect to EmitC dialect, maintaining structured"
" control flow";
- let dependentDialects = ["emitc::EmitCDialect"];
+ let dependentDialects = ["emitc::EmitCDialect", "memref::MemRefDialect"];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index e0c421741b305..2d2c8f988fdc5 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -90,6 +91,12 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
if (isa_and_present<UnitAttr>(initialValue))
initialValue = {};
+ // If converted type is a scalar, extract the splatted initial value.
+ if (initialValue && !isa<emitc::ArrayType>(resultTy)) {
+ auto elementsAttr = llvm::cast<ElementsAttr>(initialValue);
+ initialValue = elementsAttr.getSplatValue<Attribute>();
+ }
+
rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
op, operands.getSymName(), resultTy, initialValue, externSpecifier,
staticSpecifier, operands.getConstant());
@@ -116,6 +123,19 @@ struct ConvertGetGlobal final
}
};
+template <typename T>
+static Value getMemoryAccess(Value memref, Location loc,
+ typename T::Adaptor operands,
+ ConversionPatternRewriter &rewriter) {
+ // If MemRef is an array, access location using array subscripts.
+ if (auto arrayValue = dyn_cast<TypedValue<emitc::ArrayType>>(memref))
+ return rewriter.create<emitc::SubscriptOp>(loc, arrayValue,
+ operands.getIndices());
+
+ // MemRef is a scalar, access location using variable's name.
+ return memref;
+}
+
struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
using OpConversionPattern::OpConversionPattern;
@@ -128,20 +148,13 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
}
- auto arrayValue =
- dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
- if (!arrayValue) {
- return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
- }
-
- auto subscript = rewriter.create<emitc::SubscriptOp>(
- op.getLoc(), arrayValue, operands.getIndices());
-
+ Value lvalue = getMemoryAccess<memref::LoadOp>(
+ operands.getMemref(), op.getLoc(), operands, rewriter);
auto noInit = emitc::OpaqueAttr::get(getContext(), "");
auto var =
rewriter.create<emitc::VariableOp>(op.getLoc(), resultTy, noInit);
- rewriter.create<emitc::AssignOp>(op.getLoc(), var, subscript);
+ rewriter.create<emitc::AssignOp>(op.getLoc(), var, lvalue);
rewriter.replaceOp(op, var);
return success();
}
@@ -153,15 +166,10 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
LogicalResult
matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
- auto arrayValue =
- dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
- if (!arrayValue) {
- return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
- }
+ Value lvalue = getMemoryAccess<memref::StoreOp>(
+ operands.getMemref(), op.getLoc(), operands, rewriter);
- auto subscript = rewriter.create<emitc::SubscriptOp>(
- op.getLoc(), arrayValue, operands.getIndices());
- rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
+ rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, lvalue,
operands.getValue());
return success();
}
@@ -172,13 +180,15 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
typeConverter.addConversion(
[&](MemRefType memRefType) -> std::optional<Type> {
if (!memRefType.hasStaticShape() ||
- !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0) {
+ !memRefType.getLayout().isIdentity()) {
return {};
}
Type convertedElementType =
typeConverter.convertType(memRefType.getElementType());
if (!convertedElementType)
return {};
+ if (memRefType.getRank() == 0)
+ return convertedElementType;
return emitc::ArrayType::get(memRefType.getShape(),
convertedElementType);
});
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
index 367142a520742..e4aa10f4cf208 100644
--- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
+++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
@@ -14,11 +14,13 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
@@ -63,21 +65,31 @@ static SmallVector<Value> createVariablesForResults(T op,
for (OpResult result : op.getResults()) {
Type resultType = result.getType();
- emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
- emitc::VariableOp var =
- rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
+ SmallVector<OpFoldResult> dimensions; // Zero rank for scalar memref.
+ memref::AllocaOp var =
+ rewriter.create<memref::AllocaOp>(loc, dimensions, resultType);
resultVariables.push_back(var);
}
return resultVariables;
}
-// Create a series of assign ops assigning given values to given variables at
+// Create a series of load ops reading the values of given variables at
+// the current insertion point of given rewriter.
+static SmallVector<Value> readValues(SmallVector<Value> &variables,
+ PatternRewriter &rewriter, Location loc) {
+ SmallVector<Value> values;
+ for (Value var : variables)
+ values.push_back(rewriter.create<memref::LoadOp>(loc, var).getResult());
+ return values;
+}
+
+// Create a series of store ops assigning given values to given variables at
// the current insertion point of given rewriter.
static void assignValues(ValueRange values, SmallVector<Value> &variables,
PatternRewriter &rewriter, Location loc) {
for (auto [value, var] : llvm::zip(values, variables))
- rewriter.create<emitc::AssignOp>(loc, var, value);
+ rewriter.create<memref::StoreOp>(loc, value, var);
}
static void lowerYield(SmallVector<Value> &resultVariables,
@@ -100,8 +112,6 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the loop body.
- SmallVector<Value> resultVariables =
- createVariablesForResults(forOp, rewriter);
SmallVector<Value> iterArgsVariables =
createVariablesForResults(forOp, rewriter);
@@ -115,18 +125,25 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
// Erase the auto-generated terminator for the lowered for op.
rewriter.eraseOp(loweredBody->getTerminator());
+ IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint();
+ rewriter.setInsertionPointToEnd(loweredBody);
+ SmallVector<Value> iterArgsValues =
+ readValues(iterArgsVariables, rewriter, loc);
+ rewriter.restoreInsertionPoint(ip);
+
SmallVector<Value> replacingValues;
replacingValues.push_back(loweredFor.getInductionVar());
- replacingValues.append(iterArgsVariables.begin(), iterArgsVariables.end());
+ replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
lowerYield(iterArgsVariables, rewriter,
cast<scf::YieldOp>(loweredBody->getTerminator()));
// Copy iterArgs into results after the for loop.
- assignValues(iterArgsVariables, resultVariables, rewriter, loc);
+ SmallVector<Value> resultValues =
+ readValues(iterArgsVariables, rewriter, loc);
- rewriter.replaceOp(forOp, resultVariables);
+ rewriter.replaceOp(forOp, resultValues);
return success();
}
@@ -169,6 +186,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
auto loweredIf =
rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false);
+ SmallVector<Value> resultValues = readValues(resultVariables, rewriter, loc);
Region &loweredThenRegion = loweredIf.getThenRegion();
lowerRegion(thenRegion, loweredThenRegion);
@@ -178,7 +196,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
lowerRegion(elseRegion, loweredElseRegion);
}
- rewriter.replaceOp(ifOp, resultVariables);
+ rewriter.replaceOp(ifOp, resultValues);
return success();
}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
index 89dafa7529ed5..0cb33034680d8 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
@@ -33,13 +33,5 @@ func.func @non_identity_layout() {
// -----
-func.func @zero_rank() {
- // expected-error@+1 {{failed to legalize operation 'memref.alloca'}}
- %0 = memref.alloca() : memref<f32>
- return
-}
-
-// -----
-
// expected-error@+1 {{failed to legalize operation 'memref.global'}}
memref.global "nested" constant @nested_global : memref<3x7xf32>
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index bc40ef48268eb..aafaf9810711b 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -4,11 +4,15 @@
// CHECK-SAME: %[[v:.*]]: f32, %[[i:.*]]: index, %[[j:.*]]: index
func.func @memref_store(%v : f32, %i: index, %j: index) {
// CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
+ // CHECK: %[[SCALAR:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
%0 = memref.alloca() : memref<4x8xf32>
+ %s = memref.alloca() : memref<f32>
// CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, index, index) -> f32
// CHECK: emitc.assign %[[v]] : f32 to %[[SUBSCRIPT:.*]] : f32
memref.store %v, %0[%i, %j] : memref<4x8xf32>
+ // CHECK: emitc.assign %[[v]] : f32 to %[[SCALAR]] : f32
+ memref.store %v, %s[] : memref<f32>
return
}
@@ -16,16 +20,21 @@ func.func @memref_store(%v : f32, %i: index, %j: index) {
// CHECK-LABEL: memref_load
// CHECK-SAME: %[[i:.*]]: index, %[[j:.*]]: index
-func.func @memref_load(%i: index, %j: index) -> f32 {
+func.func @memref_load(%i: index, %j: index) -> (f32, f32) {
// CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
%0 = memref.alloca() : memref<4x8xf32>
+ // CHECK: %[[SCALAR:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+ %s = memref.alloca() : memref<f32>
// CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, index, index) -> f32
// CHECK: %[[VAR:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
// CHECK: emitc.assign %[[LOAD]] : f32 to %[[VAR]] : f32
%1 = memref.load %0[%i, %j] : memref<4x8xf32>
- // CHECK: return %[[VAR]] : f32
- return %1 : f32
+ // CHECK: %[[VAR_S:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+ // CHECK: emitc.assign %[[SCALAR]] : f32 to %[[VAR_S]] : f32
+ %sv = memref.load %s[] : memref<f32>
+ // CHECK: return %[[VAR]], %[[VAR_S]] : f32, f32
+ return %1, %sv : f32, f32
}
// -----
@@ -38,10 +47,13 @@ module @globals {
// CHECK: emitc.global extern @public_global : !emitc.array<3x7xf32>
memref.global @uninitialized_global : memref<3x7xf32> = uninitialized
// CHECK: emitc.global extern @uninitialized_global : !emitc.array<3x7xf32>
+ memref.global "private" constant @internal_global_scalar : memref<f32> = dense<4.0>
+ // CHECK: emitc.global static const @internal_global_scalar : f32 = 4.000000e+00
func.func @use_global() {
// CHECK: emitc.get_global @public_global : !emitc.array<3x7xf32>
%0 = memref.get_global @public_global : memref<3x7xf32>
+ %1 = memref.get_global @internal_global_scalar : memref<f32>
return
}
}
diff --git a/mlir/test/Conversion/SCFToEmitC/for.mlir b/mlir/test/Conversion/SCFToEmitC/for.mlir
index 7f90310af2189..d4e211b7a5950 100644
--- a/mlir/test/Conversion/SCFToEmitC/for.mlir
+++ b/mlir/test/Conversion/SCFToEmitC/for.mlir
@@ -47,20 +47,20 @@ func.func @for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> (f32, f32)
// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) -> (f32, f32) {
// CHECK-NEXT: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[VAL_4:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-NEXT: %[[VAL_5:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
-// CHECK-NEXT: %[[VAL_6:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
-// CHECK-NEXT: %[[VAL_7:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
-// CHECK-NEXT: %[[VAL_8:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
-// CHECK-NEXT: emitc.assign %[[VAL_3]] : f32 to %[[VAL_7]] : f32
-// CHECK-NEXT: emitc.assign %[[VAL_4]] : f32 to %[[VAL_8]] : f32
-// CHECK-NEXT: emitc.for %[[VAL_9:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
-// CHECK-NEXT: %[[VAL_10:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32
-// CHECK-NEXT: emitc.assign %[[VAL_10]] : f32 to %[[VAL_7]] : f32
-// CHECK-NEXT: emitc.assign %[[VAL_10]] : f32 to %[[VAL_8]] : f32
+// CHECK-NEXT: %[[VAL_5:.*]] = memref.alloca() : memref<f32>
+// CHECK-NEXT: %[[VAL_6:.*]] = memref.alloca() : memref<f32>
+// CHECK-NEXT: memref.store %[[VAL_3]], %[[VAL_5]][] : memref<f32>
+// CHECK-NEXT: memref.store %[[VAL_4]], %[[VAL_6]][] : memref<f32>
+// CHECK-NEXT: emitc.for %[[VAL_7:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT: %[[VAL_8:.*]] = memref.load %[[VAL_5]][] : memref<f32>
+// CHECK-NEXT: %[[VAL_9:.*]] = memref.load %[[VAL_6]][] : memref<f32>
+// CHECK-NEXT: %[[VAL_10:.*]] = arith.addf %[[VAL_8]], %[[VAL_9]] : f32
+// CHECK-NEXT: memref.store %[[VAL_10]], %[[VAL_5]][] : memref<f32>
+// CHECK-NEXT: memref.store %[[VAL_10]], %[[VAL_6]][] : memref<f32>
// CHECK-NEXT: }
-// CHECK-NEXT: emitc.assign %[[VAL_7]] : f32 to %[[VAL_5]] : f32
-// CHECK-NEXT: emitc.assign %[[VAL_8]] : f32 to %[[VAL_6]] : f32
-// CHECK-NEXT: return %[[VAL_5]], %[[VAL_6]] : f32, f32
+// CHECK-NEXT: %[[VAL_11:.*]] = memref.load %[[VAL_5]][] : memref<f32>
+// CHECK-NEXT: %[[VAL_12:.*]] = memref.load %[[VAL_6]][] : memref<f32>
+// CHECK-NEXT: return %[[VAL_11]], %[[VAL_12]] : f32, f32
// CHECK-NEXT: }
func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 {
@@ -77,20 +77,20 @@ func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32
// CHECK-LABEL: func.func @nested_for_yield(
// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) -> f32 {
// CHECK-NEXT: %[[VAL_3:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-NEXT: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
-// CHECK-NEXT: %[[VAL_5:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
-// CHECK-NEXT: emitc.assign %[[VAL_3]] : f32 to %[[VAL_5]] : f32
-// CHECK-NEXT: emitc.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
-// CHECK-NEXT: %[[VAL_7:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
-// CHECK-NEXT: %[[VAL_8:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
-// CHECK-NEXT: emitc.assign %[[VAL_5]] : f32 to %[[VAL_8]] : f32
-// CHECK-NEXT: emitc.for %[[VAL_9:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
-// CHECK-NEXT: %[[VAL_10:.*]] = arith.addf %[[VAL_8]], %[[VAL_8]] : f32
-// CHECK-NEXT: emitc.assign %[[VAL_10]] : f32 to %[[VAL_8]] : f32
+// CHECK-NEXT: %[[VAL_4:.*]] = memref.alloca() : memref<f32>
+// CHECK-NEXT: memref.store %[[VAL_3]], %[[VAL_4]][] : memref<f32>
+// CHECK-NEXT: emitc.for %[[VAL_5:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT: %[[VAL_6:.*]] = memref.load %[[VAL_4]][] : memref<f32>
+// CHECK-NEXT: %[[VAL_7:.*]] = memref.alloca() : memref<f32>
+// CHECK-NEXT: memref.store %[[VAL_6]], %[[VAL_7]][] : memref<f32>
+// CHECK-NEXT: emitc.for %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT: %[[VAL_9:.*]] = memref.load %[[VAL_7]][] : memref<f32>
+// CHECK-NEXT: %[[VAL_10:.*]] = arith.addf %[[VAL_9]], %[[VAL_9]] : f32
+// CHECK-NEXT: memref.store %[[VAL_10]], %[[VAL_7]][] : memref<f32>
// CHECK-NEXT: }
-// CHECK-NEXT: emitc.assign %[[VAL_8]] : f32 to %[[VAL_7]] : f32
-// CHECK-NEXT: emitc.assign %[[VAL_7]] : f32 to %[[VAL_5]] : f32
+// CHECK-NEXT: %[[VAL_11:.*]] = memref.load %[[VAL_7]][] : memref<f32>
+// CHECK-NEXT: memref.store %[[VAL_11]], %[[VAL_4]][] : memref<f32>
// CHECK-NEXT: }
-// CHECK-NEXT: emitc.assign %[[VAL_5]] : f32 to %[[VAL_4]] : f32
-// CHECK-NEXT: return %[[VAL_4]] : f32
+// CHECK-NEXT: %[[VAL_12:.*]] = memref.load %[[VAL_4]][] : memref<f32>
+// CHECK-NEXT: return %[[VAL_12]] : f32
// CHECK-NEXT: }
diff --git a/mlir/test/Conversion/SCFToEmitC/if.mlir b/mlir/test/Conversion/SCFToEmitC/if.mlir
index afc9abc761eb4..0753d9eda3283 100644
--- a/mlir/test/Conversion/SCFToEmitC/if.mlir
+++ b/mlir/test/Conversion/SCFToEmitC/if.mlir
@@ -53,18 +53,20 @@ func.func @test_if_yield(%arg0: i1, %arg1: f32) {
// CHECK-SAME: %[[VAL_0:.*]]: i1,
// CHECK-SAME: %[[VAL_1:.*]]: f32) {
// CHECK-NEXT: %[[VAL_2:.*]] = arith.constant 0 : i8
-// CHECK-NEXT: %[[VAL_3:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
-// CHECK-NEXT: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f64
+// CHECK-NEXT: %[[VAL_3:.*]] = memref.alloca() : memref<i32>
+// CHECK-NEXT: %[[VAL_4:.*]] = memref.alloca() : memref<f64>
// CHECK-NEXT: emitc.if %[[VAL_0]] {
// CHECK-NEXT: %[[VAL_5:.*]] = emitc.call_opaque "func_true_1"(%[[VAL_1]]) : (f32) -> i32
// CHECK-NEXT: %[[VAL_6:.*]] = emitc.call_opaque "func_true_2"(%[[VAL_1]]) : (f32) -> f64
-// CHECK-NEXT: emitc.assign %[[VAL_5]] : i32 to %[[VAL_3]] : i32
-// CHECK-NEXT: emitc.assign %[[VAL_6]] : f64 to %[[VAL_4]] : f64
+// CHECK-NEXT: memref.store %[[VAL_5]], %[[VAL_3]][] : memref<i32>
+// CHECK-NEXT: memref.store %[[VAL_6]], %[[VAL_4]][] : memref<f64>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[VAL_7:.*]] = emitc.call_opaque "func_false_1"(%[[VAL_1]]) : (f32) -> i32
// CHECK-NEXT: %[[VAL_8:.*]] = emitc.call_opaque "func_false_2"(%[[VAL_1]]) : (f32) -> f64
-// CHECK-NEXT: emitc.assign %[[VAL_7]] : i32 to %[[VAL_3]] : i32
-// CHECK-NEXT: emitc.assign %[[VAL_8]] : f64 to %[[VAL_4]] : f64
+// CHECK-NEXT: memref.store %[[VAL_7]], %[[VAL_3]][] : memref<i32>
+// CHECK-NEXT: ...
[truncated]
|
This patch hopefully simplifies #91475 by consolidating scf's lowering memory semantics into the memref dialect. |
I'm not convinced that this is the canonical way to handle rank-0 memrefs. For example, when you have a rank-0 memref as function argument and then the function body writes into it, you cannot turn the argument into a scalar. There are alternative ways to lower this (e.g. promoting rank-0 memrefs to rank-1, or turning rank-0 memrefs into scalars to pointers), so this doesn't seem to be the canonical choice. Converting rank-0 memrefs to scalars sounds like a transform that is not specific to emitc, so I'm more in favor of having this as a separate transform (in the memref dialect) instead if merging it into the memref-to-emitc conversion pass. |
Good catch! Function arguments completely slipped my mind. We can perhaps temporarily exclude their support using the signature conversion mechanisms, but better have a complete design first. For now we can use rank-1 memrefs in the scf lowering pass. While passing variables by reference in C requires conversion to pointer/array, C++ has native support for that. Both representations differ from the way C/C++ programmers define and use scalars within functions. It would be good to emit natural C/C++ code where possible. There are also other cases where different C-variants have distinct natural constructs, e.g.
Modeling variables consistently using
Doing Mem2Reg/SROA within the memref dialect where possible is indeed desired and not limited to rank-0 memrefs. Regardless, rank-0 memrefs model the concept of scalar variables which also lowers naturally and directly to llvm IR, e.g. the following memref.global @scalar : memref<f32> = dense<1.0>
memref.global @array : memref<1xf32> = dense<1.0> gets lowered to llvm.mlir.global external @scalar(1.000000e+00 : f32) {addr_space = 0 : i32} : f32
llvm.mlir.global external @array(dense<1.000000e+00> : tensor<1xf32>) {addr_space = 0 : i32} : !llvm.array<1 x f32> which translates to @scalar = global float 1.000000e+00
@array = global [1 x float] [float 1.000000e+00] |
SCF if/for lowering using 1-D memrefs is up for review as #93371. |
Extend MemRefToEmitC to map zero-ranked memrefs into scalar C variables
and replace the use of emitc.{variable, assign} in SCFToEmitC with
memref.{alloca, load, store}, leaving it to the memref dialect lowering
to handle memory allocation and accesses.