Skip to content
122 changes: 122 additions & 0 deletions tests/models/language/pooling/test_splade_sparse_pooler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import types

import numpy as np
import pytest
import torch
import torch.nn as nn

from vllm.model_executor.models.bert import (
BertMLMHead,
SPLADESparsePooler,
)

# ---------------------------------------------------------------------
# 1) Functional test: SPLADE formula correctness (no HF download needed)
# ---------------------------------------------------------------------


@pytest.mark.parametrize("B,T,H,V", [(2, 3, 5, 7)])
def test_splade_pooler_matches_reference_formula(B, T, H, V):
"""Ensure SPLADESparsePooler forward() matches the mathematical formula:
log1p(relu(logits)) -> max over sequence length (after masking)."""
torch.manual_seed(0)

# Prepare [B] sequences of shape [T, H]
hs_list = [torch.randn(T, H) for _ in range(B)]

# Simulate PoolingMetadata (only required fields)
prompt_lens = [T, T - 1]
token_ids = torch.tensor(
[
[101, 5, 102], # Batch 0: [CLS], token, [SEP]
[101, 6, 6], # Batch 1: [CLS], token, token (last token ignored)
],
dtype=torch.long,
)
meta = types.SimpleNamespace(prompt_lens=prompt_lens, prompt_token_ids=token_ids)

# MLM head (prefer BertMLMHead, fallback to Linear if unavailable)
try:
mlm_head = BertMLMHead(hidden_size=H, vocab_size=V, layer_norm_eps=1e-12)
except Exception:
mlm_head = nn.Linear(H, V, bias=True)

# Forward pass through SPLADE pooler
pooler = SPLADESparsePooler(mlm_head=mlm_head, pooling="max", remove_cls_sep=True)
pooled = pooler(hidden_states=hs_list, pooling_metadata=meta) # list of [V]

# Basic output checks
assert isinstance(pooled, list) and len(pooled) == B
for vec in pooled:
assert vec.shape == (V,)
assert torch.isfinite(vec).all()
assert (vec >= 0).all(), "SPLADE outputs must be non-negative."

# Reference implementation for comparison
def ref_one(hs: torch.Tensor, L: int, tid_row: torch.Tensor) -> torch.Tensor:
keep = torch.ones(L, dtype=torch.bool)
if L > 0 and tid_row[0].item() == 101: # remove CLS
keep[0] = False
if L > 0 and tid_row[L - 1].item() == 102: # remove SEP
keep[L - 1] = False

valid = hs[:L][keep[:L]]
if valid.numel() == 0:
return torch.zeros(V, dtype=torch.float32)

logits = mlm_head(valid) # [L', V]
scores = torch.log1p(torch.relu(logits)) # [L', V]
return scores.max(dim=0).values.to(torch.float32)

torch.testing.assert_close(
pooled[0],
ref_one(hs_list[0], prompt_lens[0], token_ids[0]),
rtol=1e-4,
atol=1e-4,
)
torch.testing.assert_close(
pooled[1],
ref_one(hs_list[1], prompt_lens[1], token_ids[1]),
rtol=1e-4,
atol=1e-4,
)


# ---------------------------------------------------------------------
# 2) Integration smoke test: end-to-end embedding path wiring
# ---------------------------------------------------------------------


@pytest.mark.cpu_model
def test_bert_splade_sparse_embed_smoke(vllm_runner, monkeypatch):
"""Ensure BertSpladeSparseEmbeddingModel loads and produces sparse embeddings."""
from transformers import AutoTokenizer

MODEL_ID = "hf-internal-testing/tiny-random-bert"
hf_overrides = {"architectures": ["BertSpladeSparseEmbeddingModel"]}

# Enforce CPU-only execution (optional)
monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "")
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")

tok = AutoTokenizer.from_pretrained(MODEL_ID)
vocab_size = tok.vocab_size

# The embed path should route through SPLADESparsePooler
with vllm_runner(
MODEL_ID,
runner="pooling",
max_model_len=64,
hf_overrides=hf_overrides,
) as vm:
outs = vm.embed(["hello world", "splade sparse test"])

