Skip to content

Commit a872c53

Browse files
committed
Works
1 parent da772d2 commit a872c53

File tree

4 files changed

+30
-31
lines changed

4 files changed

+30
-31
lines changed

optimum/commands/export/executorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def parse_args_executorch(parser):
101101
"Options:\n"
102102
" 8da4w - 8-bit dynamic activation, 4-bit weight\n"
103103
" 8da8w - 8-bit dynamic activation, 8-bit weight\n"
104-
" 8da4w,8da8w - 8-bit dynamic activation, 4-bit weight and 8-bit weight\n"
104+
" 8da4w,8da8w - 8-bit dynamic activation, 4-bit weight; fallback on 8-bit dynamic activation, 8-bit weight per-channel where group size doesn't divide block size cleanly \n"
105105
" 4w - 4-bit weight only\n"
106106
" 8w - 8-bit weight only"
107107
),

optimum/exporters/executorch/quantization.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def quantize_model_(
4040
if qlinear_config == "8w":
4141
assert (
4242
qembedding_group_size == 0
43-
), "8-bit embedding quantization only supports per-channel at the moment, please use qembedding_group_size = 0."
43+
), "8-bit embedding quantization only supports per-token at the moment, please use qembedding_group_size = 0."
4444
if qembedding_group_size == 0:
4545
embedding_weight_granularity = PerAxis(0)
4646
else:
@@ -67,6 +67,7 @@ def quantize_model_(
6767
)
6868

6969
if qlinear_config:
70+
7071
def build_linear_config(config_key: str, granularity):
7172
if config_key == "8da4w":
7273
return Int8DynamicActivationIntxWeightConfig(
@@ -94,9 +95,7 @@ def build_linear_config(config_key: str, granularity):
9495
if any(cfg == "" for cfg in qlinear_configs):
9596
raise ValueError("Linear quantization config entries must be non-empty.")
9697
if len(qlinear_configs) > 2:
97-
raise ValueError(
98-
"Expected at most one fallback linear quantization config, got more than one comma."
99-
)
98+
raise ValueError("Expected at most one fallback linear quantization config, got more than one comma.")
10099

101100
primary_linear_config_key = qlinear_configs[0]
102101
fallback_linear_config_key = qlinear_configs[1] if len(qlinear_configs) == 2 else None
@@ -109,16 +108,16 @@ def build_linear_config(config_key: str, granularity):
109108
)
110109
fallback_linear_config_key = None
111110
else:
112-
assert qlinear_group_size % 2 == 0, f"Linear quantization group size must be a multiple of 2, got {qlinear_group_size}."
111+
assert (
112+
qlinear_group_size % 2 == 0
113+
), f"Linear quantization group size must be a multiple of 2, got {qlinear_group_size}."
113114
linear_weight_granularity = PerGroup(qlinear_group_size)
114115

115116
logging.info("Quantizing linear layers.")
116-
primary_linear_config = build_linear_config(
117-
primary_linear_config_key, linear_weight_granularity
118-
)
117+
primary_linear_config = build_linear_config(primary_linear_config_key, linear_weight_granularity)
119118

120119
# First, quantize layers that are compatible with group quantization
121-
def quant_filter(module, fqn):
120+
def per_group_filter(module, fqn):
122121
if isinstance(module, torch.nn.Linear):
123122
# Check if hidden dimension is divisible by group size
124123
# For Linear layers, weight shape is [out_features, in_features]
@@ -129,20 +128,16 @@ def quant_filter(module, fqn):
129128
quantize_(
130129
eager_model,
131130
primary_linear_config,
132-
filter_fn=quant_filter,
131+
filter_fn=per_group_filter,
133132
)
134133

