Skip to content

Commit 87869f2

Browse files
authored
Intx Quantization Tensor Class (#468)
* init class * tensor subclasses work but slow? * fixed frame break * removed a print * llama profile added * perf * added profile time * added intx quantization to benchmark scripts * add tests * Delete trace.json * Delete profile.txt * seperated dtype and affine quant WIP * works without compile * seperated stuff, added tests * remove intx from api til ready * undo spacing in aqt * updated torch_dispatch * updated test * re-added missing comment * remove new line * add new line * white space fix * whitespace fix * fixed test * refactored implements, actually fixed tests * tests only run on nightly * clean up from pr reviews
1 parent d582f9a commit 87869f2

File tree

8 files changed

+680
-339
lines changed

8 files changed

+680
-339
lines changed

benchmarks/benchmark_bitpacking.py

Lines changed: 0 additions & 93 deletions
This file was deleted.

benchmarks/benchmark_uintx.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from math import log
2+
from copy import deepcopy
3+
4+
import torch
5+
from torchao.utils import unwrap_tensor_subclass
6+
from torchao.prototype.uintx import uintx_affine_weight_only, pack, unpack, pack_cpu, unpack_cpu
7+
from torchao.quantization.quant_api import quantize_
8+
9+
class Linear16(torch.nn.Module):
10+
def __init__(self, scale):
11+
super().__init__()
12+
self.net = torch.nn.Sequential(
13+
torch.nn.Linear(scale*2, scale, bias=True, dtype=torch.float16).cuda(),
14+
torch.nn.Linear(scale, scale, bias=True, dtype=torch.float16).cuda(),
15+
torch.nn.Linear(scale, scale//2, bias=True, dtype=torch.float16).cuda(),
16+
)
17+
18+
def forward(self, x):
19+
return self.net(x)
20+
21+
22+
def benchmark(function, args, num_runs):
23+
# warmup
24+
torch._dynamo.reset()
25+
for i in range(100):
26+
function(*args)
27+
torch.cuda.synchronize()
28+
start_event = torch.cuda.Event(enable_timing=True)
29+
end_event = torch.cuda.Event(enable_timing=True)
30+
start_event.record()
31+
32+
for _ in range(num_runs):
33+
function(*args)
34+
35+
end_event.record()
36+
torch.cuda.synchronize()
37+
return start_event.elapsed_time(end_event) / num_runs
38+
39+
40+
def profile_bitpack():
41+
from torch.profiler import profile, record_function, ProfilerActivity
42+
fake_tensor = [torch.randint(2**8, (512,512), dtype=torch.uint8).cuda()]
43+
func = torch.compile(unpack_cpu, fullgraph=True)
44+
with profile(activities=[
45+
ProfilerActivity.CPU,
46+
ProfilerActivity.CUDA],
47+
record_shapes=True,
48+
with_stack=True
49+
) as prof:
50+
51+
for _ in range(1000):
52+
unpacked = func(fake_tensor, 4)
53+
54+
# Print a summary
55+
with open("profile-bitpack.txt", "a") as f:
56+
print(f'{func}',file=f)
57+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10), file=f)
58+
prof.export_chrome_trace("trace.json")
59+
'''
60+
CPU perf:
61+
unpack_gpu
62+
Self CPU time total: 602.501ms
63+
64+
unpack_cpu
65+
Self CPU time total: 415.469ms
66+
GPU perf:
67+
unpack_gpu on gpu:
68+
Self CPU time total: 58.512ms
69+
Self CUDA time total: 5.083ms
70+
71+
unpack_cpu:
72+
Self CPU time total: 96.947ms
73+
Self CUDA time total: 5.253ms
74+
'''
75+
76+
def uintx_vs_fp16(nbits= [1,2,3,4,5,6,7], scales=[256, 512, 1024], repeats=30):
77+
results = []
78+
nbits.sort()
79+
scales.sort()
80+
for scale in scales:
81+
test_input = torch.randn(scale*2, dtype=torch.float16).cuda()
82+
forward_args = [test_input]
83+
times = [scale]
84+
85+
fp16 = Linear16(scale)
86+
fp16c = torch.compile(fp16, fullgraph=True)
87+
fp16_time = benchmark(fp16c.forward, forward_args, repeats)
88+
times.append(fp16_time)
89+
for bit_size in nbits:
90+
m = deepcopy(fp16)
91+
quantize_(m, uintx_affine_weight_only(bit_size))
92+
m = torch.compile(m, fullgraph=True)
93+
uintx_time = benchmark(m.forward, forward_args, repeats)
94+
times.append(uintx_time)
95+
print(f'scale={scale} done')
96+
97+
results.append(times)
98+
print("----------- benchmark results -----------")
99+
for result in results:
100+
print(f"scale: {result[0]} fp16 time:{result[1]: .2f}ms speedups:")
101+
for i in range(2, len(result)):
102+
print(f"int{nbits[i-2]}: {result[1]/result[i]: .2f}x")
103+
104+
105+
106+
if __name__ == "__main__":
107+
uintx_vs_fp16(nbits=[4,7])
108+
109+

