Skip to content

Commit e5b001d

Browse files
committed
Update readme; Format code; Add example yaml.
1 parent c38e77d commit e5b001d

File tree

8 files changed

+182
-30
lines changed

8 files changed

+182
-30
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
2323

2424
<h2 id="Updates">🔥 Updates</h2>
2525

26+
* **Mar 15, 2025**: Support ROCm on AMD GPU ([Tutorial](./doc/en/ROCm.md)).
2627
* **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and [IQ1_S/FP8 hybrid](./doc/en/fp8_kernel.md) weights. Support 139K [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context) for DeepSeek-V3 and R1 in 24GB VRAM.
2728
* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context).
2829
* **Feb 15, 2025**: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed (+15%, up to 16 Tokens/s), update [docs](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/).

doc/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
2222

2323
<h2 id="Updates">🔥 Updates</h2>
2424

25+
* **Mar 15, 2025**: Support ROCm on AMD GPU ([Tutorial](./doc/en/ROCm.md)).
2526
* **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and [IQ1_S/FP8 hybrid](./doc/en/fp8_kernel.md) weights. Support 139K [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context) for DeepSeek-V3 and R1 in 24GB VRAM.
2627
* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context).
2728
* **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. The detailed tutorial is [here](./en/DeepseekR1_V3_tutorial.md).

doc/SUMMARY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
- [Injection Tutorial](en/injection_tutorial.md)
1111
- [Multi-GPU Tutorial](en/multi-gpu-tutorial.md)
1212
- [Use FP8 GPU Kernel](en/fp8_kernel.md)
13+
- [Use AMD GPU](en/ROCm.md)
1314
# Server
1415
- [Server](en/api/server/server.md)
1516
- [Website](en/api/server/website.md)

