Skip to content

Commit f7d1e1b

Browse files
authored
Merge pull request huggingface#5 from kaixuanliu/ipex
add hpu flashBert support
2 parents 081ab41 + cbc3ee2 commit f7d1e1b

File tree

5 files changed

+114
-22
lines changed

5 files changed

+114
-22
lines changed

Dockerfile-intel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ RUN python -m pip install torch==2.4.0 torchvision torchaudio==2.4.0 --index-url
7777
RUN cd backends/python/server && \
7878
make install
7979

80-
FROM vault.habana.ai/gaudi-docker/1.16.1/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest AS hpu
80+
FROM vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest AS hpu
8181
ENV HUGGINGFACE_HUB_CACHE=/data \
8282
PORT=80
8383

backends/python/server/text_embeddings_server/models/__init__.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def get_model(model_path: Path, dtype: Optional[str]) :
3737
raise RuntimeError(f"Unknown dtype {dtype}")
3838

3939
device = get_device()
40+
logger.info(f"backend device: {device}")
4041
config = AutoConfig.from_pretrained(model_path)
4142
if config.model_type == "bert":
4243
config: BertConfig
@@ -48,14 +49,12 @@ def get_model(model_path: Path, dtype: Optional[str]) :
4849
):
4950
return FlashBert(model_path, device, datatype) # type: ignore
5051
if use_ipex() and device.type in ["cpu", "xpu"]:
52+
import intel_extension_for_pytorch as ipex
5153
return FlashBert(model_path, device, datatype) # type: ignore
5254
if device.type == "hpu":
53-
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
54-
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
55-
adapt_transformers_to_gaudi()
56-
model_handle = DefaultModel(model_path, device, datatype)
57-
model_handle.model = wrap_in_hpu_graph(model_handle.model, disable_tensor_cache=True)
58-
return model_handle
55+
import habana_frameworks.torch.core as htcore
56+
return FlashBert(model_path, device, datatype)
57+
5958
return DefaultModel(model_path, device, datatype)
6059
else:
6160
return DefaultModel(model_path, device, datatype)

backends/python/server/text_embeddings_server/models/flash_bert.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,43 @@
11
import torch
2-
32
from pathlib import Path
43
from torch import nn
4+
import torch.nn.functional as F
55
from typing import Type, List
66
from safetensors import safe_open
77
from transformers.activations import ACT2FN
88
from transformers.models.bert import BertConfig
99
from opentelemetry import trace
10-
11-
1210
from text_embeddings_server.models import Model
1311
from text_embeddings_server.models.types import FlashBatch, Embedding
1412
from text_embeddings_server.utils.flash_attn import attention
1513
from text_embeddings_server.utils.device import use_ipex
14+
1615
tracer = trace.get_tracer(__name__)
1716

17+
def hpu_add_layer_norm(
18+
add: torch.Tensor,
19+
x: torch.Tensor,
20+
weight: torch.Tensor,
21+
bias: torch.Tensor,
22+
epsilon: float,
23+
add_back: bool
24+
):
25+
if add is not None:
26+
added_tensor = torch.add(add, x, alpha=1.0)
27+
output = F.layer_norm(added_tensor, [x.size(-1)], weight, bias, epsilon)
28+
if add_back:
29+
add.add_(x)
30+
return output
31+
else:
32+
return F.layer_norm(x, [x.size(-1)], weight=weight, bias=bias, eps=epsilon)
1833

1934
class FastLayerNorm:
2035
def __init__(self, prefix, handle, device, dtype, config: BertConfig):
2136
self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device)
2237
self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device)
2338
self.variance_epsilon = config.layer_norm_eps
2439
self.device = device
40+
self.use_ipex = use_ipex()
2541

2642
def forward(self, hidden_states, residual=None):
2743
# Flash attention imports
@@ -48,7 +64,7 @@ def forward(self, hidden_states, residual=None):
4864
)
4965
if res is None:
5066
res = hidden_states
51-
elif use_ipex():
67+
elif self.use_ipex:
5268
import intel_extension_for_pytorch as ipex
5369
normed_hidden_states = ipex.llm.functional.add_layer_norm(
5470
residual,
@@ -60,7 +76,16 @@ def forward(self, hidden_states, residual=None):
6076
)
6177

6278
res = residual if residual is not None else hidden_states
63-
79+
elif self.device.type == "hpu":
80+
normed_hidden_states = hpu_add_layer_norm(
81+
residual,
82+
hidden_states,
83+
self.weight,
84+
self.bias,
85+
self.variance_epsilon,
86+
residual is not None
87+
)
88+
res = residual if residual is not None else hidden_states
6489
return normed_hidden_states, res
6590

6691

@@ -242,7 +267,9 @@ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
242267
config = BertConfig.from_pretrained(model_path)
243268
with safe_open(model_path / "model.safetensors", framework="pt") as f:
244269
model = FlashBertModel(f, device, dtype, config)
245-
270+
if device.type == "hpu":
271+
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
272+
model = wrap_in_hpu_graph(model, disable_tensor_cache=False)
246273
self.hidden_size = config.hidden_size
247274

248275
super(FlashBert, self).__init__(model=model, dtype=dtype, device=device)

backends/python/server/text_embeddings_server/utils/device.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ def get_major_and_minor_from_version(full_version):
2727
return False
2828
return True
2929

