Skip to content

[Bug]: Missing detection of BFloat16 for CPU ARM #11814

@wallashss

Description

@wallashss

Your current environment

The output of `python collect_env.py`
Collecting environment information...
PyTorch version: 2.5.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (aarch64)
GCC version: (Ubuntu 12.3.0-1ubuntu1~22.04) 12.3.0
Clang version: Could not collect
CMake version: version 3.31.2
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov  6 2024, 20:22:13) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.9.12-200.fc40.aarch64-aarch64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         aarch64
CPU op-mode(s):                       64-bit
Byte Order:                           Little Endian
CPU(s):                               4
On-line CPU(s) list:                  0-3
Vendor ID:                            Apple
Model:                                0
Thread(s) per core:                   1
Core(s) per cluster:                  4
Socket(s):                            -
Cluster(s):                           1
Stepping:                             0x0
BogoMIPS:                             48.00
Flags:                                fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 asimddp sha512 asimdfhm dit uscat ilrcpc flagm ssbs sb paca pacg dcpodp flagm2 frint
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-3
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; __user pointer sanitization
Vulnerability Spectre v2:             Not affected
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] pyzmq==26.2.0
[pip3] torch==2.5.1
[pip3] torchvision==0.20.1
[pip3] transformers==4.47.1
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.6.4.post2.dev543+g313608dab.d20250102
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
Could not collect

LD_LIBRARY_PATH=/usr/local/lib/python3.10/dist-packages/cv2/../../lib64:


Model Input Dumps

No response

🐛 Describe the bug

On CPU-ARM not all processors have support for bfloat16. In those case trying to run inference will crash like in the following stacktrace:

ERROR 01-07 18:11:29 engine.py:135] RuntimeError('"rms_norm_impl" not implemented for \'BFloat16\'')
ERROR 01-07 18:11:29 engine.py:135] Traceback (most recent call last):
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/multiprocessing/engine.py", line 133, in start
ERROR 01-07 18:11:29 engine.py:135]     self.run_engine_loop()
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/multiprocessing/engine.py", line 196, in run_engine_loop
ERROR 01-07 18:11:29 engine.py:135]     request_outputs = self.engine_step()
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/multiprocessing/engine.py", line 214, in engine_step
ERROR 01-07 18:11:29 engine.py:135]     raise e
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/multiprocessing/engine.py", line 205, in engine_step
ERROR 01-07 18:11:29 engine.py:135]     return self.engine.step()
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 1394, in step
ERROR 01-07 18:11:29 engine.py:135]     outputs = self.model_executor.execute_model(
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/cpu_executor.py", line 201, in execute_model
ERROR 01-07 18:11:29 engine.py:135]     output = self.driver_method_invoker(self.driver_worker,
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/cpu_executor.py", line 298, in _driver_method_invoker
ERROR 01-07 18:11:29 engine.py:135]     return getattr(driver, method)(*args, **kwargs)
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker_base.py", line 344, in execute_model
ERROR 01-07 18:11:29 engine.py:135]     output = self.model_runner.execute_model(
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 01-07 18:11:29 engine.py:135]     return func(*args, **kwargs)
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/cpu_model_runner.py", line 530, in execute_model
ERROR 01-07 18:11:29 engine.py:135]     hidden_states = model_executable(
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
ERROR 01-07 18:11:29 engine.py:135]     return self._call_impl(*args, **kwargs)
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
ERROR 01-07 18:11:29 engine.py:135]     return forward_call(*args, **kwargs)
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/qwen2.py", line 477, in forward
ERROR 01-07 18:11:29 engine.py:135]     hidden_states = self.model(input_ids, positions, kv_caches,
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/vllm/compilation/decorators.py", line 168, in __call__
ERROR 01-07 18:11:29 engine.py:135]     return self.forward(*args, **kwargs)
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/qwen2.py", line 340, in forward
ERROR 01-07 18:11:29 engine.py:135]     hidden_states, residual = layer(
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
ERROR 01-07 18:11:29 engine.py:135]     return self._call_impl(*args, **kwargs)
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
ERROR 01-07 18:11:29 engine.py:135]     return forward_call(*args, **kwargs)
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/qwen2.py", line 243, in forward
ERROR 01-07 18:11:29 engine.py:135]     hidden_states = self.input_layernorm(hidden_states)
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
ERROR 01-07 18:11:29 engine.py:135]     return self._call_impl(*args, **kwargs)
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
ERROR 01-07 18:11:29 engine.py:135]     return forward_call(*args, **kwargs)
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/custom_op.py", line 24, in forward
ERROR 01-07 18:11:29 engine.py:135]     return self._forward_method(*args, **kwargs)
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/custom_op.py", line 48, in forward_cpu
ERROR 01-07 18:11:29 engine.py:135]     return self.forward_cuda(*args, **kwargs)
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/layernorm.py", line 94, in forward_cuda
ERROR 01-07 18:11:29 engine.py:135]     ops.rms_norm(
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/vllm/_custom_ops.py", line 182, in rms_norm
ERROR 01-07 18:11:29 engine.py:135]     torch.ops._C.rms_norm(out, input, weight, epsilon)
ERROR 01-07 18:11:29 engine.py:135]   File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1116, in __call__
ERROR 01-07 18:11:29 engine.py:135]     return self._op(*args, **(kwargs or {}))
ERROR 01-07 18:11:29 engine.py:135] RuntimeError: "rms_norm_impl" not implemented for 'BFloat16'

Server

vllm serve Qwen/Qwen2.5-0.5B-Instruct

Request

  curl http://localhost:8002/v1/completions -H "Content-Type: application/json"   -d '{
    "model": "Qwen/Qwen2.5-0.5B-Instruct",
    "prompt": ["How to make pizza"],
    "max_tokens": 100,
    "temperature": 0 
  }'

Fix suggestion

To support or not bfloat16 is a device dependent issue. The ideal solution is to check if the device support or not in vLLM code. I see to two ways to address this problem:

  • At build time: The easiest way. We need to check during the build for CPU if the host has support to it like in cpu_extension.cpu and add some bindings to the python side to check if the build support the feature and properly handle cases where this feature is used to allow it or not and avoid crashes. This solution works nice when the user build and install vLLM from the source.
  • At runtime: Part of the implementation is similar to the build time, however we should always build with support of bfloat64, and we have to check not only if the build support it but also if bfloat64 is available in the host at runtime (probably it is possible to address the latter with python only). This solution is proper for distribution if we plan to do that in the future.

This issue also helps in the case of Mac with Apple chips, M1 does not support bfloat64 but newer version shall support it and they could take advantage of this feature.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleOver 90 days of inactivity

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions