Skip to content

Commit 3b3ae50

Browse files
committed
Works
1 parent da772d2 commit 3b3ae50

File tree

4 files changed

+31
-31
lines changed

4 files changed

+31
-31
lines changed

optimum/commands/export/executorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def parse_args_executorch(parser):
101101
"Options:\n"
102102
" 8da4w - 8-bit dynamic activation, 4-bit weight\n"
103103
" 8da8w - 8-bit dynamic activation, 8-bit weight\n"
104-
" 8da4w,8da8w - 8-bit dynamic activation, 4-bit weight and 8-bit weight\n"
104+
" 8da4w,8da8w - 8-bit dynamic activation, 4-bit weight; fallback on 8-bit dynamic activation, 8-bit weight per-channel where group size doesn't divide block size cleanly \n"
105105
" 4w - 4-bit weight only\n"
106106
" 8w - 8-bit weight only"
107107
),

optimum/exporters/executorch/quantization.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def quantize_model_(
4040
if qlinear_config == "8w":
4141
assert (
4242
qembedding_group_size == 0
43-
), "8-bit embedding quantization only supports per-channel at the moment, please use qembedding_group_size = 0."
43+
), "8-bit embedding quantization only supports per-token at the moment, please use qembedding_group_size = 0."
4444
if qembedding_group_size == 0:
4545
embedding_weight_granularity = PerAxis(0)
4646
else:
@@ -94,9 +94,7 @@ def build_linear_config(config_key: str, granularity):
9494
if any(cfg == "" for cfg in qlinear_configs):
9595
raise ValueError("Linear quantization config entries must be non-empty.")
9696
if len(qlinear_configs) > 2:
97-
raise ValueError(
98-
"Expected at most one fallback linear quantization config, got more than one comma."
99-
)
97+
raise ValueError("Expected at most one fallback linear quantization config, got more than one comma.")
10098

10199
primary_linear_config_key = qlinear_configs[0]
102100
fallback_linear_config_key = qlinear_configs[1] if len(qlinear_configs) == 2 else None
@@ -109,16 +107,16 @@ def build_linear_config(config_key: str, granularity):
109107
)
110108
fallback_linear_config_key = None
111109
else:
112-
assert qlinear_group_size % 2 == 0, f"Linear quantization group size must be a multiple of 2, got {qlinear_group_size}."
110+
assert (
111+
qlinear_group_size % 2 == 0
112+
), f"Linear quantization group size must be a multiple of 2, got {qlinear_group_size}."
113113
linear_weight_granularity = PerGroup(qlinear_group_size)
114114

115115
logging.info("Quantizing linear layers.")
116-
primary_linear_config = build_linear_config(
117-
primary_linear_config_key, linear_weight_granularity
118-
)
116+
primary_linear_config = build_linear_config(primary_linear_config_key, linear_weight_granularity)
119117

120118
# First, quantize layers that are compatible with group quantization
121-
def quant_filter(module, fqn):
119+
def per_group_filter(module, fqn):
122120
if isinstance(module, torch.nn.Linear):
123121
# Check if hidden dimension is divisible by group size
124122
# For Linear layers, weight shape is [out_features, in_features]
@@ -129,20 +127,16 @@ def quant_filter(module, fqn):
129127
quantize_(
130128
eager_model,
131129
primary_linear_config,
132-
filter_fn=quant_filter,
130+
filter_fn=per_group_filter,
133131
)
134132

135133
# Then, quantize incompatible layers using the fallback per-axis config
136134
if fallback_linear_config_key is not None:
137-
fallback_linear_config = build_linear_config(
138-
fallback_linear_config_key, PerAxis(0)
139-
)
140-
141-
def per_channel_filter(module, fqn):
135+
fallback_linear_config = build_linear_config(fallback_linear_config_key, PerAxis(0))
136+
137+
def per_token_filter(module, fqn):
142138
if isinstance(module, torch.nn.Linear):
143-
# Only quantize layers that are NOT compatible with group quantization
144-
# and haven't been quantized yet
145-
return not quant_filter(module, fqn)
139+
return module.weight.shape[1] % qlinear_group_size != 0
146140
return False
147141

148142
logging.info(
@@ -152,7 +146,7 @@ def per_channel_filter(module, fqn):
152146
quantize_(
153147
eager_model,
154148
fallback_linear_config,
155-
filter_fn=per_channel_filter,
149+
filter_fn=per_token_filter,
156150
)
157151

158152
unwrap_tensor_subclass(eager_model)

optimum/exporters/executorch/tasks/multimodal_text_to_text.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
import json
17+
import logging
1718
import os.path
1819

1920
import torchao
@@ -201,15 +202,24 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
201202
qembedding_group_size = kwargs.get("qembedding_group_size", None)
202203

203204
# Quantize decoder linear weights.
205+
if qlinear_config:
206+
logging.info("Quantizing decoder linears...")
204207
quantize_decoder_kwargs = {
205208
"eager_model": getattr(eager_model, decoder_name),
206209
"qlinear_config": qlinear_config,
207210
}
211+
quantize_lm_head_kwargs = {
212+
"eager_model": eager_model.lm_head,
213+
"qlinear_config": qlinear_config,
214+
}
208215
if qlinear_group_size is not None:
209216
quantize_decoder_kwargs["qlinear_group_size"] = qlinear_group_size
210217
quantize_model_(**quantize_decoder_kwargs)
218+
quantize_model_(**quantize_lm_head_kwargs)
211219

212220
# Quantize encoder linear weights.
221+
if qlinear_encoder_config:
222+
logging.info("Quantizing encoder linears...")
213223
quantize_encoder_kwargs = {
214224
"eager_model": getattr(eager_model, encoder_name),
215225
"qlinear_config": qlinear_encoder_config,
@@ -219,6 +229,8 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
219229
quantize_model_(**quantize_encoder_kwargs)
220230

221231
# Quantize decoder embeddings.
232+
if qembedding_config:
233+
logging.info("Quantizing decoder embeddings...")
222234
quantize_decoder_embedding_kwargs = {
223235
"eager_model": eager_model,
224236
"qembedding_config": qembedding_config,
@@ -227,14 +239,8 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
227239
quantize_decoder_embedding_kwargs["qembedding_group_size"] = qembedding_group_size
228240
quantize_model_(**quantize_decoder_embedding_kwargs)
229241

230-
# Quantize lm_head
231-
if hasattr(eager_model, "lm_head") and qlinear_config is not None:
232-
quantize_model_(
233-
eager_model=eager_model.lm_head,
234-
qlinear_config=qlinear_config,
235-
qlinear_group_size=qlinear_group_size if qlinear_group_size is not None else 0,
236-
)
237-
print(eager_model)
242+
# TODO: Quantize encoder embeddings.
243+
238244
return MultiModalTextToTextExportableModule(
239245
model=eager_model,
240246
modality="audio" if audio_encoder_name else "vision",

tests/models/test_modeling_gemma3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,9 @@ def test_gemma3_image_vision_with_custom_sdpa_kv_cache_8da4w_8we(self):
309309
use_custom_kv_cache=True,
310310
qlinear="8da4w",
311311
qlinear_group_size=32,
312-
# Can't quantize the encoder a the moment, hidden dim of 4304 doesn't fit ExecuTorch's
313-
# XNNPack 32-group size quantized kernels. See https://github.com/pytorch/executorch/issues/14221.
314-
qembedding_config="8w",
312+
qlinear_encoder="8da4w,8da8w",
313+
qlinear_encoder_group_size=32,
314+
qembedding="8w",
315315
)
316316

317317
# Generate

0 commit comments

Comments
 (0)