Skip to content

Commit fc979a9

Browse files
committed
nice code
Signed-off-by: Liu, Kaixuan <[email protected]>
1 parent e161e83 commit fc979a9

File tree

3 files changed

+9
-13
lines changed

3 files changed

+9
-13
lines changed

backends/python/server/text_embeddings_server/models/flash_bert.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from text_embeddings_server.models import Model
1313
from text_embeddings_server.models.types import FlashBatch, Embedding
1414
from text_embeddings_server.utils.flash_attn import attention
15-
15+
from text_embeddings_server.utils.device import use_ipex
1616
tracer = trace.get_tracer(__name__)
1717

1818

@@ -25,6 +25,8 @@ def __init__(self, prefix, handle, device, dtype, config: BertConfig):
2525

2626
def forward(self, hidden_states, residual=None):
2727
# Flash attention imports
28+
normed_hidden_states = None
29+
res = None
2830
if self.device.type == "cuda":
2931
import dropout_layer_norm
3032
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
@@ -46,7 +48,7 @@ def forward(self, hidden_states, residual=None):
4648
)
4749
if res is None:
4850
res = hidden_states
49-
else:
51+
elif use_ipex():
5052
import intel_extension_for_pytorch as ipex
5153
normed_hidden_states = ipex.llm.functional.add_layer_norm(
5254
residual,

backends/python/server/text_embeddings_server/utils/device.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,7 @@ def _is_hpu() -> bool:
3232
try:
3333
subprocess.run(["hl-smi"], capture_output=True, check=True)
3434
except (FileNotFoundError, PermissionError, subprocess.CalledProcessError):
35-
if not os.path.exists('/dev/accel/accel0') and not os.path.exists(
36-
'/dev/accel/accel_controlD0'):
37-
# last resort...
38-
try:
39-
output = subprocess.check_output(
40-
'lsmod | grep habanalabs | wc -l', shell=True)
41-
is_hpu_available = int(output) > 0
42-
except (ValueError, FileNotFoundError, PermissionError,
43-
subprocess.CalledProcessError):
44-
is_hpu_available = False
35+
is_hpu_available = False
4536
return is_hpu_available
4637

4738
def use_ipex() -> bool:

backends/python/server/text_embeddings_server/utils/flash_attn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@ def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False):
5858
if HAS_FLASH_ATTN_V2:
5959
if use_ipex():
6060
import intel_extension_for_pytorch as ipex
61-
return ipex.llm.functional.varlen_attention(q, k, v, out, cu_seqlens, cu_seqlens, max_s, max_s, 0, softmax_scale, zero_tensors=False, is_causal=False, return_softmax=False, gen_=None)
61+
return ipex.llm.functional.varlen_attention(q, k, v, out, cu_seqlens, cu_seqlens,
62+
max_s, max_s, 0, softmax_scale,
63+
zero_tensors=False, is_causal=False,
64+
return_softmax=False, gen_=None)
6265
else:
6366
return flash_attn_2_cuda.varlen_fwd(
6467
q,

0 commit comments

Comments
 (0)