-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[Feature] Add support for naver/splade-v3 (BERT-based sparse embedding model) #26339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3827c27
657860b
693c658
415137d
0c22312
3276ca4
b220766
706a735
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
||
import types | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
|
||
from vllm.model_executor.models.bert import ( | ||
BertMLMHead, | ||
SPLADESparsePooler, | ||
) | ||
|
||
# --------------------------------------------------------------------- | ||
# 1) Functional test: SPLADE formula correctness (no HF download needed) | ||
# --------------------------------------------------------------------- | ||
|
||
|
||
@pytest.mark.parametrize("B,T,H,V", [(2, 3, 5, 7)]) | ||
def test_splade_pooler_matches_reference_formula(B, T, H, V): | ||
"""Ensure SPLADESparsePooler forward() matches the mathematical formula: | ||
log1p(relu(logits)) -> max over sequence length (after masking).""" | ||
torch.manual_seed(0) | ||
|
||
# Prepare [B] sequences of shape [T, H] | ||
hs_list = [torch.randn(T, H) for _ in range(B)] | ||
|
||
# Simulate PoolingMetadata (only required fields) | ||
prompt_lens = [T, T - 1] | ||
token_ids = torch.tensor( | ||
[ | ||
[101, 5, 102], # Batch 0: [CLS], token, [SEP] | ||
[101, 6, 6], # Batch 1: [CLS], token, token (last token ignored) | ||
], | ||
dtype=torch.long, | ||
) | ||
meta = types.SimpleNamespace(prompt_lens=prompt_lens, prompt_token_ids=token_ids) | ||
|
||
# MLM head (prefer BertMLMHead, fallback to Linear if unavailable) | ||
try: | ||
mlm_head = BertMLMHead(hidden_size=H, vocab_size=V, layer_norm_eps=1e-12) | ||
except Exception: | ||
mlm_head = nn.Linear(H, V, bias=True) | ||
|
||
# Forward pass through SPLADE pooler | ||
pooler = SPLADESparsePooler(mlm_head=mlm_head, pooling="max", remove_cls_sep=True) | ||
pooled = pooler(hidden_states=hs_list, pooling_metadata=meta) # list of [V] | ||
|
||
# Basic output checks | ||
assert isinstance(pooled, list) and len(pooled) == B | ||
for vec in pooled: | ||
assert vec.shape == (V,) | ||
assert torch.isfinite(vec).all() | ||
assert (vec >= 0).all(), "SPLADE outputs must be non-negative." | ||
|
||
# Reference implementation for comparison | ||
def ref_one(hs: torch.Tensor, L: int, tid_row: torch.Tensor) -> torch.Tensor: | ||
keep = torch.ones(L, dtype=torch.bool) | ||
if L > 0 and tid_row[0].item() == 101: # remove CLS | ||
keep[0] = False | ||
if L > 0 and tid_row[L - 1].item() == 102: # remove SEP | ||
keep[L - 1] = False | ||
|
||
valid = hs[:L][keep[:L]] | ||
if valid.numel() == 0: | ||
return torch.zeros(V, dtype=torch.float32) | ||
|
||
logits = mlm_head(valid) # [L', V] | ||
scores = torch.log1p(torch.relu(logits)) # [L', V] | ||
return scores.max(dim=0).values.to(torch.float32) | ||
|
||
torch.testing.assert_close( | ||
pooled[0], | ||
ref_one(hs_list[0], prompt_lens[0], token_ids[0]), | ||
rtol=1e-4, | ||
atol=1e-4, | ||
) | ||
torch.testing.assert_close( | ||
pooled[1], | ||
ref_one(hs_list[1], prompt_lens[1], token_ids[1]), | ||
rtol=1e-4, | ||
atol=1e-4, | ||
) | ||
|
||
|
||
# --------------------------------------------------------------------- | ||
# 2) Integration smoke test: end-to-end embedding path wiring | ||
# --------------------------------------------------------------------- | ||
|
||
|
||
@pytest.mark.cpu_model | ||
def test_bert_splade_sparse_embed_smoke(vllm_runner, monkeypatch): | ||
"""Ensure BertSpladeSparseEmbeddingModel loads and produces sparse embeddings.""" | ||
from transformers import AutoTokenizer | ||
|
||
MODEL_ID = "hf-internal-testing/tiny-random-bert" | ||
hf_overrides = {"architectures": ["BertSpladeSparseEmbeddingModel"]} | ||
|
||
# Enforce CPU-only execution (optional) | ||
monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "") | ||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") | ||
|
||
tok = AutoTokenizer.from_pretrained(MODEL_ID) | ||
vocab_size = tok.vocab_size | ||
|
||
# The embed path should route through SPLADESparsePooler | ||
with vllm_runner( | ||
MODEL_ID, | ||
runner="pooling", | ||
max_model_len=64, | ||
hf_overrides=hf_overrides, | ||
) as vm: | ||
outs = vm.embed(["hello world", "splade sparse test"]) | ||
|
||
# Basic sanity checks | ||
assert len(outs) == 2 | ||
assert outs[0].shape[0] == vocab_size | ||
assert outs[1].shape[0] == vocab_size | ||
assert np.isfinite(outs[0]).all() and (outs[0] >= 0).all() | ||
assert np.isfinite(outs[1]).all() and (outs[1] >= 0).all() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -573,6 +573,229 @@ def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor: | |
return token_type_ids | ||
|
||
|
||
class BertMLMHead(nn.Module): | ||
def __init__( | ||
self, hidden_size: int, vocab_size: int, layer_norm_eps: float = 1e-12 | ||
): | ||
super().__init__() | ||
self.dense = nn.Linear(hidden_size, hidden_size) | ||
self.activation = nn.GELU() | ||
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) | ||
self.decoder = nn.Linear(hidden_size, vocab_size, bias=True) | ||
|
||
def tie_weights_with_embeddings(self, embeddings_weight: torch.Tensor): | ||
self.decoder.weight = embeddings_weight | ||
|
||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
x = self.dense(hidden_states) | ||
x = self.activation(x) | ||
x = self.layer_norm(x) | ||
logits = self.decoder(x) | ||
return logits | ||
|
||
|
||
class SPLADESparsePooler(Pooler): | ||
""" | ||
SPLADE sparse pooling: | ||
logits = mlm_head(hidden_states) | ||
-> log1p(relu(logits)) | ||
-> (max|sum over L) | ||
-> [V] | ||
|
||
Padding is masked with an attention mask, | ||
[CLS]/[SEP] is removed (selected), | ||
and then pooled. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
mlm_head: nn.Module, | ||
cls_token_id: Optional[int] = 101, | ||
sep_token_id: Optional[int] = 102, | ||
pooling: str = "max", | ||
remove_cls_sep: bool = True, | ||
): | ||
super().__init__() | ||
assert pooling in ("max", "sum") | ||
self.mlm_head = mlm_head | ||
self.cls_token_id = cls_token_id | ||
self.sep_token_id = sep_token_id | ||
self.pooling = pooling | ||
self.remove_cls_sep = remove_cls_sep | ||
|
||
def get_supported_tasks(self) -> Set[PoolingTask]: | ||
return {"embed"} | ||
|
||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: | ||
return PoolingParamsUpdate(requires_token_ids=True) | ||
|
||
def forward( | ||
self, | ||
hidden_states: Union[torch.Tensor, list[torch.Tensor]], | ||
pooling_metadata: PoolingMetadata, | ||
) -> Union[torch.Tensor, list[torch.Tensor]]: | ||
lens_tensor: torch.Tensor = pooling_metadata.prompt_lens | ||
lens: list[int] = lens_tensor.tolist() | ||
B: int = len(lens) | ||
max_len: int = int(lens_tensor.max().item()) | ||
|
||
if isinstance(hidden_states, list): | ||
hs_list = hidden_states | ||
else: | ||
hs_list = torch.split(hidden_states, lens, dim=0) | ||
|
||
device = hs_list[0].device | ||
H = hs_list[0].size(-1) | ||
|
||
padded = hidden_states.new_zeros((B, max_len, H)) | ||
valid_mask = torch.zeros((B, max_len), dtype=torch.bool, device=device) | ||
for i, (hs, L) in enumerate(zip(hs_list, lens)): | ||
L = int(L) | ||
padded[i, :L] = hs | ||
valid_mask[i, :L] = True | ||
|
||
token_ids = pooling_metadata.prompt_token_ids | ||
if self.remove_cls_sep and token_ids is not None: | ||
for i, L in enumerate(lens): | ||
if ( | ||
self.cls_token_id is not None | ||
gjgjos marked this conversation as resolved.
Show resolved
Hide resolved
|
||
and int(token_ids[i, 0].item()) == self.cls_token_id | ||
): | ||
valid_mask[i, 0] = False | ||
if ( | ||
self.sep_token_id is not None | ||
and int(token_ids[i, L - 1].item()) == self.sep_token_id | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the tensor dtype is int, you don't need to cast. |
||
): | ||
valid_mask[i, L - 1] = False | ||
|
||
flat = padded.reshape(B * max_len, H) | ||
logits = self.mlm_head(flat) | ||
V = int(logits.size(-1)) | ||
logits = logits.view(B, max_len, V) | ||
|
||
# SPLADE activation | ||
scores = torch.log1p(torch.relu(logits)) # [B, T, V] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I'm not mistaken, since hidden_states is always a concatenated tensor in V1, up to this point you don't need to split and construct a padded batch. This would save some expensive memory operations. In the code below this point you would have to loop over the split scores, which isn't as pretty, but a mask would not be required. |
||
|
||
if self.pooling == "sum": | ||
pooled = (scores * valid_mask.to(scores.dtype).unsqueeze(-1)).sum(dim=1) | ||
else: | ||
neg_inf = torch.tensor( | ||
float("-inf"), device=scores.device, dtype=scores.dtype | ||
) | ||
masked = scores.masked_fill(~valid_mask.unsqueeze(-1), neg_inf) | ||
pooled = masked.amax(dim=1) | ||
pooled = torch.where( | ||
torch.isneginf(pooled), torch.zeros_like(pooled), pooled | ||
) | ||
|
||
return pooled.contiguous() | ||
|
||
|
||
@default_pooling_type("CLS") | ||
class BertSpladeSparseEmbeddingModel(BertEmbeddingModel): | ||
""" | ||
BertEmbeddingModel + SPLADE sparse embedding. | ||
- Make logits by self.mlm_head | ||
- pooler: SPLADESparsePooler(mlm_head...) | ||
""" | ||
|
||
def __init__( | ||
self, *, vllm_config: VllmConfig, prefix: str = "", splade_pooling: str = "max" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The This parameter should be made configurable, for instance by reading it from |
||
): | ||
super().__init__(vllm_config=vllm_config, prefix=prefix) | ||
cfg = vllm_config.model_config.hf_config | ||
|
||
# MLM head | ||
self.mlm_head = BertMLMHead( | ||
hidden_size=cfg.hidden_size, | ||
vocab_size=cfg.vocab_size, | ||
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12), | ||
) | ||
|
||
self._splade_pooling = splade_pooling | ||
pooler_config = vllm_config.model_config.pooler_config | ||
assert pooler_config is not None | ||
self.pooler = self._build_pooler(pooler_config) | ||
|
||
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: | ||
cfg = self.model.config | ||
|
||
if not hasattr(self, "mlm_head"): | ||
self.mlm_head = BertMLMHead( | ||
hidden_size=cfg.hidden_size, | ||
vocab_size=cfg.vocab_size, | ||
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12), | ||
) | ||
|
||
pooling_mode = getattr(self, "_splade_pooling", "max") | ||
|
||
cls_id = getattr(cfg, "cls_token_id", None) | ||
sep_id = getattr(cfg, "sep_token_id", None) | ||
|
||
return DispatchPooler( | ||
{ | ||
"encode": Pooler.for_encode(pooler_config), | ||
"embed": SPLADESparsePooler( | ||
mlm_head=self.mlm_head, | ||
cls_token_id=cls_id, | ||
sep_token_id=sep_id, | ||
pooling=pooling_mode, # "max" or "sum" | ||
remove_cls_sep=True, | ||
), | ||
} | ||
) | ||
|
||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): | ||
if not hasattr(self, "mlm_head"): | ||
cfg = self.model.config | ||
self.mlm_head = BertMLMHead( | ||
hidden_size=cfg.hidden_size, | ||
vocab_size=cfg.vocab_size, | ||
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12), | ||
) | ||
|
||
def _strip(name: str) -> str: | ||
for p in ("model.", "bert."): | ||
if name.startswith(p): | ||
name = name[len(p) :] | ||
return name | ||
|
||
weights_list = list(weights) | ||
model_side: list[tuple[str, torch.Tensor]] = [] | ||
mlm_side: list[tuple[str, torch.Tensor]] = [] | ||
|
||
for k, w in weights_list: | ||
name = _strip(k) | ||
if name.startswith("cls.predictions."): | ||
mlm_side.append((name, w)) | ||
else: | ||
model_side.append((name, w)) | ||
|
||
loaded: set[str] = set() | ||
loaded_model = self.model.load_weights(model_side) | ||
loaded.update({"model." + n for n in loaded_model}) | ||
|
||
if mlm_side: | ||
name_map = { | ||
"cls.predictions.transform.dense.weight": "mlm_head.dense.weight", | ||
"cls.predictions.transform.dense.bias": "mlm_head.dense.bias", | ||
("cls.predictions.transform.LayerNorm.weight"): ( | ||
"mlm_head.layer_norm.weight" | ||
), | ||
("cls.predictions.transform.LayerNorm.bias"): ( | ||
"mlm_head.layer_norm.bias" | ||
), | ||
"cls.predictions.decoder.weight": "mlm_head.decoder.weight", | ||
"cls.predictions.decoder.bias": "mlm_head.decoder.bias", | ||
} | ||
remapped = [(name_map[n], w) for n, w in mlm_side if n in name_map] | ||
if remapped: | ||
loaded_mlm = AutoWeightsLoader(self).load_weights(remapped) | ||
loaded.update(loaded_mlm) | ||
|
||
return loaded | ||
|
||
|
||
@default_pooling_type("CLS") | ||
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant): | ||
"""A model that uses Bert to provide embedding functionalities. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@maxdebayser under what circumstances are the hidden stats a list? It seems that
GPUModelRunner._pool
annotates it as a tensor.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I'm not mistaken, now that we've removed the V0 code, it can never be a list.