135134
# Then, quantize incompatible layers using the fallback per-axis config
136135
if fallback_linear_config_key is not None:
137-
fallback_linear_config = build_linear_config(
138-
fallback_linear_config_key, PerAxis(0)
139-
)
140-
141-
def per_channel_filter(module, fqn):
136+
fallback_linear_config = build_linear_config(fallback_linear_config_key, PerAxis(0))
137+
138+
def per_token_filter(module, fqn):
142139
if isinstance(module, torch.nn.Linear):
143-
# Only quantize layers that are NOT compatible with group quantization
144-
# and haven't been quantized yet
145-
return not quant_filter(module, fqn)
140+
return module.weight.shape[1] % qlinear_group_size != 0
146141
return False
147142

148143
logging.info(
@@ -152,7 +147,7 @@ def per_channel_filter(module, fqn):
152147
quantize_(
153148
eager_model,
154149
fallback_linear_config,
155-
filter_fn=per_channel_filter,
150+
filter_fn=per_token_filter,
156151
)
157152

158153
unwrap_tensor_subclass(eager_model)

optimum/exporters/executorch/tasks/multimodal_text_to_text.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
import json
17+
import logging
1718
import os.path
1819

1920
import torchao
@@ -201,15 +202,24 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
201202
qembedding_group_size = kwargs.get("qembedding_group_size", None)
202203

203204
# Quantize decoder linear weights.
205+
if qlinear_config:
206+
logging.info("Quantizing decoder linears...")
204207
quantize_decoder_kwargs = {
205208
"eager_model": getattr(eager_model, decoder_name),
206209
"qlinear_config": qlinear_config,
207210
}
211+
quantize_lm_head_kwargs = {
212+
"eager_model": eager_model.lm_head,
213+
"qlinear_config": qlinear_config,
214+
}
208215
if qlinear_group_size is not None:
209216
quantize_decoder_kwargs["qlinear_group_size"] = qlinear_group_size
210217
quantize_model_(**quantize_decoder_kwargs)
218+
quantize_model_(**quantize_lm_head_kwargs)
211219

212220
# Quantize encoder linear weights.
221+
if qlinear_encoder_config:
222+
logging.info("Quantizing encoder linears...")
213223
quantize_encoder_kwargs = {
214224
"eager_model": getattr(eager_model, encoder_name),
215225
"qlinear_config": qlinear_encoder_config,
@@ -219,6 +229,8 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
219229
quantize_model_(**quantize_encoder_kwargs)
220230

221231
# Quantize decoder embeddings.
232+
if qembedding_config:
233+
logging.info("Quantizing decoder embeddings...")
222234
quantize_decoder_embedding_kwargs = {
223235
"eager_model": eager_model,
224236
"qembedding_config": qembedding_config,
@@ -227,14 +239,6 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
227239
quantize_decoder_embedding_kwargs["qembedding_group_size"] = qembedding_group_size
228240
quantize_model_(**quantize_decoder_embedding_kwargs)
229241

230-
# Quantize lm_head
231-
if hasattr(eager_model, "lm_head") and qlinear_config is not None:
232-
quantize_model_(
233-
eager_model=eager_model.lm_head,
234-
qlinear_config=qlinear_config,
235-
qlinear_group_size=qlinear_group_size if qlinear_group_size is not None else 0,
236-
)
237-
print(eager_model)
238242
return MultiModalTextToTextExportableModule(
239243
model=eager_model,
240244
modality="audio" if audio_encoder_name else "vision",

tests/models/test_modeling_gemma3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,9 @@ def test_gemma3_image_vision_with_custom_sdpa_kv_cache_8da4w_8we(self):
309309
use_custom_kv_cache=True,
310310
qlinear="8da4w",
311311
qlinear_group_size=32,
312-
# Can't quantize the encoder a the moment, hidden dim of 4304 doesn't fit ExecuTorch's
313-
# XNNPack 32-group size quantized kernels. See https://github.com/pytorch/executorch/issues/14221.
314-
qembedding_config="8w",
312+
qlinear_encoder="8da4w,8da8w",
313+
qlinear_encoder_group_size=32,
314+
qembedding="8w",
315315
)
316316

317317
# Generate

0 commit comments

Comments
 (0)