diff --git a/engine/src/flutter/impeller/renderer/backend/metal/compute_pass_mtl.mm b/engine/src/flutter/impeller/renderer/backend/metal/compute_pass_mtl.mm index e4755fef471..e027cc91df8 100644 --- a/engine/src/flutter/impeller/renderer/backend/metal/compute_pass_mtl.mm +++ b/engine/src/flutter/impeller/renderer/backend/metal/compute_pass_mtl.mm @@ -252,8 +252,8 @@ bool ComputePassMTL::EncodeCommands(const std::shared_ptr& allocator, while (width * height > static_cast( pass_bindings.GetPipeline().maxTotalThreadsPerThreadgroup)) { - width /= 2; - height /= 2; + width = std::max(1LL, width / 2); + height = std::max(1LL, height / 2); } auto size = MTLSizeMake(width, height, 1); [encoder dispatchThreadgroups:size threadsPerThreadgroup:size]; diff --git a/engine/src/flutter/impeller/renderer/compute_unittests.cc b/engine/src/flutter/impeller/renderer/compute_unittests.cc index ffed4bd7e85..cbb5c1b3bd2 100644 --- a/engine/src/flutter/impeller/renderer/compute_unittests.cc +++ b/engine/src/flutter/impeller/renderer/compute_unittests.cc @@ -204,5 +204,86 @@ TEST_P(ComputeTest, MultiStageInputAndOutput) { latch.Wait(); } +TEST_P(ComputeTest, CanCorrectlyDownScaleLargeGridSize) { + using CS = SampleComputeShader; + auto context = GetContext(); + ASSERT_TRUE(context); + ASSERT_TRUE(context->GetDeviceCapabilities().SupportsCompute()); + + using SamplePipelineBuilder = ComputePipelineBuilder; + auto pipeline_desc = + SamplePipelineBuilder::MakeDefaultPipelineDescriptor(*context); + ASSERT_TRUE(pipeline_desc.has_value()); + auto compute_pipeline = + context->GetPipelineLibrary()->GetPipeline(pipeline_desc).Get(); + ASSERT_TRUE(compute_pipeline); + + auto cmd_buffer = context->CreateCommandBuffer(); + auto pass = cmd_buffer->CreateComputePass(); + ASSERT_TRUE(pass && pass->IsValid()); + + static constexpr size_t kCount = 5; + + // Intentionally making the grid size obscenely large. No GPU will tolerate + // this. + pass->SetGridSize(ISize(std::numeric_limits::max(), 1)); + pass->SetThreadGroupSize(ISize(std::numeric_limits::max(), 1)); + + ComputeCommand cmd; + cmd.label = "Compute"; + cmd.pipeline = compute_pipeline; + + CS::Info info{.count = kCount}; + CS::Input0 input_0; + CS::Input1 input_1; + for (size_t i = 0; i < kCount; i++) { + input_0.elements[i] = Vector4(2.0 + i, 3.0 + i, 4.0 + i, 5.0 * i); + input_1.elements[i] = Vector4(6.0, 7.0, 8.0, 9.0); + } + + input_0.fixed_array[1] = IPoint32(2, 2); + input_1.fixed_array[0] = UintPoint32(3, 3); + input_0.some_int = 5; + input_1.some_struct = CS::SomeStruct{.vf = Point(3, 4), .i = 42}; + + auto output_buffer = CreateHostVisibleDeviceBuffer>( + context, "Output Buffer"); + + CS::BindInfo(cmd, pass->GetTransientsBuffer().EmplaceUniform(info)); + CS::BindInput0(cmd, + pass->GetTransientsBuffer().EmplaceStorageBuffer(input_0)); + CS::BindInput1(cmd, + pass->GetTransientsBuffer().EmplaceStorageBuffer(input_1)); + CS::BindOutput(cmd, output_buffer->AsBufferView()); + + ASSERT_TRUE(pass->AddCommand(std::move(cmd))); + ASSERT_TRUE(pass->EncodeCommands()); + + fml::AutoResetWaitableEvent latch; + ASSERT_TRUE( + cmd_buffer->SubmitCommands([&latch, output_buffer, &input_0, + &input_1](CommandBuffer::Status status) { + EXPECT_EQ(status, CommandBuffer::Status::kCompleted); + + auto view = output_buffer->AsBufferView(); + EXPECT_EQ(view.range.length, sizeof(CS::Output)); + + CS::Output* output = + reinterpret_cast*>(view.contents); + EXPECT_TRUE(output); + for (size_t i = 0; i < kCount; i++) { + Vector4 vector = output->elements[i]; + Vector4 computed = input_0.elements[i] * input_1.elements[i]; + EXPECT_EQ(vector, Vector4(computed.x + 2 + input_1.some_struct.i, + computed.y + 3 + input_1.some_struct.vf.x, + computed.z + 5 + input_1.some_struct.vf.y, + computed.w)); + } + latch.Signal(); + })); + + latch.Wait(); +} + } // namespace testing } // namespace impeller