# Basic sanity checks
assert len(outs) == 2
assert outs[0].shape[0] == vocab_size
assert outs[1].shape[0] == vocab_size
assert np.isfinite(outs[0]).all() and (outs[0] >= 0).all()
assert np.isfinite(outs[1]).all() and (outs[1] >= 0).all()
3 changes: 3 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,9 @@ def check_available_online(
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"),
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"),
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"),
"BertSpladeSparseEmbeddingModel": _HfExamplesInfo(
"naver/splade-v3", is_available_online=False
),
# [Multimodal]
"CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"),
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
Expand Down
223 changes: 223 additions & 0 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,229 @@ def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
return token_type_ids


class BertMLMHead(nn.Module):
def __init__(
self, hidden_size: int, vocab_size: int, layer_norm_eps: float = 1e-12
):
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.activation = nn.GELU()
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.decoder = nn.Linear(hidden_size, vocab_size, bias=True)

def tie_weights_with_embeddings(self, embeddings_weight: torch.Tensor):
self.decoder.weight = embeddings_weight

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
x = self.dense(hidden_states)
x = self.activation(x)
x = self.layer_norm(x)
logits = self.decoder(x)
return logits


class SPLADESparsePooler(Pooler):
"""
SPLADE sparse pooling:
logits = mlm_head(hidden_states)
-> log1p(relu(logits))
-> (max|sum over L)
-> [V]

Padding is masked with an attention mask,
[CLS]/[SEP] is removed (selected),
and then pooled.
"""

def __init__(
self,
mlm_head: nn.Module,
cls_token_id: Optional[int] = 101,
sep_token_id: Optional[int] = 102,
pooling: str = "max",
remove_cls_sep: bool = True,
):
super().__init__()
assert pooling in ("max", "sum")
self.mlm_head = mlm_head
self.cls_token_id = cls_token_id
self.sep_token_id = sep_token_id
self.pooling = pooling
self.remove_cls_sep = remove_cls_sep

def get_supported_tasks(self) -> Set[PoolingTask]:
return {"embed"}

def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)

def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[torch.Tensor, list[torch.Tensor]]:
lens_tensor: torch.Tensor = pooling_metadata.prompt_lens
lens: list[int] = lens_tensor.tolist()
B: int = len(lens)
max_len: int = int(lens_tensor.max().item())

if isinstance(hidden_states, list):
hs_list = hidden_states
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maxdebayser under what circumstances are the hidden stats a list? It seems that GPUModelRunner._pool annotates it as a tensor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm not mistaken, now that we've removed the V0 code, it can never be a list.

else:
hs_list = torch.split(hidden_states, lens, dim=0)

device = hs_list[0].device
H = hs_list[0].size(-1)

padded = hidden_states.new_zeros((B, max_len, H))
valid_mask = torch.zeros((B, max_len), dtype=torch.bool, device=device)
for i, (hs, L) in enumerate(zip(hs_list, lens)):
L = int(L)
padded[i, :L] = hs
valid_mask[i, :L] = True

token_ids = pooling_metadata.prompt_token_ids
if self.remove_cls_sep and token_ids is not None:
for i, L in enumerate(lens):
if (
self.cls_token_id is not None
and int(token_ids[i, 0].item()) == self.cls_token_id
):
valid_mask[i, 0] = False
if (
self.sep_token_id is not None
and int(token_ids[i, L - 1].item()) == self.sep_token_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the tensor dtype is int, you don't need to cast.

):
valid_mask[i, L - 1] = False

flat = padded.reshape(B * max_len, H)
logits = self.mlm_head(flat)
V = int(logits.size(-1))
logits = logits.view(B, max_len, V)

# SPLADE activation
scores = torch.log1p(torch.relu(logits)) # [B, T, V]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm not mistaken, since hidden_states is always a concatenated tensor in V1, up to this point you don't need to split and construct a padded batch. This would save some expensive memory operations. In the code below this point you would have to loop over the split scores, which isn't as pretty, but a mask would not be required.


