-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[Feature] Add support for naver/splade-v3 (BERT-based sparse embedding model) #26339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: gjgjos <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for the naver/splade-v3
sparse embedding model by introducing BertSpladeSparseEmbeddingModel
and SPLADESparsePooler
. The implementation is well-tested and demonstrates correctness against Hugging Face and TEI frameworks. My review focuses on improving the robustness and maintainability of the new BertSpladeSparseEmbeddingModel
class, particularly in the load_weights
method, where I've identified opportunities for optimization and safer error handling.
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): | ||
if not hasattr(self, "mlm_head"): | ||
cfg = self.model.config | ||
self.mlm_head = BertMLMHead( | ||
hidden_size=cfg.hidden_size, | ||
vocab_size=cfg.vocab_size, | ||
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12), | ||
) | ||
|
||
weights_list = list(weights) | ||
loaded: set[str] = set() | ||
|
||
model_side: list[tuple[str, torch.Tensor]] = [] | ||
for k, w in weights_list: | ||
if k.startswith("cls.predictions."): | ||
continue | ||
name = k | ||
if name.startswith("model."): | ||
name = name[len("model.") :] | ||
if name.startswith("bert."): | ||
name = name[len("bert.") :] | ||
model_side.append((name, w)) | ||
|
||
other, stacked = self.model._load_weights(model_side) | ||
loaded.update({"model." + n for n in stacked}) | ||
|
||
other_prefixed = [("model." + n, w) for (n, w) in other] | ||
loader_top = AutoWeightsLoader( | ||
self, skip_prefixes=["pooler.", "mlm_head.", "lm_head."] | ||
) | ||
loaded_other = loader_top.load_weights(other_prefixed) | ||
loaded.update(loaded_other) | ||
|
||
name_map = { | ||
"cls.predictions.transform.dense.weight": "mlm_head.dense.weight", | ||
"cls.predictions.transform.dense.bias": "mlm_head.dense.bias", | ||
"cls.predictions.transform.LayerNorm.weight": "mlm_head.layer_norm.weight", | ||
"cls.predictions.transform.LayerNorm.bias": "mlm_head.layer_norm.bias", | ||
"cls.predictions.decoder.weight": "mlm_head.decoder.weight", | ||
"cls.predictions.decoder.bias": "mlm_head.decoder.bias", | ||
} | ||
extras: list[tuple[str, torch.Tensor]] = [] | ||
for k, w in weights_list: | ||
name = k | ||
if name.startswith("model."): | ||
name = name[len("model.") :] | ||
if name.startswith("bert."): | ||
name = name[len("bert.") :] | ||
tgt = name_map.get(name) | ||
if tgt is not None: | ||
extras.append((tgt, w)) | ||
|
||
if extras: | ||
mlm_loader = AutoWeightsLoader(self) | ||
loaded_mlm = mlm_loader.load_weights(extras) | ||
loaded.update(loaded_mlm) | ||
|
||
try: | ||
emb_w = self.model.embeddings.word_embeddings.weight | ||
dec_w = self.mlm_head.decoder.weight | ||
if dec_w.shape == emb_w.shape and dec_w.data_ptr() != emb_w.data_ptr(): | ||
self.mlm_head.decoder.weight = emb_w | ||
except Exception: | ||
pass | ||
|
||
return loaded |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The load_weights
method can be improved for efficiency, correctness, and robustness.
-
Inefficient and Buggy Weight Processing: The method iterates over
weights_list
twice. This is inefficient and the filtering logic is buggy. For example, a weight namedbert.cls.predictions...
would be incorrectly processed by both loops. This can be optimized into a single, correct loop. -
Unsafe Exception Handling: The
try...except Exception: pass
block is too broad and silently swallows all errors. This can hide critical bugs during weight tying. It should be narrowed to specific exceptions likeAttributeError
. -
Redundant Initialization: The
mlm_head
initialization logic is duplicated across__init__
,_build_pooler
, andload_weights
. This should be centralized in a helper method to improve maintainability. While not included in this suggestion, it's a recommended refactoring for the class.
The following suggestion refactors load_weights
to address the efficiency, bug, and exception handling issues.
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
if not hasattr(self, "mlm_head"):
cfg = self.model.config
self.mlm_head = BertMLMHead(
hidden_size=cfg.hidden_size,
vocab_size=cfg.vocab_size,
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
)
weights_list = list(weights)
loaded: set[str] = set()
name_map = {
"cls.predictions.transform.dense.weight": "mlm_head.dense.weight",
"cls.predictions.transform.dense.bias": "mlm_head.dense.bias",
"cls.predictions.transform.LayerNorm.weight": "mlm_head.layer_norm.weight",
"cls.predictions.transform.LayerNorm.bias": "mlm_head.layer_norm.bias",
"cls.predictions.decoder.weight": "mlm_head.decoder.weight",
"cls.predictions.decoder.bias": "mlm_head.decoder.bias",
}
model_side: list[tuple[str, torch.Tensor]] = []
extras: list[tuple[str, torch.Tensor]] = []
for k, w in weights_list:
stripped_name = k
if stripped_name.startswith("model."):
stripped_name = stripped_name[len("model."):]
if stripped_name.startswith("bert."):
stripped_name = stripped_name[len("bert."):]
tgt = name_map.get(stripped_name)
if tgt is not None:
extras.append((tgt, w))
else:
model_side.append((stripped_name, w))
other, stacked = self.model._load_weights(model_side)
loaded.update({"model." + n for n in stacked})
other_prefixed = [("model." + n, w) for (n, w) in other]
loader_top = AutoWeightsLoader(
self, skip_prefixes=["pooler.", "mlm_head.", "lm_head."]
)
loaded_other = loader_top.load_weights(other_prefixed)
loaded.update(loaded_other)
if extras:
mlm_loader = AutoWeightsLoader(self)
loaded_mlm = mlm_loader.load_weights(extras)
loaded.update(loaded_mlm)
try:
emb_w = self.model.embeddings.word_embeddings.weight
dec_w = self.mlm_head.decoder.weight
if dec_w.shape == emb_w.shape and dec_w.data_ptr() != emb_w.data_ptr():
self.mlm_head.decoder.weight = emb_w
except AttributeError:
# It's possible that the model doesn't have these attributes.
# Silently passing is acceptable if weight tying is optional.
pass
return loaded
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also address this? The weight loading logic indeed looks quite complicated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
vllm/model_executor/models/bert.py
Outdated
@torch.no_grad() | ||
def forward( | ||
self, | ||
hidden_states: Union[torch.Tensor, list[torch.Tensor]], | ||
pooling_metadata: PoolingMetadata, | ||
) -> Union[torch.Tensor, list[torch.Tensor]]: | ||
if isinstance(hidden_states, torch.Tensor): | ||
hs_list = [hidden_states] | ||
else: | ||
hs_list = list(hidden_states) | ||
|
||
for i, hs in enumerate(hs_list): | ||
if hs.dim() == 3 and hs.size(0) == 1: | ||
hs_list[i] = hs.squeeze(0) # [L, H] | ||
elif hs.dim() != 2: | ||
raise ValueError(f"Expected [L,H] or [1,L,H], got {tuple(hs.shape)}") | ||
|
||
B = len(hs_list) | ||
H = hs_list[0].size(-1) | ||
|
||
raw_lens = getattr(pooling_metadata, "prompt_lens", None) | ||
|
||
def _fallback_lens_from_hs(): | ||
return [int(h.size(0)) for h in hs_list] | ||
|
||
if raw_lens is None: | ||
lens = _fallback_lens_from_hs() | ||
elif isinstance(raw_lens, int): | ||
lens = [int(raw_lens)] * B | ||
else: | ||
try: | ||
tmp = list(raw_lens) | ||
if len(tmp) == B: | ||
lens = [int(x) for x in tmp] | ||
elif len(tmp) == 1: | ||
lens = [int(tmp[0])] * B | ||
else: | ||
lens = _fallback_lens_from_hs() | ||
except TypeError: | ||
lens = _fallback_lens_from_hs() | ||
|
||
max_len = max(int(h.size(0)) for h in hs_list) | ||
device = hs_list[0].device | ||
|
||
# pad to [B, T, H] | ||
padded = hs_list[0].new_zeros((B, max_len, H)) # zeros | ||
attn_mask = torch.zeros((B, max_len), dtype=torch.bool, device=device) | ||
|
||
for i, (hs, L) in enumerate(zip(hs_list, lens)): | ||
L = int(L) | ||
L = min(L, max_len) | ||
padded[i, :L] = hs[:L] | ||
attn_mask[i, :L] = True | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pooler ignores batching layout and drops extra requests
The new SPLADESparsePooler.forward
wraps the incoming hidden_states
tensor into a single item whenever it is a 2‑D tensor (lines 638‑649) and never consults the pooling_metadata.pooling_cursor
that encodes how multiple requests are concatenated. In the vLLM runner, embeddings are pooled from a single [total_tokens, hidden]
tensor containing all prompts in a batch. With the current logic only the first prompt in the batch is padded and pooled while the remaining prompts are silently ignored, causing incorrect or missing embeddings whenever more than one request is processed together. The pooler should use pooling_cursor
(as done in SimplePooler
) to split the tensor per request before applying the MLM head.
Useful? React with 👍 / 👎.
@hmellor I guess transformers backend can't really handle custom poolers based on the current design, right? |
Right now no there is no way to register custom poolers. It wouldn't be too hard to add a Or do you mean a mechanism to register custom poolers in the Transformers backend with no upstream changes? |
Yeah that's what I'm thinking. I guess implementing this in vLLM is the most reasonable solution without upstream changes then. |
2799f7f
to
3106979
Compare
Oh that wouldn't require any upstream changes. These changes would be made in https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/transformers_pooling.py |
The only caveat is that it would mean users have to install Transformers from source because the Transformers side refactor that enables the Transformers backend for BERT models is not in a release yet. |
…h.no_grad() (handled by vLLM framework)- Added model loading entry to tests/models/registry.py- Added SPLADESparsePooler functional + smoke tests to ensure future stability Signed-off-by: gjgjos <[email protected]>
3ab178a
to
657860b
Compare
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for the naver/splade-v3
sparse embedding model. The implementation is well-structured, introducing BertSpladeSparseEmbeddingModel
and SPLADESparsePooler
. The accompanying tests are thorough, covering both functional correctness and integration with vLLM's serving capabilities.
My review identifies two high-severity issues. First, a broad except Exception: pass
in the weight loading logic could mask critical errors and lead to silent failures. Second, the SPLADE pooling method is hardcoded to 'max', preventing users from selecting the 'sum' method, which is mentioned as supported. Addressing these points will improve the robustness and configurability of the new model support.
""" | ||
|
||
def __init__( | ||
self, *, vllm_config: VllmConfig, prefix: str = "", splade_pooling: str = "max" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The splade_pooling
parameter in BertSpladeSparseEmbeddingModel.__init__
is hardcoded with a default value of "max"
and is not exposed for user configuration. The PR description mentions that SPLADESparsePooler
supports both "max" and "sum" pooling, which implies this should be configurable. Currently, there is no way for a user to select "sum" pooling.
This parameter should be made configurable, for instance by reading it from vllm_config.model_config.pooler_config
or vllm_config.model_config.hf_config
.
try: | ||
emb_w = self.model.embeddings.word_embeddings.weight | ||
dec_w = self.mlm_head.decoder.weight | ||
if dec_w.shape == emb_w.shape and dec_w.data_ptr() != emb_w.data_ptr(): | ||
self.mlm_head.decoder.weight = emb_w | ||
except Exception: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The try...except Exception: pass
block is too broad and can hide important errors during weight loading. For instance, if self.model.embeddings
or other attributes do not exist due to a model structure mismatch, an AttributeError
would be silently ignored, making debugging difficult. This could lead to weights not being tied when they should be, resulting in incorrect model behavior. It's better to catch more specific exceptions, like AttributeError
, or at least log a warning if an exception occurs.
try: | |
emb_w = self.model.embeddings.word_embeddings.weight | |
dec_w = self.mlm_head.decoder.weight | |
if dec_w.shape == emb_w.shape and dec_w.data_ptr() != emb_w.data_ptr(): | |
self.mlm_head.decoder.weight = emb_w | |
except Exception: | |
pass | |
try: | |
emb_w = self.model.embeddings.word_embeddings.weight | |
dec_w = self.mlm_head.decoder.weight | |
if dec_w.shape == emb_w.shape and dec_w.data_ptr() != emb_w.data_ptr(): | |
self.mlm_head.decoder.weight = emb_w | |
except AttributeError: | |
# It's possible that some BERT variants may not have this structure. | |
# If we can't find the weights to tie, it's not a critical | |
# error, as the model can still function with untied weights. | |
pass |
Thanks for contributing. There are some places where the code can be cleaned up quite a bit, but overall this is a nice addition. |
Signed-off-by: gjgjos <[email protected]>
Thanks for the thorough review, @maxdebayser! I’ve pushed an update addressing all your points:
|
float("-inf"), device=scores.device, dtype=scores.dtype | ||
) | ||
masked = scores.masked_fill(~valid_mask.unsqueeze(-1), neg_inf) | ||
pooled = masked.max(dim=1).values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pooled = masked.max(dim=1).values | |
pooled = masked.amax(dim=1) |
assert isinstance(hidden_states, torch.Tensor) and hidden_states.dim() == 2 | ||
|
||
device = hidden_states.device | ||
H: int = int(self.mlm_head.dense.in_features) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
H: int = int(self.mlm_head.dense.in_features) | |
H = int(self.mlm_head.dense.in_features) |
device = hidden_states.device | ||
H: int = int(self.mlm_head.dense.in_features) | ||
|
||
hs_list = list(torch.split(hidden_states, lens, dim=0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hs_list = list(torch.split(hidden_states, lens, dim=0)) | |
hs_list = torch.split(hidden_states, lens, dim=0) |
I think the default tuple
should work fine already, no need to change to list. Should also rename the variable though
Purpose
This PR adds official support for the
naver/splade-v3
model, a BERT-based sparse retrieval model utilizing the SPLADE pooling mechanism.The implementation introduces the
BertSpladeSparseEmbeddingModel
class, extendingBertEmbeddingModel
to generate sparse lexical embeddings from the MLM head output (log1p(ReLU(logits))
), fully compatible with vLLM’s embedding API (/v1/embeddings
and/pooling
endpoints).This enables users to serve SPLADE models via vLLM with high performance and verified consistency against Hugging Face’s
SparseEncoder
and TEI (Text Embeddings Inference) frameworks.Implementation Details
New model registration
Architecture
Backbone:
bert
Head: MLM head (
cls.predictions.*
)Pooling:
SPLADESparsePooler
(supportsmax
orsum
)Output: sparse lexical embedding vector (dimension = vocab size ≈ 30k)
Modified files
bert.py
→ addedBertSpladeSparseEmbeddingModel
registry.py
→ registered model under"bert"
familyTest Plan
1️⃣ vLLM-based Docker serving
Run script
Server log highlights
✅ Successfully initialized with torch.compile graph caching and KVCache disabled (sparse embedding mode).
The
/v1/embeddings
route was available for inference.2️⃣ vLLM Inference Test (Python Client) — Actual response & parsed preview
Request
Actual response JSON (shape)
Parsing helper & preview
Observed output
3️⃣ Hugging Face
SparseEncoder
VerificationResult
✅ The vLLM and Hugging Face results are numerically identical (within 1e-4 float tolerance) across all nonzero indices and values.
4️⃣ TEI (Text Embeddings Inference) Consistency Test
Container launch
Test via curl
Response
✅ The TEI server’s output is functionally equivalent to the vLLM response, confirming correct sparse pooling and alignment of activation magnitudes.
Test Result Summary
All three implementations produce identical sparse activation patterns and values, demonstrating full correctness and interoperability.
Notes
No regression for existing
BertEmbeddingModel
or dense embedding workflows.Sparse embedding fully integrated with
PoolingTask.embed
.Works with FlashAttention backend and torch.compile graph caching.
TEI consistency ensures vLLM can serve SPLADE models interchangeably in hybrid retrieval systems.
Clearly described purpose (add SPLADE support for
naver/splade-v3
)Test plan included (vLLM, HF, TEI parity)
Verified consistent outputs across frameworks
Registry and pooling code updated
No backward-compatibility issues introduced
(Optional) Update
supported_models.md
(Optional) Add release note entry
✅ Summary:
This PR adds end-to-end integration of the BERT-based
naver/splade-v3
sparse embedding model into vLLM.