Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,10 @@ steps:
commands:
# temporary install here since we need nightly, will move to requirements/test.in
# after torchao 0.12 release, and pin a working version of torchao nightly here

# since torchao nightly is only compatible with torch nightly currently
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
# we can only upgrade after this is resolved
- pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization

Expand Down
20 changes: 20 additions & 0 deletions tests/quantization/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,25 @@ def test_qwenvl_int8wo_model_loading_with_params(vllm_runner):
print(output)


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
@pytest.mark.skip(
reason="since torchao nightly is only compatible with torch nightly"
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
"torchao tests that requires newer versions (0.14.0.dev+) for now")
def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner):
torch._dynamo.reset()
model_name = ("torchao-testing/opt-125m-AWQConfig-Int4WeightOnlyConfig-v2"
"-0.14.0.dev")
with vllm_runner(model_name=model_name,
quantization="torchao",
dtype="bfloat16",
pt_load_map_location="cuda:0") as llm:
output = llm.generate_greedy(["The capital of France is"],
max_tokens=32)

assert output
print(output)


if __name__ == "__main__":
pytest.main([__file__])
16 changes: 9 additions & 7 deletions vllm/model_executor/layers/quantization/torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,18 +152,20 @@ def torchao_quantize_param_data(param: torch.Tensor,
from torchao.quantization import quantize_

assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}"
"""
Avoid real weight allocation for faster load, since we will
"""
Avoid real weight allocation for faster load, since we will
end up setting it to param.
"""
with torch.device("meta"):
dummy_linear = torch.nn.Linear(param.shape[1],
param.shape[0],
bias=False)
# linear can't be top level module since quantize_ is inplace
# while some of our configs need to do module swap, and only non-top
# level modules support module swap
dummy_linear = torch.nn.Sequential(
torch.nn.Linear(param.shape[1], param.shape[0], bias=False))

dummy_linear.weight = param
dummy_linear[0].weight = param
quantize_(dummy_linear, torchao_config)
return dummy_linear.weight
return dummy_linear[0].weight
Comment on lines +166 to +168
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

After quantization, the module may have been swapped. Instead of directly accessing dummy_linear[0].weight, retrieve the weight by inspecting the module's parameters. This avoids making fragile assumptions about the internal structure of the quantized module, which may change in future torchao versions.

    dummy_linear[0].weight = param
    quantize_(dummy_linear, torchao_config)
    # After quantization, the module may have been swapped.
    # We retrieve the single parameter, which is the quantized weight.
    params = list(dummy_linear.parameters())
    assert len(params) == 1, (
        "Expected the dummy module to have exactly one parameter after "
        f"quantization, but found {len(params)}."
    )
    return params[0].data



class TorchAOLinearMethod(LinearMethodBase):
Expand Down