From 67ea005bae7e04f3bc6a37e317ed5a215b3099ae Mon Sep 17 00:00:00 2001 From: kaixuanliu Date: Fri, 7 Mar 2025 01:24:50 +0000 Subject: [PATCH 1/5] optimize flash bert for hpu device Signed-off-by: kaixuanliu --- .../text_embeddings_server/models/__init__.py | 11 ++- .../models/flash_bert.py | 99 ++++++++++++++++--- .../utils/flash_attn.py | 65 +++--------- 3 files changed, 105 insertions(+), 70 deletions(-) diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 679db9fc7..78acc220f 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -1,3 +1,4 @@ +import os import torch from loguru import logger @@ -12,9 +13,17 @@ from text_embeddings_server.utils.device import get_device, use_ipex __all__ = ["Model"] - +ALLOW_REDUCED_PRECISION = os.getenv( + "ALLOW_REDUCED_PRECISION_FP16_BF16", "true" +).lower() in [ + "true", + "1", +] # Disable gradients torch.set_grad_enabled(False) +# WA for perf degradation from pytorch 2.5 +if ALLOW_REDUCED_PRECISION: + torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) FLASH_ATTENTION = True try: diff --git a/backends/python/server/text_embeddings_server/models/flash_bert.py b/backends/python/server/text_embeddings_server/models/flash_bert.py index fce5c3f22..78abb4918 100644 --- a/backends/python/server/text_embeddings_server/models/flash_bert.py +++ b/backends/python/server/text_embeddings_server/models/flash_bert.py @@ -8,7 +8,7 @@ from transformers.models.bert import BertConfig from opentelemetry import trace from text_embeddings_server.models import Model -from text_embeddings_server.models.types import FlashBatch, Embedding +from text_embeddings_server.models.types import FlashBatch, Embedding, PaddedBatch from text_embeddings_server.utils.flash_attn import attention from text_embeddings_server.utils.device import use_ipex @@ -166,22 +166,41 @@ def __init__(self, prefix, handle, device, dtype, config: BertConfig): self.num_heads = config.num_attention_heads self.device = device - def forward(self, hidden_states, cu_seqlens, max_s): + def forward(self, hidden_states, cu_seqlens, max_s, attn_mask=None): residual = hidden_states + bs = 1 + if hidden_states.dim() > 2: + bs, _, hidden_dim = hidden_states.size() + hidden_states = hidden_states.view(-1, hidden_dim) qkv = torch.addmm(self.qkv_bias, hidden_states, self.qkv_weight) - q, k, v = qkv.view(-1, self.num_heads * 3, self.head_size).split( - self.num_heads, dim=1 - ) + if residual.dim() > 2: + q, k, v = qkv.view(bs, -1, self.num_heads * 3, self.head_size).split( + self.num_heads, dim=2 + ) + else: + q, k, v = qkv.view(-1, self.num_heads * 3, self.head_size).split( + self.num_heads, dim=1 + ) attn_output = torch.empty_like(q) - attention(q, k, v, attn_output, cu_seqlens, max_s, self.softmax_scale) + attention( + q, + k, + v, + attn_output, + cu_seqlens, + max_s, + self.softmax_scale, + attn_mask=attn_mask, + ) hidden_states = torch.addmm( self.dense_bias, attn_output.view(-1, self.num_heads * self.head_size), self.dense_weight, ) + hidden_states = hidden_states.view(bs, -1, hidden_dim) hidden_states, _ = self.layer_norm.forward(hidden_states, residual) return hidden_states @@ -224,10 +243,14 @@ def __init__(self, prefix, handle, device, dtype, config: BertConfig): f"{prefix}.output.LayerNorm", handle, device, dtype, config ) - def forward(self, hidden_states, cu_seqlens, max_s): - hidden_states = self.attention.forward(hidden_states, cu_seqlens, max_s) + def forward(self, hidden_states, cu_seqlens, max_s, attn_mask=None): + hidden_states = self.attention.forward( + hidden_states, cu_seqlens, max_s, attn_mask + ) residual = hidden_states - + if hidden_states.dim() > 2: + bs, _, hidden_dim = hidden_states.size() + hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = torch.addmm( self.intermediate_bias, hidden_states, self.intermediate_weight ) @@ -237,6 +260,8 @@ def forward(self, hidden_states, cu_seqlens, max_s): hidden_states, self.output_weight, ) + if residual.dim() > 2: + hidden_states = hidden_states.view(bs, -1, hidden_dim) hidden_states, _ = self.layer_norm.forward(hidden_states, residual) return hidden_states @@ -248,9 +273,9 @@ def __init__(self, prefix, handle, device, dtype, config: BertConfig): for i in range(config.num_hidden_layers) ] - def forward(self, hidden_states, cu_seqlens, max_s): + def forward(self, hidden_states, cu_seqlens, max_s, attn_mask=None): for layer in self.layers: - hidden_states = layer.forward(hidden_states, cu_seqlens, max_s) + hidden_states = layer.forward(hidden_states, cu_seqlens, max_s, attn_mask) return hidden_states @@ -259,10 +284,21 @@ def __init__(self, handle, device, dtype, config: BertConfig): self.embeddings = BertEmbeddings("embeddings", handle, device, dtype, config) self.encoder = BertEncoder("encoder", handle, device, dtype, config) - def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s): + def forward( + self, + input_ids, + token_type_ids, + position_ids, + cu_seqlens, + max_s, + mask=None, + attn_mask=None, + ): embeddings = self.embeddings.forward(input_ids, token_type_ids, position_ids) - encoder_outputs = self.encoder.forward(embeddings, cu_seqlens, max_s) - + encoder_outputs = self.encoder.forward(embeddings, cu_seqlens, max_s, attn_mask) + if mask is not None: + outputs = encoder_outputs[mask] + return outputs[cu_seqlens[:-1]] return encoder_outputs[cu_seqlens[:-1]] @@ -271,6 +307,7 @@ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): config = BertConfig.from_pretrained(model_path) with safe_open(model_path / "model.safetensors", framework="pt") as f: model = FlashBertModel(f, device, dtype, config) + self.device = device if device.type == "hpu": from habana_frameworks.torch.hpu import wrap_in_hpu_graph @@ -281,7 +318,8 @@ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): @property def batch_type(self) -> Type[FlashBatch]: - return FlashBatch + # for hpu devices, we use PaddedBatch as we do not have real varlen fwd yet + return FlashBatch if self.device.type != "hpu" else PaddedBatch @tracer.start_as_current_span("embed") def embed(self, batch: FlashBatch) -> List[Embedding]: @@ -300,3 +338,34 @@ def embed(self, batch: FlashBatch) -> List[Embedding]: ) for i in range(len(batch)) ] + + @tracer.start_as_current_span("embed") + def embed(self, batch: PaddedBatch) -> List[Embedding]: + input_lens = batch.attention_mask.cumsum(-1)[:, -1].to(torch.int32) + max_input_lens = input_lens.max().item() + cu_seqlens = torch.cat( + (input_lens.new_tensor([0]), input_lens.cumsum(-1).int()) + ) + mask = batch.attention_mask.to(torch.bool) + batch_size = input_lens.size(0) + attn_mask = torch.empty( + [batch_size, 1, 1, max_input_lens], device=self.device + ).fill_(float("-inf")) + attn_mask[:, :, :, :max_input_lens].masked_fill_(mask[:, None, None, :], 0) + embedding = self.model.forward( + input_ids=batch.input_ids, + token_type_ids=batch.token_type_ids, + position_ids=batch.position_ids, + cu_seqlens=cu_seqlens, + max_s=max_input_lens, + mask=mask, + attn_mask=attn_mask, + ) + cpu_results = embedding.view(-1).tolist() + + return [ + Embedding( + values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size] + ) + for i in range(len(batch)) + ] diff --git a/backends/python/server/text_embeddings_server/utils/flash_attn.py b/backends/python/server/text_embeddings_server/utils/flash_attn.py index ecfc4f9e2..2c718d540 100644 --- a/backends/python/server/text_embeddings_server/utils/flash_attn.py +++ b/backends/python/server/text_embeddings_server/utils/flash_attn.py @@ -62,6 +62,7 @@ def hpu_attn( k, v, out, + attn_mask, seqlen_q, seqlen_k, max_seqlen_q, @@ -71,66 +72,21 @@ def hpu_attn( ): from habana_frameworks.torch.hpex.kernels import FusedSDPA - total_q, num_head, head_size = q.size() - total_k, num_head_k, _ = k.size() - batch_size = seqlen_q.size(0) - 1 - seqlen_q_ = seqlen_q.clone() - seqlen_q_[:batch_size] = seqlen_q[1:] - seqlen_q = (seqlen_q_ - seqlen_q)[:batch_size] - seqlen_k_ = seqlen_k.clone() - seqlen_k_[:batch_size] = seqlen_k[1:] - seqlen_k = (seqlen_k_ - seqlen_k)[:batch_size] - - pad_q = torch.zeros( - [batch_size, max_seqlen_q, num_head, head_size], - dtype=q.dtype, - device=q.device, - ) - pad_k = torch.zeros( - [batch_size, max_seqlen_k, num_head_k, head_size], - dtype=k.dtype, - device=k.device, - ) - pad_v = torch.zeros( - [batch_size, max_seqlen_k, num_head_k, head_size], - dtype=v.dtype, - device=v.device, - ) - q_mask = torch.arange(0, max_seqlen_q, device=q.device)[None, :].repeat( - batch_size, 1 - ) - q_mask = q_mask < seqlen_q[:, None].repeat(1, q_mask.size(-1)) - k_mask = torch.arange(0, max_seqlen_k, device=k.device)[None, :].repeat( - batch_size, 1 - ) - k_mask = k_mask < seqlen_k[:, None].repeat(1, k_mask.size(-1)) - align_mask_seqlen = max_seqlen_k - attn_mask = torch.empty( - [batch_size, 1, 1, align_mask_seqlen], - dtype=q.dtype, - device=q.device, - ).fill_(float("-inf")) - attn_mask[:, :, :, :max_seqlen_k].masked_fill_(k_mask[:, None, None, :], 0) - - pad_q[q_mask] = q - pad_k[k_mask] = k - pad_v[k_mask] = v - - pad_q = pad_q.permute(0, 2, 1, 3) - pad_k = pad_k.permute(0, 2, 1, 3) - pad_v = pad_v.permute(0, 2, 1, 3) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) if is_causal: attn_mask = None - out_ = FusedSDPA.apply( - pad_q, pad_k, pad_v, attn_mask, 0.0, is_causal, softmax_scale - ) - out_ = out_.permute(0, 2, 1, 3) - out.copy_(out_[q_mask]) + out_ = FusedSDPA.apply(q, k, v, attn_mask, 0.0, is_causal, softmax_scale) + out_ = out_.transpose(1, 2) + out.copy_(out_) return out -def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False): +def attention( + q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False, attn_mask=None +): if HAS_FLASH_ATTN_V2: if use_ipex: import intel_extension_for_pytorch as ipex @@ -157,6 +113,7 @@ def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False): k, v, out, + attn_mask, cu_seqlens, cu_seqlens, max_s, From d69268e9045ec1c7b3ea68468fb8902c521d6beb Mon Sep 17 00:00:00 2001 From: kaixuanliu Date: Fri, 7 Mar 2025 03:35:06 +0000 Subject: [PATCH 2/5] nice code Signed-off-by: kaixuanliu --- .../text_embeddings_server/models/__init__.py | 9 -- .../models/flash_bert.py | 82 ++++++------------- .../text_embeddings_server/utils/device.py | 10 ++- 3 files changed, 34 insertions(+), 67 deletions(-) diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 78acc220f..03fb3891c 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -13,17 +13,8 @@ from text_embeddings_server.utils.device import get_device, use_ipex __all__ = ["Model"] -ALLOW_REDUCED_PRECISION = os.getenv( - "ALLOW_REDUCED_PRECISION_FP16_BF16", "true" -).lower() in [ - "true", - "1", -] # Disable gradients torch.set_grad_enabled(False) -# WA for perf degradation from pytorch 2.5 -if ALLOW_REDUCED_PRECISION: - torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) FLASH_ATTENTION = True try: diff --git a/backends/python/server/text_embeddings_server/models/flash_bert.py b/backends/python/server/text_embeddings_server/models/flash_bert.py index 78abb4918..89a037b85 100644 --- a/backends/python/server/text_embeddings_server/models/flash_bert.py +++ b/backends/python/server/text_embeddings_server/models/flash_bert.py @@ -2,7 +2,7 @@ from pathlib import Path from torch import nn import torch.nn.functional as F -from typing import Type, List +from typing import Type, List, Union from safetensors import safe_open from transformers.activations import ACT2FN from transformers.models.bert import BertConfig @@ -168,13 +168,9 @@ def __init__(self, prefix, handle, device, dtype, config: BertConfig): def forward(self, hidden_states, cu_seqlens, max_s, attn_mask=None): residual = hidden_states - bs = 1 + qkv = F.linear(hidden_states, self.qkv_weight.T, self.qkv_bias) if hidden_states.dim() > 2: - bs, _, hidden_dim = hidden_states.size() - hidden_states = hidden_states.view(-1, hidden_dim) - - qkv = torch.addmm(self.qkv_bias, hidden_states, self.qkv_weight) - if residual.dim() > 2: + bs = hidden_states.size(0) q, k, v = qkv.view(bs, -1, self.num_heads * 3, self.head_size).split( self.num_heads, dim=2 ) @@ -182,7 +178,6 @@ def forward(self, hidden_states, cu_seqlens, max_s, attn_mask=None): q, k, v = qkv.view(-1, self.num_heads * 3, self.head_size).split( self.num_heads, dim=1 ) - attn_output = torch.empty_like(q) attention( q, @@ -195,12 +190,7 @@ def forward(self, hidden_states, cu_seqlens, max_s, attn_mask=None): attn_mask=attn_mask, ) - hidden_states = torch.addmm( - self.dense_bias, - attn_output.view(-1, self.num_heads * self.head_size), - self.dense_weight, - ) - hidden_states = hidden_states.view(bs, -1, hidden_dim) + hidden_states = F.linear(hidden_states, self.dense_weight, self.dense_bias) hidden_states, _ = self.layer_norm.forward(hidden_states, residual) return hidden_states @@ -248,20 +238,9 @@ def forward(self, hidden_states, cu_seqlens, max_s, attn_mask=None): hidden_states, cu_seqlens, max_s, attn_mask ) residual = hidden_states - if hidden_states.dim() > 2: - bs, _, hidden_dim = hidden_states.size() - hidden_states = hidden_states.view(-1, hidden_dim) - hidden_states = torch.addmm( - self.intermediate_bias, hidden_states, self.intermediate_weight - ) + hidden_states = F.linear(hidden_states, self.intermediate_weight.T, self.intermediate_bias) hidden_states = self.intermediate_act_fn(hidden_states) - hidden_states = torch.addmm( - self.output_bias, - hidden_states, - self.output_weight, - ) - if residual.dim() > 2: - hidden_states = hidden_states.view(bs, -1, hidden_dim) + hidden_states = F.linear(hidden_states, self.output_weight.T, self.output_bias) hidden_states, _ = self.layer_norm.forward(hidden_states, residual) return hidden_states @@ -317,41 +296,30 @@ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): super(FlashBert, self).__init__(model=model, dtype=dtype, device=device) @property - def batch_type(self) -> Type[FlashBatch]: + def batch_type(self) -> Union[FlashBatch, PaddedBatch]: # for hpu devices, we use PaddedBatch as we do not have real varlen fwd yet return FlashBatch if self.device.type != "hpu" else PaddedBatch @tracer.start_as_current_span("embed") - def embed(self, batch: FlashBatch) -> List[Embedding]: - embedding = self.model.forward( - input_ids=batch.input_ids, - token_type_ids=batch.token_type_ids, - position_ids=batch.position_ids, - cu_seqlens=batch.cu_seqlens, - max_s=batch.max_s, - ) - cpu_results = embedding.view(-1).tolist() - - return [ - Embedding( - values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size] + def embed(self, batch: Union[FlashBatch, PaddedBatch]) -> List[Embedding]: + if isinstance(batch, PaddedBatch): + input_lens = batch.attention_mask.cumsum(-1)[:, -1].to(torch.int32) + max_input_lens = input_lens.max().item() + cu_seqlens = torch.cat( + (input_lens.new_tensor([0]), input_lens.cumsum(-1).int()) ) - for i in range(len(batch)) - ] + mask = batch.attention_mask.to(torch.bool) + batch_size = input_lens.size(0) + attn_mask = torch.empty( + [batch_size, 1, 1, max_input_lens], device=self.device + ).fill_(float("-inf")) + attn_mask[:, :, :, :max_input_lens].masked_fill_(mask[:, None, None, :], 0) + elif isinstance(batch, FlashBatch): + cu_seqlens = batch.cu_seqlens + mask = None + attn_mask = None + max_input_lens = batch.max_s - @tracer.start_as_current_span("embed") - def embed(self, batch: PaddedBatch) -> List[Embedding]: - input_lens = batch.attention_mask.cumsum(-1)[:, -1].to(torch.int32) - max_input_lens = input_lens.max().item() - cu_seqlens = torch.cat( - (input_lens.new_tensor([0]), input_lens.cumsum(-1).int()) - ) - mask = batch.attention_mask.to(torch.bool) - batch_size = input_lens.size(0) - attn_mask = torch.empty( - [batch_size, 1, 1, max_input_lens], device=self.device - ).fill_(float("-inf")) - attn_mask[:, :, :, :max_input_lens].masked_fill_(mask[:, None, None, :], 0) embedding = self.model.forward( input_ids=batch.input_ids, token_type_ids=batch.token_type_ids, @@ -359,7 +327,7 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]: cu_seqlens=cu_seqlens, max_s=max_input_lens, mask=mask, - attn_mask=attn_mask, + attn_mask=attn_mask ) cpu_results = embedding.view(-1).tolist() diff --git a/backends/python/server/text_embeddings_server/utils/device.py b/backends/python/server/text_embeddings_server/utils/device.py index d450b3737..76cd0dbfb 100644 --- a/backends/python/server/text_embeddings_server/utils/device.py +++ b/backends/python/server/text_embeddings_server/utils/device.py @@ -6,6 +6,12 @@ import torch import subprocess +ALLOW_REDUCED_PRECISION = os.getenv( + "ALLOW_REDUCED_PRECISION_FP16_BF16", "true" +).lower() in [ + "true", + "1", +] def _is_ipex_available(): def get_major_and_minor_from_version(full_version): @@ -54,7 +60,9 @@ def get_device(): device = torch.device("cuda") elif is_hpu(): import habana_frameworks.torch.core as htcore - + # WA for perf degradation from pytorch 2.5 + if ALLOW_REDUCED_PRECISION: + torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) if hasattr(torch, "hpu") and torch.hpu.is_available(): # type: ignore device = torch.device("hpu") elif use_ipex(): From 80491dc72200b3c377643a051572585bbda5d007 Mon Sep 17 00:00:00 2001 From: kaixuanliu Date: Fri, 7 Mar 2025 06:37:18 +0000 Subject: [PATCH 3/5] small adjuest Signed-off-by: kaixuanliu --- .../text_embeddings_server/models/__init__.py | 1 - .../text_embeddings_server/models/flash_bert.py | 16 ++++++++++++---- .../text_embeddings_server/utils/device.py | 2 ++ 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 03fb3891c..51b1087e8 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -1,4 +1,3 @@ -import os import torch from loguru import logger diff --git a/backends/python/server/text_embeddings_server/models/flash_bert.py b/backends/python/server/text_embeddings_server/models/flash_bert.py index 89a037b85..c0cab5ebd 100644 --- a/backends/python/server/text_embeddings_server/models/flash_bert.py +++ b/backends/python/server/text_embeddings_server/models/flash_bert.py @@ -189,8 +189,14 @@ def forward(self, hidden_states, cu_seqlens, max_s, attn_mask=None): self.softmax_scale, attn_mask=attn_mask, ) - - hidden_states = F.linear(hidden_states, self.dense_weight, self.dense_bias) + if hidden_states.dim() > 2: + hidden_states = F.linear(hidden_states, self.dense_weight, self.dense_bias) + else: + hidden_states = torch.addmm( + self.dense_bias, + attn_output.view(-1, self.num_heads * self.head_size), + self.dense_weight, + ) hidden_states, _ = self.layer_norm.forward(hidden_states, residual) return hidden_states @@ -238,7 +244,9 @@ def forward(self, hidden_states, cu_seqlens, max_s, attn_mask=None): hidden_states, cu_seqlens, max_s, attn_mask ) residual = hidden_states - hidden_states = F.linear(hidden_states, self.intermediate_weight.T, self.intermediate_bias) + hidden_states = F.linear( + hidden_states, self.intermediate_weight.T, self.intermediate_bias + ) hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = F.linear(hidden_states, self.output_weight.T, self.output_bias) hidden_states, _ = self.layer_norm.forward(hidden_states, residual) @@ -327,7 +335,7 @@ def embed(self, batch: Union[FlashBatch, PaddedBatch]) -> List[Embedding]: cu_seqlens=cu_seqlens, max_s=max_input_lens, mask=mask, - attn_mask=attn_mask + attn_mask=attn_mask, ) cpu_results = embedding.view(-1).tolist() diff --git a/backends/python/server/text_embeddings_server/utils/device.py b/backends/python/server/text_embeddings_server/utils/device.py index 76cd0dbfb..3f3b04dd7 100644 --- a/backends/python/server/text_embeddings_server/utils/device.py +++ b/backends/python/server/text_embeddings_server/utils/device.py @@ -13,6 +13,7 @@ "1", ] + def _is_ipex_available(): def get_major_and_minor_from_version(full_version): return ( @@ -60,6 +61,7 @@ def get_device(): device = torch.device("cuda") elif is_hpu(): import habana_frameworks.torch.core as htcore + # WA for perf degradation from pytorch 2.5 if ALLOW_REDUCED_PRECISION: torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) From fd4ab866fb6d0c3985285ea6b63e4af7105ed72a Mon Sep 17 00:00:00 2001 From: kaixuanliu Date: Fri, 7 Mar 2025 09:54:47 +0000 Subject: [PATCH 4/5] small fix Signed-off-by: kaixuanliu --- .../models/flash_bert.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/backends/python/server/text_embeddings_server/models/flash_bert.py b/backends/python/server/text_embeddings_server/models/flash_bert.py index c0cab5ebd..0b2685808 100644 --- a/backends/python/server/text_embeddings_server/models/flash_bert.py +++ b/backends/python/server/text_embeddings_server/models/flash_bert.py @@ -169,7 +169,11 @@ def __init__(self, prefix, handle, device, dtype, config: BertConfig): def forward(self, hidden_states, cu_seqlens, max_s, attn_mask=None): residual = hidden_states qkv = F.linear(hidden_states, self.qkv_weight.T, self.qkv_bias) + bs = 1 + hidden_dim = hidden_states.size(-1) + is_flat = True if hidden_states.dim() > 2: + is_flat = False bs = hidden_states.size(0) q, k, v = qkv.view(bs, -1, self.num_heads * 3, self.head_size).split( self.num_heads, dim=2 @@ -189,14 +193,14 @@ def forward(self, hidden_states, cu_seqlens, max_s, attn_mask=None): self.softmax_scale, attn_mask=attn_mask, ) - if hidden_states.dim() > 2: - hidden_states = F.linear(hidden_states, self.dense_weight, self.dense_bias) - else: - hidden_states = torch.addmm( - self.dense_bias, - attn_output.view(-1, self.num_heads * self.head_size), - self.dense_weight, - ) + + hidden_states = torch.addmm( + self.dense_bias, + attn_output.view(-1, self.num_heads * self.head_size), + self.dense_weight, + ) + if not is_flat: + hidden_states = hidden_states.view(bs, -1, hidden_dim) hidden_states, _ = self.layer_norm.forward(hidden_states, residual) return hidden_states From 7121111c7f86d7f6d1bd9338c61716b5fb3b1817 Mon Sep 17 00:00:00 2001 From: kaixuanliu Date: Tue, 11 Mar 2025 01:14:50 +0000 Subject: [PATCH 5/5] small fix Signed-off-by: kaixuanliu --- .../python/server/text_embeddings_server/models/__init__.py | 1 - .../python/server/text_embeddings_server/models/flash_bert.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 54ceaf00f..e5cbf72cc 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -15,7 +15,6 @@ __all__ = ["Model"] TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"] - # Disable gradients torch.set_grad_enabled(False) diff --git a/backends/python/server/text_embeddings_server/models/flash_bert.py b/backends/python/server/text_embeddings_server/models/flash_bert.py index 0b2685808..d563c07dd 100644 --- a/backends/python/server/text_embeddings_server/models/flash_bert.py +++ b/backends/python/server/text_embeddings_server/models/flash_bert.py @@ -323,9 +323,9 @@ def embed(self, batch: Union[FlashBatch, PaddedBatch]) -> List[Embedding]: mask = batch.attention_mask.to(torch.bool) batch_size = input_lens.size(0) attn_mask = torch.empty( - [batch_size, 1, 1, max_input_lens], device=self.device + [batch_size, 1, 1, mask.shape[-1]], device=self.device ).fill_(float("-inf")) - attn_mask[:, :, :, :max_input_lens].masked_fill_(mask[:, None, None, :], 0) + attn_mask[:, :, :, :].masked_fill_(mask[:, None, None, :], 0) elif isinstance(batch, FlashBatch): cu_seqlens = batch.cu_seqlens mask = None