Skip to content

Commit 2799f7f

Browse files
committed
Add SPLADE sparse embedding model and tests- Removed unnecessary torch.no_grad() (handled by vLLM framework)- Added model loading entry to tests/models/registry.py- Added SPLADESparsePooler functional + smoke tests to ensure future stability
Signed-off-by: gjgjos <[email protected]>
1 parent 3827c27 commit 2799f7f

File tree

3 files changed

+117
-1
lines changed

3 files changed

+117
-1
lines changed
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import types
5+
import numpy as np
6+
import torch
7+
import torch.nn as nn
8+
import pytest
9+
10+
from vllm.model_executor.models.bert import (
11+
SPLADESparsePooler,
12+
BertMLMHead,
13+
)
14+
15+
16+
# ---------------------------------------------------------------------
17+
# 1) Functional test: SPLADE formula correctness (no HF download needed)
18+
# ---------------------------------------------------------------------
19+
20+
@pytest.mark.parametrize("B,T,H,V", [(2, 3, 5, 7)])
21+
def test_splade_pooler_matches_reference_formula(B, T, H, V):
22+
"""Ensure SPLADESparsePooler forward() matches the mathematical formula:
23+
log1p(relu(logits)) -> max over sequence length (after masking)."""
24+
torch.manual_seed(0)
25+
26+
# Prepare [B] sequences of shape [T, H]
27+
hs_list = [torch.randn(T, H) for _ in range(B)]
28+
29+
# Simulate PoolingMetadata (only required fields)
30+
prompt_lens = [T, T - 1]
31+
token_ids = torch.tensor(
32+
[
33+
[101, 5, 102], # Batch 0: [CLS], token, [SEP]
34+
[101, 6, 6], # Batch 1: [CLS], token, token (last token ignored)
35+
],
36+
dtype=torch.long,
37+
)
38+
meta = types.SimpleNamespace(prompt_lens=prompt_lens, prompt_token_ids=token_ids)
39+
40+
# MLM head (prefer BertMLMHead, fallback to Linear if unavailable)
41+
try:
42+
mlm_head = BertMLMHead(hidden_size=H, vocab_size=V, layer_norm_eps=1e-12)
43+
except Exception:
44+
mlm_head = nn.Linear(H, V, bias=True)
45+
46+
# Forward pass through SPLADE pooler
47+
pooler = SPLADESparsePooler(mlm_head=mlm_head, pooling="max", remove_cls_sep=True)
48+
pooled = pooler(hidden_states=hs_list, pooling_metadata=meta) # list of [V]
49+
50+
# Basic output checks
51+
assert isinstance(pooled, list) and len(pooled) == B
52+
for vec in pooled:
53+
assert vec.shape == (V,)
54+
assert torch.isfinite(vec).all()
55+
assert (vec >= 0).all(), "SPLADE outputs must be non-negative."
56+
57+
# Reference implementation for comparison
58+
def ref_one(hs: torch.Tensor, L: int, tid_row: torch.Tensor) -> torch.Tensor:
59+
keep = torch.ones(L, dtype=torch.bool)
60+
if L > 0 and tid_row[0].item() == 101: # remove CLS
61+
keep[0] = False
62+
if L > 0 and tid_row[L - 1].item() == 102: # remove SEP
63+
keep[L - 1] = False
64+
65+
valid = hs[:L][keep[:L]]
66+
if valid.numel() == 0:
67+
return torch.zeros(V, dtype=torch.float32)
68+
69+
logits = mlm_head(valid) # [L', V]
70+
scores = torch.log1p(torch.relu(logits)) # [L', V]
71+
return scores.max(dim=0).values.to(torch.float32)
72+
73+
torch.testing.assert_close(
74+
pooled[0], ref_one(hs_list[0], prompt_lens[0], token_ids[0]), rtol=1e-4, atol=1e-4
75+
)
76+
torch.testing.assert_close(
77+
pooled[1], ref_one(hs_list[1], prompt_lens[1], token_ids[1]), rtol=1e-4, atol=1e-4
78+
)
79+
80+
81+
# ---------------------------------------------------------------------
82+
# 2) Integration smoke test: end-to-end embedding path wiring
83+
# ---------------------------------------------------------------------
84+
85+
@pytest.mark.cpu_model
86+
def test_bert_splade_sparse_embed_smoke(vllm_runner, monkeypatch):
87+
"""Ensure BertSpladeSparseEmbeddingModel loads and produces sparse embeddings."""
88+
from transformers import AutoTokenizer
89+
90+
MODEL_ID = "hf-internal-testing/tiny-random-bert"
91+
hf_overrides = {"architectures": ["BertSpladeSparseEmbeddingModel"]}
92+
93+
# Enforce CPU-only execution (optional)
94+
monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "")
95+
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
96+
97+
tok = AutoTokenizer.from_pretrained(MODEL_ID)
98+
vocab_size = tok.vocab_size
99+
100+
# The embed path should route through SPLADESparsePooler
101+
with vllm_runner(
102+
MODEL_ID,
103+
runner="pooling",
104+
max_model_len=64,
105+
hf_overrides=hf_overrides,
106+
) as vm:
107+
outs = vm.embed(["hello world", "splade sparse test"])
108+
109+
# Basic sanity checks
110+
assert len(outs) == 2
111+
assert outs[0].shape[0] == vocab_size
112+
assert outs[1].shape[0] == vocab_size
113+
assert np.isfinite(outs[0]).all() and (outs[0] >= 0).all()
114+
assert np.isfinite(outs[1]).all() and (outs[1] >= 0).all()

tests/models/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,9 @@ def check_available_online(
483483
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"),
484484
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"),
485485
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"),
486+
"BertSpladeSparseEmbeddingModel": _HfExamplesInfo(
487+
"naver/splade-v3", is_available_online=False
488+
),
486489
# [Multimodal]
487490
"CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"),
488491
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),

vllm/model_executor/models/bert.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,6 @@ def get_supported_tasks(self) -> Set[PoolingTask]:
629629
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
630630
return PoolingParamsUpdate(requires_token_ids=True)
631631

632-
@torch.no_grad()
633632
def forward(
634633
self,
635634
hidden_states: Union[torch.Tensor, list[torch.Tensor]],

0 commit comments

Comments
 (0)