30-
def _is_hpu() -> bool:
30+
def is_hpu() -> bool:
3131
is_hpu_available = True
3232
try:
3333
subprocess.run(["hl-smi"], capture_output=True, check=True)
34-
except (FileNotFoundError, PermissionError, subprocess.CalledProcessError):
34+
except:
3535
is_hpu_available = False
3636
return is_hpu_available
3737

@@ -43,7 +43,7 @@ def get_device() :
4343
device = torch.device("cpu")
4444
if torch.cuda.is_available():
4545
device = torch.device("cuda")
46-
elif _is_hpu():
46+
elif is_hpu():
4747
import habana_frameworks.torch.core as htcore
4848
if hasattr(torch, "hpu") and torch.hpu.is_available(): # type: ignore
4949
device = torch.device("hpu")

backends/python/server/text_embeddings_server/utils/flash_attn.py

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import torch
3-
from text_embeddings_server.utils.device import use_ipex
3+
from text_embeddings_server.utils.device import use_ipex, is_hpu
44

55
from loguru import logger
66

@@ -10,7 +10,10 @@
1010
HAS_FLASH_ATTN = False
1111
HAS_FLASH_ATTN_V2 = False
1212

13-
if use_ipex():
13+
is_hpu = is_hpu()
14+
use_ipex = use_ipex()
15+
16+
if use_ipex or is_hpu:
1417
HAS_FLASH_ATTN_V2 = True
1518
else:
1619
if not torch.cuda.is_available():
@@ -54,14 +57,77 @@
5457
HAS_FLASH_ATTN = True
5558

5659

60+
def hpu_attn(q, k, v, out, seqlen_q, seqlen_k, max_seqlen_q, max_seqlen_k, softmax_scale, is_causal=False):
61+
from habana_frameworks.torch.hpex.kernels import FusedSDPA
62+
total_q, num_head, head_size = q.size()
63+
total_k, num_head_k, _ = k.size()
64+
batch_size = seqlen_q.size(0) - 1
65+
seqlen_q_ = seqlen_q.clone()
66+
seqlen_q_[:batch_size] = seqlen_q[1:]
67+
seqlen_q = (seqlen_q_ - seqlen_q)[:batch_size]
68+
seqlen_k_ = seqlen_k.clone()
69+
seqlen_k_[:batch_size] = seqlen_k[1:]
70+
seqlen_k = (seqlen_k_ - seqlen_k)[:batch_size]
71+
72+
pad_q = torch.zeros(
73+
[batch_size, max_seqlen_q, num_head, head_size],
74+
dtype=q.dtype,
75+
device=q.device,
76+
)
77+
pad_k = torch.zeros(
78+
[batch_size, max_seqlen_k, num_head_k, head_size],
79+
dtype=k.dtype,
80+
device=k.device,
81+
)
82+
pad_v = torch.zeros(
83+
[batch_size, max_seqlen_k, num_head_k, head_size],
84+
dtype=v.dtype,
85+
device=v.device,
86+
)
87+
q_mask = torch.arange(0, max_seqlen_q, device=q.device)[None, :].repeat(
88+
batch_size, 1
89+
)
90+
q_mask = q_mask < seqlen_q[:, None].repeat(1, q_mask.size(-1))
91+
k_mask = torch.arange(0, max_seqlen_k, device=k.device)[None, :].repeat(
92+
batch_size, 1
93+
)
94+
k_mask = k_mask < seqlen_k[:, None].repeat(1, k_mask.size(-1))
95+
align_mask_seqlen = max_seqlen_k
96+
attn_mask = torch.empty(
97+
[batch_size, 1, 1, align_mask_seqlen],
98+
dtype=q.dtype,
99+
device=q.device,
100+
).fill_(float("-inf"))
101+
attn_mask[:, :, :, :max_seqlen_k].masked_fill_(k_mask[:, None, None, :], 0)
102+
103+
pad_q[q_mask] = q
104+
pad_k[k_mask] = k
105+
pad_v[k_mask] = v
106+
107+
pad_q = pad_q.permute(0, 2, 1, 3)
108+
pad_k = pad_k.permute(0, 2, 1, 3)
109+
pad_v = pad_v.permute(0, 2, 1, 3)
110+
if is_causal:
111+
attn_mask = None
112+
113+
out_ = FusedSDPA.apply(pad_q, pad_k, pad_v, attn_mask, 0.0, is_causal, softmax_scale)
114+
out_ = out_.permute(0, 2, 1, 3)
115+
out.copy_(out_[q_mask])
116+
return out
117+
118+
57119
def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False):
58120
if HAS_FLASH_ATTN_V2:
59-
if use_ipex():
121+
if use_ipex:
60122
import intel_extension_for_pytorch as ipex
61-
return ipex.llm.functional.varlen_attention(q, k, v, out, cu_seqlens, cu_seqlens,
62-
max_s, max_s, 0, softmax_scale,
63-
zero_tensors=False, is_causal=False,
123+
return ipex.llm.functional.varlen_attention(q, k, v, out, cu_seqlens, cu_seqlens,
124+
max_s, max_s, 0, softmax_scale,
125+
zero_tensors=False, is_causal=False,
64126
return_softmax=False, gen_=None)
127+
elif is_hpu:
128+
return hpu_attn(q, k, v, out, cu_seqlens, cu_seqlens,
129+
max_s, max_s, softmax_scale, is_causal=False)
130+
65131
else:
66132
return flash_attn_2_cuda.varlen_fwd(
67133
q,

0 commit comments

Comments
 (0)