1+ import pytest
2+ import torch
3+ import numpy as np
4+ from compressed_tensors .utils .helpers import pack_bitmasks , unpack_bitmasks
5+ from compressed_tensors .compressors .sparse_compressors .sparse_24_bitmask import (
6+ get_24_bytemasks ,
7+ sparse24_bitmask_compress ,
8+ sparse24_bitmask_decompress ,
9+ Sparse24BitMaskTensor ,
10+ )
11+
12+
13+ class TestPackBitmasks :
14+ """Test pack_bitmasks optimizations."""
15+
16+ def test_pack_bitmasks_correctness_cpu (self ):
17+ """Test PyTorch implementation matches NumPy on CPU."""
18+ test_shapes = [
19+ (1 , 8 ),
20+ (1 , 16 ),
21+ (10 , 7 ),
22+ (10 , 8 ),
23+ (10 , 9 ),
24+ (100 , 100 ),
25+ (128 , 256 ),
26+ (1000 , 1000 ),
27+ ]
28+
29+ for shape in test_shapes :
30+ mask = torch .rand (shape ) > 0.5
31+
32+ # PyTorch implementation
33+ packed_torch = pack_bitmasks (mask )
34+
35+ # NumPy reference
36+ packed_numpy = torch .from_numpy (
37+ np .packbits (mask .numpy (), axis = - 1 , bitorder = "little" )
38+ )
39+
40+ assert torch .equal (packed_torch , packed_numpy ), \
41+ f"Mismatch for shape { shape } : PyTorch != NumPy"
42+
43+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
44+ def test_pack_bitmasks_gpu (self ):
45+ """Test GPU implementation produces correct results."""
46+ test_shapes = [(128 , 256 ), (1024 , 1024 )]
47+
48+ for shape in test_shapes :
49+ mask = torch .rand (shape ) > 0.5
50+ mask_gpu = mask .cuda ()
51+
52+ # GPU implementation
53+ packed_gpu = pack_bitmasks (mask_gpu )
54+ assert packed_gpu .is_cuda , "Result should stay on GPU"
55+
56+ # CPU reference
57+ packed_cpu = pack_bitmasks (mask )
58+
59+ assert torch .equal (packed_gpu .cpu (), packed_cpu ), \
60+ f"GPU result differs from CPU for shape { shape } "
61+
62+ def test_pack_unpack_roundtrip (self ):
63+ """Test pack/unpack roundtrip preserves data."""
64+ shapes = [(10 , 16 ), (128 , 256 ), (100 , 999 )]
65+
66+ for shape in shapes :
67+ mask = torch .rand (shape ) > 0.5
68+ packed = pack_bitmasks (mask )
69+ unpacked = unpack_bitmasks (packed , list (shape ))
70+
71+ assert torch .equal (mask , unpacked ), \
72+ f"Roundtrip failed for shape { shape } "
73+
74+ def test_edge_cases (self ):
75+ """Test edge cases."""
76+ # Empty tensor
77+ empty = torch .empty (0 , 0 , dtype = torch .bool )
78+ packed = pack_bitmasks (empty )
79+ assert packed .shape == (0 , 0 )
80+
81+ # Single element
82+ single = torch .tensor ([[True ]])
83+ packed = pack_bitmasks (single )
84+ assert packed .shape == (1 , 1 )
85+ assert packed [0 , 0 ] == 1
86+
87+ # All False
88+ all_false = torch .zeros (10 , 16 , dtype = torch .bool )
89+ packed = pack_bitmasks (all_false )
90+ assert torch .all (packed == 0 )
91+
92+ # All True
93+ all_true = torch .ones (10 , 16 , dtype = torch .bool )
94+ packed = pack_bitmasks (all_true )
95+ expected = torch .full ((10 , 2 ), 255 , dtype = torch .uint8 )
96+ assert torch .equal (packed , expected )
97+
98+
99+ class TestSparse24Compression :
100+ """Test sparse 2:4 compression optimizations."""
101+
102+ def test_compression_preserves_sparsity (self ):
103+ """Test that compression preserves 2:4 sparsity pattern."""
104+ tensor = torch .randn (128 , 256 )
105+
106+ # Get 2:4 mask
107+ mask = get_24_bytemasks (tensor )
108+ sparsity = (~ mask ).sum ().item () / mask .numel ()
109+ assert abs (sparsity - 0.5 ) < 0.01 , "Should have ~50% sparsity"
110+
111+ # Compress and decompress
112+ compressed , bitmask = sparse24_bitmask_compress (tensor )
113+ decompressed = sparse24_bitmask_decompress (compressed , bitmask , tensor .shape )
114+
115+ # Check sparsity preserved
116+ decompressed_sparsity = (decompressed == 0 ).sum ().item () / decompressed .numel ()
117+ assert abs (decompressed_sparsity - 0.5 ) < 0.01 , "Decompressed should maintain sparsity"
118+
119+ # Check values preserved
120+ assert torch .allclose (tensor [mask ], decompressed [mask ], rtol = 1e-5 )
121+
122+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
123+ def test_gpu_compression (self ):
124+ """Test compression works correctly on GPU."""
125+ tensor = torch .randn (256 , 512 ).cuda ()
126+
127+ # Compress on GPU
128+ compressed_tensor = Sparse24BitMaskTensor .from_dense (tensor )
129+
130+ # Check results moved to CPU for storage
131+ assert compressed_tensor .compressed .device .type == "cpu"
132+ assert compressed_tensor .bitmask .device .type == "cpu"
133+
134+ # Decompress and verify
135+ decompressed = compressed_tensor .decompress ()
136+ mask = get_24_bytemasks (tensor .cpu ())
137+
138+ assert torch .allclose (tensor .cpu ()[mask ], decompressed [mask ], rtol = 1e-5 )
139+
140+ def test_various_dtypes (self ):
141+ """Test compression works with various dtypes."""
142+ dtypes = [torch .float32 , torch .float16 , torch .bfloat16 ]
143+
144+ for dtype in dtypes :
145+ if dtype == torch .bfloat16 and not torch .cuda .is_available ():
146+ continue
147+
148+ tensor = torch .randn (64 , 128 , dtype = dtype )
149+ compressed_tensor = Sparse24BitMaskTensor .from_dense (tensor )
150+ decompressed = compressed_tensor .decompress ()
151+
152+ mask = get_24_bytemasks (tensor )
153+ assert torch .allclose (
154+ tensor [mask ].float (),
155+ decompressed [mask ].float (),
156+ rtol = 1e-3 if dtype == torch .float16 else 1e-5
157+ )
158+
159+ def test_deterministic_sparsity (self ):
160+ """Test that sparsity pattern is deterministic."""
161+ tensor = torch .randn (128 , 256 )
162+
163+ # Get mask multiple times
164+ mask1 = get_24_bytemasks (tensor )
165+ mask2 = get_24_bytemasks (tensor )
166+ mask3 = get_24_bytemasks (tensor )
167+
168+ assert torch .equal (mask1 , mask2 )
169+ assert torch .equal (mask2 , mask3 )
170+
171+ def test_topk_optimization (self ):
172+ """Test that topk with sorted=False produces correct results."""
173+ tensor = torch .randn (128 , 256 )
174+
175+ # Original implementation (sorted=True)
176+ reshaped = tensor .view (- 1 , 4 )
177+ abs_vals = reshaped .abs ()
178+ topk_sorted = abs_vals .topk (2 , dim = 1 , largest = True , sorted = True ).indices
179+
180+ # Optimized implementation (sorted=False)
181+ topk_unsorted = abs_vals .topk (2 , dim = 1 , largest = True , sorted = False ).indices
182+
183+ # Both should select the same elements (order doesn't matter)
184+ mask_sorted = torch .zeros_like (reshaped , dtype = torch .bool )
185+ mask_sorted .scatter_ (1 , topk_sorted , True )
186+
187+ mask_unsorted = torch .zeros_like (reshaped , dtype = torch .bool )
188+ mask_unsorted .scatter_ (1 , topk_unsorted , True )
189+
190+ assert torch .equal (mask_sorted , mask_unsorted )
191+
192+
193+ class TestPerformance :
194+ """Performance regression tests."""
195+
196+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
197+ def test_gpu_faster_than_cpu_transfer (self ):
198+ """Test that GPU processing is faster than CPU transfer for large tensors."""
199+ import time
200+
201+ tensor = torch .randn (4096 , 4096 ).cuda ()
202+
203+ # Time GPU processing
204+ torch .cuda .synchronize ()
205+ start = time .time ()
206+ compressed , bitmask = sparse24_bitmask_compress (tensor )
207+ torch .cuda .synchronize ()
208+ gpu_time = time .time () - start
209+
210+ # Time with CPU transfer
211+ torch .cuda .synchronize ()
212+ start = time .time ()
213+ tensor_cpu = tensor .cpu ()
214+ compressed_cpu , bitmask_cpu = sparse24_bitmask_compress (tensor_cpu )
215+ cpu_time = time .time () - start
216+
217+ # GPU should be faster for large tensors
218+ assert gpu_time < cpu_time , \
219+ f"GPU ({ gpu_time :.3f} s) should be faster than CPU transfer ({ cpu_time :.3f} s)"
220+
221+
222+ if __name__ == "__main__" :
223+ pytest .main ([__file__ , "-v" ])
0 commit comments