test/prototype/test_bitpacking.py

Lines changed: 45 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -1,143 +1,66 @@
11
import torch
2-
from torchao.prototype.common.bitpacking import pack, unpack
2+
from torchao.prototype.uintx import pack, unpack, pack_cpu, unpack_cpu
33
import pytest
44
from torch.utils._triton import has_triton
5-
from torchao.utils import TORCH_VERSION_AFTER_2_4
6-
7-
if not TORCH_VERSION_AFTER_2_4:
8-
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
9-
10-
dtypes = ((2, 'trinary', 1), (2, None, 1), (3, None, 2), (4, None, 2), (5, None, 4), (6, None, 4), (7, None, 4))
11-
dimensions = (2, 1, 0, -1)
12-
orders = (True, False)
135

6+
element_bit_width = (1,2,3,4,5,6,7)
7+
dimensions = (0, -1, 1)
148

159
@pytest.fixture(autouse=True)
1610
def run_before_and_after_tests():
17-
# source: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test # noqa: E501
18-
19-
# setup (currently do nothing)
20-
21-
# tests will run here
2211
yield
12+
torch._dynamo.reset() # reset cache between tests
2313

24-
# teardown
25-
# avoid dynamo cache limit issues
26-
torch._dynamo.reset()
27-
28-
@pytest.mark.parametrize("dtype", dtypes)
14+
@pytest.mark.parametrize("element_bit_width", element_bit_width)
2915
@pytest.mark.parametrize("dim", dimensions)
30-
@pytest.mark.parametrize("order", orders)
31-
def test_CPU(dtype, dim, order):
32-
element_bit_width, element_type,expected_pack_size = dtype
33-
shape = [4, 4, 4]
34-
if element_type == "trinary":
35-
test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8, device='cpu')
36-
else:
37-
test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8, device='cpu')
38-
39-
packed = pack(test_tensor,
40-
element_bit_width,
41-
element_type=element_type,
42-
dim = dim,
43-
order = order,
44-
container_dtype = torch.uint8)
45-
assert(packed.shape[dim] == expected_pack_size)
46-
unpacked = unpack(packed,
47-
element_bit_width,
48-
element_type=element_type,
49-
dim = dim,
50-
order = order)
16+
def test_CPU(element_bit_width, dim):
17+
test_tensor = torch.randint(0, 2**element_bit_width, (32,32,32), dtype=torch.uint8, device='cpu')
18+
packed = pack_cpu(test_tensor, element_bit_width, dim = dim)
19+
unpacked = unpack_cpu(packed, element_bit_width, dim = dim)
5120
assert(unpacked.allclose(test_tensor))
5221

53-
54-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
55-
@pytest.mark.parametrize("dtype", dtypes)
22+
23+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
24+
@pytest.mark.parametrize("element_bit_width", element_bit_width)
5625
@pytest.mark.parametrize("dim", dimensions)
57-
@pytest.mark.parametrize("order", orders)
58-
def test_GPU(dtype, dim, order):
59-
element_bit_width, element_type,expected_pack_size = dtype
60-
shape = [4, 4, 4]
61-
if element_type == "trinary":
62-
test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda()
63-
else:
64-
test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8).cuda()
65-
66-
packed = pack(test_tensor,
67-
element_bit_width,
68-
element_type=element_type,
69-
dim = dim,
70-
order = order,
71-
container_dtype = torch.uint8)
72-
assert(packed.shape[dim] == expected_pack_size)
73-
unpacked = unpack(packed,
74-
element_bit_width,
75-
element_type=element_type,
76-
order = order,
77-
dim = dim)
26+
def test_GPU(element_bit_width, dim):
27+
test_tensor = torch.randint(0, 2**element_bit_width, (32,32,32), dtype=torch.uint8).cuda()
28+
packed = pack(test_tensor, element_bit_width, dim = dim)
29+
unpacked = unpack(packed, element_bit_width, dim = dim)
7830
assert(unpacked.allclose(test_tensor))
7931

