|
5 | 5 | #include "flutter/fml/synchronization/waitable_event.h"
|
6 | 6 | #include "flutter/fml/time/time_point.h"
|
7 | 7 | #include "flutter/testing/testing.h"
|
| 8 | +#include "gmock/gmock.h" |
8 | 9 | #include "impeller/base/strings.h"
|
9 | 10 | #include "impeller/fixtures/sample.comp.h"
|
| 11 | +#include "impeller/fixtures/stage1.comp.h" |
| 12 | +#include "impeller/fixtures/stage2.comp.h" |
10 | 13 | #include "impeller/playground/compute_playground_test.h"
|
11 | 14 | #include "impeller/renderer/command_buffer.h"
|
12 | 15 | #include "impeller/renderer/compute_command.h"
|
@@ -102,5 +105,116 @@ TEST_P(ComputeTest, CanCreateComputePass) {
|
102 | 105 | latch.Wait();
|
103 | 106 | }
|
104 | 107 |
|
| 108 | +TEST_P(ComputeTest, MultiStageInputAndOutput) { |
| 109 | + using CS1 = Stage1ComputeShader; |
| 110 | + using Stage1PipelineBuilder = ComputePipelineBuilder<CS1>; |
| 111 | + using CS2 = Stage2ComputeShader; |
| 112 | + using Stage2PipelineBuilder = ComputePipelineBuilder<CS2>; |
| 113 | + |
| 114 | + auto context = GetContext(); |
| 115 | + ASSERT_TRUE(context); |
| 116 | + |
| 117 | + auto pipeline_desc_1 = |
| 118 | + Stage1PipelineBuilder::MakeDefaultPipelineDescriptor(*context); |
| 119 | + ASSERT_TRUE(pipeline_desc_1.has_value()); |
| 120 | + auto compute_pipeline_1 = |
| 121 | + context->GetPipelineLibrary()->GetPipeline(pipeline_desc_1).Get(); |
| 122 | + ASSERT_TRUE(compute_pipeline_1); |
| 123 | + |
| 124 | + auto pipeline_desc_2 = |
| 125 | + Stage2PipelineBuilder::MakeDefaultPipelineDescriptor(*context); |
| 126 | + ASSERT_TRUE(pipeline_desc_2.has_value()); |
| 127 | + auto compute_pipeline_2 = |
| 128 | + context->GetPipelineLibrary()->GetPipeline(pipeline_desc_2).Get(); |
| 129 | + ASSERT_TRUE(compute_pipeline_2); |
| 130 | + |
| 131 | + auto cmd_buffer = context->CreateCommandBuffer(); |
| 132 | + auto pass = cmd_buffer->CreateComputePass(); |
| 133 | + ASSERT_TRUE(pass && pass->IsValid()); |
| 134 | + |
| 135 | + static constexpr size_t kCount1 = 5; |
| 136 | + static constexpr size_t kCount2 = kCount1 * 2; |
| 137 | + |
| 138 | + pass->SetGridSize(ISize(512, 1)); |
| 139 | + pass->SetThreadGroupSize(ISize(512, 1)); |
| 140 | + |
| 141 | + CS1::Input<kCount1> input_1; |
| 142 | + input_1.count = kCount1; |
| 143 | + for (uint i = 0; i < kCount1; i++) { |
| 144 | + input_1.elements[i] = i; |
| 145 | + } |
| 146 | + |
| 147 | + CS2::Input<kCount2> input_2; |
| 148 | + input_2.count = kCount2; |
| 149 | + for (uint i = 0; i < kCount2; i++) { |
| 150 | + input_2.elements[i] = i; |
| 151 | + } |
| 152 | + |
| 153 | + DeviceBufferDescriptor output_desc_1; |
| 154 | + output_desc_1.storage_mode = StorageMode::kHostVisible; |
| 155 | + output_desc_1.size = sizeof(CS1::Output<kCount2>); |
| 156 | + |
| 157 | + auto output_buffer_1 = |
| 158 | + context->GetResourceAllocator()->CreateBuffer(output_desc_1); |
| 159 | + output_buffer_1->SetLabel("Output Buffer Stage 1"); |
| 160 | + |
| 161 | + DeviceBufferDescriptor output_desc_2; |
| 162 | + output_desc_2.storage_mode = StorageMode::kHostVisible; |
| 163 | + output_desc_2.size = sizeof(CS2::Output<kCount2>); |
| 164 | + |
| 165 | + auto output_buffer_2 = |
| 166 | + context->GetResourceAllocator()->CreateBuffer(output_desc_2); |
| 167 | + output_buffer_2->SetLabel("Output Buffer Stage 2"); |
| 168 | + |
| 169 | + { |
| 170 | + ComputeCommand cmd; |
| 171 | + cmd.label = "Compute1"; |
| 172 | + cmd.pipeline = compute_pipeline_1; |
| 173 | + |
| 174 | + CS1::BindInput(cmd, |
| 175 | + pass->GetTransientsBuffer().EmplaceStorageBuffer(input_1)); |
| 176 | + CS1::BindOutput(cmd, output_buffer_1->AsBufferView()); |
| 177 | + |
| 178 | + ASSERT_TRUE(pass->AddCommand(std::move(cmd))); |
| 179 | + } |
| 180 | + |
| 181 | + { |
| 182 | + ComputeCommand cmd; |
| 183 | + cmd.label = "Compute2"; |
| 184 | + cmd.pipeline = compute_pipeline_2; |
| 185 | + |
| 186 | + CS1::BindInput(cmd, output_buffer_1->AsBufferView()); |
| 187 | + CS2::BindOutput(cmd, output_buffer_2->AsBufferView()); |
| 188 | + ASSERT_TRUE(pass->AddCommand(std::move(cmd))); |
| 189 | + } |
| 190 | + |
| 191 | + ASSERT_TRUE(pass->EncodeCommands()); |
| 192 | + |
| 193 | + fml::AutoResetWaitableEvent latch; |
| 194 | + ASSERT_TRUE(cmd_buffer->SubmitCommands([&latch, &output_buffer_1, |
| 195 | + &output_buffer_2]( |
| 196 | + CommandBuffer::Status status) { |
| 197 | + EXPECT_EQ(status, CommandBuffer::Status::kCompleted); |
| 198 | + |
| 199 | + CS1::Output<kCount2>* output_1 = reinterpret_cast<CS1::Output<kCount2>*>( |
| 200 | + output_buffer_1->AsBufferView().contents); |
| 201 | + EXPECT_TRUE(output_1); |
| 202 | + EXPECT_EQ(output_1->count, 10u); |
| 203 | + EXPECT_THAT(output_1->elements, |
| 204 | + ::testing::ElementsAre(0, 0, 2, 3, 4, 6, 6, 9, 8, 12)); |
| 205 | + |
| 206 | + CS2::Output<kCount2>* output_2 = reinterpret_cast<CS2::Output<kCount2>*>( |
| 207 | + output_buffer_2->AsBufferView().contents); |
| 208 | + EXPECT_TRUE(output_2); |
| 209 | + EXPECT_EQ(output_2->count, 10u); |
| 210 | + EXPECT_THAT(output_2->elements, |
| 211 | + ::testing::ElementsAre(0, 0, 4, 6, 8, 12, 12, 18, 16, 24)); |
| 212 | + |
| 213 | + latch.Signal(); |
| 214 | + })); |
| 215 | + |
| 216 | + latch.Wait(); |
| 217 | +} |
| 218 | + |
105 | 219 | } // namespace testing
|
106 | 220 | } // namespace impeller
|
0 commit comments