if self.pooling == "sum":
pooled = (scores * valid_mask.to(scores.dtype).unsqueeze(-1)).sum(dim=1)
else:
neg_inf = torch.tensor(
float("-inf"), device=scores.device, dtype=scores.dtype
)
masked = scores.masked_fill(~valid_mask.unsqueeze(-1), neg_inf)
pooled = masked.amax(dim=1)
pooled = torch.where(
torch.isneginf(pooled), torch.zeros_like(pooled), pooled
)

return pooled.contiguous()


@default_pooling_type("CLS")
class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
"""
BertEmbeddingModel + SPLADE sparse embedding.
- Make logits by self.mlm_head
- pooler: SPLADESparsePooler(mlm_head...)
"""

def __init__(
self, *, vllm_config: VllmConfig, prefix: str = "", splade_pooling: str = "max"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The splade_pooling parameter in BertSpladeSparseEmbeddingModel.__init__ is hardcoded with a default value of "max" and is not exposed for user configuration. The PR description mentions that SPLADESparsePooler supports both "max" and "sum" pooling, which implies this should be configurable. Currently, there is no way for a user to select "sum" pooling.

This parameter should be made configurable, for instance by reading it from vllm_config.model_config.pooler_config or vllm_config.model_config.hf_config.

):
super().__init__(vllm_config=vllm_config, prefix=prefix)
cfg = vllm_config.model_config.hf_config

# MLM head
self.mlm_head = BertMLMHead(
hidden_size=cfg.hidden_size,
vocab_size=cfg.vocab_size,
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
)

self._splade_pooling = splade_pooling
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = self._build_pooler(pooler_config)

def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
cfg = self.model.config

if not hasattr(self, "mlm_head"):
self.mlm_head = BertMLMHead(
hidden_size=cfg.hidden_size,
vocab_size=cfg.vocab_size,
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
)

pooling_mode = getattr(self, "_splade_pooling", "max")

cls_id = getattr(cfg, "cls_token_id", None)
sep_id = getattr(cfg, "sep_token_id", None)

return DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"embed": SPLADESparsePooler(
mlm_head=self.mlm_head,
cls_token_id=cls_id,
sep_token_id=sep_id,
pooling=pooling_mode, # "max" or "sum"
remove_cls_sep=True,
),
}
)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
if not hasattr(self, "mlm_head"):
cfg = self.model.config
self.mlm_head = BertMLMHead(
hidden_size=cfg.hidden_size,
vocab_size=cfg.vocab_size,
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
)

def _strip(name: str) -> str:
for p in ("model.", "bert."):
if name.startswith(p):
name = name[len(p) :]
return name

weights_list = list(weights)
model_side: list[tuple[str, torch.Tensor]] = []
mlm_side: list[tuple[str, torch.Tensor]] = []

for k, w in weights_list:
name = _strip(k)
if name.startswith("cls.predictions."):
mlm_side.append((name, w))
else:
model_side.append((name, w))

loaded: set[str] = set()
loaded_model = self.model.load_weights(model_side)
loaded.update({"model." + n for n in loaded_model})

if mlm_side:
name_map = {
"cls.predictions.transform.dense.weight": "mlm_head.dense.weight",
"cls.predictions.transform.dense.bias": "mlm_head.dense.bias",
("cls.predictions.transform.LayerNorm.weight"): (
"mlm_head.layer_norm.weight"
),
("cls.predictions.transform.LayerNorm.bias"): (
"mlm_head.layer_norm.bias"
),
"cls.predictions.decoder.weight": "mlm_head.decoder.weight",
"cls.predictions.decoder.bias": "mlm_head.decoder.bias",
}
remapped = [(name_map[n], w) for n, w in mlm_side if n in name_map]
if remapped:
loaded_mlm = AutoWeightsLoader(self).load_weights(remapped)
loaded.update(loaded_mlm)

return loaded


@default_pooling_type("CLS")
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
_EMBEDDING_MODELS = {
# [Text-only]
"BertModel": ("bert", "BertEmbeddingModel"),
"BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3TextModel": ("gemma3", "Gemma3Model"),
Expand Down