Skip to content

Commit b95dfa3

Browse files
[mlir][spirv] Fix LowerABIAttributesPass to generate EntryPoints for SPV1.4 (#118994)
- Extend the SPIRV::LowerABIAttributesPass to detect when the target env is using SPIR-V ver >= 1.4, and in this case add all the functions' interface storage variables to the spirv.EntryPoint calls, as required by the spec of OpEntryPoint: "_Before version 1.4, the interface’s storage classes are limited to the Input and Output storage classes. Starting with version 1.4, the interface’s storage classes are all storage classes used in declaring all global variables referenced by the entry point’s call tree_." - Fix: generate the replacement ops (spirv.AddressOf and .AccessChain) in the order in which the associated variable appears in the function signature Signed-off-by: Fabrizio Indirli <[email protected]>
1 parent 3dfc1d9 commit b95dfa3

File tree

3 files changed

+65
-32
lines changed

3 files changed

+65
-32
lines changed

mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ getInterfaceVariables(spirv::FuncOp funcOp,
8585
if (!module) {
8686
return failure();
8787
}
88+
spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(funcOp);
89+
spirv::TargetEnv targetEnv(targetEnvAttr);
90+
8891
SetVector<Operation *> interfaceVarSet;
8992

9093
// TODO: This should in reality traverse the entry function
@@ -93,18 +96,18 @@ getInterfaceVariables(spirv::FuncOp funcOp,
9396
funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
9497
auto var =
9598
module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.getVariable());
96-
// TODO: Per SPIR-V spec: "Before version 1.4, the interfaces
99+
// Per SPIR-V spec: "Before version 1.4, the interface's
97100
// storage classes are limited to the Input and Output storage classes.
98-
// Starting with version 1.4, the interfaces storage classes are all
101+
// Starting with version 1.4, the interface's storage classes are all
99102
// storage classes used in declaring all global variables referenced by the
100-
// entry point’s call tree." We should consider the target environment here.
101-
switch (cast<spirv::PointerType>(var.getType()).getStorageClass()) {
102-
case spirv::StorageClass::Input:
103-
case spirv::StorageClass::Output:
103+
// entry point’s call tree."
104+
const spirv::StorageClass storageClass =
105+
cast<spirv::PointerType>(var.getType()).getStorageClass();
106+
if ((targetEnvAttr && targetEnv.getVersion() >= spirv::Version::V_1_4) ||
107+
(llvm::is_contained(
108+
{spirv::StorageClass::Input, spirv::StorageClass::Output},
109+
storageClass))) {
104110
interfaceVarSet.insert(var.getOperation());
105-
break;
106-
default:
107-
break;
108111
}
109112
});
110113
for (auto &var : interfaceVarSet) {
@@ -124,6 +127,9 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
124127
return failure();
125128
}
126129

130+
spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(funcOp);
131+
spirv::TargetEnv targetEnv(targetEnvAttr);
132+
127133
OpBuilder::InsertionGuard moduleInsertionGuard(builder);
128134
auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
129135
builder.setInsertionPointToEnd(spirvModule.getBody());
@@ -135,8 +141,6 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
135141
return failure();
136142
}
137143

138-
spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(funcOp);
139-
spirv::TargetEnv targetEnv(targetEnvAttr);
140144
FailureOr<spirv::ExecutionModel> executionModel =
141145
spirv::getExecutionModel(targetEnvAttr);
142146
if (failed(executionModel))
@@ -234,6 +238,10 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
234238
auto indexType = typeConverter.getIndexType();
235239

236240
auto attrName = spirv::getInterfaceVarABIAttrName();
241+
242+
OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
243+
rewriter.setInsertionPointToStart(&funcOp.front());
244+
237245
for (const auto &argType :
238246
llvm::enumerate(funcOp.getFunctionType().getInputs())) {
239247
auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
@@ -250,8 +258,6 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
250258
if (!var)
251259
return failure();
252260

253-
OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
254-
rewriter.setInsertionPointToStart(&funcOp.front());
255261
// Insert spirv::AddressOf and spirv::AccessChain operations.
256262
Value replacement =
257263
rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);

mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ spirv.module Logical GLSL450 {
1919
%arg1: !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32>)>, StorageBuffer>
2020
{spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) "None"
2121
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1], subgroup_size = 64>} {
22-
// CHECK: [[ARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
2322
// CHECK: [[ADDRESSARG0:%.*]] = spirv.mlir.addressof [[VAR0]]
2423
// CHECK: [[CONST0:%.*]] = spirv.Constant 0 : i32
2524
// CHECK: [[ARG0PTR:%.*]] = spirv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
2625
// CHECK: [[ARG0:%.*]] = spirv.Load "StorageBuffer" [[ARG0PTR]]
26+
// CHECK: [[ARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
2727
// CHECK: spirv.Return
2828
spirv.Return
2929
}
@@ -39,3 +39,30 @@ module {
3939
// expected-error@+1 {{'spirv.module' op missing SPIR-V target env attribute}}
4040
spirv.module Logical GLSL450 {}
4141
} // end module
42+
43+
// -----
44+
45+
// CHECK-LABEL: spirv.module
46+
// Test case with SPIRV version 1.4: all the interface's storage variables are passed to OpEntryPoint
47+
spirv.module Logical GLSL450 attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} {
48+
// CHECK-DAG: spirv.GlobalVariable [[VAR0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>
49+
// CHECK-DAG: spirv.GlobalVariable [[VAR1:@.*]] bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32, stride=4> [0])>, StorageBuffer>
50+
// CHECK: spirv.func [[FN:@.*]]()
51+
// CHECK-SAME: #spirv.entry_point_abi<subgroup_size = 64>
52+
spirv.func @kernel(
53+
%arg0: f32
54+
{spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0), StorageBuffer>},
55+
%arg1: !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32>)>, StorageBuffer>
56+
{spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) "None"
57+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1], subgroup_size = 64>} {
58+
// CHECK: [[ADDRESSARG0:%.*]] = spirv.mlir.addressof [[VAR0]]
59+
// CHECK: [[CONST0:%.*]] = spirv.Constant 0 : i32
60+
// CHECK: [[ARG0PTR:%.*]] = spirv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
61+
// CHECK: [[ARG0:%.*]] = spirv.Load "StorageBuffer" [[ARG0PTR]]
62+
// CHECK: [[ARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
63+
// CHECK: spirv.Return
64+
spirv.Return
65+
}
66+
// CHECK: spirv.EntryPoint "GLCompute" [[FN]], [[VAR0]], [[VAR1]]
67+
// CHECK: spirv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1
68+
} // end spirv.module

mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,28 +39,28 @@ spirv.module Logical GLSL450 {
3939
%arg6: i32
4040
{spirv.interface_var_abi = #spirv.interface_var_abi<(0, 6), StorageBuffer>}) "None"
4141
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1]>} {
42-
// CHECK: [[ADDRESSARG6:%.*]] = spirv.mlir.addressof [[VAR6]]
43-
// CHECK: [[CONST6:%.*]] = spirv.Constant 0 : i32
44-
// CHECK: [[ARG6PTR:%.*]] = spirv.AccessChain [[ADDRESSARG6]]{{\[}}[[CONST6]]
45-
// CHECK: {{%.*}} = spirv.Load "StorageBuffer" [[ARG6PTR]]
46-
// CHECK: [[ADDRESSARG5:%.*]] = spirv.mlir.addressof [[VAR5]]
47-
// CHECK: [[CONST5:%.*]] = spirv.Constant 0 : i32
48-
// CHECK: [[ARG5PTR:%.*]] = spirv.AccessChain [[ADDRESSARG5]]{{\[}}[[CONST5]]
49-
// CHECK: {{%.*}} = spirv.Load "StorageBuffer" [[ARG5PTR]]
50-
// CHECK: [[ADDRESSARG4:%.*]] = spirv.mlir.addressof [[VAR4]]
51-
// CHECK: [[CONST4:%.*]] = spirv.Constant 0 : i32
52-
// CHECK: [[ARG4PTR:%.*]] = spirv.AccessChain [[ADDRESSARG4]]{{\[}}[[CONST4]]
53-
// CHECK: [[ARG4:%.*]] = spirv.Load "StorageBuffer" [[ARG4PTR]]
42+
// CHECK: [[ADDRESSARG0:%.*]] = spirv.mlir.addressof [[VAR0]]
43+
// CHECK: [[ARG0:%.*]] = spirv.Bitcast [[ADDRESSARG0]]
44+
// CHECK: [[ADDRESSARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
45+
// CHECK: [[ARG1:%.*]] = spirv.Bitcast [[ADDRESSARG1]]
46+
// CHECK: [[ADDRESSARG2:%.*]] = spirv.mlir.addressof [[VAR2]]
47+
// CHECK: [[ARG2:%.*]] = spirv.Bitcast [[ADDRESSARG2]]
5448
// CHECK: [[ADDRESSARG3:%.*]] = spirv.mlir.addressof [[VAR3]]
5549
// CHECK: [[CONST3:%.*]] = spirv.Constant 0 : i32
5650
// CHECK: [[ARG3PTR:%.*]] = spirv.AccessChain [[ADDRESSARG3]]{{\[}}[[CONST3]]
5751
// CHECK: [[ARG3:%.*]] = spirv.Load "StorageBuffer" [[ARG3PTR]]
58-
// CHECK: [[ADDRESSARG2:%.*]] = spirv.mlir.addressof [[VAR2]]
59-
// CHECK: [[ARG2:%.*]] = spirv.Bitcast [[ADDRESSARG2]]
60-
// CHECK: [[ADDRESSARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
61-
// CHECK: [[ARG1:%.*]] = spirv.Bitcast [[ADDRESSARG1]]
62-
// CHECK: [[ADDRESSARG0:%.*]] = spirv.mlir.addressof [[VAR0]]
63-
// CHECK: [[ARG0:%.*]] = spirv.Bitcast [[ADDRESSARG0]]
52+
// CHECK: [[ADDRESSARG4:%.*]] = spirv.mlir.addressof [[VAR4]]
53+
// CHECK: [[CONST4:%.*]] = spirv.Constant 0 : i32
54+
// CHECK: [[ARG4PTR:%.*]] = spirv.AccessChain [[ADDRESSARG4]]{{\[}}[[CONST4]]
55+
// CHECK: [[ARG4:%.*]] = spirv.Load "StorageBuffer" [[ARG4PTR]]
56+
// CHECK: [[ADDRESSARG5:%.*]] = spirv.mlir.addressof [[VAR5]]
57+
// CHECK: [[CONST5:%.*]] = spirv.Constant 0 : i32
58+
// CHECK: [[ARG5PTR:%.*]] = spirv.AccessChain [[ADDRESSARG5]]{{\[}}[[CONST5]]
59+
// CHECK: {{%.*}} = spirv.Load "StorageBuffer" [[ARG5PTR]]
60+
// CHECK: [[ADDRESSARG6:%.*]] = spirv.mlir.addressof [[VAR6]]
61+
// CHECK: [[CONST6:%.*]] = spirv.Constant 0 : i32
62+
// CHECK: [[ARG6PTR:%.*]] = spirv.AccessChain [[ADDRESSARG6]]{{\[}}[[CONST6]]
63+
// CHECK: {{%.*}} = spirv.Load "StorageBuffer" [[ARG6PTR]]
6464
%0 = spirv.mlir.addressof @__builtin_var_WorkgroupId__ : !spirv.ptr<vector<3xi32>, Input>
6565
%1 = spirv.Load "Input" %0 : vector<3xi32>
6666
%2 = spirv.CompositeExtract %1[0 : i32] : vector<3xi32>

0 commit comments

Comments
 (0)