From 3fdd8cab8c58db0be666f3454c41f73ad5964743 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Mon, 7 Apr 2025 10:42:17 -0700 Subject: [PATCH] [ET-VK][ez] Allow logit linear layer to be lowered to Vulkan Pull Request resolved: https://github.com/pytorch/executorch/pull/9918 ## Context Due to poor performance of Vulkan's int4 linear operator, the final logit layer of the transformer model was not being delegated to vulkan, and was instead quantized and executed with the XNNPACK delegate. However, with D72412950 / https://github.com/pytorch/executorch/pull/9883 decent performance can now be achieved with Vulkan/s int4 linear op. Therefore, the final logit layer can be lowered to Vulkan. ## Changes * Remove limit from `VkInt4WeightOnlyQuantizer` that was causing it to ignore the logit layer of the transformer * Do not apply XNNPACK partitioner and quantizer when lowering with Vulkan ghstack-source-id: 276566114 Differential Revision: [D72480177](https://our.internmc.facebook.com/intern/diff/D72480177/) --- .../vulkan/_passes/int4_weight_only_quantizer.py | 14 +------------- backends/vulkan/op_registry.py | 2 ++ examples/models/llama/export_llama_lib.py | 4 ---- .../models/llama/source_transformation/quantize.py | 11 ----------- 4 files changed, 3 insertions(+), 28 deletions(-) diff --git a/backends/vulkan/_passes/int4_weight_only_quantizer.py b/backends/vulkan/_passes/int4_weight_only_quantizer.py index 409cbb4b755..d0b73b8af0e 100644 --- a/backends/vulkan/_passes/int4_weight_only_quantizer.py +++ b/backends/vulkan/_passes/int4_weight_only_quantizer.py @@ -118,9 +118,6 @@ def _vk_replace_linear_int4( # Use custom vulkan linear layer as default linear_class: Type[torch.nn.Module] = VkWeightOnlyInt4Linear, copy_weights: bool = False, - # Serves the same purpose as `tensor_dim_limit` in - # executorch.backends.vulkan.partitioner.VulkanSupportedOperators - feature_limit: int = 16384, ): for name, child in module.named_children(): if isinstance(child, torch.nn.Linear) and ( @@ -131,8 +128,6 @@ def _vk_replace_linear_int4( if ( _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed - ) and ( - child.out_features < feature_limit and child.in_features < feature_limit ): new_linear = linear_class( child.in_features, @@ -175,7 +170,6 @@ def __init__( inner_k_tiles: Optional[int] = 8, device: torch.device = torch.device("cpu"), # noqa precision: torch.dtype = torch.float32, - feature_limit: int = 16384, ) -> None: super().__init__() assert inner_k_tiles in [2, 4, 8] @@ -186,9 +180,6 @@ def __init__( self.padding_allowed: bool = padding_allowed self.device: torch.device = device self.precision: torch.dtype = precision - # Serves the same purpose as `tensor_dim_limit` in - # executorch.backends.vulkan.partitioner.VulkanSupportedOperators - self.feature_limit = feature_limit @torch.no_grad() def _create_quantized_state_dict( @@ -197,10 +188,7 @@ def _create_quantized_state_dict( cur_state_dict = model.state_dict() for fqn, mod in model.named_modules(): # Add additional check to make sure features do not exceed feature limit - if isinstance(mod, torch.nn.Linear) and ( - mod.out_features < self.feature_limit - and mod.in_features < self.feature_limit - ): + if isinstance(mod, torch.nn.Linear): out_features = mod.out_features in_features = mod.in_features logging.info(f"linear: {fqn}, in={in_features}, out={out_features}") diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 97a428c77aa..026f1db9273 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -393,6 +393,7 @@ def register_int8_mm_op(features: OpFeatures): @update_features(exir_ops.edge.et_vk.linear_weight_int4.default) def register_int4_mm_op(features: OpFeatures): + features.buffer_impl = True features.texture_impl = TextureImplFeatures( uses_axis_map=False, valid_packed_dims={PackedDim.WIDTH}, @@ -401,6 +402,7 @@ def register_int4_mm_op(features: OpFeatures): features.optimal_storage = VkStorageType.TEXTURE_3D features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED features.handles_own_prepacking = True + features.skip_limits_check = {1} return features diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 8e6d4fefb0e..249a25f23c4 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -793,10 +793,6 @@ def _to_edge_and_lower_llama( # noqa: C901 args.enable_dynamic_shape, ) ) - # Apply XNNPACK after Vulkan so that undelegated ops can be accelerated by XNNPACK - partitioners.append( - get_xnnpack_partitioner(dynamic_quant_only_partitioner=False) - ) modelname = f"vulkan_{modelname}" # Need to remove asserts from the graph to prevent graph breaks diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index 2ef016de097..d51d4378705 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -206,17 +206,6 @@ def quantize( # noqa C901 q_group_size = 256 if group_size is None else group_size model = VkInt4WeightOnlyQuantizer(groupsize=q_group_size).quantize(model) - # Apply additional quantizer for linear layers that aren't lowered to Vulkan - # at the moment - from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer - - # 1. Quantize in checkpoint dtype. - model = Int8DynActInt4WeightQuantizer( - precision=checkpoint_torch_dtype, groupsize=q_group_size - ).quantize(model) - # 2. Set the computation dtype (what weights/acts dequantize to). - model = set_8da4w_computation_dtype(model, computation_torch_dtype) - return model else: raise Exception(f"Unrecognized quantize mode: {qmode}")