doc/en/ROCm.md

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# ROCm Support for ktransformers (Beta)
2+
3+
## Introduction
4+
5+
### Overview
6+
In our effort to expand GPU architecture support beyond NVIDIA, we are excited to introduce **AMD GPU support through ROCm** in ktransformers (Beta release). This implementation has been tested and developed using EPYC 9274F processors and AMD Radeon 7900xtx GPUs.
7+
8+
## Installation Guide
9+
10+
### 1. Install ROCm Driver
11+
Begin by installing the ROCm drivers for your AMD GPU:
12+
- [Official ROCm Installation Guide for Radeon GPUs](https://rocm.docs.amd.com/projects/radeon/en/latest/docs/install/native_linux/install-radeon.html)
13+
14+
### 2. Set Up Conda Environment
15+
We recommend using Miniconda3/Anaconda3 for environment management:
16+
17+
```bash
18+
# Download Miniconda
19+
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
20+
21+
# Create environment
22+
conda create --name ktransformers python=3.11
23+
conda activate ktransformers
24+
25+
# Install required libraries
26+
conda install -c conda-forge libstdcxx-ng
27+
28+
# Verify GLIBCXX version (should include 3.4.32)
29+
strings ~/anaconda3/envs/ktransformers/lib/libstdc++.so.6 | grep GLIBCXX
30+
```
31+
32+
> **Note:** Adjust the Anaconda path if your installation directory differs from `~/anaconda3`
33+
34+
### 3. Install PyTorch for ROCm
35+
Install PyTorch with ROCm 6.2.4 support:
36+
37+
```bash
38+
pip3 install torch torchvision torchaudio \
39+
--index-url https://download.pytorch.org/whl/rocm6.2.4
40+
pip3 install packaging ninja cpufeature numpy
41+
```
42+
43+
> **Tip:** For other ROCm versions, visit [PyTorch Previous Versions](https://pytorch.org/get-started/previous-versions/)
44+
45+
### 4. Build ktransformers
46+
47+
```bash
48+
# Clone repository
49+
git clone https://github.com/kvcache-ai/ktransformers.git
50+
cd ktransformers
51+
git submodule update --init
52+
53+
# Optional: Compile web interface
54+
# See: api/server/website.md
55+
56+
# Install dependencies
57+
bash install.sh
58+
```
59+
60+
## Running DeepSeek-R1 Models
61+
62+
### Configuration for 24GB VRAM GPUs
63+
Use our optimized configuration for constrained VRAM:
64+
65+
```bash
66+
python ktransformers/local_chat.py \
67+
--model_path deepseek-ai/DeepSeek-R1 \
68+
--gguf_path <path_to_gguf_files> \
69+
--optimize_config_path ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml \
70+
--cpu_infer <cpu_cores + 1>
71+
```
72+
73+
> **Beta Note:** Current Q8 linear implementation (Marlin alternative) shows suboptimal performance. Expect optimizations in future releases.
74+
75+
### Configuration for 40GB+ VRAM GPUs
76+
For better performance on high-VRAM GPUs:
77+
78+
1. Modify `DeepSeek-V3-Chat.yaml`:
79+
```yaml
80+
# Replace all instances of:
81+
KLinearMarlin → KLinearTorch
82+
```
83+
84+
2. Execute with:
85+
```bash
86+
python ktransformers/local_chat.py \
87+
--model_path deepseek-ai/DeepSeek-R1 \
88+
--gguf_path <path_to_gguf_files> \
89+
--optimize_config_path <modified_yaml_path> \
90+
--cpu_infer <cpu_cores + 1>
91+
```
92+
> **Tip:** If you got 2 * 24GB AMD GPUS, you may also do the same modify and run `ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml` instead.
93+
94+
## Known Limitations
95+
- Marlin operations not supported on ROCm platform
96+
- Current Q8 linear implementation shows reduced performance (Beta limitation)

ktransformers/operators/linear.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,6 @@ def __init__(
187187
config: PretrainedConfig,
188188
orig_module: nn.Module = None,
189189
device: str = "cuda",
190-
group_size: int = 128, # 增大分组大小,减少量化噪声
191-
percentile: float = 99.99, # 新增:对异常值进行截断的百分位数
192190
**kwargs,
193191
):
194192
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
@@ -199,8 +197,6 @@ def __init__(
199197
self.weight_zero_point = None
200198
self.bias = None
201199
self.loaded = False
202-
self.group_size = group_size
203-
self.percentile = percentile
204200

205201
def forward(self, x: torch.Tensor) -> torch.Tensor:
206202
orig_dtype = x.dtype
@@ -246,16 +242,9 @@ def _dequantize_weight(self, q_matrix, scales, bits=8):
246242
# For Q4, ensure the values stay within 4-bit range
247243
if bits == 4:
248244
q_matrix = torch.clamp(q_matrix, -7, 7)
249-
250-
# Get matrix shape
251245
rows, cols = q_matrix.shape
252-
253-
# Convert to float32
254246
dequant_matrix = q_matrix.to(torch.float32)
255-
256-
# Create broadcasted scales: reshape scales to [1, cols] for broadcasting
257247
scales_broadcast = scales.view(1, cols)
258-
259248
# Apply dequantization to all columns at once using matrix multiplication
260249
dequant_matrix = dequant_matrix * scales_broadcast
261250

@@ -285,21 +274,14 @@ def _quantize_weight(self, matrix, bits=8):
285274

286275
# Determine quantization parameters based on bits
287276
if bits == 8:
288-
# Q8: range is -127 to 127
289277
max_int = 127
290278
qtype = torch.int8
291279
elif bits == 4:
292-
# Q4: range is -7 to 7 (using 4-bit signed integers)
293280
max_int = 7
294-
qtype = torch.int8 # We'll still use int8 storage but limit to 4-bit range
281+
qtype = torch.int8 # We'll still use int8 storage but limit to 4-bit range, wait for native support
295282
else:
296283
raise ValueError("Quantization bits must be either 8 or 4")
297-
298-
# Initialize results and scale factors
299-
q_matrix = torch.zeros_like(matrix, dtype=qtype)
300-
scales = torch.zeros(cols, dtype=torch.float32, device=matrix.device)
301-
302-
# Initialize scale factors
284+
303285
scales = torch.zeros(cols, dtype=torch.float32, device=matrix.device)
304286

305287
# Calculate max absolute value for each column
@@ -370,13 +352,8 @@ def unload(self):
370352
class KLinearFP8(KLinearBase):
371353
# this kernel requires special handling for weight
372354
# Please load the weight file downloaded from KVCache.AI
373-
marlin_q_w: torch.Tensor
374-
marlin_s: torch.Tensor
375-
g_idx: torch.Tensor
376-
sort_indices: torch.Tensor
377355
has_bias: bool
378356
weight: torch.Tensor
379-
scale_w: torch.Tensor
380357
bias: torch.Tensor
381358
def __init__(
382359
self,

ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
kwargs:
1414
generate_device: "cuda"
1515
prefill_device: "cuda"
16-
generate_op: "KLinearQ8"
16+
generate_op: "KLinearMarlin"
1717
prefill_op: "KLinearTorch"
1818

1919
- match:
@@ -24,7 +24,7 @@
2424
kwargs:
2525
generate_device: "cpu"
2626
prefill_device: "cuda"
27-
generate_op: "KLinearTorch"
27+
generate_op: "KLinearMarlin"
2828
prefill_op: "KLinearTorch"
2929

3030
- match:

ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
kwargs:
1515
generate_device: "cuda"
1616
prefill_device: "cuda"
17-
generate_op: "KLinearQ8"
17+
generate_op: "KLinearMarlin"
1818
prefill_op: "KLinearTorch"
1919

2020
- match:
@@ -23,9 +23,9 @@
2323
replace:
2424
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
2525
kwargs:
26-
generate_device: "cpu"
26+
generate_device: "cuda"
2727
prefill_device: "cuda"
28-
generate_op: "KLinearCPUInfer"
28+
generate_op: "KLinearMarlin"
2929
prefill_op: "KLinearTorch"
3030
- match:
3131
name: "^model\\.layers\\..*\\.mlp$"
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
- match:
2+
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
3+
replace:
4+
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
5+
kwargs:
6+
generate_device: "cuda"
7+
prefill_device: "cuda"
8+
9+
- match:
10+
name: "^lm_head$" # regular expression
11+
class: torch.nn.Linear # only match modules matching name and class simultaneously
12+
replace:
13+
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
14+
kwargs:
15+
generate_device: "cuda"
16+
prefill_device: "cuda"
17+
generate_op: "KLinearCPUInfer"
18+
prefill_op: "KLinearTorch"
19+
20+
- match:
21+
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
22+
class: torch.nn.Linear # only match modules matching name and class simultaneously
23+
replace:
24+
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
25+
kwargs:
26+
generate_device: "cpu"
27+
prefill_device: "cuda"
28+
generate_op: "KLinearQ8"
29+
prefill_op: "KLinearTorch"
30+
- match:
31+
name: "^model\\.layers\\..*\\.mlp$"
32+
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
33+
replace:
34+
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
35+
kwargs:
36+
generate_device: "cuda"
37+
prefill_device: "cuda"
38+
- match:
39+
class: ktransformers.models.modeling_deepseek_v3.MoEGate
40+
replace:
41+
class: ktransformers.operators.gate.KMoEGate
42+
kwargs:
43+
generate_device: "cuda:0"
44+
prefill_device: "cuda:0"
45+
- match:
46+
name: "^model\\.layers\\..*\\.mlp\\.experts$"
47+
replace:
48+
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
49+
kwargs:
50+
prefill_device: "cuda"
51+
prefill_op: "KExpertsTorch"
52+
generate_device: "cpu"
53+
generate_op: "KExpertsCPU"
54+
out_device: "cuda"
55+
recursive: False # don't recursively inject submodules of this module
56+
- match:
57+
name: "^model\\.layers\\..*\\.self_attn$"
58+
replace:
59+
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
60+
kwargs:
61+
generate_device: "cuda"
62+
prefill_device: "cuda"
63+
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
64+
- match:
65+
name: "^model$"
66+
replace:
67+
class: "ktransformers.operators.models.KDeepseekV2Model"
68+
kwargs:
69+
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
70+
- match:
71+
name: "^model.embed_tokens"
72+
replace:
73+
class: "default"
74+
kwargs:
75+
generate_device: "cpu"
76+
prefill_device: "cpu"

0 commit comments

Comments
 (0)