Skip to content

Commit 09a0d3f

Browse files
authored
Merge pull request huggingface#3 from kaixuanliu/ipex
add XPU and HPU support
2 parents f61b8bd + fc979a9 commit 09a0d3f

File tree

10 files changed

+303
-52
lines changed

10 files changed

+303
-52
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ tracing = "0.1"
2929
serde = { version = "1.0", features = ["serde_derive"] }
3030
serde_json = "1.0"
3131
thiserror = "1.0"
32+
rand = "0.8"
3233

3334

3435
[patch.crates-io]

Dockerfile-intel

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
ARG PLATFORM=cpu
12
FROM lukemathwalker/cargo-chef:latest-rust-1.75-bookworm AS chef
23
WORKDIR /usr/src
3-
44
ENV SCCACHE=0.5.4
55
ENV RUSTC_WRAPPER=/usr/local/bin/sccache
66

@@ -54,8 +54,7 @@ COPY proto proto
5454

5555
RUN cargo build --release --bin text-embeddings-router -F grpc -F python --no-default-features && sccache -s
5656

57-
FROM intel/intel-optimized-pytorch:2.3.0-pip-base as base
58-
57+
FROM intel/intel-optimized-pytorch:2.4.0-pip-base AS cpu
5958
ENV HUGGINGFACE_HUB_CACHE=/data \
6059
PORT=80
6160

@@ -72,26 +71,79 @@ COPY backends backends
7271
COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py
7372
COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml
7473
COPY backends/python/server/requirements-intel.txt backends/python/server/requirements.txt
74+
75+
RUN python -m pip install torch==2.4.0 torchvision torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cpu
76+
7577
RUN cd backends/python/server && \
7678
make install
7779

