Skip to content

Commit 2048c4e

Browse files
authored
[torchao] Support quantization configs using module swap (#21982)
Signed-off-by: Jerry Zhang <[email protected]>
1 parent d133601 commit 2048c4e

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,10 @@ steps:
507507
commands:
508508
# temporary install here since we need nightly, will move to requirements/test.in
509509
# after torchao 0.12 release, and pin a working version of torchao nightly here
510+
511+
# since torchao nightly is only compatible with torch nightly currently
512+
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
513+
# we can only upgrade after this is resolved
510514
- pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128
511515
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization
512516

tests/quantization/test_torchao.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,5 +75,25 @@ def test_qwenvl_int8wo_model_loading_with_params(vllm_runner):
7575
print(output)
7676

7777

78+
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
79+
@pytest.mark.skip(
80+
reason="since torchao nightly is only compatible with torch nightly"
81+
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
82+
"torchao tests that requires newer versions (0.14.0.dev+) for now")
83+
def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner):
84+
torch._dynamo.reset()
85+
model_name = ("torchao-testing/opt-125m-AWQConfig-Int4WeightOnlyConfig-v2"
86+
"-0.14.0.dev")
87+
with vllm_runner(model_name=model_name,
88+
quantization="torchao",
89+
dtype="bfloat16",
90+
pt_load_map_location="cuda:0") as llm:
91+
output = llm.generate_greedy(["The capital of France is"],
92+
max_tokens=32)
93+
94+
assert output
95+
print(output)
96+
97+
7898
if __name__ == "__main__":
7999
pytest.main([__file__])

vllm/model_executor/layers/quantization/torchao.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,18 +152,20 @@ def torchao_quantize_param_data(param: torch.Tensor,
152152
from torchao.quantization import quantize_
153153

154154
assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}"
155-
"""
156-
Avoid real weight allocation for faster load, since we will
155+
"""
156+
Avoid real weight allocation for faster load, since we will
157157
end up setting it to param.
158158
"""
159159
with torch.device("meta"):
160-
dummy_linear = torch.nn.Linear(param.shape[1],
161-
param.shape[0],
162-
bias=False)
160+
# linear can't be top level module since quantize_ is inplace
161+
# while some of our configs need to do module swap, and only non-top
162+
# level modules support module swap
163+
dummy_linear = torch.nn.Sequential(
164+
torch.nn.Linear(param.shape[1], param.shape[0], bias=False))
163165

164-
dummy_linear.weight = param
166+
dummy_linear[0].weight = param
165167
quantize_(dummy_linear, torchao_config)
166-
return dummy_linear.weight
168+
return dummy_linear[0].weight
167169

168170

169171
class TorchAOLinearMethod(LinearMethodBase):

0 commit comments

Comments
 (0)