|
| 1 | +// Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +// All rights reserved. |
| 3 | +// |
| 4 | +// This source code is licensed under the license found in the |
| 5 | +// LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +#include <torch/library.h> |
| 8 | +#include <torch/script.h> |
| 9 | +#include <torch/torch.h> |
| 10 | +#include <torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h> |
| 11 | +#include <torchao/experimental/kernels/cpu/parallel.h> |
| 12 | + |
| 13 | +template <int weight_nbit> |
| 14 | +at::Tensor pack_weights_cpu( |
| 15 | + const at::Tensor& weight_qvals, |
| 16 | + const at::Tensor& weight_scales, |
| 17 | + // TODO(T200095131): convert to int64_t when supported by AOTI |
| 18 | + // group_size is a meta tensor with size (group_size) |
| 19 | + const at::Tensor& group_size_tensor) { |
| 20 | + int64_t group_size = group_size_tensor.size(0); |
| 21 | + |
| 22 | + TORCH_CHECK( |
| 23 | + weight_qvals.dtype() == torch::kInt8, "weight_qvals must be int8"); |
| 24 | + TORCH_CHECK(weight_qvals.dim() == 2, "weight_qvals must be 2D"); |
| 25 | + |
| 26 | + // In PyTorch, weights are nxk in row-major format (with activations being |
| 27 | + // right-multiplied). |
| 28 | + // In kernel, activations are left-multiplied by kxn transposed |
| 29 | + // weights in column-major format. |
| 30 | + // Note the underlying data is the same in both cases |
| 31 | + int n = weight_qvals.size(0); |
| 32 | + int k = weight_qvals.size(1); |
| 33 | + |
| 34 | + TORCH_CHECK( |
| 35 | + weight_scales.dtype() == torch::kFloat32, |
| 36 | + "weight_scales must be float32"); |
| 37 | + TORCH_CHECK(weight_scales.dim() == 1, "weight_scales must be 1D"); |
| 38 | + TORCH_CHECK( |
| 39 | + weight_scales.size(0) == ((n * k) / group_size), |
| 40 | + "expected 1 scale per group"); |
| 41 | + |
| 42 | + using namespace torchao::operators::cpu::linear:: |
| 43 | + channelwise_8bit_activation_groupwise_lowbit_weight; |
| 44 | + |
| 45 | + auto ukernel_config = get_ukernel_config< |
| 46 | + weight_nbit, |
| 47 | + false /*has_weight_zeros*/, |
| 48 | + false /*has_bias*/, |
| 49 | + false /*has_clamp*/>(); |
| 50 | + auto pack_weight_tiling_params = get_default_pack_weight_data_tiling_params( |
| 51 | + ukernel_config, n, /*target_panels_per_thread=*/1); |
| 52 | + |
| 53 | + torchao::set_num_threads(torch::get_num_threads()); |
| 54 | + |
| 55 | + auto packed_weight_data_size = |
| 56 | + get_packed_weight_data_size(ukernel_config, n, k, group_size); |
| 57 | + auto options = torch::TensorOptions().dtype(torch::kInt8); |
| 58 | + |
| 59 | + at::Tensor packed_weights = torch::empty({packed_weight_data_size}, options); |
| 60 | + pack_weight_data_operator( |
| 61 | + ukernel_config, |
| 62 | + pack_weight_tiling_params, |
| 63 | + packed_weights.data_ptr<int8_t>(), |
| 64 | + n, |
| 65 | + k, |
| 66 | + group_size, |
| 67 | + weight_qvals.const_data_ptr<int8_t>(), |
| 68 | + weight_scales.const_data_ptr<float>(), |
| 69 | + /*weight_zeros=*/nullptr); |
| 70 | + |
| 71 | + return packed_weights; |
| 72 | +} |
| 73 | + |
| 74 | +template <int weight_nbit> |
| 75 | +at::Tensor pack_weights_meta( |
| 76 | + const at::Tensor& weight_qvals, |
| 77 | + const at::Tensor& weight_scales, |
| 78 | + // TODO(T200095131): convert to int64_t when supported by AOTI |
| 79 | + // group_size is a meta tensor with size (group_size) |
| 80 | + const at::Tensor& group_size_tensor) { |
| 81 | + int64_t group_size = group_size_tensor.size(0); |
| 82 | + |
| 83 | + int n = weight_qvals.size(0); |
| 84 | + int k = weight_qvals.size(1); |
| 85 | + |
| 86 | + using namespace torchao::operators::cpu::linear:: |
| 87 | + channelwise_8bit_activation_groupwise_lowbit_weight; |
| 88 | + |
| 89 | + auto ukernel_config = get_ukernel_config< |
| 90 | + weight_nbit, |
| 91 | + false /*has_weight_zeros*/, |
| 92 | + false /*has_bias*/, |
| 93 | + false /*has_clamp*/>(); |
| 94 | + |
| 95 | + auto packed_weight_data_size = |
| 96 | + get_packed_weight_data_size(ukernel_config, n, k, group_size); |
| 97 | + return torch::empty({packed_weight_data_size}).to("meta"); |
| 98 | +} |
| 99 | + |
| 100 | +template <int weight_nbit> |
| 101 | +at::Tensor linear_cpu( |
| 102 | + const at::Tensor& packed_weights, |
| 103 | + // TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to |
| 104 | + // int64_t when supported by AOTI Currently they are meta tensors with size |
| 105 | + // equal to the int they wrap |
| 106 | + const at::Tensor& n_tensor, |
| 107 | + const at::Tensor& k_tensor, |
| 108 | + const at::Tensor& group_size_tensor, |
| 109 | + const at::Tensor& activations) { |
| 110 | + int n = n_tensor.size(0); |
| 111 | + int k = k_tensor.size(0); |
| 112 | + int group_size = group_size_tensor.size(0); |
| 113 | + |
| 114 | + TORCH_CHECK( |
| 115 | + activations.dtype() == torch::kFloat32, "activations must be float32"); |
| 116 | + TORCH_CHECK(activations.dim() == 2, "activations must be 2D"); |
| 117 | + int m = activations.size(0); |
| 118 | + int k_ = activations.size(1); |
| 119 | + TORCH_CHECK(k == k_, "activation shape is incompatible with packed weights."); |
| 120 | + |
| 121 | + using namespace torchao::operators::cpu::linear:: |
| 122 | + channelwise_8bit_activation_groupwise_lowbit_weight; |
| 123 | + |
| 124 | + auto ukernel_config = get_ukernel_config< |
| 125 | + weight_nbit, |
| 126 | + false /*has_weight_zeros*/, |
| 127 | + false /*has_bias*/, |
| 128 | + false /*has_clamp*/>(); |
| 129 | + auto linear_tiling_params = get_default_linear_tiling_params( |
| 130 | + ukernel_config, |
| 131 | + m, |
| 132 | + n, |
| 133 | + /*target_tiles_per_thread=*/5); |
| 134 | + auto linear_scheduling_policy = |
| 135 | + LinearTileSchedulingPolicy::single_mc_parallel_nc; |
| 136 | + |
| 137 | + torchao::set_num_threads(torch::get_num_threads()); |
| 138 | + |
| 139 | + auto activation_data_buffer_size = get_activation_data_buffer_size( |
| 140 | + ukernel_config, |
| 141 | + linear_tiling_params, |
| 142 | + linear_scheduling_policy, |
| 143 | + m, |
| 144 | + k, |
| 145 | + group_size); |
| 146 | + std::vector<char> activation_data_buffer(activation_data_buffer_size); |
| 147 | + |
| 148 | + at::Tensor output_tensor = torch::empty({m, n}, torch::kFloat32); |
| 149 | + linear_operator( |
| 150 | + ukernel_config, |
| 151 | + linear_tiling_params, |
| 152 | + linear_scheduling_policy, |
| 153 | + activation_data_buffer.data(), |
| 154 | + output_tensor.data_ptr<float>(), |
| 155 | + m, |
| 156 | + n, |
| 157 | + k, |
| 158 | + group_size, |
| 159 | + packed_weights.const_data_ptr<int8_t>(), |
| 160 | + activations.const_data_ptr<float>(), |
| 161 | + /*bias=*/nullptr, |
| 162 | + // Clamp parameters are ignored because config is created from |
| 163 | + // has_clamp = false |
| 164 | + /*clamp_min=*/0.0, |
| 165 | + /*clamp_max=*/0.0); |
| 166 | + |
| 167 | + return output_tensor; |
| 168 | +} |
| 169 | + |
| 170 | +template <int weight_nbit> |
| 171 | +at::Tensor linear_meta( |
| 172 | + const at::Tensor& packed_weights, |
| 173 | + // TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to |
| 174 | + // int64_t when supported by AOTI |
| 175 | + // Currently they are meta tensors with size equal to the int they wrap |
| 176 | + const at::Tensor& n_tensor, |
| 177 | + const at::Tensor& k_tensor, |
| 178 | + const at::Tensor& group_size_tensor, |
| 179 | + const at::Tensor& activations) { |
| 180 | + int n = n_tensor.size(0); |
| 181 | + int k = k_tensor.size(0); |
| 182 | + |
| 183 | + int m = activations.size(0); |
| 184 | + int k_ = activations.size(1); |
| 185 | + TORCH_CHECK(k == k_, "activation shape is incompatible with packed weights."); |
| 186 | + return torch::empty({m, n}).to("meta"); |
| 187 | +} |
| 188 | + |
| 189 | +TORCH_LIBRARY(torchao, m) { |
| 190 | + m.def( |
| 191 | + "_pack_weights_3bit(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor"); |
| 192 | + m.def( |
| 193 | + "_linear_3bit(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor"); |
| 194 | + m.def( |
| 195 | + "_pack_weights_4bit(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor"); |
| 196 | + m.def( |
| 197 | + "_linear_4bit(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor"); |
| 198 | +} |
| 199 | + |
| 200 | +TORCH_LIBRARY_IMPL(torchao, CPU, m) { |
| 201 | + m.impl("_pack_weights_3bit", &pack_weights_cpu<3>); |
| 202 | + m.impl("_linear_3bit", &linear_cpu<3>); |
| 203 | + m.impl("_pack_weights_4bit", &pack_weights_cpu<4>); |
| 204 | + m.impl("_linear_4bit", &linear_cpu<4>); |
| 205 | +} |
| 206 | + |
| 207 | +TORCH_LIBRARY_IMPL(torchao, Meta, m) { |
| 208 | + m.impl("_pack_weights_3bit", &pack_weights_meta<3>); |
| 209 | + m.impl("_linear_3bit", &linear_meta<3>); |
| 210 | + m.impl("_pack_weights_4bit", &pack_weights_meta<4>); |
| 211 | + m.impl("_linear_4bit", &linear_meta<4>); |
| 212 | +} |
0 commit comments