Skip to content

Commit 8b22cad

Browse files
author
Martin Yuan
committed
Apply hybrid quantization on Mimi
1 parent 878a7f6 commit 8b22cad

File tree

3 files changed

+140
-3
lines changed

3 files changed

+140
-3
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1060,7 +1060,9 @@ def _load_llama_model(
10601060
model.vocab_size,
10611061
metadata_str,
10621062
),
1063-
args=args,
1063+
qnn = args.qnn,
1064+
export_only=args.export_only,
1065+
output_name=args.output_name,
10641066
)
10651067

10661068

examples/models/llava/export_llava.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ def forward(self, input_pos, embeddings):
9292
use_kv_cache=True,
9393
example_inputs=(torch.tensor([0], dtype=torch.int64), embeddings),
9494
dynamic_shapes=dynamic_shapes,
95-
args=llava.text_model_args,
95+
qnn = llava.text_model_args.qnn,
96+
export_only=llava.text_model_args.export_only,
97+
output_name=llava.text_model_args.output_name,
9698
)
9799

98100
dtype_override = DType.fp32

examples/models/moshi/mimi/test_mimi.py

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import os
33
import random
44
import unittest
5+
from functools import partial
6+
from executorch.examples.models.llama.source_transformation.quantize import quantize
57

68
import numpy as np
79
import requests
@@ -22,13 +24,23 @@
2224
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
2325
from torch.export import export, ExportedProgram
2426
from torch.utils._pytree import tree_flatten
27+
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
28+
from omegaconf import OmegaConf
29+
from pathlib import Path
30+
from executorch.examples.models.llama.export_llama_lib import (
31+
_get_source_transforms,
32+
get_quantizer_and_quant_params,
33+
)
34+
from executorch.extension.llm.export.partitioner_lib import (
35+
get_xnnpack_partitioner,
36+
)
37+
import logging
2538

2639
proxies = {
2740
"http": "http://fwdproxy:8080",
2841
"https": "http://fwdproxy:8080",
2942
}
3043

31-
3244
def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> float:
3345
assert x.shape == y.shape, "Tensor shapes do not match"
3446
x = x.float()
@@ -219,6 +231,127 @@ def forward(self, x):
219231
print(f"SQNR: {sqnr}")
220232
torch.testing.assert_close(eager_res, res[0], atol=1e-3, rtol=1e-3)
221233

234+
def test_exported_decoder_hybrid_quant(self):
235+
class MimiDecode(nn.Module):
236+
def __init__(self, mimi: nn.Module):
237+
super().__init__()
238+
self.mimi_model = mimi
239+
240+
def forward(self, x):
241+
return self.mimi_model.decode(x)
242+
243+
sample_pcm = torch.tensor(self.sample_pcm, device=self.device)[None]
244+
pcm_chunk_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
245+
chunk = sample_pcm[..., 0:pcm_chunk_size]
246+
input = self.mimi.encode(chunk)
247+
248+
mimi_decode = MimiDecode(self.mimi)
249+
eager_output = mimi_decode(input)
250+
251+
config_dict = {
252+
"model": "llama3",
253+
"checkpoint": None,
254+
"max_seq_length": 128,
255+
"dtype_override": "fp32",
256+
"use_kv_cache": False,
257+
"generate_full_logits": False,
258+
"enable_dynamic_shape": False,
259+
"verbose": True,
260+
"qnn": False,
261+
"export_only": False,
262+
"output_name": "output",
263+
"output_dir": "/tmp",
264+
"xnnpack_extended_ops": True,
265+
"quantization": {
266+
"mode": "8da4w",
267+
"group_size": 64,
268+
}
269+
}
270+
271+
# Create a DictConfig object from the dictionary
272+
config = OmegaConf.create(config_dict)
273+
274+
edge_manager = LLMEdgeManager(
275+
model=mimi_decode,
276+
modelname=config.model,
277+
max_seq_len=config.max_seq_length,
278+
dtype=config.dtype_override,
279+
use_kv_cache=config.use_kv_cache,
280+
generate_full_logits=config.generate_full_logits,
281+
example_inputs=(input,),
282+
example_kwarg_inputs=None,
283+
dynamic_shapes=None,
284+
enable_dynamic_shape=config.enable_dynamic_shape,
285+
verbose=config.verbose,
286+
)
287+
288+
dtype_override = DType[config.dtype_override]
289+
edge_manager.to_dtype(dtype_override)
290+
291+
transforms = []
292+
293+
# Linear 8da4w.
294+
# TODO: look into decode_latent as an "embedding" layer
295+
if config.quantization.mode:
296+
quant_args = {
297+
"group_size": config.quantization.group_size,
298+
}
299+
transforms.append(
300+
partial(
301+
quantize,
302+
**quant_args,
303+
qmode=config.quantization.mode,
304+
computation_dtype=dtype_override,
305+
checkpoint_dtype=None,
306+
checkpoint_path=(Path(path) if (path := config.checkpoint) is not None else None),
307+
)
308+
)
309+
310+
llm_manager = edge_manager.source_transform(transforms)
311+
# llm_manager = edge_manager
312+
builder_exported = llm_manager.export()
313+
builder_exported.run_canonical_optimizations()
314+
315+
316+
# Lower to xnnpack
317+
partitioners = []
318+
319+
# Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled
320+
partitioners.append(get_xnnpack_partitioner(dynamic_quant_only_partitioner=True))
321+
322+
323+
if config.xnnpack_extended_ops:
324+
partitioners.append(
325+
get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
326+
)
327+
328+
for partitioner in partitioners:
329+
logging.info(f"--> {partitioner.__class__.__name__}")
330+
331+
# builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
332+
builder = builder_exported.to_edge_transform_and_lower(
333+
partitioners
334+
)
335+
print_delegation_info(builder.edge_manager.exported_program().graph_module)
336+
337+
builder = builder.to_executorch()
338+
339+
builder.save_to_pte("mimi_4bit")
340+
341+
#
342+
#
343+
# llm_manager = edge_manager.source_transform(
344+
# _get_source_transforms(
345+
# modelname=config.model, dtype_override=dtype_override, args=args
346+
# )
347+
# )
348+
# builder_exported = edge_manager.set_output_dir(args.output_dir).export()
349+
# builder_exported.run_canonical_optimizations()
350+
#
351+
# pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(
352+
# args
353+
# )
354+
222355

223356
if __name__ == "__main__":
224357
unittest.main()

0 commit comments

Comments
 (0)