1
1
import torch
2
- from torchao .prototype .common . bitpacking import pack , unpack
2
+ from torchao .prototype .uintx import pack , unpack , pack_cpu , unpack_cpu
3
3
import pytest
4
4
from torch .utils ._triton import has_triton
5
- from torchao .utils import TORCH_VERSION_AFTER_2_4
6
-
7
- if not TORCH_VERSION_AFTER_2_4 :
8
- pytest .skip ("Unsupported PyTorch version" , allow_module_level = True )
9
-
10
- dtypes = ((2 , 'trinary' , 1 ), (2 , None , 1 ), (3 , None , 2 ), (4 , None , 2 ), (5 , None , 4 ), (6 , None , 4 ), (7 , None , 4 ))
11
- dimensions = (2 , 1 , 0 , - 1 )
12
- orders = (True , False )
13
5
6
+ element_bit_width = (1 ,2 ,3 ,4 ,5 ,6 ,7 )
7
+ dimensions = (0 , - 1 , 1 )
14
8
15
9
@pytest .fixture (autouse = True )
16
10
def run_before_and_after_tests ():
17
- # source: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test # noqa: E501
18
-
19
- # setup (currently do nothing)
20
-
21
- # tests will run here
22
11
yield
12
+ torch ._dynamo .reset () # reset cache between tests
23
13
24
- # teardown
25
- # avoid dynamo cache limit issues
26
- torch ._dynamo .reset ()
27
-
28
- @pytest .mark .parametrize ("dtype" , dtypes )
14
+ @pytest .mark .parametrize ("element_bit_width" , element_bit_width )
29
15
@pytest .mark .parametrize ("dim" , dimensions )
30
- @pytest .mark .parametrize ("order" , orders )
31
- def test_CPU (dtype , dim , order ):
32
- element_bit_width , element_type ,expected_pack_size = dtype
33
- shape = [4 , 4 , 4 ]
34
- if element_type == "trinary" :
35
- test_tensor = torch .randint (- 1 , 1 , shape , dtype = torch .int8 , device = 'cpu' )
36
- else :
37
- test_tensor = torch .randint (0 , 2 ** element_bit_width , shape , dtype = torch .uint8 , device = 'cpu' )
38
-
39
- packed = pack (test_tensor ,
40
- element_bit_width ,
41
- element_type = element_type ,
42
- dim = dim ,
43
- order = order ,
44
- container_dtype = torch .uint8 )
45
- assert (packed .shape [dim ] == expected_pack_size )
46
- unpacked = unpack (packed ,
47
- element_bit_width ,
48
- element_type = element_type ,
49
- dim = dim ,
50
- order = order )
16
+ def test_CPU (element_bit_width , dim ):
17
+ test_tensor = torch .randint (0 , 2 ** element_bit_width , (32 ,32 ,32 ), dtype = torch .uint8 , device = 'cpu' )
18
+ packed = pack_cpu (test_tensor , element_bit_width , dim = dim )
19
+ unpacked = unpack_cpu (packed , element_bit_width , dim = dim )
51
20
assert (unpacked .allclose (test_tensor ))
52
21
53
-
54
- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
55
- @pytest .mark .parametrize ("dtype " , dtypes )
22
+
23
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
24
+ @pytest .mark .parametrize ("element_bit_width " , element_bit_width )
56
25
@pytest .mark .parametrize ("dim" , dimensions )
57
- @pytest .mark .parametrize ("order" , orders )
58
- def test_GPU (dtype , dim , order ):
59
- element_bit_width , element_type ,expected_pack_size = dtype
60
- shape = [4 , 4 , 4 ]
61
- if element_type == "trinary" :
62
- test_tensor = torch .randint (- 1 , 1 , shape , dtype = torch .int8 ).cuda ()
63
- else :
64
- test_tensor = torch .randint (0 , 2 ** element_bit_width , shape , dtype = torch .uint8 ).cuda ()
65
-
66
- packed = pack (test_tensor ,
67
- element_bit_width ,
68
- element_type = element_type ,
69
- dim = dim ,
70
- order = order ,
71
- container_dtype = torch .uint8 )
72
- assert (packed .shape [dim ] == expected_pack_size )
73
- unpacked = unpack (packed ,
74
- element_bit_width ,
75
- element_type = element_type ,
76
- order = order ,
77
- dim = dim )
26
+ def test_GPU (element_bit_width , dim ):
27
+ test_tensor = torch .randint (0 , 2 ** element_bit_width , (32 ,32 ,32 ), dtype = torch .uint8 ).cuda ()
28
+ packed = pack (test_tensor , element_bit_width , dim = dim )
29
+ unpacked = unpack (packed , element_bit_width , dim = dim )
78
30
assert (unpacked .allclose (test_tensor ))
79
31
80
32
81
33
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
82
34
@pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
83
- @pytest .mark .parametrize ("dtype" , dtypes )
84
- @pytest .mark .parametrize ("dim" , dimensions )
85
- @pytest .mark .parametrize ("order" , orders )
86
- def test_padding (dtype , dim , order ):
87
- element_bit_width , element_type ,expected_pack_size = dtype
88
- torch ._dynamo .config .specialize_int = True
89
- shape = [4 , 4 , 4 ]
90
- shape [dim ] = 5
91
-
92
- if element_type == "trinary" :
93
- test_tensor = torch .randint (- 1 , 1 , shape , dtype = torch .int8 ).cuda ()
94
- else :
95
- test_tensor = torch .randint (0 , 2 ** element_bit_width , shape , dtype = torch .uint8 ).cuda ()
96
-
97
- packed = pack (test_tensor ,
98
- element_bit_width ,
99
- element_type = element_type ,
100
- dim = dim ,
101
- container_dtype = torch .uint8 ,
102
- order = order ,
103
- pad = True )
104
- assert packed .shape [dim ] == expected_pack_size + 1 , f"packed.shape[dim] { packed .shape [dim ]} " # +1 for this scenario
105
- unpacked = unpack (packed ,
106
- element_bit_width ,
107
- element_type = element_type ,
108
- dim = dim ,
109
- order = order )
110
- slices = [slice (None )] * packed .ndim
111
- slices [dim ] = slice (None , 5 )
112
- assert unpacked [slices ].allclose (test_tensor )
113
-
114
-
115
-
116
- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
117
- @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
118
- @pytest .mark .parametrize ("dtype" , dtypes )
35
+ @pytest .mark .parametrize ("element_bit_width" , element_bit_width )
119
36
@pytest .mark .parametrize ("dim" , dimensions )
120
- @pytest .mark .parametrize ("order" , orders )
121
- def test_compile (dtype , dim , order ):
122
- pack_compile = torch .compile (pack , fullgraph = True , dynamic = True )
123
- unpack_compile = torch .compile (unpack , fullgraph = True , dynamic = True )
124
- element_bit_width , element_type ,expected_pack_size = dtype
37
+ def test_compile (element_bit_width , dim ):
125
38
torch ._dynamo .config .specialize_int = True
126
- shape = [4 , 4 , 4 ]
127
- if element_type == "trinary" :
128
- test_tensor = torch .randint (- 1 , 1 , shape , dtype = torch .int8 ).cuda ()
129
- else :
130
- test_tensor = torch .randint (0 , 2 ** element_bit_width , shape , dtype = torch .int8 ).cuda ()
131
-
132
- packed = pack_compile (test_tensor , element_bit_width ,
133
- element_type = element_type ,
134
- dim = dim ,
135
- container_dtype = torch .int8 ,
136
- order = order )
137
- assert (packed .shape [dim ] == expected_pack_size )
138
- unpacked = unpack_compile (packed ,
139
- element_bit_width ,
140
- element_type = element_type ,
141
- dim = dim ,
142
- order = order )
39
+ pack_compile = torch .compile (pack , fullgraph = True )
40
+ unpack_compile = torch .compile (unpack , fullgraph = True )
41
+ test_tensor = torch .randint (0 , 2 ** element_bit_width , (32 ,32 ,32 ), dtype = torch .uint8 ).cuda ()
42
+ packed = pack (test_tensor , element_bit_width , dim = dim )
43
+ unpacked = unpack (packed , element_bit_width , dim = dim )
143
44
assert (unpacked .allclose (test_tensor ))
45
+
46
+ # these test cases are for the example pack walk through in the bitpacking.py file
47
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
48
+ def test_pack_example ():
49
+ test_tensor = torch .tensor ([0x30 ,0x29 ,0x17 ,0x5 ,0x20 ,0x16 ,0x9 ,0x22 ], dtype = torch .uint8 ).cuda ()
50
+ shard_4 ,shard_2 = pack (test_tensor , 6 )
51
+ print (shard_4 , shard_2 )
52
+ assert torch .tensor ([0 , 105 , 151 , 37 ], dtype = torch .uint8 ).cuda ().allclose (shard_4 )
53
+ assert torch .tensor ([39 , 146 ], dtype = torch .uint8 ).cuda ().allclose (shard_2 )
54
+ unpacked = unpack ([shard_4 , shard_2 ], 6 )
55
+ assert unpacked .allclose (test_tensor )
56
+
57
+ def test_pack_example_CPU ():
58
+ test_tensor = torch .tensor ([0x30 ,0x29 ,0x17 ,0x5 ,0x20 ,0x16 ,0x9 ,0x22 ], dtype = torch .uint8 )
59
+ shard_4 ,shard_2 = pack (test_tensor , 6 )
60
+ print (shard_4 , shard_2 )
61
+ assert torch .tensor ([0 , 105 , 151 , 37 ], dtype = torch .uint8 ).allclose (shard_4 )
62
+ assert torch .tensor ([39 , 146 ], dtype = torch .uint8 ).allclose (shard_2 )
63
+ unpacked = unpack ([shard_4 , shard_2 ], 6 )
64
+ assert unpacked .allclose (test_tensor )
65
+
66
+
0 commit comments