Skip to content
This repository was archived by the owner on Feb 25, 2025. It is now read-only.

Commit c2e165e

Browse files
authored
Fix multi-function compute (#39603)
1 parent f7dfb2b commit c2e165e

File tree

5 files changed

+194
-17
lines changed

5 files changed

+194
-17
lines changed

impeller/fixtures/BUILD.gn

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,19 @@ impeller_shaders("shader_fixtures") {
2727
"mipmaps.frag",
2828
"mipmaps.vert",
2929
"sample.comp",
30+
"stage1.comp",
31+
"stage2.comp",
3032
"simple.vert",
3133
"test_texture.frag",
3234
"test_texture.vert",
3335
]
3436

3537
if (impeller_enable_opengles) {
36-
gles_exclusions = [ "sample.comp" ]
38+
gles_exclusions = [
39+
"sample.comp",
40+
"stage1.comp",
41+
"stage2.comp",
42+
]
3743
}
3844
}
3945

@@ -77,6 +83,8 @@ test_fixtures("file_fixtures") {
7783
"sample_with_binding.vert",
7884
"simple.vert.hlsl",
7985
"sa%m#ple.vert",
86+
"stage1.comp",
87+
"stage2.comp",
8088
"struct_def_bug.vert",
8189
"table_mountain_nx.png",
8290
"table_mountain_ny.png",

impeller/fixtures/stage1.comp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
layout(local_size_x = 128) in;
2+
layout(std430) buffer;
3+
4+
layout(binding = 0) writeonly buffer Output {
5+
uint count;
6+
uint elements[];
7+
}
8+
output_data;
9+
10+
layout(binding = 1) readonly buffer Input {
11+
uint count;
12+
uint elements[];
13+
}
14+
input_data;
15+
16+
void main() {
17+
uint ident = gl_GlobalInvocationID.x;
18+
19+
if (ident >= input_data.count) {
20+
return;
21+
}
22+
23+
uint out_slot = ident * 2;
24+
25+
output_data.count = input_data.count * 2;
26+
27+
output_data.elements[out_slot] = input_data.elements[ident] * 2;
28+
output_data.elements[out_slot + 1] = input_data.elements[ident] * 3;
29+
}

impeller/fixtures/stage2.comp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
layout(local_size_x = 128) in;
2+
layout(std430) buffer;
3+
4+
layout(binding = 0) writeonly buffer Output {
5+
uint count;
6+
uint elements[];
7+
}
8+
output_data;
9+
10+
layout(binding = 1) readonly buffer Input {
11+
uint count;
12+
uint elements[];
13+
}
14+
input_data;
15+
16+
void main() {
17+
uint ident = gl_GlobalInvocationID.x;
18+
19+
if (ident >= input_data.count) {
20+
return;
21+
}
22+
23+
output_data.count = input_data.count;
24+
25+
output_data.elements[ident] = input_data.elements[ident] * 2;
26+
}

impeller/renderer/backend/metal/compute_pass_mtl.mm

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -241,23 +241,23 @@ static bool Bind(ComputePassBindingsCache& pass,
241241
return false;
242242
}
243243
}
244+
// TODO(dnfield): use feature detection to support non-uniform threadgroup
245+
// sizes.
246+
// https://github.com/flutter/flutter/issues/110619
247+
248+
// For now, check that the sizes are uniform.
249+
FML_DCHECK(grid_size == thread_group_size);
250+
auto width = grid_size.width;
251+
auto height = grid_size.height;
252+
while (width * height >
253+
static_cast<int64_t>(
254+
pass_bindings.GetPipeline().maxTotalThreadsPerThreadgroup)) {
255+
width /= 2;
256+
height /= 2;
257+
}
258+
auto size = MTLSizeMake(width, height, 1);
259+
[encoder dispatchThreadgroups:size threadsPerThreadgroup:size];
244260
}
245-
// TODO(dnfield): use feature detection to support non-uniform threadgroup
246-
// sizes.
247-
// https://github.com/flutter/flutter/issues/110619
248-
249-
// For now, check that the sizes are uniform.
250-
FML_DCHECK(grid_size == thread_group_size);
251-
auto width = grid_size.width;
252-
auto height = grid_size.height;
253-
while (width * height >
254-
static_cast<int64_t>(
255-
pass_bindings.GetPipeline().maxTotalThreadsPerThreadgroup)) {
256-
width /= 2;
257-
height /= 2;
258-
}
259-
auto size = MTLSizeMake(width, height, 1);
260-
[encoder dispatchThreadgroups:size threadsPerThreadgroup:size];
261261

262262
return true;
263263
}

impeller/renderer/compute_unittests.cc

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55
#include "flutter/fml/synchronization/waitable_event.h"
66
#include "flutter/fml/time/time_point.h"
77
#include "flutter/testing/testing.h"
8+
#include "gmock/gmock.h"
89
#include "impeller/base/strings.h"
910
#include "impeller/fixtures/sample.comp.h"
11+
#include "impeller/fixtures/stage1.comp.h"
12+
#include "impeller/fixtures/stage2.comp.h"
1013
#include "impeller/playground/compute_playground_test.h"
1114
#include "impeller/renderer/command_buffer.h"
1215
#include "impeller/renderer/compute_command.h"
@@ -102,5 +105,116 @@ TEST_P(ComputeTest, CanCreateComputePass) {
102105
latch.Wait();
103106
}
104107

