From a17f098a908d9daf025784931c283c710819dc71 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Wed, 23 Apr 2025 13:39:11 -0700 Subject: [PATCH 1/2] [ET-VK] Enable int8 tiled compute shader to be used with buffer tensors Pull Request resolved: https://github.com/pytorch/executorch/pull/10302 ## Context As title. Allow the optimized int8 tiled compute shader to be usable for buffer-backed tensors as well. ## Changes * Generate buffer variants for the int8 linear tiled shader * Force the scales tensor to always be a buffer to reduce the number of shader variants that need to be generated. * Generate an additional variant that computes only 1 output row * Do not require output rows to be an exact multiple of 4 or 6 to use the tiled implementation ghstack-source-id: 279878372 Differential Revision: [D73276277](https://our.internmc.facebook.com/intern/diff/D73276277/) --- .../graph/ops/glsl/q_8w_linear_tiled.glsl | 34 +++++++------ .../graph/ops/glsl/q_8w_linear_tiled.yaml | 28 ++++++++--- .../graph/ops/impl/QuantizedLinearInt8.cpp | 50 +++++++++++++------ backends/vulkan/test/op_tests/cases.py | 1 + 4 files changed, 75 insertions(+), 38 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl index c3bd9f41af9..8a8670b4bb3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl @@ -17,17 +17,17 @@ ${define_required_extensions(DTYPE)} -$if STORAGE == "buffer": +$if WEIGHT_STORAGE == "buffer": ${define_required_extensions("int8")} #extension GL_EXT_control_flow_attributes : require layout(std430) buffer; -${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_weight", "int8", STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_scales", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array=False)} layout(push_constant) uniform restrict Block { @@ -50,10 +50,10 @@ void main() { VEC4_T b[4]; VEC4_T c[TILE_ROWS]; - $if STORAGE == "buffer": + $if SCALES_STORAGE == "buffer": const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]); $else: - const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec3(out_col >> 2, 0, 0), 0)); + const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2(out_col >> 2, 0), 0)); [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { c[i] = VEC4_T(0.0); @@ -62,30 +62,32 @@ void main() { for (int pos = 0; pos < in_sizes.x; pos += 4) { // Preload weight tensor [[unroll]] for (int i = 0; i < 4; i++) { - $if STORAGE == "buffer": - b[i] = t_weight[((pos + i) * B_sizes.x + out_col) >> 2]; + $if WEIGHT_STORAGE == "buffer": + b[i] = t_weight[((pos + i) * out_sizes.x + out_col) >> 2]; $else: - b[i] = VEC4_T(texelFetch(t_weight, ivec3(out_col >> 2, pos + i, 0), 0)); + b[i] = VEC4_T(texelFetch(t_weight, ivec2(out_col >> 2, pos + i), 0)); } // Preload input tensor [[unroll]] for (int i = 0; i < TILE_ROWS; i++) { - $if STORAGE == "buffer": - a[i] = t_in[((out_row + i) * in_sizes.x + (pos)) >> 2]; + $if IN_STORAGE == "buffer": + a[i] = t_in[((out_row + i) * in_sizes.x + pos) >> 2]; $else: a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0)); } - // Compute partial output + // Accumulate output [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { c[i] += a[i].x * b[0] + a[i].y * b[1] + a[i].z * b[2] + a[i].w * b[3]; } } - // Store output tensor + // Store to output tensor [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { - $if STORAGE == "buffer": - t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales; + $if OUT_STORAGE == "buffer": + if (out_row + i < out_sizes.y) { + t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales; + } $else: imageStore(t_out, ivec3(out_col >> 2, out_row + i, 0), c[i] * scales); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml index b01af47e179..1e8a5e1fe7d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml @@ -7,12 +7,26 @@ q_8w_linear_tiled: parameter_names_with_default_values: DTYPE: float - STORAGE: texture3d + IN_STORAGE: texture3d + OUT_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + SCALES_STORAGE: texture2d TILE_ROWS: 4 + generate_variant_forall: + TILE_ROWS: + - VALUE: 1 + SUFFIX: o4x1 + - VALUE: 4 + SUFFIX: o4x4 + - VALUE: 6 + SUFFIX: o4x6 shader_variants: - - NAME: q_8w_linear_tiled_o4x4_texture3d_float - STORAGE: texture3d - TILE_ROWS: 4 - - NAME: q_8w_linear_tiled_o4x6_texture3d_float - STORAGE: texture3d - TILE_ROWS: 6 + - NAME: q_8w_linear_tiled_texture3d_texture3d_texture2d_texture2d_float + - NAME: q_8w_linear_tiled_buffer_buffer_texture2d_texture2d_float + IN_STORAGE: buffer + OUT_STORAGE: buffer + - NAME: q_8w_linear_tiled_buffer_buffer_buffer_buffer_float + IN_STORAGE: buffer + OUT_STORAGE: buffer + WEIGHT_STORAGE: buffer + SCALES_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp index 5054b2e5e9c..64c2d202529 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp @@ -146,28 +146,53 @@ void add_q_8w_linear_tiled_node( const ValueRef q_mat2_data, const ValueRef scales_data, const ValueRef out) { - utils::StorageType stype = graph.storage_type_of(out); + utils::StorageType q_mat2_storage = utils::kTexture2D; + + uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); + std::vector qmat2_orig_sizes = graph.sizes_of(q_mat2_data); + const int64_t ndim = graph.dim_of(q_mat2_data); + const int64_t K = qmat2_orig_sizes.at(ndim - 1); + const int64_t N = qmat2_orig_sizes.at(ndim - 2); + + if (N > max_extent * 4 || K > max_extent) { + q_mat2_storage = utils::kBuffer; + } + ValueRef q_mat2 = prepack_standard_hw_transposed( - graph, q_mat2_data, stype, utils::kWidthPacked); + graph, q_mat2_data, q_mat2_storage, utils::kWidthPacked); + + utils::StorageType scales_storage = utils::kTexture2D; + if (N > max_extent) { + scales_storage = utils::kBuffer; + } ValueRef scales = - prepack_standard(graph, scales_data, stype, utils::kWidthPacked); + prepack_standard(graph, scales_data, scales_storage, utils::kWidthPacked); std::string kernel_name = "q_8w_linear_tiled"; kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(mat1)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(q_mat2)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(scales)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + std::vector mat1_sizes = graph.sizes_of(mat1); const int64_t M = utils::val_at(-2, mat1_sizes); int out_tile_nrows = 4; if (M % 6 == 0) { kernel_name += "_o4x6"; out_tile_nrows = 6; + } else if (M % 4 == 0) { + kernel_name += "_o4x4"; + out_tile_nrows = 4; + } else if (M % 1 == 0) { + kernel_name += "_o4x1"; + out_tile_nrows = 1; } else { kernel_name += "_o4x4"; out_tile_nrows = 4; } - add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - utils::uvec3 global_wg_size = graph.logical_limits_of(out); global_wg_size[1] = global_wg_size[1] / out_tile_nrows; @@ -209,18 +234,13 @@ bool can_use_tiled_impl( if (graph.size_at(-1, mat1) % 4 != 0) { return false; } - // Check that M is a multiple of 4 or 6 - if (graph.size_at(-2, mat1) % 4 != 0 && - graph.size_at(-2, mat1) % 6 != 0) { - return false; - } - // Check that the storage type is texture - // TODO(ssjia): Add support for buffer storage in the tiled impl - if (graph.storage_type_of(out) != utils::kTexture3D) { + // Check that N is a multiple of 4 + if (graph.size_at(-1, out) % 4 != 0) { return false; } // Check that the packed dim is the width dim - if (graph.packed_dim_of(mat1) != WHCN::kWidthDim) { + if (graph.packed_dim_of(mat1) != WHCN::kWidthDim && + graph.packed_dim_of(out) != WHCN::kWidthDim) { return false; } // Check that no special axis mapping is used for the input diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index f97b2c51370..525f74609a6 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -152,6 +152,7 @@ def get_linear_inputs(): @register_test_suite("aten._weight_int8pack_mm.default") def get_weight_int8pack_mm_inputs(): MKN_list = [ + [3, 480, 256], [6, 480, 256], [6, 256, 1024], [6, 1024, 256], From 02a11152dbf212b00bae87892cf560b886ed1816 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Wed, 23 Apr 2025 13:58:42 -0700 Subject: [PATCH 2/2] [ET-VK] Add coop shader for int8 linear Pull Request resolved: https://github.com/pytorch/executorch/pull/10304 Title says it all! ## Changes Add some utility functions to `ComputeGraph` to get the device name and look for substrings within the device name. Apply co-operative shader for vector * matrix computations, except for Adreno 702 for which it performs worse as determined by experimentation. ghstack-source-id: 279884141 Differential Revision: [D73279548](https://our.internmc.facebook.com/intern/diff/D73279548/) --- .../vulkan/runtime/graph/ComputeGraph.cpp | 5 + backends/vulkan/runtime/graph/ComputeGraph.h | 9 ++ .../graph/ops/glsl/q_8w_linear_coop.glsl | 122 ++++++++++++++++++ .../graph/ops/glsl/q_8w_linear_coop.yaml | 28 ++++ .../graph/ops/impl/QuantizedLinearInt8.cpp | 22 +++- backends/vulkan/runtime/vk_api/Adapter.h | 9 ++ backends/vulkan/test/op_tests/cases.py | 3 + 7 files changed, 196 insertions(+), 2 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.yaml diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 5109e198206..7fde7e04f91 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -179,6 +179,11 @@ utils::GPUMemoryLayout ComputeGraph::suggested_memory_layout( return utils::kChannelsPacked; } +bool ComputeGraph::device_name_contains(const char* substr) { + return context_->adapter_ptr()->device_name().find(substr) != + std::string::npos; +} + void ComputeGraph::check_no_active_value_ptrs() { VK_CHECK_COND( values_in_use_ == 0, diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 3d46aa327b8..d09597ad778 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -443,6 +443,15 @@ class ComputeGraph final { utils::GPUMemoryLayout suggested_memory_layout( const std::vector& sizes); + inline bool device_is_adreno() { + return context_->adapter_ptr()->device_type() == vkapi::DeviceType::ADRENO; + } + const std::string& device_name() { + return context()->adapter_ptr()->device_name(); + } + + bool device_name_contains(const char* substr); + // // Graph Building // diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.glsl new file mode 100644 index 00000000000..c8ccbacffc1 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.glsl @@ -0,0 +1,122 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} +#define VEC4_T ${buffer_gvec_type(DTYPE, 4)} + +#define TILE_ROWS ${TILE_ROWS} + +#define NGROUPS 8 +#define NWORKERS 8 + +${define_required_extensions(DTYPE)} + +$if WEIGHT_STORAGE == "buffer": + ${define_required_extensions("int8")} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array=False)} + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; + ivec4 in_sizes; + ivec4 weight_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +shared VEC4_T partial_c[NGROUPS][NWORKERS][TILE_ROWS]; + +void main() { + const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS; + const uint out_col = gl_GlobalInvocationID.x << 2; + + const int gid = int(gl_LocalInvocationID.x); // group id + const int wid = int(gl_LocalInvocationID.z); // worker id + + if (out_col >= out_sizes.x || out_row >= out_sizes.y) { + return; + } + + VEC4_T a[TILE_ROWS]; + VEC4_T b[4]; + VEC4_T local_c[TILE_ROWS]; + + [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { + local_c[i] = VEC4_T(0.0); + } + + $if SCALES_STORAGE == "buffer": + const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]); + $else: + const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2(out_col >> 2, 0), 0)); + + for (int pos = 4 * wid; pos < in_sizes.x; pos += (4 * NWORKERS)) { + // Preload t_weight + [[unroll]] for (int i = 0; i < 4; i++) { + $if WEIGHT_STORAGE == "buffer": + b[i] = t_weight[((pos + i) * weight_sizes.x + out_col) >> 2]; + $else: + b[i] = VEC4_T(texelFetch(t_weight, ivec2(out_col >> 2, pos + i), 0)); + } + // Preload t_in + for (int i = 0; i < TILE_ROWS; i++) { + $if IN_STORAGE == "buffer": + a[i] = t_in[((out_row + i) * in_sizes.x + pos) >> 2]; + $else: + a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0)); + } + + // Accumulate partial output + [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { + local_c[i] += a[i].x * b[0] + + a[i].y * b[1] + + a[i].z * b[2] + + a[i].w * b[3]; + } + } + + [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { + partial_c[gid][wid][i] = local_c[i]; + } + + memoryBarrierShared(); + barrier(); + + if (wid != 0) { + return; + } + + VEC4_T c[TILE_ROWS]; + + for (int row = 0; row < TILE_ROWS; ++row) { + c[row] = VEC4_T(0.0); + [[unroll]] for (int worker = 0; worker < NWORKERS; ++worker) { + c[row] += partial_c[gid][worker][row]; + } + } + + [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { + $if OUT_STORAGE == "buffer": + if (out_row + i < out_sizes.y) { + t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales; + } + $else: + imageStore(t_out, ivec3(out_col >> 2, out_row + i, 0), c[i] * scales); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.yaml new file mode 100644 index 00000000000..5daf28132e6 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.yaml @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q_8w_linear_coop: + parameter_names_with_default_values: + DTYPE: float + IN_STORAGE: texture3d + OUT_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + SCALES_STORAGE: texture2d + TILE_ROWS: 4 + generate_variant_forall: + TILE_ROWS: + - VALUE: 1 + SUFFIX: o4x1 + shader_variants: + - NAME: q_8w_linear_coop_texture3d_texture3d_texture2d_texture2d_float + - NAME: q_8w_linear_coop_buffer_buffer_texture2d_texture2d_float + IN_STORAGE: buffer + OUT_STORAGE: buffer + - NAME: q_8w_linear_coop_buffer_buffer_buffer_buffer_float + IN_STORAGE: buffer + OUT_STORAGE: buffer + WEIGHT_STORAGE: buffer + SCALES_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp index 64c2d202529..4a10f469be0 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp @@ -142,6 +142,7 @@ void add_q_8w_linear_node( void add_q_8w_linear_tiled_node( ComputeGraph& graph, + const bool use_coop_algorithm, const ValueRef mat1, const ValueRef q_mat2_data, const ValueRef scales_data, @@ -168,7 +169,8 @@ void add_q_8w_linear_tiled_node( ValueRef scales = prepack_standard(graph, scales_data, scales_storage, utils::kWidthPacked); - std::string kernel_name = "q_8w_linear_tiled"; + std::string kernel_name = + use_coop_algorithm ? "q_8w_linear_coop" : "q_8w_linear_tiled"; kernel_name.reserve(kShaderNameReserve); add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_storage_type_suffix(kernel_name, graph.storage_type_of(mat1)); @@ -197,6 +199,9 @@ void add_q_8w_linear_tiled_node( global_wg_size[1] = global_wg_size[1] / out_tile_nrows; utils::uvec3 local_wg_size{64, 1, 1}; + if (use_coop_algorithm) { + local_wg_size = {8, 1, 8}; + } graph.execute_nodes().emplace_back(new DispatchNode( graph, @@ -257,13 +262,26 @@ bool can_use_tiled_impl( return true; } +bool can_use_coop_impl(ComputeGraph& graph, const ValueRef mat1) { + // Do not use coop algorithm for Adreno 702; manual experimentation shows that + // it performs worse than the tiled algorithm. + // TODO(ssjia): Determine a more robust heuristic to determine when the coop + // algorithm should be used, instead of depending on specific device identity. + if (graph.device_is_adreno() && graph.device_name_contains("702")) { + return false; + } + // Check that the computation is vector * matrix + return (graph.size_at(-2, mat1) == 1); +} + void weight_int8pack_mm( ComputeGraph& graph, const std::vector& args) { check_q_8w_linear_args(graph, args[0], args[1], args[2], args[3]); if (can_use_tiled_impl(graph, args[0], args[1], args[2], args[3])) { + bool use_coop_algorithm = can_use_coop_impl(graph, args[0]); return add_q_8w_linear_tiled_node( - graph, args[0], args[1], args[2], args[3]); + graph, use_coop_algorithm, args[0], args[1], args[2], args[3]); } return add_q_8w_linear_node(graph, args[0], args[1], args[2], args[3]); } diff --git a/backends/vulkan/runtime/vk_api/Adapter.h b/backends/vulkan/runtime/vk_api/Adapter.h index d73ed1bc0ce..8ae61095be8 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.h +++ b/backends/vulkan/runtime/vk_api/Adapter.h @@ -122,6 +122,15 @@ class Adapter final { return physical_device_.timestamp_period; } + // Device Identity + inline const std::string& device_name() const { + return physical_device_.device_name; + } + + inline vkapi::DeviceType device_type() const { + return physical_device_.device_type; + } + // Queue Management Queue request_queue(); diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 525f74609a6..4a12f16bbf9 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -152,6 +152,9 @@ def get_linear_inputs(): @register_test_suite("aten._weight_int8pack_mm.default") def get_weight_int8pack_mm_inputs(): MKN_list = [ + [1, 480, 256], + [1, 1024, 1024], + [1, 1024, 256], [3, 480, 256], [6, 480, 256], [6, 256, 1024],