Skip to content

Commit 867ec9a

Browse files
committed
groupsize consistency
Summary: half of the apis used groupsize and half used group_size, swapping them all to groupsize Test Plan: python eval.py -q int8wo --limit 1 wikitext: {'word_perplexity,none': 12.204889603121593, 'byte_perplexity,none': 1.5965674184201175, 'bits_per_byte,none': 0.6749734750293632, 'alias': 'wikitext'} python generate.py --quantization int4wo-64 Average tokens/sec: 13.93 Average Bandwidth: 52.04 GB/s Peak Memory Usage: 15.92 GB Model Size: 3.74 GB Reviewers: Subscribers: Tasks: Tags:
1 parent ef1e745 commit 867ec9a

31 files changed

+166
-166
lines changed

benchmarks/benchmark_hqq.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def bench_custom_kernel(
3131
W_q,
3232
scales,
3333
zeros,
34-
group_size,
34+
groupsize,
3535
transposed=False,
3636
kernel_type="max_autotune",
3737
fp8_fast_accum=False,
@@ -45,7 +45,7 @@ def fn():
4545
scales.T,
4646
zeros.T,
4747
transposed=transposed,
48-
group_size=group_size,
48+
groupsize=groupsize,
4949
fp8_fast_accum=fp8_fast_accum,
5050
kernel_type=kernel_type,
5151
)
@@ -65,11 +65,11 @@ def reference_fn():
6565

6666

6767
def run_benchmark(
68-
shape, group_size, dtype, axis=1, transposed=False, quant_dtype=torch.uint8
68+
shape, groupsize, dtype, axis=1, transposed=False, quant_dtype=torch.uint8
6969
):
7070
qcfg = {
7171
**BASE_QUANT_CONFIG,
72-
**dict(group_size=group_size, axis=axis),
72+
**dict(groupsize=groupsize, axis=axis),
7373
}
7474
M, N, K = shape
7575