78-
RUN python -m pip install torch==2.4.0 torchvision torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/test/cpu
79-
RUN python -m pip uninstall -y intel-extension-for-pytorch
80-
RUN git clone https://github.com/intel/intel-extension-for-pytorch.git &&\
81-
cd intel-extension-for-pytorch &&\
82-
git reset --hard 620a9bfd9db42813931a857e78fa3f5d298be200 &&\
83-
git submodule sync &&\
84-
git submodule update --init --recursive &&\
85-
python setup.py install
80+
FROM vault.habana.ai/gaudi-docker/1.16.1/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest AS hpu
81+
ENV HUGGINGFACE_HUB_CACHE=/data \
82+
PORT=80
83+
84+
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
85+
build-essential \
86+
git \
87+
cmake \
88+
ninja-build \
89+
python3-dev &&\
90+
rm -rf /var/lib/apt/lists/*
91+
92+
WORKDIR /usr/src
93+
COPY backends backends
94+
COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py
95+
COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml
96+
COPY backends/python/server/requirements-hpu.txt backends/python/server/requirements.txt
97+
98+
RUN cd backends/python/server && \
99+
make install
100+
101+
FROM intel/intel-extension-for-pytorch:2.1.40-xpu AS xpu
102+
103+
ENV HUGGINGFACE_HUB_CACHE=/data \
104+
PORT=80
105+
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
106+
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
107+
108+
RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null
109+
110+
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
111+
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
112+
113+
RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build pciutils
114+
WORKDIR /usr/src
115+
RUN pip install PyYAML
116+
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
117+
RUN pip install https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
118+
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed
119+
120+
ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest
121+
ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest
122+
ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric
123+
ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib
124+
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:
125+
ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
126+
ENV CCL_ZE_IPC_EXCHANGE=sockets
127+
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
128+
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
129+
130+
RUN pip uninstall -y intel-extension-for-pytorch && cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
131+
132+
COPY backends backends
133+
COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py
134+
COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml
135+
COPY backends/python/server/requirements-intel.txt backends/python/server/requirements.txt
136+
RUN cd backends/python/server && \
137+
make install
86138

87-
FROM base as grpc
139+
FROM ${PLATFORM} AS grpc
88140

89141
COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router
90142

91143
ENTRYPOINT ["text-embeddings-router"]
92144
CMD ["--json-output"]
93145

94-
FROM base
146+
FROM ${PLATFORM}
95147

96148
COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router
97149

backends/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ text-embeddings-backend-candle = { path = "candle", optional = true }
1515
text-embeddings-backend-ort = { path = "ort", optional = true }
1616
tokio = { workspace = true }
1717
tracing = { workspace = true }
18+
rand = { workspace = true }
1819

1920
[features]
2021
clap = ["dep:clap", "text-embeddings-backend-core/clap"]
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
2+
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
3+
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
4+
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
5+
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
6+
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
7+
filelock==3.13.4 ; python_version >= "3.9" and python_version < "3.13"
8+
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
9+
fsspec[http]==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
10+
googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
11+
grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13"
12+
grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
13+
grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
14+
grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
15+
huggingface-hub==0.22.2 ; python_version >= "3.9" and python_version < "3.13"
16+
humanfriendly==10.0 ; python_version >= "3.9" and python_version < "3.13"
17+
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
18+
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
19+
jinja2==3.1.3 ; python_version >= "3.9" and python_version < "3.13"
20+
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
21+
markupsafe==2.1.5 ; python_version >= "3.9" and python_version < "3.13"
22+
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
23+
networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
24+
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
25+
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
26+
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
27+
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
28+
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
29+
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
30+
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
31+
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
32+
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
33+
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
34+
optimum-habana==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
35+
optimum==1.20.0 ; python_version >= "3.9" and python_version < "3.13"
36+
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
37+
pandas==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
38+
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
39+
protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13"
40+
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
41+
regex==2024.4.16 ; python_version >= "3.9" and python_version < "3.13"
42+
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
43+
safetensors==0.4.1 ; python_version >= "3.9" and python_version < "3.13"
44+
setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13"
45+
six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
46+
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
47+
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
48+
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
49+
transformers==4.40.2 ; python_version >= "3.9" and python_version < "3.13"
50+
transformers[sentencepiece]==4.40.2 ; python_version >= "3.9" and python_version < "3.13"
51+
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
52+
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
53+
tzdata==2024.1 ; python_version >= "3.9" and python_version < "3.13"
54+
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
55+
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
56+
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
57+
xxhash==3.4.1 ; python_version >= "3.9" and python_version < "3.13"
58+
yarl==1.9.4 ; python_version >= "3.9" and python_version < "3.13"
59+
zipp==3.18.1 ; python_version >= "3.9" and python_version < "3.13"
60+
pyrsistent==0.20.0 ; python_version >= "3.9" and python_version < "3.13"

backends/python/server/requirements-intel.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,5 @@ typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
4040
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
4141
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
4242
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
43-
transformers==4.40.0 ; python_version >= "3.9" and python_version < "3.13"
43+
transformers==4.40.0 ; python_version >= "3.9" and python_version < "3.13"
44+
pyrsistent==0.20.0 ; python_version >= "3.9" and python_version < "3.13"

backends/python/server/text_embeddings_server/models/__init__.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from pyrsistent import s
12
import torch
23

34
from loguru import logger
@@ -25,38 +26,36 @@
2526
__all__.append(FlashBert)
2627

2728

28-
def get_model(model_path: Path, dtype: Optional[str]):
29+
def get_model(model_path: Path, dtype: Optional[str]) :
2930
if dtype == "float32":
30-
dtype = torch.float32
31+
datatype = torch.float32
3132
elif dtype == "float16":
32-
dtype = torch.float16
33+
datatype = torch.float16
3334
elif dtype == "bfloat16":
34-
dtype = torch.bfloat16
35+
datatype = torch.bfloat16
3536
else:
3637
raise RuntimeError(f"Unknown dtype {dtype}")
3738

3839
device = get_device()
3940
config = AutoConfig.from_pretrained(model_path)
40-
4141
if config.model_type == "bert":
4242
config: BertConfig
4343
if (
4444
device.type == "cuda"
4545
and config.position_embedding_type == "absolute"
46-
and dtype in [torch.float16, torch.bfloat16]
46+
and datatype in [torch.float16, torch.bfloat16]
4747
and FLASH_ATTENTION
4848
):
49-
return FlashBert(model_path, device, dtype)
50-
elif (
51-
device.type == "cpu"
52-
and use_ipex()
53-
):
54-
logger.info("Use the flashBert for CPU")
55-
return FlashBert(model_path, device, dtype)
56-
else:
57-
return DefaultModel(model_path, device, dtype)
49+
return FlashBert(model_path, device, datatype) # type: ignore
50+
if use_ipex() and device.type in ["cpu", "xpu"]:
51+
return FlashBert(model_path, device, datatype) # type: ignore
52+
if device.type == "hpu":
53+
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
54+
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
55+
adapt_transformers_to_gaudi()
56+
model_handle = DefaultModel(model_path, device, datatype)
57+
model_handle.model = wrap_in_hpu_graph(model_handle.model, disable_tensor_cache=True)
58+
return model_handle
59+
return DefaultModel(model_path, device, datatype)
5860
else:
59-
try:
60-
return DefaultModel(model_path, device, dtype)
61-
except:
62-
raise RuntimeError(f"Unsupported model_type {config.model_type}")
61+
return DefaultModel(model_path, device, datatype)

backends/python/server/text_embeddings_server/models/flash_bert.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from text_embeddings_server.models import Model
1313
from text_embeddings_server.models.types import FlashBatch, Embedding
1414
from text_embeddings_server.utils.flash_attn import attention
15-
15+
from text_embeddings_server.utils.device import use_ipex
1616
tracer = trace.get_tracer(__name__)
1717

1818

@@ -25,6 +25,8 @@ def __init__(self, prefix, handle, device, dtype, config: BertConfig):
2525

2626
def forward(self, hidden_states, residual=None):
2727
# Flash attention imports
28+
normed_hidden_states = None
29+
res = None
2830
if self.device.type == "cuda":
2931
import dropout_layer_norm
3032
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
@@ -46,7 +48,7 @@ def forward(self, hidden_states, residual=None):
4648
)
4749
if res is None:
4850
res = hidden_states
49-
else:
51+
elif use_ipex():
5052
import intel_extension_for_pytorch as ipex
5153
normed_hidden_states = ipex.llm.functional.add_layer_norm(
5254
residual,
Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
import os
2-
from loguru import logger
2+
from loguru import logger # type: ignore
33
import importlib
44
from packaging import version
55
import torch
6+
import subprocess
67

7-
def is_ipex_available():
8+
def _is_ipex_available():
89
def get_major_and_minor_from_version(full_version):
910
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
1011

11-
_torch_version = importlib.metadata.version("torch")
12-
if importlib.util.find_spec("intel_extension_for_pytorch") is None:
12+
_torch_version = importlib.metadata.version("torch") # type: ignore
13+
if importlib.util.find_spec("intel_extension_for_pytorch") is None: # type: ignore
1314
return False
1415
_ipex_version = "N/A"
1516
try:
16-
_ipex_version = importlib.metadata.version("intel_extension_for_pytorch")
17-
except importlib.metadata.PackageNotFoundError:
17+
_ipex_version = importlib.metadata.version("intel_extension_for_pytorch") # type: ignore
18+
except importlib.metadata.PackageNotFoundError: # type: ignore
1819
return False
1920
torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
2021
ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
@@ -26,22 +27,29 @@ def get_major_and_minor_from_version(full_version):
2627
return False
2728
return True
2829

29-
def use_ipex() :
30+
def _is_hpu() -> bool:
31+
is_hpu_available = True
32+
try:
33+
subprocess.run(["hl-smi"], capture_output=True, check=True)
34+
except (FileNotFoundError, PermissionError, subprocess.CalledProcessError):
35+
is_hpu_available = False
36+
return is_hpu_available
37+
38+
def use_ipex() -> bool:
3039
value = os.environ.get("USE_IPEX", "True").lower()
31-
if value in ["true", "1"] and is_ipex_available():
32-
return True
33-
else:
34-
return False
40+
return (value in ["true", "1"] and _is_ipex_available())
3541

3642
def get_device() :
43+
device = torch.device("cpu")
3744
if torch.cuda.is_available():
3845
device = torch.device("cuda")
39-
elif is_ipex_available():
46+
elif _is_hpu():
47+
import habana_frameworks.torch.core as htcore
48+
if hasattr(torch, "hpu") and torch.hpu.is_available(): # type: ignore
49+
device = torch.device("hpu")
50+
elif use_ipex():
4051
if hasattr(torch, "xpu") and torch.xpu.is_available():
4152
device = torch.device("xpu")
42-
else:
43-
device = torch.device("cpu")
44-
else:
45-
device = torch.device("cpu")
53+
4654
return device
4755

backends/python/server/text_embeddings_server/utils/flash_attn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@ def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False):
5858
if HAS_FLASH_ATTN_V2:
5959
if use_ipex():
6060
import intel_extension_for_pytorch as ipex
61-
return ipex.llm.functional.varlen_attention(q, k, v, out, cu_seqlens, cu_seqlens, max_s, max_s, 0, softmax_scale, zero_tensors=False, is_causal=False, return_softmax=False, gen_=None)
61+
return ipex.llm.functional.varlen_attention(q, k, v, out, cu_seqlens, cu_seqlens,
62+
max_s, max_s, 0, softmax_scale,
63+
zero_tensors=False, is_causal=False,
64+
return_softmax=False, gen_=None)
6265
else:
6366
return flash_attn_2_cuda.varlen_fwd(
6467
q,

0 commit comments

Comments
 (0)