8032

8133
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
8234
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
83-
@pytest.mark.parametrize("dtype", dtypes)
84-
@pytest.mark.parametrize("dim", dimensions)
85-
@pytest.mark.parametrize("order", orders)
86-
def test_padding(dtype, dim, order):
87-
element_bit_width, element_type,expected_pack_size = dtype
88-
torch._dynamo.config.specialize_int = True
89-
shape =[4, 4, 4]
90-
shape[dim] = 5
91-
92-
if element_type == "trinary":
93-
test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda()
94-
else:
95-
test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8).cuda()
96-
97-
packed = pack(test_tensor,
98-
element_bit_width,
99-
element_type=element_type,
100-
dim = dim,
101-
container_dtype = torch.uint8,
102-
order = order,
103-
pad= True)
104-
assert packed.shape[dim] == expected_pack_size+1, f"packed.shape[dim] {packed.shape[dim]}" # +1 for this scenario
105-
unpacked = unpack(packed,
106-
element_bit_width,
107-
element_type=element_type,
108-
dim = dim,
109-
order = order)
110-
slices = [slice(None)] * packed.ndim
111-
slices[dim] = slice(None, 5)
112-
assert unpacked[slices].allclose(test_tensor)
113-
114-
115-
116-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
117-
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
118-
@pytest.mark.parametrize("dtype", dtypes)
35+
@pytest.mark.parametrize("element_bit_width", element_bit_width)
11936
@pytest.mark.parametrize("dim", dimensions)
120-
@pytest.mark.parametrize("order", orders)
121-
def test_compile(dtype, dim, order):
122-
pack_compile = torch.compile(pack, fullgraph=True, dynamic=True)
123-
unpack_compile = torch.compile(unpack, fullgraph=True, dynamic=True)
124-
element_bit_width, element_type,expected_pack_size = dtype
37+
def test_compile(element_bit_width, dim):
12538
torch._dynamo.config.specialize_int = True
126-
shape = [4, 4, 4]
127-
if element_type == "trinary":
128-
test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda()
129-
else:
130-
test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.int8).cuda()
131-
132-
packed = pack_compile(test_tensor, element_bit_width,
133-
element_type=element_type,
134-
dim = dim,
135-
container_dtype = torch.int8,
136-
order = order)
137-
assert(packed.shape[dim] == expected_pack_size)
138-
unpacked = unpack_compile(packed,
139-
element_bit_width,
140-
element_type=element_type,
141-
dim = dim,
142-
order = order)
39+
pack_compile = torch.compile(pack, fullgraph=True)
40+
unpack_compile = torch.compile(unpack, fullgraph=True)
41+
test_tensor = torch.randint(0, 2**element_bit_width, (32,32,32), dtype=torch.uint8).cuda()
42+
packed = pack(test_tensor, element_bit_width, dim = dim)
43+
unpacked = unpack(packed, element_bit_width, dim = dim)
14344
assert(unpacked.allclose(test_tensor))
45+
46+
# these test cases are for the example pack walk through in the bitpacking.py file
47+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
48+
def test_pack_example():
49+
test_tensor = torch.tensor([0x30,0x29,0x17,0x5,0x20,0x16,0x9,0x22], dtype=torch.uint8).cuda()
50+
shard_4,shard_2 = pack(test_tensor, 6)
51+
print(shard_4, shard_2)
52+
assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).cuda().allclose(shard_4)
53+
assert torch.tensor([39, 146], dtype=torch.uint8).cuda().allclose(shard_2)
54+
unpacked = unpack([shard_4, shard_2], 6)
55+
assert unpacked.allclose(test_tensor)
56+
57+
def test_pack_example_CPU():
58+
test_tensor = torch.tensor([0x30,0x29,0x17,0x5,0x20,0x16,0x9,0x22], dtype=torch.uint8)
59+
shard_4,shard_2 = pack(test_tensor, 6)
60+
print(shard_4, shard_2)
61+
assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).allclose(shard_4)
62+
assert torch.tensor([39, 146], dtype=torch.uint8).allclose(shard_2)
63+
unpacked = unpack([shard_4, shard_2], 6)
64+
assert unpacked.allclose(test_tensor)
65+
66+

0 commit comments

Comments
 (0)