Skip to content

Commit f4c8109

Browse files
metascroyHDCharles
authored andcommitted
Lowbit custom torch op
Differential Revision: D61896155 Pull Request resolved: #780
1 parent ed8c423 commit f4c8109

File tree

5 files changed

+516
-0
lines changed

5 files changed

+516
-0
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
project(examples)
8+
9+
cmake_minimum_required(VERSION 3.19)
10+
set(CMAKE_CXX_STANDARD 17)
11+
set(CMAKE_BUILD_TYPE Release)
12+
13+
add_compile_options("-Wall" "-Werror")
14+
15+
include(CMakePrintHelpers)
16+
message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}")
17+
include_directories(${TORCHAO_LIBRARIES})
18+
19+
add_library(
20+
torchao_dep
21+
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp
22+
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp
23+
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp
24+
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp
25+
)
26+
27+
include(FetchContent)
28+
FetchContent_Declare(pthreadpool
29+
GIT_REPOSITORY https://github.com/Maratyszcza/pthreadpool.git
30+
GIT_TAG master)
31+
FetchContent_MakeAvailable(
32+
pthreadpool)
33+
34+
find_package(Torch REQUIRED)
35+
message("TORCH_INCLUDE_DIRS: ${TORCH_INCLUDE_DIRS}")
36+
include_directories("${TORCH_INCLUDE_DIRS}")
37+
38+
add_library(torch_custom_op SHARED torch_custom_op.cpp)
39+
target_link_libraries(torch_custom_op PRIVATE "${TORCH_LIBRARIES}")
40+
target_link_libraries(torch_custom_op PRIVATE torchao_dep)
41+
target_compile_definitions(torch_custom_op PRIVATE TORCHAO_PARALLEL_PTHREADPOOL=1)
42+
target_link_libraries(torch_custom_op PRIVATE pthreadpool)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#!/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
9+
export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../../..
10+
11+
export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')"
12+
echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}"
13+
export CMAKE_OUT=/tmp/cmake-out/torch_ao/examples/torch_custom_op
14+
cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
15+
-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \
16+
-S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \
17+
-B ${CMAKE_OUT}
18+
cmake --build ${CMAKE_OUT}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from torch_custom_op import quantize, replace_linear_with_quantized_linear
8+
import torch
9+
import copy
10+
11+
group_size = 16
12+
m = 1
13+
n = 4096
14+
k = 4096
15+
nbit = 4
16+
n_layers = 10
17+
18+
print("Creating random model")
19+
layers = [torch.nn.Linear(k, n, bias=False) for _ in range(n_layers)]
20+
model = torch.nn.Sequential(*layers)
21+
model = model.eval()
22+
23+
print("Quantizing random model")
24+
quantized_model = copy.deepcopy(model)
25+
quantized_model = quantized_model.eval()
26+
replace_linear_with_quantized_linear(quantized_model, kwargs={"group_size": group_size, "nbit": nbit})
27+
28+
print("Creating random activations")
29+
activations = torch.randn(m, k, dtype=torch.float32)
30+
31+
print("Exporting quantized model")
32+
exported = torch.export.export(quantized_model, (activations,))
33+
34+
print("Using torch.compile on quantized model")
35+
quantized_model_compiled = torch.compile(quantized_model)
36+
with torch.no_grad():
37+
quantized_model_compiled(activations)
38+
39+
print("Compiling quantized model with AOTI")
40+
torch._export.aot_compile(
41+
quantized_model,
42+
(activations,),
43+
options={"aot_inductor.output_path": "/tmp/torch_custom_op_example_model.so"},
44+
)
45+
46+
print("Running AOTI")
47+
fn = torch._export.aot_load("/tmp/torch_custom_op_example_model.so", "cpu")
48+
fn(activations)
49+
50+
51+
print("Checking correctness on layer 0")
52+
53+
rtol=1e-05
54+
55+
# default is 1e-8, but PyTorch and C++ (and ARM neon) have different rounding
56+
# conventions for ties (PyTorch rounds half to even and C++ rounds half to odd)
57+
# TODO(T200109708): address this
58+
atol=1e-05
59+
60+
linear = model[0]
61+
quantized_linear = quantized_model[0]
62+
weight_qvals, weight_scales = quantize(linear.weight, group_size, quantized_linear.nbit, scale_only=True)
63+
64+
activation_qvals, activations_scales, activations_zeros = quantize(activations, k, 8, False)
65+
activations_dequantized = activations_scales * (activation_qvals - activations_zeros)
66+
weights_dequantized = (weight_qvals.reshape(-1, group_size) * weight_scales.reshape(-1, 1)).reshape(n, k)
67+
68+
with torch.no_grad():
69+
result = quantized_linear(activations)
70+
expected_result = torch.matmul(activations_dequantized, weights_dequantized.transpose(1, 0))
71+
non_quantized_result = linear(activations)
72+
73+
if not (torch.allclose(result, expected_result, rtol=rtol, atol=atol)):
74+
rand_idxs = torch.randint(0, result.shape[1], (5,))
75+
print("rand_idxs: ", rand_idxs)
76+
print("kernel_result[rand_idxs]: ", result[0][rand_idxs])
77+
print("expected_result[rand_idxs]: ", expected_result[0][rand_idxs])
78+
assert False
79+
else:
80+
print("Correctness check passed")
81+
82+
print("kernel_result[0:5]: ", result[0][0:5])
83+
print("non_quantized_result[0:5]: ", non_quantized_result[0][0:5])
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include <torch/library.h>
8+
#include <torch/script.h>
9+
#include <torch/torch.h>
10+
#include <torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h>
11+
#include <torchao/experimental/kernels/cpu/parallel.h>
12+
13+
template <int weight_nbit>
14+
at::Tensor pack_weights_cpu(
15+
const at::Tensor& weight_qvals,
16+
const at::Tensor& weight_scales,
17+
// TODO(T200095131): convert to int64_t when supported by AOTI
18+
// group_size is a meta tensor with size (group_size)
19+
const at::Tensor& group_size_tensor) {
20+
int64_t group_size = group_size_tensor.size(0);
21+
22+
TORCH_CHECK(
23+
weight_qvals.dtype() == torch::kInt8, "weight_qvals must be int8");
24+
TORCH_CHECK(weight_qvals.dim() == 2, "weight_qvals must be 2D");
25+
26+
// In PyTorch, weights are nxk in row-major format (with activations being
27+
// right-multiplied).
28+
// In kernel, activations are left-multiplied by kxn transposed
29+
// weights in column-major format.
30+
// Note the underlying data is the same in both cases
31+
int n = weight_qvals.size(0);
32+
int k = weight_qvals.size(1);
33+
34+
TORCH_CHECK(
35+
weight_scales.dtype() == torch::kFloat32,
36+
"weight_scales must be float32");
37+
TORCH_CHECK(weight_scales.dim() == 1, "weight_scales must be 1D");
38+
TORCH_CHECK(
39+
weight_scales.size(0) == ((n * k) / group_size),
40+
"expected 1 scale per group");
41+
42+
using namespace torchao::operators::cpu::linear::
43+
channelwise_8bit_activation_groupwise_lowbit_weight;
44+
45+
auto ukernel_config = get_ukernel_config<
46+
weight_nbit,
47+
false /*has_weight_zeros*/,
48+
false /*has_bias*/,
49+
false /*has_clamp*/>();
50+
auto pack_weight_tiling_params = get_default_pack_weight_data_tiling_params(
51+
ukernel_config, n, /*target_panels_per_thread=*/1);
52+
53+
torchao::set_num_threads(torch::get_num_threads());
54+
55+
auto packed_weight_data_size =
56+
get_packed_weight_data_size(ukernel_config, n, k, group_size);
57+
auto options = torch::TensorOptions().dtype(torch::kInt8);
58+
59+
at::Tensor packed_weights = torch::empty({packed_weight_data_size}, options);
60+
pack_weight_data_operator(
61+
ukernel_config,
62+
pack_weight_tiling_params,
63+
packed_weights.data_ptr<int8_t>(),
64+
n,
65+
k,
66+
group_size,
67+
weight_qvals.const_data_ptr<int8_t>(),
68+
weight_scales.const_data_ptr<float>(),
69+
/*weight_zeros=*/nullptr);
70+
71+
return packed_weights;
72+
}
73+
74+
template <int weight_nbit>
75+
at::Tensor pack_weights_meta(
76+
const at::Tensor& weight_qvals,
77+
const at::Tensor& weight_scales,
78+
// TODO(T200095131): convert to int64_t when supported by AOTI
79+
// group_size is a meta tensor with size (group_size)
80+
const at::Tensor& group_size_tensor) {
81+
int64_t group_size = group_size_tensor.size(0);
82+
83+
int n = weight_qvals.size(0);
84+
int k = weight_qvals.size(1);
85+
86+
using namespace torchao::operators::cpu::linear::
87+
channelwise_8bit_activation_groupwise_lowbit_weight;
88+
89+
auto ukernel_config = get_ukernel_config<
90+
weight_nbit,
91+
false /*has_weight_zeros*/,
92+
false /*has_bias*/,
93+
false /*has_clamp*/>();
94+
95+
auto packed_weight_data_size =
96+
get_packed_weight_data_size(ukernel_config, n, k, group_size);
97+
return torch::empty({packed_weight_data_size}).to("meta");
98+
}
99+
100+
template <int weight_nbit>
101+
at::Tensor linear_cpu(
102+
const at::Tensor& packed_weights,
103+
// TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
104+
// int64_t when supported by AOTI Currently they are meta tensors with size
105+
// equal to the int they wrap
106+
const at::Tensor& n_tensor,
107+
const at::Tensor& k_tensor,
108+
const at::Tensor& group_size_tensor,
109+
const at::Tensor& activations) {
110+
int n = n_tensor.size(0);
111+
int k = k_tensor.size(0);
112+
int group_size = group_size_tensor.size(0);
113+
114+
TORCH_CHECK(
115+
activations.dtype() == torch::kFloat32, "activations must be float32");
116+
TORCH_CHECK(activations.dim() == 2, "activations must be 2D");
117+
int m = activations.size(0);
118+
int k_ = activations.size(1);
119+
TORCH_CHECK(k == k_, "activation shape is incompatible with packed weights.");
120+
121+
using namespace torchao::operators::cpu::linear::
122+
channelwise_8bit_activation_groupwise_lowbit_weight;
123+
124+
auto ukernel_config = get_ukernel_config<
125+
weight_nbit,
126+
false /*has_weight_zeros*/,
127+
false /*has_bias*/,
128+
false /*has_clamp*/>();
129+
auto linear_tiling_params = get_default_linear_tiling_params(
130+
ukernel_config,
131+
m,
132+
n,
133+
/*target_tiles_per_thread=*/5);
134+
auto linear_scheduling_policy =
135+
LinearTileSchedulingPolicy::single_mc_parallel_nc;
136+
137+
torchao::set_num_threads(torch::get_num_threads());
138+
139+
auto activation_data_buffer_size = get_activation_data_buffer_size(
140+
ukernel_config,
141+
linear_tiling_params,
142+
linear_scheduling_policy,
143+
m,
144+
k,
145+
group_size);
146+
std::vector<char> activation_data_buffer(activation_data_buffer_size);
147+
148+
at::Tensor output_tensor = torch::empty({m, n}, torch::kFloat32);
149+
linear_operator(
150+
ukernel_config,
151+
linear_tiling_params,
152+
linear_scheduling_policy,
153+
activation_data_buffer.data(),
154+
output_tensor.data_ptr<float>(),
155+
m,
156+
n,
157+
k,
158+
group_size,
159+
packed_weights.const_data_ptr<int8_t>(),
160+
activations.const_data_ptr<float>(),
161+
/*bias=*/nullptr,
162+
// Clamp parameters are ignored because config is created from
163+
// has_clamp = false
164+
/*clamp_min=*/0.0,
165+
/*clamp_max=*/0.0);
166+
167+
return output_tensor;
168+
}
169+
170+
template <int weight_nbit>
171+
at::Tensor linear_meta(
172+
const at::Tensor& packed_weights,
173+
// TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
174+
// int64_t when supported by AOTI
175+
// Currently they are meta tensors with size equal to the int they wrap
176+
const at::Tensor& n_tensor,
177+
const at::Tensor& k_tensor,
178+
const at::Tensor& group_size_tensor,
179+
const at::Tensor& activations) {
180+
int n = n_tensor.size(0);
181+
int k = k_tensor.size(0);
182+
183+
int m = activations.size(0);
184+
int k_ = activations.size(1);
185+
TORCH_CHECK(k == k_, "activation shape is incompatible with packed weights.");
186+
return torch::empty({m, n}).to("meta");
187+
}
188+
189+
TORCH_LIBRARY(torchao, m) {
190+
m.def(
191+
"_pack_weights_3bit(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor");
192+
m.def(
193+
"_linear_3bit(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
194+
m.def(
195+
"_pack_weights_4bit(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor");
196+
m.def(
197+
"_linear_4bit(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
198+
}
199+
200+
TORCH_LIBRARY_IMPL(torchao, CPU, m) {
201+
m.impl("_pack_weights_3bit", &pack_weights_cpu<3>);
202+
m.impl("_linear_3bit", &linear_cpu<3>);
203+
m.impl("_pack_weights_4bit", &pack_weights_cpu<4>);
204+
m.impl("_linear_4bit", &linear_cpu<4>);
205+
}
206+
207+
TORCH_LIBRARY_IMPL(torchao, Meta, m) {
208+
m.impl("_pack_weights_3bit", &pack_weights_meta<3>);
209+
m.impl("_linear_3bit", &linear_meta<3>);
210+
m.impl("_pack_weights_4bit", &pack_weights_meta<4>);
211+
m.impl("_linear_4bit", &linear_meta<4>);
212+
}

0 commit comments

Comments
 (0)