Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions tests/models/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,9 @@ def test_distributed(
kwargs_test=kwargs)


@pytest.mark.skipif(
current_platform.is_rocm(),
reason="bitsandbytes quantization is currently not supported in rocm.")
@pytest.mark.parametrize("model, quantization_kwargs", [
("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {}),
("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {}),
(
"meta-llama/Llama-3.2-1B-Instruct",
{
Expand All @@ -121,6 +120,11 @@ def test_quantization(
max_tokens: int,
num_logprobs: int,
) -> None:
if (current_platform.is_rocm()
and quantization_kwargs.get("quantization", "") == "bitsandbytes"):
pytest.skip(
"bitsandbytes quantization is currently not supported in rocm.")

with vllm_runner(
model, model_impl="auto", enforce_eager=True,
**quantization_kwargs) as vllm_model: # type: ignore[arg-type]
Expand Down
22 changes: 17 additions & 5 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.device_config: DeviceConfig = vllm_config.device_config
self.model_config: ModelConfig = vllm_config.model_config
self.parallel_config: ParallelConfig = vllm_config.parallel_config
self.quant_config: QuantizationConfig = vllm_config.quant_config
self.quant_config: Optional[
QuantizationConfig] = vllm_config.quant_config

self.pp_group = get_pp_group()
self.pp_size = self.pp_group.world_size
Expand All @@ -456,7 +457,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

# Weights to skip in `self.load_weights`
self.skip_prefixes: list[str] = []
"""Skip loading weights whose qualname starts with these prefixes."""
self.skip_substrs: list[str] = []
"""Skip loading weights whose qualname contains these substrings."""
self.ignore_unexpected_prefixes: list[str] = []
"""Ignore unexpected weights whose qualname starts with these prefixes.
"""
self.ignore_unexpected_suffixes: list[str] = []
"""Ignore unexpected weights whose qualname ends with these suffixes."""

# Skip loading extra bias for GPTQ models.
if self.quant_config and "gptq" in self.quant_config.get_name():
self.ignore_unexpected_suffixes.append(".bias")

# Set correct attn and init on "meta" to delay allocating GPU tensors
# TODO: @raushan, use the public `model.set_attn_implementation()`
Expand Down Expand Up @@ -563,9 +575,7 @@ def tensor_parallel(self):
raise ValueError(
f"{type(self.model)} does not support tensor parallel. {tip}")

def _tensor_parallel(module: nn.Module,
prefix: str = "",
tp_plan=None):
def _tensor_parallel(module: nn.Module, prefix: str, tp_plan=None):
tp_plan = tp_plan or {}

# If the current module is a PreTrainedModel, set the tp_plan for
Expand Down Expand Up @@ -597,7 +607,7 @@ def _tensor_parallel(module: nn.Module,
prefix=qual_name,
tp_plan=tp_plan)

_tensor_parallel(self.model)
_tensor_parallel(self.model, prefix="model")

def create_attention_instances(
self,
Expand Down Expand Up @@ -696,6 +706,8 @@ def load_weights(self, weights: Iterable[tuple[str,
self,
skip_prefixes=self.skip_prefixes,
skip_substrs=self.skip_substrs,
ignore_unexpected_prefixes=self.ignore_unexpected_prefixes,
ignore_unexpected_suffixes=self.ignore_unexpected_suffixes,
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

Expand Down
7 changes: 5 additions & 2 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,15 @@ def __init__(
skip_prefixes: Optional[list[str]] = None,
skip_substrs: Optional[list[str]] = None,
ignore_unexpected_prefixes: Optional[list[str]] = None,
ignore_unexpected_suffixes: Optional[list[str]] = None,
) -> None:
super().__init__()

self.module = module
self.skip_prefixes = skip_prefixes or []
self.skip_substrs = skip_substrs or []
self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
self.ignore_unexpected_suffixes = ignore_unexpected_suffixes or []
# update default skip_substrs
self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS

Expand Down Expand Up @@ -149,8 +151,9 @@ def _can_skip(self, qualname: str) -> bool:
or any(substr in qualname for substr in self.skip_substrs))

def _can_ignore_unexpected(self, qualname: str) -> bool:
return any(
qualname.startswith(p) for p in self.ignore_unexpected_prefixes)
iup = (qualname.startswith(p) for p in self.ignore_unexpected_prefixes)
ius = (qualname.endswith(s) for s in self.ignore_unexpected_suffixes)
return any(iup) or any(ius)

def _load_param(
self,
Expand Down