diff --git a/backends/vulkan/test/glsl/all_shaders.yaml b/backends/vulkan/test/glsl/all_shaders.yaml new file mode 100644 index 00000000000..5ba38dbb545 --- /dev/null +++ b/backends/vulkan/test/glsl/all_shaders.yaml @@ -0,0 +1,42 @@ +binary_op_nobroadcast__test: + parameter_names_with_default_values: + OPERATOR: X + Y + shader_variants: + - NAME: binary_add_nobroadcast__test + OPERATOR: X + Y + - NAME: binary_sub_nobroadcast__test + OPERATOR: X - Y + - NAME: binary_mul_nobroadcast__test + OPERATOR: X * Y + - NAME: binary_div_nobroadcast__test + OPERATOR: X / Y + - NAME: binary_pow_nobroadcast__test + OPERATOR: pow(X, Y) + +image_to_nchw__test: + parameter_names_with_default_values: + NDIM: 3 + DTYPE: float + PACKING: CHANNELS_PACKED + generate_variant_forall: + DTYPE: + - VALUE: "half" + SUFFIX: "half" + - VALUE: "float" + SUFFIX: "float" + shader_variants: + - NAME: image3d_to_nchw__test_C_packed + +nchw_to_image__test: + parameter_names_with_default_values: + NDIM: 3 + DTYPE: float + PACKING: CHANNELS_PACKED + generate_variant_forall: + DTYPE: + - VALUE: "half" + SUFFIX: "half" + - VALUE: "float" + SUFFIX: "float" + shader_variants: + - NAME: nchw_to_image3d__test_C_packed diff --git a/backends/vulkan/test/glsl/binary_op_nobroadcast__test.glsl b/backends/vulkan/test/glsl/binary_op_nobroadcast__test.glsl new file mode 100644 index 00000000000..dd7f7d303a4 --- /dev/null +++ b/backends/vulkan/test/glsl/binary_op_nobroadcast__test.glsl @@ -0,0 +1,43 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core +// clang-format off +#define PRECISION ${PRECISION} +#define FORMAT ${FORMAT} + +#define OP(X, Y) ${OPERATOR} +// clang-format on + +layout(std430) buffer; + +// clang-format off +layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D image_out; +// clang-format on +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; +layout(set = 0, binding = 2) uniform PRECISION sampler3D image_other; + +layout(set = 0, binding = 3) uniform PRECISION restrict OutExtents { + uvec4 data; +} +out_extents; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, out_extents.data.xyz))) { + return; + } + + vec4 in_texel = texelFetch(image_in, pos, 0); + vec4 other_texel = texelFetch(image_other, pos, 0); + + imageStore(image_out, pos, OP(in_texel, other_texel)); +} diff --git a/backends/vulkan/test/glsl/fill_texture__test.glsl b/backends/vulkan/test/glsl/fill_texture__test.glsl new file mode 100644 index 00000000000..ead168fa8e0 --- /dev/null +++ b/backends/vulkan/test/glsl/fill_texture__test.glsl @@ -0,0 +1,36 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core +#define PRECISION ${PRECISION} +#define FORMAT ${FORMAT} + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +// clang-format off +layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; +// clang-format on +layout(set = 0, binding = 1) uniform PRECISION restrict Block { + ivec3 size; + int fill; + vec4 vals; +} params; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, params.size))) { + return; + } + + imageStore(uOutput, pos, params.vals); +} diff --git a/backends/vulkan/test/glsl/image_to_nchw__test.glsl b/backends/vulkan/test/glsl/image_to_nchw__test.glsl new file mode 100644 index 00000000000..b5563b080fb --- /dev/null +++ b/backends/vulkan/test/glsl/image_to_nchw__test.glsl @@ -0,0 +1,62 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core +// clang-format off +#define PRECISION ${PRECISION} +// clang-format on + +#include "indexing_utils.h" + +layout(std430) buffer; + +layout(set = 0, binding = 0) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} image_in; +layout(set = 0, binding = 1) buffer PRECISION restrict writeonly Buffer { + ${T[DTYPE]} data[]; +} +buffer_out; + +layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { + ivec4 data; +} +gpu_sizes; + +layout(set = 0, binding = 3) uniform PRECISION restrict CpuSizes { + ivec4 data; +} +cpu_sizes; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 coord = POS_TO_COORD_${PACKING}(pos, gpu_sizes.data); + + if (any(greaterThanEqual(coord, gpu_sizes.data))) { + return; + } + + const ${VEC4_T[DTYPE]} intex = texelFetch(image_in, pos, 0); + + const int base_index = COORD_TO_BUFFER_IDX(coord, cpu_sizes.data); + const ivec4 buf_indices = + base_index + ivec4(0, 1, 2, 3) * (gpu_sizes.data.x * gpu_sizes.data.y); + + if (coord.z < cpu_sizes.data.z) { + buffer_out.data[buf_indices.x] = intex.x; + } + if (coord.z + 1 < cpu_sizes.data.z) { + buffer_out.data[buf_indices.y] = intex.y; + } + if (coord.z + 2 < cpu_sizes.data.z) { + buffer_out.data[buf_indices.z] = intex.z; + } + if (coord.z + 3 < cpu_sizes.data.z) { + buffer_out.data[buf_indices.w] = intex.w; + } +} diff --git a/backends/vulkan/test/glsl/indexing_utils.h b/backends/vulkan/test/glsl/indexing_utils.h new file mode 100644 index 00000000000..d3f005c1eea --- /dev/null +++ b/backends/vulkan/test/glsl/indexing_utils.h @@ -0,0 +1,14 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#define POS_TO_COORD_CHANNELS_PACKED(pos, sizes) \ + ivec4(pos.x, pos.y, (pos.z * 4) % sizes.z, (pos.z * 4) / sizes.z) + +#define COORD_TO_BUFFER_IDX(coord, sizes) \ + coord.x + coord.y* sizes.x + coord.z* sizes.y* sizes.x + \ + coord.w* sizes.z* sizes.y* sizes.x; diff --git a/backends/vulkan/test/glsl/nchw_to_image__test.glsl b/backends/vulkan/test/glsl/nchw_to_image__test.glsl new file mode 100644 index 00000000000..1a41fd88d0f --- /dev/null +++ b/backends/vulkan/test/glsl/nchw_to_image__test.glsl @@ -0,0 +1,64 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core +// clang-format off +#define PRECISION ${PRECISION} +// clang-format on + +#include "indexing_utils.h" + +layout(std430) buffer; + +// clang-format off +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +// clang-format on +layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { + ${T[DTYPE]} data[]; +} +buffer_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { + ivec4 data; +} +gpu_sizes; + +layout(set = 0, binding = 3) uniform PRECISION restrict CpuSizes { + ivec4 data; +} +cpu_sizes; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 coord = POS_TO_COORD_${PACKING}(pos, gpu_sizes.data); + + if (any(greaterThanEqual(coord, gpu_sizes.data))) { + return; + } + + const int base_index = COORD_TO_BUFFER_IDX(coord, cpu_sizes.data); + const ivec4 buf_indices = + base_index + ivec4(0, 1, 2, 3) * (gpu_sizes.data.x * gpu_sizes.data.y); + + ${T[DTYPE]} val_x = buffer_in.data[buf_indices.x]; + ${T[DTYPE]} val_y = buffer_in.data[buf_indices.y]; + ${T[DTYPE]} val_z = buffer_in.data[buf_indices.z]; + ${T[DTYPE]} val_w = buffer_in.data[buf_indices.w]; + + ${VEC4_T[DTYPE]} texel = ${VEC4_T[DTYPE]}(val_x, val_y, val_z, val_w); + + if (coord.z + 3 >= cpu_sizes.data.z) { + ivec4 c_ind = ivec4(coord.z) + ivec4(0, 1, 2, 3); + vec4 valid_c = vec4(lessThan(c_ind, ivec4(cpu_sizes.data.z))); + texel = texel * valid_c; + } + + imageStore(image_out, ${GET_POS[NDIM]("pos")}, texel); +} diff --git a/backends/vulkan/test/glsl/test_shader.glsl b/backends/vulkan/test/glsl/test_shader.glsl index 897b8a45486..39edc92cc62 100644 --- a/backends/vulkan/test/glsl/test_shader.glsl +++ b/backends/vulkan/test/glsl/test_shader.glsl @@ -1,3 +1,11 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + #version 450 core #define PRECISION ${PRECISION} #define FORMAT ${FORMAT} diff --git a/backends/vulkan/test/utils/test_utils.cpp b/backends/vulkan/test/utils/test_utils.cpp new file mode 100644 index 00000000000..7fa7a9ad110 --- /dev/null +++ b/backends/vulkan/test/utils/test_utils.cpp @@ -0,0 +1,226 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +// +// Operator Recording Functions +// + +void record_nchw_to_buffer_op( + api::Context* const context, + api::VulkanBuffer& src_buffer, + vTensor& v_dst) { + uint32_t buf_len = api::utils::safe_downcast(v_dst.gpu_numel()); + api::utils::uvec3 global_size = {buf_len, 1u, 1u}; + api::utils::uvec3 local_size = {32u, 1u, 1u}; + + api::UniformParamsBuffer cpu_buffer_metadata( + context, v_dst.get_cpu_buffer_metadata()); + api::PipelineBarrier pipeline_barrier{}; + + context->submit_compute_job( + VK_KERNEL(buffer_to_buffer), + pipeline_barrier, + global_size, + local_size, + VK_NULL_HANDLE, + v_dst.buffer( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_dst.buffer_metadata(), + src_buffer, + cpu_buffer_metadata.buffer()); +} + +bool record_buffer_to_nchw_op( + api::Context* const context, + vTensor& v_src, + api::VulkanBuffer& dst_buffer) { + uint32_t buf_len = api::utils::safe_downcast(v_src.numel()); + api::utils::uvec3 global_size = {buf_len, 1u, 1u}; + api::utils::uvec3 local_size = {4u, 1u, 1u}; + + api::UniformParamsBuffer cpu_buffer_metadata( + context, v_src.get_cpu_buffer_metadata()); + api::PipelineBarrier pipeline_barrier{}; + + return context->submit_compute_job( + VK_KERNEL(buffer_to_buffer), + pipeline_barrier, + global_size, + local_size, + VK_NULL_HANDLE, + dst_buffer, + cpu_buffer_metadata.buffer(), + v_src.buffer( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_src.buffer_metadata()); +} + +void record_nchw_to_image_op( + api::Context* const context, + api::VulkanBuffer& src_buffer, + vTensor& v_dst) { + api::PipelineBarrier pipeline_barrier{}; + api::ShaderInfo compute_shader = + VK_KERNEL(nchw_to_image3d__test_C_packed_half); + if (v_dst.image().format() == VK_FORMAT_R32G32B32A32_SFLOAT) { + compute_shader = VK_KERNEL(nchw_to_image3d__test_C_packed_float); + } + context->submit_compute_job( + compute_shader, + pipeline_barrier, + v_dst.virtual_extents(), + adaptive_work_group_size(v_dst.virtual_extents()), + VK_NULL_HANDLE, + v_dst.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + src_buffer, + v_dst.gpu_sizes_ubo()->buffer(), + v_dst.cpu_sizes_ubo()->buffer()); +} + +void record_image_to_nchw_op( + api::Context* const context, + vTensor& v_src, + api::VulkanBuffer& dst_buffer) { + api::ShaderInfo compute_shader = + VK_KERNEL(image3d_to_nchw__test_C_packed_half); + if (v_src.image().format() == VK_FORMAT_R32G32B32A32_SFLOAT) { + compute_shader = VK_KERNEL(image3d_to_nchw__test_C_packed_float); + } + api::PipelineBarrier pipeline_barrier{}; + context->submit_compute_job( + compute_shader, + pipeline_barrier, + v_src.virtual_extents(), + adaptive_work_group_size(v_src.virtual_extents()), + VK_NULL_HANDLE, + v_src.image(pipeline_barrier, api::PipelineStage::COMPUTE), + dst_buffer, + v_src.gpu_sizes_ubo()->buffer(), + v_src.cpu_sizes_ubo()->buffer()); +} + +void record_arithmetic_op( + api::Context* const context, + const api::ShaderInfo& compute_shader, + vTensor& v_in1, + vTensor& v_in2, + vTensor& v_dst) { + api::PipelineBarrier pipeline_barrier{}; + context->submit_compute_job( + compute_shader, + pipeline_barrier, + v_dst.virtual_extents(), + adaptive_work_group_size(v_dst.virtual_extents()), + VK_NULL_HANDLE, + v_dst.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_in1.image(pipeline_barrier, api::PipelineStage::COMPUTE), + v_in2.image(pipeline_barrier, api::PipelineStage::COMPUTE), + v_dst.extents_ubo()->buffer()); +} + +void execute_and_check_add( + vTensor& a, + vTensor& b, + vTensor& c, + float a_val, + float b_val) { + // Add shader kernel + api::ShaderInfo kernel = VK_KERNEL(binary_add_nobroadcast__test); + + // Fill input tensors + fill_vtensor(a, a_val); + fill_vtensor(b, b_val); + + // a + b = c + record_arithmetic_op(api::context(), kernel, a, b, c); + + // Extract output tensor + std::vector data_out = extract_vtensor(c); + + // Check output + for (const auto& d : data_out) { + EXPECT_TRUE(d == (a_val + b_val)); + } +} + +// +// Input & Output Utilities +// + +void fill_vtensor(vTensor& vten, std::vector& data) { + api::StorageBuffer staging_buffer(api::context(), api::kFloat, data.size()); + + copy_ptr_to_staging(data.data(), staging_buffer, vten.gpu_nbytes()); + + if (vten.storage_type() == api::StorageType::BUFFER) { + record_nchw_to_buffer_op(api::context(), staging_buffer.buffer(), vten); + } else { + record_nchw_to_image_op(api::context(), staging_buffer.buffer(), vten); + } +} + +void fill_vtensor(ComputeGraph& graph, const IOValueRef idx, float val) { + std::vector data(graph.get_val(idx.value).toTensor().gpu_numel()); + std::fill(data.begin(), data.end(), val); + + graph.copy_into_staging(idx.staging, data.data(), data.size()); +} + +void extract_vtensor(vTensor& vten, std::vector& data) { + api::StorageBuffer staging_buffer( + api::context(), api::kFloat, vten.gpu_numel()); + + if (vten.storage_type() == api::StorageType::BUFFER) { + record_buffer_to_nchw_op(api::context(), vten, staging_buffer.buffer()); + } else { + record_image_to_nchw_op(api::context(), vten, staging_buffer.buffer()); + } + + api::VulkanFence fence = api::context()->fences().get_fence(); + api::context()->submit_cmd_to_gpu(fence.get_submit_handle()); + fence.wait(); + + copy_staging_to_ptr(staging_buffer, data.data(), vten.gpu_nbytes()); +} + +// +// Context Management +// + +void submit_to_gpu() { + api::VulkanFence fence = api::context()->fences().get_fence(); + api::context()->submit_cmd_to_gpu(fence.get_submit_handle()); + fence.wait(); +} + +api::MemoryAllocation allocate_memory_for(const vTensor& vten) { + return api::context()->adapter_ptr()->vma().create_allocation( + vten.get_memory_requirements(), vten.get_allocation_create_info()); +} + +VmaTotalStatistics get_vma_stats() { + return api::context()->adapter_ptr()->vma().get_memory_statistics(); +} + +size_t get_vma_allocation_count() { + return get_vma_stats().total.statistics.allocationCount; +} diff --git a/backends/vulkan/test/utils/test_utils.h b/backends/vulkan/test/utils/test_utils.h new file mode 100644 index 00000000000..c0ded230f9a --- /dev/null +++ b/backends/vulkan/test/utils/test_utils.h @@ -0,0 +1,143 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +#include + +using namespace at::native::vulkan; + +#define CREATE_FLOAT_TEXTURE(sizes, allocate_memory) \ + vTensor( \ + api::context(), \ + sizes, \ + api::kFloat, \ + api::StorageType::TEXTURE_3D, \ + api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, \ + allocate_memory); + +#define CREATE_FLOAT_BUFFER(sizes, allocate_memory) \ + vTensor( \ + api::context(), \ + sizes, \ + api::kFloat, \ + api::StorageType::BUFFER, \ + api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, \ + allocate_memory); + +#define DEFINE_STAGING_BUFFER_AND_RECORD_TO_GPU_FOR(tensor) \ + api::StorageBuffer staging_buffer_##tensor( \ + api::context(), api::kFloat, tensor.gpu_numel()); \ + record_nchw_to_image_op( \ + api::context(), staging_buffer_##tensor.buffer(), tensor); + +#define DEFINE_STAGING_BUFFER_AND_RECORD_FROM_GPU_FOR(tensor) \ + api::StorageBuffer staging_buffer_##tensor( \ + api::context(), api::kFloat, tensor.gpu_numel()); \ + record_image_to_nchw_op( \ + api::context(), tensor, staging_buffer_##tensor.buffer()); + +// +// Operator Recording +// + +void record_nchw_to_buffer_op( + api::Context* const context, + api::VulkanBuffer& src_buffer, + vTensor& v_dst); + +bool record_buffer_to_nchw_op( + api::Context* const context, + vTensor& v_src, + api::VulkanBuffer& dst_buffer); + +void record_nchw_to_image_op( + api::Context* const context, + api::VulkanBuffer& src_buffer, + vTensor& v_dst); + +void record_image_to_nchw_op( + api::Context* const context, + vTensor& v_src, + api::VulkanBuffer& dst_buffer); + +void record_arithmetic_op( + api::Context* const context, + const api::ShaderInfo& compute_shader, + vTensor& v_in1, + vTensor& v_in2, + vTensor& v_dst); + +void execute_and_check_add( + vTensor& a, + vTensor& b, + vTensor& c, + float a_val, + float b_val); + +// +// Input & Output Utilities +// + +inline void +fill_staging(api::StorageBuffer& staging, float val, int numel = -1) { + if (numel < 0) { + numel = staging.numel(); + } + std::vector data(numel); + std::fill(data.begin(), data.end(), val); + copy_ptr_to_staging(data.data(), staging, sizeof(float) * numel); +} + +void fill_vtensor(vTensor& vten, std::vector& data); + +inline void fill_vtensor(vTensor& vten, float val) { + std::vector vten_data(vten.gpu_numel()); + std::fill(vten_data.begin(), vten_data.end(), val); + + fill_vtensor(vten, vten_data); +} + +void fill_vtensor(ComputeGraph& graph, const IOValueRef idx, float val); + +void extract_vtensor(vTensor& vten, std::vector& data); + +inline std::vector extract_vtensor(vTensor& vten) { + std::vector data_out(vten.gpu_numel()); + extract_vtensor(vten, data_out); + return data_out; +} + +inline void +check_staging_buffer(api::StorageBuffer& staging, float val, int numel = -1) { + if (numel < 0) { + numel = staging.numel(); + } + std::vector data(numel); + copy_staging_to_ptr(staging, data.data(), sizeof(float) * numel); + + for (const auto& d : data) { + EXPECT_TRUE(d == val); + } +} + +// +// Context Management +// + +void submit_to_gpu(); + +api::MemoryAllocation allocate_memory_for(const vTensor& vten); + +VmaTotalStatistics get_vma_stats(); + +size_t get_vma_allocation_count(); diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 6b041ab826a..9df6a8dd2d1 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -19,221 +19,12 @@ #include -using namespace at::native::vulkan; - -#define CREATE_FLOAT_TEXTURE(sizes, allocate_memory) \ - vTensor( \ - api::context(), \ - sizes, \ - api::kFloat, \ - api::StorageType::TEXTURE_3D, \ - api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, \ - allocate_memory); - -#define CREATE_FLOAT_BUFFER(sizes, allocate_memory) \ - vTensor( \ - api::context(), \ - sizes, \ - api::kFloat, \ - api::StorageType::BUFFER, \ - api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, \ - allocate_memory); - -// -// Simplified versions of ATen Vulkan legacy functions -// - -void record_nchw_to_buffer_op( - api::Context* const context, - api::VulkanBuffer& src_buffer, - vTensor& v_dst) { - uint32_t buf_len = api::utils::safe_downcast(v_dst.gpu_numel()); - api::utils::uvec3 global_size = {buf_len, 1u, 1u}; - api::utils::uvec3 local_size = {32u, 1u, 1u}; - - api::UniformParamsBuffer cpu_buffer_metadata( - context, v_dst.get_cpu_buffer_metadata()); - api::PipelineBarrier pipeline_barrier{}; - - context->submit_compute_job( - VK_KERNEL(buffer_to_buffer), - pipeline_barrier, - global_size, - local_size, - VK_NULL_HANDLE, - v_dst.buffer( - pipeline_barrier, - api::PipelineStage::COMPUTE, - api::MemoryAccessType::WRITE), - v_dst.buffer_metadata(), - src_buffer, - cpu_buffer_metadata.buffer()); -} - -bool record_buffer_to_nchw_op( - api::Context* const context, - vTensor& v_src, - api::VulkanBuffer& dst_buffer) { - uint32_t buf_len = api::utils::safe_downcast(v_src.numel()); - api::utils::uvec3 global_size = {buf_len, 1u, 1u}; - api::utils::uvec3 local_size = {4u, 1u, 1u}; - - api::UniformParamsBuffer cpu_buffer_metadata( - context, v_src.get_cpu_buffer_metadata()); - api::PipelineBarrier pipeline_barrier{}; - - return context->submit_compute_job( - VK_KERNEL(buffer_to_buffer), - pipeline_barrier, - global_size, - local_size, - VK_NULL_HANDLE, - dst_buffer, - cpu_buffer_metadata.buffer(), - v_src.buffer( - pipeline_barrier, - api::PipelineStage::COMPUTE, - api::MemoryAccessType::WRITE), - v_src.buffer_metadata()); -} - -void record_nchw_to_image_op( - api::Context* const context, - api::VulkanBuffer& src_buffer, - vTensor& v_dst) { - api::utils::uvec3 global_size = v_dst.extents(); - api::utils::uvec3 local_size = adaptive_work_group_size(global_size); - - api::UniformParamsBuffer params(context, create_staging_params(v_dst)); - api::PipelineBarrier pipeline_barrier{}; - - context->submit_compute_job( - get_nchw_to_image_shader(v_dst), - pipeline_barrier, - global_size, - local_size, - VK_NULL_HANDLE, - v_dst.image( - pipeline_barrier, - api::PipelineStage::COMPUTE, - api::MemoryAccessType::WRITE), - src_buffer, - params.buffer()); -} - -bool record_image_to_nchw_op( - api::Context* const context, - vTensor& v_src, - api::VulkanBuffer& dst_buffer) { - api::utils::uvec3 global_size = v_src.extents(); - api::utils::uvec3 local_size = adaptive_work_group_size(global_size); - - api::UniformParamsBuffer params(context, create_staging_params(v_src)); - api::PipelineBarrier pipeline_barrier{}; - - return context->submit_compute_job( - get_image_to_nchw_shader(v_src), - pipeline_barrier, - global_size, - local_size, - VK_NULL_HANDLE, - v_src.image( - pipeline_barrier, - api::PipelineStage::COMPUTE, - api::MemoryAccessType::WRITE), - dst_buffer, - params.buffer()); -} - -void record_arithmetic_op( - api::Context* const context, - const api::ShaderInfo& compute_shader, - vTensor& v_in1, - vTensor& v_in2, - vTensor& v_dst, - const float alpha) { - api::utils::uvec3 global_size = v_dst.extents(); - api::utils::uvec3 local_size = adaptive_work_group_size(global_size); - - ArithmeticParams block{ - get_size_as_ivec4(v_dst), - get_size_as_ivec4(v_in1), - get_size_as_ivec4(v_in2), - alpha, - }; - api::UniformParamsBuffer params(context, block); - api::PipelineBarrier pipeline_barrier{}; - - context->submit_compute_job( - compute_shader, - pipeline_barrier, - global_size, - local_size, - VK_NULL_HANDLE, - v_dst.image( - pipeline_barrier, - api::PipelineStage::COMPUTE, - api::MemoryAccessType::WRITE), - v_in1.image(pipeline_barrier, api::PipelineStage::COMPUTE), - v_in2.image(pipeline_barrier, api::PipelineStage::COMPUTE), - params.buffer()); -} - -// -// Utilities -// - -void fill_vtensor(vTensor& vten, std::vector& data) { - api::StorageBuffer staging_buffer(api::context(), api::kFloat, data.size()); - - copy_ptr_to_staging(data.data(), staging_buffer, vten.gpu_nbytes()); - - if (vten.storage_type() == api::StorageType::BUFFER) { - record_nchw_to_buffer_op(api::context(), staging_buffer.buffer(), vten); - } else { - record_nchw_to_image_op(api::context(), staging_buffer.buffer(), vten); - } -} - -void fill_vtensor(ComputeGraph& graph, const IOValueRef idx, float val) { - std::vector data(graph.get_val(idx.value).toTensor().gpu_numel()); - std::fill(data.begin(), data.end(), val); - - graph.copy_into_staging(idx.staging, data.data(), data.size()); -} - -void extract_vtensor(vTensor& vten, std::vector& data) { - api::StorageBuffer staging_buffer( - api::context(), api::kFloat, vten.gpu_numel()); - - if (vten.storage_type() == api::StorageType::BUFFER) { - record_buffer_to_nchw_op(api::context(), vten, staging_buffer.buffer()); - } else { - record_image_to_nchw_op(api::context(), vten, staging_buffer.buffer()); - } +#include - api::VulkanFence fence = api::context()->fences().get_fence(); - api::context()->submit_cmd_to_gpu(fence.get_submit_handle()); - fence.wait(); - - copy_staging_to_ptr(staging_buffer, data.data(), vten.gpu_nbytes()); -} - -api::MemoryAllocation allocate_memory_for(const vTensor& vten) { - return api::context()->adapter_ptr()->vma().create_allocation( - vten.get_memory_requirements(), vten.get_allocation_create_info()); -} - -VmaTotalStatistics get_vma_stats() { - return api::context()->adapter_ptr()->vma().get_memory_statistics(); -} - -size_t get_vma_allocation_count() { - return get_vma_stats().total.statistics.allocationCount; -} +using namespace at::native::vulkan; // -// Test Wrapper +// Compute API Tests // class VulkanComputeAPITest : public ::testing::Test { @@ -251,10 +42,6 @@ class VulkanComputeAPITest : public ::testing::Test { } }; -// -// Compute API Tests -// - TEST_F(VulkanComputeAPITest, retrieve_custom_shader_test) { // Try to get shader from custom shader library const api::ShaderInfo& kernel = VK_KERNEL(test_shader); @@ -262,6 +49,58 @@ TEST_F(VulkanComputeAPITest, retrieve_custom_shader_test) { EXPECT_TRUE(kernel.kernel_name == "test_shader"); } +TEST_F(VulkanComputeAPITest, update_params_between_submit) { + api::context()->set_cmd(/*reusable = */ true); + std::vector sizes = {4, 4, 2}; + vTensor a = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ true); + + struct Params final { + api::utils::ivec3 size; + int32_t fill; + api::utils::vec4 values; + }; + + Params block{ + {2, 4, 1}, + 0, + {5.0, 5.0, 5.0, 5.0}, + }; + + api::UniformParamsBuffer params(api::context(), block); + + { + api::PipelineBarrier pipeline_barrier{}; + api::context()->submit_compute_job( + VK_KERNEL(fill_texture__test), + pipeline_barrier, + {4, 4, 4}, + {4, 4, 4}, + VK_NULL_HANDLE, + a.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + params.buffer()); + } + + api::StorageBuffer staging_buffer(api::context(), api::kFloat, a.gpu_numel()); + record_image_to_nchw_op(api::context(), a, staging_buffer.buffer()); + + submit_to_gpu(); + check_staging_buffer(staging_buffer, 5.0f); + + Params new_block{ + {2, 4, 1}, + 0, + {4.0, 4.0, 4.0, 4.0}, + }; + + params.update(new_block); + + submit_to_gpu(); + check_staging_buffer(staging_buffer, 4.0f); +} + TEST_F(VulkanComputeAPITest, buffer_copy_sanity_check) { // Simple test that copies data into a and reads from a std::vector sizes = {4, 4, 1}; @@ -290,9 +129,7 @@ TEST_F(VulkanComputeAPITest, buffer_deferred_allocation_test) { std::vector sizes = {4, 4, 1}; vTensor a = CREATE_FLOAT_BUFFER(sizes, /*allocate_memory = */ false); - // For buffer storage, a small uniform buffer is allocated containing size and - // stride data, which is why the check is for 1 allocation below. - EXPECT_TRUE(get_vma_allocation_count() == 1); + EXPECT_TRUE(get_vma_allocation_count() == 0); // Input data std::vector data_in(a.gpu_numel()); @@ -302,7 +139,7 @@ TEST_F(VulkanComputeAPITest, buffer_deferred_allocation_test) { api::MemoryAllocation a_mem = allocate_memory_for(a); a.buffer().bind_allocation(a_mem); - EXPECT_TRUE(get_vma_allocation_count() == 2); + EXPECT_TRUE(get_vma_allocation_count() == 1); // Fill input tensor fill_vtensor(a, data_in); @@ -325,25 +162,16 @@ TEST_F(VulkanComputeAPITest, texture_add_sanity_check) { vTensor b = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ true); vTensor c = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ true); - // Input data - std::vector data_a(a.gpu_numel()); - std::fill(data_a.begin(), data_a.end(), 2.5f); - std::vector data_b(b.gpu_numel()); - std::fill(data_b.begin(), data_b.end(), 1.5f); - - // Add shader kernel - api::ShaderInfo kernel = VK_KERNEL(add); - // Fill input tensors - fill_vtensor(a, data_a); - fill_vtensor(b, data_b); + fill_vtensor(a, 2.5f); + fill_vtensor(b, 1.5f); // a + b -> c - record_arithmetic_op(api::context(), kernel, a, b, c, 1.0f); + record_arithmetic_op( + api::context(), VK_KERNEL(binary_add_nobroadcast__test), a, b, c); // Extract output tensor - std::vector data_out(c.gpu_numel()); - extract_vtensor(c, data_out); + std::vector data_out = extract_vtensor(c); // Check output for (const auto& d : data_out) { @@ -368,7 +196,7 @@ TEST_F(VulkanComputeAPITest, texture_deferred_allocation_test) { std::vector data_b(b.gpu_numel()); std::fill(data_b.begin(), data_b.end(), 1.5f); - api::ShaderInfo kernel = VK_KERNEL(add); + api::ShaderInfo kernel = VK_KERNEL(binary_add_nobroadcast__test); // Allocate memory at the last possible opportunity api::MemoryAllocation a_mem = allocate_memory_for(a); @@ -384,7 +212,7 @@ TEST_F(VulkanComputeAPITest, texture_deferred_allocation_test) { fill_vtensor(a, data_a); fill_vtensor(b, data_b); - record_arithmetic_op(api::context(), kernel, a, b, c, 1.0f); + record_arithmetic_op(api::context(), kernel, a, b, c); std::vector data_c(c.gpu_numel()); extract_vtensor(c, data_c); @@ -434,20 +262,20 @@ TEST_F(VulkanComputeAPITest, texture_resource_aliasing_test) { std::fill(data_d.begin(), data_d.end(), 1.0f); // Get shader kernel for add - api::ShaderInfo kernel = VK_KERNEL(add); + api::ShaderInfo kernel = VK_KERNEL(binary_add_nobroadcast__test); // First, fill a and b with data fill_vtensor(a, data_a); fill_vtensor(b, data_b); // a + b -> c - record_arithmetic_op(api::context(), kernel, a, b, c, 1.0f); + record_arithmetic_op(api::context(), kernel, a, b, c); // Now d can be filled with data fill_vtensor(d, data_d); // c + d -> e - record_arithmetic_op(api::context(), kernel, c, d, e, 1.0f); + record_arithmetic_op(api::context(), kernel, c, d, e); // Extract data from e std::vector data_e(e.gpu_numel()); @@ -509,6 +337,104 @@ TEST_F(VulkanComputeAPITest, use_non_bound_textures_fails) { EXPECT_THROW(fill_vtensor(a, data_a), api::Error); } +TEST_F(VulkanComputeAPITest, tensor_reallocation_test) { + std::vector sizes = {4, 4, 1}; + vTensor a = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ true); + vTensor b = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ true); + vTensor c = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ true); + + execute_and_check_add(a, b, c, 3.0f, 5.0f); + + // Redo with new sizes + std::vector new_sizes = {4, 6, 3}; + a.reallocate(new_sizes); + b.reallocate(new_sizes); + c.reallocate(new_sizes); + + // Flush everything + api::context()->flush(); + + execute_and_check_add(a, b, c, 12.0f, 10.0f); +} + +TEST_F( + VulkanComputeAPITest, + tensor_reallocation_with_deferred_allocation_test) { + std::vector sizes = {8, 8, 8}; + vTensor a = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ false); + vTensor b = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ false); + vTensor c = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ false); + + api::MemoryAllocation a_mem = allocate_memory_for(a); + a.image().bind_allocation(a_mem); + api::MemoryAllocation b_mem = allocate_memory_for(b); + b.image().bind_allocation(b_mem); + api::MemoryAllocation c_mem = allocate_memory_for(c); + c.image().bind_allocation(c_mem); + + execute_and_check_add(a, b, c, 4.0f, 8.0f); + + std::vector> new_sizes_list = { + {4, 3, 5}, {4, 1, 7}, {8, 3, 2}, {8, 7, 2}}; + + for (auto& new_sizes : new_sizes_list) { + // Redo with new sizes + a.reallocate(new_sizes); + b.reallocate(new_sizes); + c.reallocate(new_sizes); + + // Flush everything + api::context()->flush(); + + a.image().bind_allocation(a_mem); + b.image().bind_allocation(b_mem); + c.image().bind_allocation(c_mem); + + execute_and_check_add( + a, b, c, float(new_sizes[1] + 4.5f), float(new_sizes[2] + 13.0f)); + } +} + +TEST_F(VulkanComputeAPITest, texture_virtual_resize) { + api::context()->set_cmd(/*reusable = */ true); + std::vector sizes = {8, 12, 12}; + vTensor a = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ true); + vTensor b = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ true); + vTensor c = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ true); + + DEFINE_STAGING_BUFFER_AND_RECORD_TO_GPU_FOR(a) + DEFINE_STAGING_BUFFER_AND_RECORD_TO_GPU_FOR(b) + + fill_staging(staging_buffer_a, 11.5f); + fill_staging(staging_buffer_b, 12.5f); + + record_arithmetic_op( + api::context(), VK_KERNEL(binary_add_nobroadcast__test), a, b, c); + + DEFINE_STAGING_BUFFER_AND_RECORD_FROM_GPU_FOR(c) + + submit_to_gpu(); + check_staging_buffer(staging_buffer_c, 24.0f); + + std::vector> new_sizes_list = { + {4, 2, 4}, {4, 3, 6}, {8, 12, 12}, {8, 1, 1}, {8, 11, 10}}; + + for (auto& new_sizes : new_sizes_list) { + a.virtual_resize(new_sizes); + b.virtual_resize(new_sizes); + c.virtual_resize(new_sizes); + + fill_staging(staging_buffer_a, float(new_sizes[1] + 1.5f), a.gpu_numel()); + fill_staging(staging_buffer_b, float(new_sizes[2] + 55.0f), b.gpu_numel()); + + submit_to_gpu(); + check_staging_buffer( + staging_buffer_c, + float(new_sizes[1] + new_sizes[2] + 56.5f), + c.gpu_numel()); + } +} + // // Compute Graph Tests //