108+
TEST_P(ComputeTest, MultiStageInputAndOutput) {
109+
using CS1 = Stage1ComputeShader;
110+
using Stage1PipelineBuilder = ComputePipelineBuilder<CS1>;
111+
using CS2 = Stage2ComputeShader;
112+
using Stage2PipelineBuilder = ComputePipelineBuilder<CS2>;
113+
114+
auto context = GetContext();
115+
ASSERT_TRUE(context);
116+
117+
auto pipeline_desc_1 =
118+
Stage1PipelineBuilder::MakeDefaultPipelineDescriptor(*context);
119+
ASSERT_TRUE(pipeline_desc_1.has_value());
120+
auto compute_pipeline_1 =
121+
context->GetPipelineLibrary()->GetPipeline(pipeline_desc_1).Get();
122+
ASSERT_TRUE(compute_pipeline_1);
123+
124+
auto pipeline_desc_2 =
125+
Stage2PipelineBuilder::MakeDefaultPipelineDescriptor(*context);
126+
ASSERT_TRUE(pipeline_desc_2.has_value());
127+
auto compute_pipeline_2 =
128+
context->GetPipelineLibrary()->GetPipeline(pipeline_desc_2).Get();
129+
ASSERT_TRUE(compute_pipeline_2);
130+
131+
auto cmd_buffer = context->CreateCommandBuffer();
132+
auto pass = cmd_buffer->CreateComputePass();
133+
ASSERT_TRUE(pass && pass->IsValid());
134+
135+
static constexpr size_t kCount1 = 5;
136+
static constexpr size_t kCount2 = kCount1 * 2;
137+
138+
pass->SetGridSize(ISize(512, 1));
139+
pass->SetThreadGroupSize(ISize(512, 1));
140+
141+
CS1::Input<kCount1> input_1;
142+
input_1.count = kCount1;
143+
for (uint i = 0; i < kCount1; i++) {
144+
input_1.elements[i] = i;
145+
}
146+
147+
CS2::Input<kCount2> input_2;
148+
input_2.count = kCount2;
149+
for (uint i = 0; i < kCount2; i++) {
150+
input_2.elements[i] = i;
151+
}
152+
153+
DeviceBufferDescriptor output_desc_1;
154+
output_desc_1.storage_mode = StorageMode::kHostVisible;
155+
output_desc_1.size = sizeof(CS1::Output<kCount2>);
156+
157+
auto output_buffer_1 =
158+
context->GetResourceAllocator()->CreateBuffer(output_desc_1);
159+
output_buffer_1->SetLabel("Output Buffer Stage 1");
160+
161+
DeviceBufferDescriptor output_desc_2;
162+
output_desc_2.storage_mode = StorageMode::kHostVisible;
163+
output_desc_2.size = sizeof(CS2::Output<kCount2>);
164+
165+
auto output_buffer_2 =
166+
context->GetResourceAllocator()->CreateBuffer(output_desc_2);
167+
output_buffer_2->SetLabel("Output Buffer Stage 2");
168+
169+
{
170+
ComputeCommand cmd;
171+
cmd.label = "Compute1";
172+
cmd.pipeline = compute_pipeline_1;
173+
174+
CS1::BindInput(cmd,
175+
pass->GetTransientsBuffer().EmplaceStorageBuffer(input_1));
176+
CS1::BindOutput(cmd, output_buffer_1->AsBufferView());
177+
178+
ASSERT_TRUE(pass->AddCommand(std::move(cmd)));
179+
}
180+
181+
{
182+
ComputeCommand cmd;
183+
cmd.label = "Compute2";
184+
cmd.pipeline = compute_pipeline_2;
185+
186+
CS1::BindInput(cmd, output_buffer_1->AsBufferView());
187+
CS2::BindOutput(cmd, output_buffer_2->AsBufferView());
188+
ASSERT_TRUE(pass->AddCommand(std::move(cmd)));
189+
}
190+
191+
ASSERT_TRUE(pass->EncodeCommands());
192+
193+
fml::AutoResetWaitableEvent latch;
194+
ASSERT_TRUE(cmd_buffer->SubmitCommands([&latch, &output_buffer_1,
195+
&output_buffer_2](
196+
CommandBuffer::Status status) {
197+
EXPECT_EQ(status, CommandBuffer::Status::kCompleted);
198+
199+
CS1::Output<kCount2>* output_1 = reinterpret_cast<CS1::Output<kCount2>*>(
200+
output_buffer_1->AsBufferView().contents);
201+
EXPECT_TRUE(output_1);
202+
EXPECT_EQ(output_1->count, 10u);
203+
EXPECT_THAT(output_1->elements,
204+
::testing::ElementsAre(0, 0, 2, 3, 4, 6, 6, 9, 8, 12));
205+
206+
CS2::Output<kCount2>* output_2 = reinterpret_cast<CS2::Output<kCount2>*>(
207+
output_buffer_2->AsBufferView().contents);
208+
EXPECT_TRUE(output_2);
209+
EXPECT_EQ(output_2->count, 10u);
210+
EXPECT_THAT(output_2->elements,
211+
::testing::ElementsAre(0, 0, 4, 6, 8, 12, 12, 18, 16, 24));
212+
213+
latch.Signal();
214+
}));
215+
216+
latch.Wait();
217+
}
218+
105219
} // namespace testing
106220
} // namespace impeller

0 commit comments

Comments
 (0)