diff --git a/examples/models/llama2/CMakeLists.txt b/examples/models/llama2/CMakeLists.txt index fa3b7cff7e7..dd0a1c022c0 100644 --- a/examples/models/llama2/CMakeLists.txt +++ b/examples/models/llama2/CMakeLists.txt @@ -44,7 +44,6 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) set(TORCH_ROOT ${EXECUTORCH_ROOT}/third-party/pytorch) include(${EXECUTORCH_ROOT}/build/Utils.cmake) -include(${EXECUTORCH_ROOT}/build/Codegen.cmake) if(NOT PYTHON_EXECUTABLE) resolve_python_executable() @@ -120,25 +119,9 @@ else() target_link_options_shared_lib(portable_ops_lib) endif() -# quantized ops yaml file operation -merge_yaml( - FUNCTIONS_YAML ${CMAKE_CURRENT_SOURCE_DIR}/ops/quantized.yaml - FALLBACK_YAML ${EXECUTORCH_ROOT}/kernels/quantized/quantized.yaml - OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}) - -gen_selected_ops("${CMAKE_CURRENT_BINARY_DIR}/merged.yaml" "" "") -generate_bindings_for_kernels( - FUNCTIONS_YAML ${CMAKE_CURRENT_BINARY_DIR}/merged.yaml) -message("Generated files ${gen_command_sources}") - -# quantized_merge_ops_lib: Register quantized op kernels into the runtime -gen_operators_lib( - "quantized_merge_ops_lib" - KERNEL_LIBS quantized_kernels - DEPS executorch) -target_include_directories(quantized_merge_ops_lib PUBLIC ${_common_include_directories}) -target_link_options_shared_lib(quantized_merge_ops_lib) -list(APPEND link_libraries quantized_kernels quantized_merge_ops_lib) +# quantized_ops_lib: Register quantized op kernels into the runtime +target_link_options_shared_lib(quantized_ops_lib) +list(APPEND link_libraries quantized_kernels quantized_ops_lib) if(EXECUTORCH_BUILD_CUSTOM) target_link_options_shared_lib(custom_ops) diff --git a/examples/models/llama2/TARGETS b/examples/models/llama2/TARGETS index 9da7a26d6d7..96c9e8b6ffd 100644 --- a/examples/models/llama2/TARGETS +++ b/examples/models/llama2/TARGETS @@ -42,7 +42,6 @@ runtime.python_library( "//caffe2:torch", "//executorch/examples/models:model_base", "//executorch/examples/models/llama2:llama_transformer", - "//executorch/examples/models/llama2/ops:quantized_aot_lib", ], ) diff --git a/examples/models/llama2/ops/TARGETS b/examples/models/llama2/ops/TARGETS deleted file mode 100644 index 0fbbff56977..00000000000 --- a/examples/models/llama2/ops/TARGETS +++ /dev/null @@ -1,5 +0,0 @@ -load(":targets.bzl", "define_common_targets") - -oncall("ai_infra_mobile_platform") - -define_common_targets() diff --git a/examples/models/llama2/ops/quantized.yaml b/examples/models/llama2/ops/quantized.yaml deleted file mode 100644 index 6708510908f..00000000000 --- a/examples/models/llama2/ops/quantized.yaml +++ /dev/null @@ -1,11 +0,0 @@ -- func: llama_quantized::DEPRECATED_DO_NOT_USE_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!) - variants: function - kernels: - - arg_meta: null - kernel_name: torch::executor::quantized_embedding_byte_out - -- func: llama_quantized::DEPRECATED_DO_NOT_USE_embedding_byte.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!) - variants: function - kernels: - - arg_meta: null - kernel_name: torch::executor::quantized_embedding_byte_dtype_out diff --git a/examples/models/llama2/ops/quantized_ops.py b/examples/models/llama2/ops/quantized_ops.py deleted file mode 100644 index 0ad80233626..00000000000 --- a/examples/models/llama2/ops/quantized_ops.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Optional - -import torch -from torch.library import impl, impl_abstract - -# NOTE: this is a hacky way to get around the fact that we can't use quantized_decomposed::embedding_byte in exir directly in eager model. That op can be found under exir/passes/_quant_patterns_and_replacements.py. Ideally we should consolidate these 2 versions. -# This op share the same signature and C++ kernel implementation with quantized_decomposed::embedding_byte. -quantized_lib = torch.library.Library( - "llama_quantized", "DEF" -) # to not be confused with torch.ops.quantized.* ops. -quantized_lib.define( - "DEPRECATED_DO_NOT_USE_embedding_byte(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " - "int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor", -) - -quantized_lib.define( - "DEPRECATED_DO_NOT_USE_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " - "int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)", -) - -quantized_lib.define( - "DEPRECATED_DO_NOT_USE_embedding_byte.dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " - "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor", -) - -quantized_lib.define( - "DEPRECATED_DO_NOT_USE_embedding_byte.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " - "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", -) - - -def embedding_byte_weight_checks(weight, weight_scales, weight_zero_points): - assert weight.dtype in [ - torch.int8, - torch.uint8, - ], f"Expecting weights to be of dtype in [torch.int8, torch.uint8], but got {weight.dtype}" - assert ( - weight.dim() == 2 - ), f"Expecting weight tensor to have dim()==2, but found {weight.dim()}" - - assert weight_scales.dtype in [ - torch.float16, - torch.float32, - ], f"Expecting weight_scales to be of dtype in [torch.float16, torch.float32], but got {weight_scales.dtype}" - assert ( - weight_scales.dim() == 1 or weight_scales.dim() == 2 - ), f"Expecting weight_scales tensor to have rank 1 or 2, but found {weight_scales.dim()}" - assert weight_scales.size(0) == weight.size( - 0 - ), f"Expecting weight and scale tensor to have same number of rows, but found {weight.size()} and {weight_scales.size()}" - - assert ( - weight_zero_points is None or weight_zero_points.dtype == weight_scales.dtype - ), "Expecting weight_zero_points to be None or have same dtype as weight_scales" - assert ( - weight_zero_points is None or weight_zero_points.dim() == 1 - ), f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}" - assert weight_zero_points is None or weight_zero_points.size(0) == weight.size( - 0 - ), f"Expecting weight_zero_points tensor to be None or have same number of rows as weights, but found {weight.size()} and {weight_zero_points.size()}" - - -@impl( - quantized_lib, "DEPRECATED_DO_NOT_USE_embedding_byte", "CompositeExplicitAutograd" -) -def embedding_byte( - weight: torch.Tensor, - weight_scales: torch.Tensor, - weight_zero_points: Optional[torch.Tensor], - weight_quant_min: int, - weight_quant_max: int, - indices: torch.Tensor, -) -> torch.Tensor: - embedding_byte_weight_checks(weight, weight_scales, weight_zero_points) - group_size = weight.size(1) // ( - weight_scales.size(1) if weight_scales.dim() == 2 else 1 - ) - weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default( - weight, - weight_scales, - weight_zero_points, - weight_quant_min, - weight_quant_max, - weight.dtype, - group_size, - weight_scales.dtype, - ) - return torch.ops.aten.embedding.default(weight, indices) - - -@impl_abstract("llama_quantized::DEPRECATED_DO_NOT_USE_embedding_byte.out") -def embedding_byte_out_meta( - weight: torch.Tensor, - weight_scales: torch.Tensor, - weight_zero_points: Optional[torch.Tensor], - weight_quant_min: int, - weight_quant_max: int, - indices: torch.Tensor, - out: torch.Tensor, -) -> torch.Tensor: - return embedding_byte( - weight, - weight_scales, - weight_zero_points, - weight_quant_min, - weight_quant_max, - indices, - ) - - -@impl( - quantized_lib, - "DEPRECATED_DO_NOT_USE_embedding_byte.dtype", - "CompositeExplicitAutograd", -) -def embedding_byte_dtype( - weight: torch.Tensor, - weight_scales: torch.Tensor, - weight_zero_points: Optional[torch.Tensor], - weight_quant_min: int, - weight_quant_max: int, - indices: torch.Tensor, - *, - dtype: Optional[torch.dtype] = None, -) -> torch.Tensor: - embedding_byte_weight_checks(weight, weight_scales, weight_zero_points) - group_size = weight.size(1) // ( - weight_scales.size(1) if weight_scales.dim() == 2 else 1 - ) - weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default( - weight, - weight_scales, - weight_zero_points, - weight_quant_min, - weight_quant_max, - weight.dtype, - group_size, - dtype, - ) - return torch.ops.aten.embedding.default(weight, indices) - - -@impl_abstract("llama_quantized::DEPRECATED_DO_NOT_USE_embedding_byte.dtype_out") -def embedding_byte_dtype_out_meta( - weight: torch.Tensor, - weight_scales: torch.Tensor, - weight_zero_points: Optional[torch.Tensor], - weight_quant_min: int, - weight_quant_max: int, - indices: torch.Tensor, - *, - dtype: Optional[torch.dtype] = None, - out: torch.Tensor, -) -> torch.Tensor: - return embedding_byte_dtype( - weight, - weight_scales, - weight_zero_points, - weight_quant_min, - weight_quant_max, - indices, - dtype=dtype, - ) diff --git a/examples/models/llama2/ops/targets.bzl b/examples/models/llama2/ops/targets.bzl deleted file mode 100644 index b441773dc3b..00000000000 --- a/examples/models/llama2/ops/targets.bzl +++ /dev/null @@ -1,50 +0,0 @@ -load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") -load("@fbsource//xplat/executorch/codegen:codegen.bzl", "et_operator_library", "executorch_generated_lib") - -def define_common_targets(): - """Defines targets that should be shared between fbcode and xplat. - - The directory containing this targets.bzl file should also contain both - TARGETS and BUCK files that call this function. - """ - - runtime.python_library( - name = "quantized_aot_lib", - srcs = [ - "quantized_ops.py", - ], - visibility = [ - "//executorch/...", - "@EXECUTORCH_CLIENTS", - ], - deps = [ - "//caffe2:torch", - ], - ) - - runtime.export_file( - name = "quantized.yaml", - visibility = [ - "@EXECUTORCH_CLIENTS", - ], - ) - - et_operator_library( - name = "all_quantized_ops", - define_static_targets = True, - ops_schema_yaml_target = ":quantized.yaml", - ) - - executorch_generated_lib( - name = "generated_lib", - custom_ops_yaml_target = ":quantized.yaml", - define_static_targets = True, - visibility = [ - "//executorch/...", - "@EXECUTORCH_CLIENTS", - ], - deps = [ - ":all_quantized_ops", - "//executorch/kernels/quantized:quantized_operators", - ], - ) diff --git a/examples/models/llama2/quantize.py b/examples/models/llama2/quantize.py index a1acc7695f6..dadd0b31f25 100644 --- a/examples/models/llama2/quantize.py +++ b/examples/models/llama2/quantize.py @@ -9,7 +9,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .ops.quantized_ops import * # noqa +from executorch.exir.passes._quant_patterns_and_replacements import ( # noqa + quantized_decomposed_lib, +) try: @@ -377,7 +379,7 @@ def __init__( @torch.no_grad() def forward(self, indices: torch.Tensor) -> torch.Tensor: - return torch.ops.llama_quantized.DEPRECATED_DO_NOT_USE_embedding_byte.dtype( + return torch.ops.quantized_decomposed.embedding_byte.dtype( self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype ) diff --git a/examples/models/llama2/runner/targets.bzl b/examples/models/llama2/runner/targets.bzl index 2a0dbe8dab0..1bf28ad1ccb 100644 --- a/examples/models/llama2/runner/targets.bzl +++ b/examples/models/llama2/runner/targets.bzl @@ -4,9 +4,9 @@ def _get_operator_lib(aten = False): if aten: return ["//executorch/kernels/aten:generated_lib"] elif runtime.is_oss: - return ["//executorch/kernels/portable:generated_lib", "//executorch/examples/models/llama2/custom_ops:custom_ops", "//executorch/examples/models/llama2/ops:generated_lib"] + return ["//executorch/kernels/portable:generated_lib", "//executorch/examples/models/llama2/custom_ops:custom_ops"] else: - return ["//executorch/configurations:optimized_native_cpu_ops", "//executorch/examples/models/llama2/custom_ops:custom_ops", "//executorch/examples/models/llama2/ops:generated_lib"] + return ["//executorch/configurations:optimized_native_cpu_ops", "//executorch/examples/models/llama2/custom_ops:custom_ops"] def define_common_targets(): for aten in (True, False): diff --git a/exir/passes/_quant_patterns_and_replacements.py b/exir/passes/_quant_patterns_and_replacements.py index 267d934dae6..bf06ce37c5c 100644 --- a/exir/passes/_quant_patterns_and_replacements.py +++ b/exir/passes/_quant_patterns_and_replacements.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import copy -from typing import Callable, List, Tuple +from typing import Callable, List, Optional, Tuple import torch from executorch.exir.dialects._ops import bind_pattern_to_op, ops as exir_ops @@ -15,6 +15,7 @@ ) from torch import fx from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib +from torch.library import impl, impl_abstract __all__ = [ @@ -34,6 +35,143 @@ "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor", ) +quantized_decomposed_lib.define( + "embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " + "int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)", +) + +quantized_decomposed_lib.define( + "embedding_byte.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " + "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", +) + + +def embedding_byte_weight_checks(weight, weight_scales, weight_zero_points): + assert weight.dtype in [ + torch.int8, + torch.uint8, + ], f"Expecting weights to be of dtype in [torch.int8, torch.uint8], but got {weight.dtype}" + assert ( + weight.dim() == 2 + ), f"Expecting weight tensor to have dim()==2, but found {weight.dim()}" + + assert weight_scales.dtype in [ + torch.float16, + torch.float32, + ], f"Expecting weight_scales to be of dtype in [torch.float16, torch.float32], but got {weight_scales.dtype}" + assert ( + weight_scales.dim() == 1 or weight_scales.dim() == 2 + ), f"Expecting weight_scales tensor to have rank 1 or 2, but found {weight_scales.dim()}" + assert weight_scales.size(0) == weight.size( + 0 + ), f"Expecting weight and scale tensor to have same number of rows, but found {weight.size()} and {weight_scales.size()}" + + assert ( + weight_zero_points is None or weight_zero_points.dtype == weight_scales.dtype + ), "Expecting weight_zero_points to be None or have same dtype as weight_scales" + assert ( + weight_zero_points is None or weight_zero_points.dim() == 1 + ), f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}" + assert weight_zero_points is None or weight_zero_points.size(0) == weight.size( + 0 + ), f"Expecting weight_zero_points tensor to be None or have same number of rows as weights, but found {weight.size()} and {weight_zero_points.size()}" + + +@impl(quantized_decomposed_lib, "embedding_byte", "CompositeExplicitAutograd") +def embedding_byte( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: Optional[torch.Tensor], + weight_quant_min: int, + weight_quant_max: int, + indices: torch.Tensor, +) -> torch.Tensor: + embedding_byte_weight_checks(weight, weight_scales, weight_zero_points) + group_size = weight.size(1) // ( + weight_scales.size(1) if weight_scales.dim() == 2 else 1 + ) + weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default( + weight, + weight_scales, + weight_zero_points, + weight_quant_min, + weight_quant_max, + weight.dtype, + group_size, + weight_scales.dtype, + ) + return torch.ops.aten.embedding.default(weight, indices) + + +@impl_abstract("quantized_decomposed::embedding_byte.out") +def embedding_byte_out_meta( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: Optional[torch.Tensor], + weight_quant_min: int, + weight_quant_max: int, + indices: torch.Tensor, + out: torch.Tensor, +) -> torch.Tensor: + return embedding_byte( + weight, + weight_scales, + weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + ) + + +@impl(quantized_decomposed_lib, "embedding_byte.dtype", "CompositeExplicitAutograd") +def embedding_byte_dtype( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: Optional[torch.Tensor], + weight_quant_min: int, + weight_quant_max: int, + indices: torch.Tensor, + dtype: Optional[torch.dtype], +) -> torch.Tensor: + embedding_byte_weight_checks(weight, weight_scales, weight_zero_points) + group_size = weight.size(1) // ( + weight_scales.size(1) if weight_scales.dim() == 2 else 1 + ) + weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default( + weight, + weight_scales, + weight_zero_points, + weight_quant_min, + weight_quant_max, + weight.dtype, + group_size, + dtype, + ) + return torch.ops.aten.embedding.default(weight, indices) + + +@impl_abstract("quantized_decomposed::embedding_byte.dtype_out") +def embedding_byte_dtype_out_meta( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: Optional[torch.Tensor], + weight_quant_min: int, + weight_quant_max: int, + indices: torch.Tensor, + dtype: Optional[torch.dtype], + out: torch.Tensor, +) -> torch.Tensor: + return embedding_byte_dtype( + weight, + weight_scales, + weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + dtype, + ) + + quantized_decomposed_lib.define( "mixed_mm(Tensor input, Tensor weight, Tensor weight_scales, Tensor? weight_zero_points) -> Tensor", ) diff --git a/kernels/quantized/targets.bzl b/kernels/quantized/targets.bzl index f907ed557ae..2c951fa4a7b 100644 --- a/kernels/quantized/targets.bzl +++ b/kernels/quantized/targets.bzl @@ -9,23 +9,54 @@ def define_common_targets(): ], ) + # Excluding embedding_byte ops because we choose to define them + # in python separately, mostly to be easy to share with oss. et_operator_library( - name = "all_quantized_ops", - ops_schema_yaml_target = ":quantized.yaml", + name = "quantized_ops_need_aot_registration", + ops = [ + "quantized_decomposed::add.out", + "quantized_decomposed::choose_qparams.Tensor_out", + "quantized_decomposed::dequantize_per_channel.out", + "quantized_decomposed::dequantize_per_tensor.out", + "quantized_decomposed::dequantize_per_tensor.Tensor_out", + "quantized_decomposed::mixed_linear.out", + "quantized_decomposed::mixed_mm.out", + "quantized_decomposed::quantize_per_channel.out", + "quantized_decomposed::quantize_per_tensor.out", + "quantized_decomposed::quantize_per_tensor.Tensor_out", + ], define_static_targets = True, ) # lib used to register quantized ops into EXIR + exir_custom_ops_aot_lib( + name = "custom_ops_generated_lib", + yaml_target = ":quantized.yaml", + visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"], + kernels = [":quantized_operators_aten"], + deps = [ + ":quantized_ops_need_aot_registration", + ], + ) + + # lib used to register quantized ops into EXIR + # TODO: merge this with custom_ops_generated_lib exir_custom_ops_aot_lib( name = "aot_lib", yaml_target = ":quantized.yaml", visibility = ["//executorch/..."], kernels = [":quantized_operators_aten"], deps = [ - ":all_quantized_ops", + ":quantized_ops_need_aot_registration", ], ) + et_operator_library( + name = "all_quantized_ops", + ops_schema_yaml_target = ":quantized.yaml", + define_static_targets = True, + ) + for aten_mode in (True, False): aten_suffix = "_aten" if aten_mode else "" @@ -49,6 +80,7 @@ def define_common_targets(): ], custom_ops_yaml_target = ":quantized.yaml", custom_ops_aten_kernel_deps = [":quantized_operators_aten"] if aten_mode else [], + custom_ops_requires_aot_registration = False, aten_mode = aten_mode, visibility = [ "//executorch/...", diff --git a/shim/xplat/executorch/codegen/codegen.bzl b/shim/xplat/executorch/codegen/codegen.bzl index 42cea1ae35d..3de73770e26 100644 --- a/shim/xplat/executorch/codegen/codegen.bzl +++ b/shim/xplat/executorch/codegen/codegen.bzl @@ -332,6 +332,7 @@ def executorch_generated_lib( define_static_targets = False, custom_ops_aten_kernel_deps = [], custom_ops_requires_runtime_registration = True, + custom_ops_requires_aot_registration = True, visibility = [], aten_mode = False, manual_registration = False, @@ -536,7 +537,7 @@ def executorch_generated_lib( platforms = platforms, ) - if custom_ops_yaml_target: + if custom_ops_yaml_target and custom_ops_requires_aot_registration: exir_custom_ops_aot_lib( name = "custom_ops_" + name, yaml_target = custom_ops_yaml_target,