@@ -103,7 +103,7 @@ def run_benchmark(
103103
scales = scales.reshape(N, -1)
104104
zeros = zeros.reshape(N, -1)
105105
tt_time = bench_custom_kernel(
106-
x, W_q, scales, zeros, group_size, transposed=transposed
106+
x, W_q, scales, zeros, groupsize, transposed=transposed
107107
)
108108

109109
should_run_tinygemm = dtype == torch.bfloat16 and not transposed
@@ -114,7 +114,7 @@ def run_benchmark(
114114
)
115115
int4_time = bench_hqq(x, hqq_int4mm, transposed=transposed, tinygemm=True)
116116

117-
print(f"{shape=}, {group_size=}, {dtype=}, {transposed=}:")
117+
print(f"{shape=}, {groupsize=}, {dtype=}, {transposed=}:")
118118

119119
print(
120120
f"Ref: {ref_time:.4f}ms",
@@ -146,7 +146,7 @@ def run_benchmark(
146146
"M",
147147
"N",
148148
"K",
149-
"group_size",
149+
"groupsize",
150150
"dtype",
151151
"transposed",
152152
"ref",
@@ -159,16 +159,16 @@ def run_benchmark(
159159
print(torch.cuda.get_device_properties(0))
160160

161161
for shape in SHAPES:
162-
for group_size in GROUP_SIZES:
162+
for groupsize in GROUP_SIZES:
163163
for dtype in DTYPES:
164164
for transposed in TRANSPOSED:
165165
timings = run_benchmark(
166-
shape, group_size, dtype, transposed=transposed
166+
shape, groupsize, dtype, transposed=transposed
167167
)
168-
data.append((*shape, group_size, dtype, transposed, *timings))
168+
data.append((*shape, groupsize, dtype, transposed, *timings))
169169

170170
output = StringIO()
171171
df = pd.DataFrame(data, columns=HEADERS)
172172
df.to_csv(output, index=False)
173173
print(output.getvalue())
174-
# df.to_csv("benchmark_hqq_tinygemm.csv", index=False)
174+
# df.to_csv("benchmark_hqq_tinygemm.csv", index=False)

benchmarks/dora/bench_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def setup_dora_base_layers(layer_type, in_features, out_features, dtype):
111111
# HQQ
112112
quant_config = BaseQuantizeConfig(
113113
nbits=4,
114-
group_size=64,
114+
groupsize=64,
115115
quant_zero=False,
116116
quant_scale=False,
117117
offload_meta=True,

test/dora/test_dora_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_dora_layer(
9191
elif model_type == "HQQDoRALinear":
9292
quant_config = BaseQuantizeConfig(
9393
nbits=4,
94-
group_size=64,
94+
groupsize=64,
9595
quant_zero=False,
9696
quant_scale=False,
9797
offload_meta=True,

test/dtypes/test_affine_quantized.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class TestAffineQuantized(TestCase):
1212
def test_tensor_core_layout_transpose(self):
1313
t = torch.rand(128, 256, dtype=torch.bfloat16, device="cuda")
1414
shape = t.shape
15-
apply_int4_weight_only_quant = int4_weight_only(group_size=32)
15+
apply_int4_weight_only_quant = int4_weight_only(groupsize=32)
1616
aqt = apply_int4_weight_only_quant(t)
1717
aqt_shape = aqt.shape
1818
self.assertEqual(aqt_shape, shape)

test/hqq/test_triton_mm.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,16 @@ def _arg_to_id(arg):
7272

7373

7474
@pytest.mark.parametrize(
75-
"shape, group_size, axis, dtype, transposed, kernel_type",
75+
"shape, groupsize, axis, dtype, transposed, kernel_type",
7676
TEST_CONFIGS,
7777
ids=_arg_to_id,
7878
)
7979
def test_mixed_mm(
80-
shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8
80+
shape, groupsize, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8
8181
):
8282
qcfg = {
8383
**BASE_QUANT_CONFIG,
84-
**dict(group_size=group_size, axis=axis),
84+
**dict(groupsize=groupsize, axis=axis),
8585
}
8686
M, N, K = shape
8787

@@ -117,7 +117,7 @@ def test_mixed_mm(
117117
scales.T,
118118
zeros.T,
119119
transposed=True,
120-
group_size=group_size,
120+
groupsize=groupsize,
121121
fp8_fast_accum=False,
122122
kernel_type=kernel_type,
123123
)
@@ -132,7 +132,7 @@ def test_mixed_mm(
132132
scales.T,
133133
zeros.T,
134134
transposed=False,
135-
group_size=group_size,
135+
groupsize=groupsize,
136136
fp8_fast_accum=False,
137137
kernel_type=kernel_type,
138138
)
@@ -147,7 +147,7 @@ def test_mixed_mm(
147147
# Only for debugging kernel without dependency on HQQ and with no autotuning
148148
def _test_mixed_mm(
149149
shape,
150-
group_size,
150+
groupsize,
151151
BLOCK_M,
152152
BLOCK_N,
153153
BLOCK_K,
@@ -159,7 +159,7 @@ def _test_mixed_mm(
159159
):
160160
qcfg = {
161161
**BASE_QUANT_CONFIG,
162-
**dict(group_size=group_size, axis=axis),
162+
**dict(groupsize=groupsize, axis=axis),
163163
}
164164
M, N, K = shape
165165

@@ -169,9 +169,9 @@ def _test_mixed_mm(
169169
quant_config.update({"weight_quant_params": qcfg})
170170
W_q = torch.randint(0, int(2**4), size=(N, K), dtype=quant_dtype, device="cuda")
171171

172-
scales = torch.arange((N * K) // group_size, dtype=dtype, device="cuda")[:, None]
172+
scales = torch.arange((N * K) // groupsize, dtype=dtype, device="cuda")[:, None]
173173
zeros = torch.zeros_like(scales)
174-
W_dq = ((W_q.reshape(-1, group_size) - zeros) * scales).reshape(N, K)
174+
W_dq = ((W_q.reshape(-1, groupsize) - zeros) * scales).reshape(N, K)
175175
scales = scales.reshape(N, -1)
176176
zeros = zeros.reshape(N, -1)
177177

@@ -187,7 +187,7 @@ def _test_mixed_mm(
187187
scales.T,
188188
zeros.T,
189189
transposed=True,
190-
group_size=group_size,
190+
groupsize=groupsize,
191191
fp8_fast_accum=False,
192192
kernel_type=kernel_type,
193193
BLOCK_M=BLOCK_M,
@@ -205,14 +205,14 @@ def _test_mixed_mm(
205205
scales.T,
206206
zeros.T,
207207
transposed=False,
208-
group_size=group_size,
208+
groupsize=groupsize,
209209
fp8_fast_accum=False,
210210
kernel_type=kernel_type,
211211
BLOCK_M=BLOCK_M,
212212
BLOCK_N=BLOCK_N,
213213
BLOCK_K=BLOCK_K,
214214
)
215-
msg = f"shape={shape}, group_size={group_size}, axis={axis}, dtype={dtype}, transposed={transposed}, kernel_type={kernel_type}, quant_dtype={quant_dtype}"
215+
msg = f"shape={shape}, groupsize={groupsize}, axis={axis}, dtype={dtype}, transposed={transposed}, kernel_type={kernel_type}, quant_dtype={quant_dtype}"
216216

217217
check(
218218
hqq_out,
@@ -229,18 +229,18 @@ def _test_mixed_mm(
229229
BLOCK_M, BLOCK_N, BLOCK_K = shape
230230
BLOCK_K = K // 2
231231
BLOCK_N = N // 2
232-
group_size = BLOCK_K
232+
groupsize = BLOCK_K
233233
_test_mixed_mm(
234234
shape,
235-
group_size=group_size,
235+
groupsize=groupsize,
236236
BLOCK_M=BLOCK_M,
237237
BLOCK_N=BLOCK_N,
238238
BLOCK_K=BLOCK_K,
239239
transposed=False,
240240
)
241241
_test_mixed_mm(
242242
shape,
243-
group_size=group_size,
243+
groupsize=groupsize,
244244
BLOCK_M=BLOCK_M,
245245
BLOCK_N=BLOCK_N,
246246
BLOCK_K=BLOCK_K,

test/hqq/test_triton_qkv_fused.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,42 +75,42 @@ def fuse_qkv(W_qs, scales, zeros):
7575
"""
7676
Args:
7777
W_qs (list[torch.Tensor]): len 3 list of tensors with shapes Nq x K, Nk x K, Nv x K where Nk == Nv
78-
scales (list[torch.Tensor]): each is N x (K // group_size), with same N requirements per W_qs
78+
scales (list[torch.Tensor]): each is N x (K // groupsize), with same N requirements per W_qs
7979
zeros (list[torch.Tensor]): same as scales
8080
8181
Returns:
8282
qkv (torch.Tensor): (N_qkv x K) where N_qkv = Nq + Nk + Nv
83-
scales (torch.Tensor): (N_qkv x (K // group_size))
84-
zeros (torch.Tensor): (N_qkv x (K // group_size))
83+
scales (torch.Tensor): (N_qkv x (K // groupsize))
84+
zeros (torch.Tensor): (N_qkv x (K // groupsize))
8585
"""
8686
qkv = torch.cat(W_qs, dim=0) # Fuse along N
8787
fused_scales = torch.cat([s for s in scales], dim=0)
8888
fused_zeros = torch.cat([z for z in zeros], dim=0)
8989
return qkv, fused_scales, fused_zeros
9090

9191

92-
def ref_proj(x, packed_w, scale, zero, group_size, kernel_type, transposed=False):
92+
def ref_proj(x, packed_w, scale, zero, groupsize, kernel_type, transposed=False):
9393
return triton_mixed_mm(
9494
x,
9595
packed_w,
9696
scale.T,
9797
zero.T,
9898
transposed=transposed,
99-
group_size=group_size,
99+
groupsize=group_size,
100100
fp8_fast_accum=False,
101101
kernel_type=kernel_type,
102102
)
103103

104104

105105
@pytest.mark.parametrize(
106-
"q_shape, kv_shape, group_size, axis, dtype, transposed, kernel_type",
106+
"q_shape, kv_shape, groupsize, axis, dtype, transposed, kernel_type",
107107
TEST_CONFIGS,
108108
ids=_arg_to_id,
109109
)
110110
def test_mixed_mm(
111111
q_shape,
112112
kv_shape,
113-
group_size,
113+
groupsize,
114114
axis,
115115
dtype,
116116
transposed,
@@ -136,7 +136,7 @@ def test_mixed_mm(
136136

137137
qcfg = {
138138
**BASE_QUANT_CONFIG,
139-
**dict(group_size=group_size, axis=axis),
139+
**dict(groupsize=group_size, axis=axis),
140140
}
141141

142142
quant_config = BaseQuantizeConfig(
@@ -172,7 +172,7 @@ def test_mixed_mm(
172172
xs = [torch.randn(seqlen, n, dtype=dtype, device=device) for n in Ns]
173173
x_fused = torch.cat(xs, dim=1)
174174
q_ref, k_ref, v_ref = [
175-
ref_proj(x, p, s, z, group_size, kernel_type, transposed=True)
175+
ref_proj(x, p, s, z, groupsize, kernel_type, transposed=True)
176176
for x, p, s, z in zip(xs, packed_ws, scales, zeros)
177177
]
178178
tt_fused = triton_mixed_mm(
@@ -181,7 +181,7 @@ def test_mixed_mm(
181181
scales_fused.T,
182182
zeros_fused.T,
183183
transposed=True,
184-
group_size=group_size,
184+
groupsize=group_size,
185185
fp8_fast_accum=False,
186186
kernel_type=kernel_type,
187187
)
@@ -191,7 +191,7 @@ def test_mixed_mm(
191191
x = torch.randn(seqlen, K, dtype=dtype, device=device)
192192

193193
q_ref, k_ref, v_ref = [
194-
ref_proj(x, p, s, z, group_size, kernel_type)
194+
ref_proj(x, p, s, z, groupsize, kernel_type)
195195
for p, s, z in zip(packed_ws, scales, zeros)
196196
]
197197

@@ -201,7 +201,7 @@ def test_mixed_mm(
201201
scales_fused.T,
202202
zeros_fused.T,
203203
transposed=False,
204-
group_size=group_size,
204+
groupsize=group_size,
205205
fp8_fast_accum=False,
206206
kernel_type=kernel_type,
207207
)

test/integration/test_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
833833
def api(mod):
834834
if TORCH_VERSION_AFTER_2_4:
835835
kwargs_copy = kwargs.copy()
836-
kwargs_copy["group_size"] = groupsize
836+
kwargs_copy["groupsize"] = groupsize
837837
del kwargs_copy["groupsize"]
838838
quantize(mod, int4_weight_only(**kwargs_copy))
839839
unwrap_tensor_subclass(mod)

test/quantization/test_galore_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize):
4242
bnb_norm = (g.reshape(-1, blocksize) / qstate.absmax[:, None]).reshape(g.shape)
4343

4444
tt_q, tt_norm, tt_absmax = triton_quantize_blockwise(
45-
g, qmap, group_size=blocksize, return_normalized=True
45+
g, qmap, groupsize=blocksize, return_normalized=True
4646
)
4747
tt_check = torch.allclose(ref_bnb, tt_q)
4848

@@ -87,5 +87,5 @@ def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize):
8787
q, qstate = F.quantize_blockwise(g, code=qmap, blocksize=blocksize)
8888

8989
dq_ref = F.dequantize_blockwise(q, qstate)
90-
dq = triton_dequant_blockwise(q, qmap, qstate.absmax, group_size=blocksize)
90+
dq = triton_dequant_blockwise(q, qmap, qstate.absmax, groupsize=blocksize)
9191
assert torch.allclose(dq, dq_ref)

0 commit comments

Comments
 (0)