Skip to content

Commit 6b1336f

Browse files
bnellnmAlvant
authored andcommitted
[Bugfix] Try to handle older versions of pytorch (vllm-project#9086)
Signed-off-by: Alvant <[email protected]>
1 parent 7034bd6 commit 6b1336f

File tree

3 files changed

+41
-21
lines changed

3 files changed

+41
-21
lines changed

tests/kernels/test_awq.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import os
22

3+
import pytest
34
import torch
45

56
from tests.kernels.utils import opcheck
67
from vllm import _custom_ops as ops # noqa: F401
78

89

10+
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_dequantize"),
11+
reason="AWQ is not supported on this GPU type.")
912
def test_awq_dequantize_opcheck():
1013
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
1114
qweight = torch.randint(-2000000000,
@@ -21,6 +24,8 @@ def test_awq_dequantize_opcheck():
2124
(qweight, scales, zeros, split_k_iters, thx, thy))
2225

2326

27+
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_gemm"),
28+
reason="AWQ is not supported on this GPU type.")
2429
def test_awq_gemm_opcheck():
2530
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
2631
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)

tests/kernels/test_awq_marlin.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
99
torch_moe_single)
10+
from vllm import _custom_ops as ops
1011
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
1112
fused_marlin_moe, single_marlin_moe)
1213
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
@@ -21,6 +22,9 @@
2122
@pytest.mark.parametrize("e", [8, 64])
2223
@pytest.mark.parametrize("topk", [2, 6])
2324
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
25+
@pytest.mark.skipif(not (ops.supports_moe_ops
26+
and hasattr(torch.ops._moe_C, "marlin_gemm_moe")),
27+
reason="Marlin is not supported on this GPU type.")
2428
def test_fused_marlin_moe_awq(
2529
m: int,
2630
n: int,

vllm/_custom_ops.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import contextlib
22
import functools
3-
from typing import List, Optional, Tuple, Union
3+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
44

55
import torch
6+
import torch.library
67

78
import vllm.envs as envs
89
from vllm._core_ext import ScalarType
@@ -25,6 +26,16 @@
2526
import vllm._moe_C # noqa: F401
2627
supports_moe_ops = True
2728

29+
if TYPE_CHECKING:
30+
31+
def register_fake(fn):
32+
return lambda name: fn
33+
else:
34+
try:
35+
from torch.library import register_fake
36+
except ImportError:
37+
from torch.library import impl_abstract as register_fake
38+
2839

2940
def hint_on_error(fn):
3041

@@ -266,7 +277,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
266277

267278
if hasattr(torch.ops._C, "gptq_gemm"):
268279

269-
@torch.library.register_fake("_C::gptq_gemm")
280+
@register_fake("_C::gptq_gemm")
270281
def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
271282
b_gptq_qzeros: torch.Tensor,
272283
b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
@@ -301,15 +312,15 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
301312

302313
if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
303314

304-
@torch.library.register_fake("_C::gptq_marlin_24_gemm")
315+
@register_fake("_C::gptq_marlin_24_gemm")
305316
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
306317
b_meta: torch.Tensor, b_scales: torch.Tensor,
307318
workspace: torch.Tensor,
308319
b_q_type: ScalarType, size_m: int,
309320
size_n: int, size_k: int) -> torch.Tensor:
310321
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
311322

312-
@torch.library.register_fake("_C::gptq_marlin_gemm")
323+
@register_fake("_C::gptq_marlin_gemm")
313324
def _gptq_marlin_gemm_fake(a: torch.Tensor,
314325
b_q_weight: torch.Tensor,
315326
b_scales: torch.Tensor,
@@ -326,12 +337,12 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor,
326337
use_fp32_reduce: bool = False) -> torch.Tensor:
327338
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
328339

329-
@torch.library.register_fake("_C::ggml_dequantize")
340+
@register_fake("_C::ggml_dequantize")
330341
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int,
331342
n: int) -> torch.Tensor:
332343
return torch.empty((m, n), dtype=torch.float16, device=W.device)
333344

334-
@torch.library.register_fake("_C::ggml_mul_mat_vec_a8")
345+
@register_fake("_C::ggml_mul_mat_vec_a8")
335346
def _ggml_mul_mat_vec_a8_fake(
336347
W: torch.Tensor,
337348
X: torch.Tensor,
@@ -340,7 +351,7 @@ def _ggml_mul_mat_vec_a8_fake(
340351
) -> torch.Tensor:
341352
return torch.empty((1, row), dtype=torch.float16, device=W.device)
342353

343-
@torch.library.register_fake("_C::ggml_mul_mat_a8")
354+
@register_fake("_C::ggml_mul_mat_a8")
344355
def _ggml_mul_mat_a8_fake(
345356
W: torch.Tensor,
346357
X: torch.Tensor,
@@ -350,7 +361,7 @@ def _ggml_mul_mat_a8_fake(
350361
batch = X.size(0)
351362
return torch.empty((batch, row), dtype=torch.float16, device=W.device)
352363

353-
@torch.library.register_fake("_C::marlin_qqq_gemm")
364+
@register_fake("_C::marlin_qqq_gemm")
354365
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
355366
s_tok: torch.Tensor, s_ch: torch.Tensor,
356367
s_group: torch.Tensor, workspace: torch.Tensor,
@@ -360,7 +371,7 @@ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
360371
dtype=torch.float16,
361372
device=a.device)
362373

363-
@torch.library.register_fake("_C::marlin_gemm")
374+
@register_fake("_C::marlin_gemm")
364375
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
365376
b_scales: torch.Tensor, workspace: torch.Tensor,
366377
size_m: int, size_n: int,
@@ -369,7 +380,7 @@ def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
369380
dtype=torch.float16,
370381
device=a.device)
371382

372-
@torch.library.register_fake("_C::awq_dequantize")
383+
@register_fake("_C::awq_dequantize")
373384
def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
374385
zeros: torch.Tensor, split_k_iters: int, thx: int,
375386
thy: int) -> torch.Tensor:
@@ -380,7 +391,7 @@ def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
380391
dtype=scales.dtype,
381392
device=scales.device)
382393

383-
@torch.library.register_fake("_C::awq_gemm")
394+
@register_fake("_C::awq_gemm")
384395
def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
385396
qzeros: torch.Tensor, scales: torch.Tensor,
386397
split_k_iters: int) -> torch.Tensor:
@@ -389,7 +400,7 @@ def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
389400
dtype=input.dtype,
390401
device=input.device).sum(0)
391402

392-
@torch.library.register_fake("_C::aqlm_gemm")
403+
@register_fake("_C::aqlm_gemm")
393404
def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
394405
codebooks: torch.Tensor, scales: torch.Tensor,
395406
codebook_partition_sizes: List[int],
@@ -405,7 +416,7 @@ def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
405416
output_sizes.append(-1)
406417
return flat_output.reshape(tuple(output_sizes))
407418

408-
@torch.library.register_fake("_C::aqlm_dequant")
419+
@register_fake("_C::aqlm_dequant")
409420
def _aqlm_dequant_fake(
410421
codes: torch.Tensor, codebooks: torch.Tensor,
411422
codebook_partition_sizes: List[int]) -> torch.Tensor:
@@ -415,14 +426,14 @@ def _aqlm_dequant_fake(
415426
dtype=codebooks.dtype,
416427
device=codebooks.device)
417428

418-
@torch.library.register_fake("_C::fp8_marlin_gemm")
429+
@register_fake("_C::fp8_marlin_gemm")
419430
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
420431
b_scales: torch.Tensor, workspace: torch.Tensor,
421432
num_bits: int, size_m: int, size_n: int,
422433
size_k: int) -> torch.Tensor:
423434
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
424435

425-
@torch.library.register_fake("_C::machete_gemm")
436+
@register_fake("_C::machete_gemm")
426437
def machete_gemm_fake(
427438
a: torch.Tensor,
428439
# Should be the tensor returned by machete_prepack_B
@@ -440,13 +451,13 @@ def machete_gemm_fake(
440451
n = b_q.size(1)
441452
return torch.empty((m, n), device=a.device, dtype=a.dtype)
442453

443-
@torch.library.register_fake("_C::machete_prepack_B")
454+
@register_fake("_C::machete_prepack_B")
444455
def machete_prepack_B_fake(b_q_weight: torch.Tensor,
445456
b_type: ScalarType) -> torch.Tensor:
446457
return torch.empty_like(b_q_weight,
447458
memory_format=torch.contiguous_format)
448459

449-
@torch.library.register_fake("_C::causal_conv1d_fwd")
460+
@register_fake("_C::causal_conv1d_fwd")
450461
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
451462
bias_: Optional[torch.Tensor],
452463
conv_states: Optional[torch.Tensor],
@@ -456,15 +467,15 @@ def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
456467
silu_activation: bool) -> torch.Tensor:
457468
return torch.empty_like(x)
458469

459-
@torch.library.register_fake("_C::causal_conv1d_update")
470+
@register_fake("_C::causal_conv1d_update")
460471
def causal_conv1d_update_fake(
461472
x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor,
462473
bias_: Optional[torch.Tensor], silu_activation: bool,
463474
cache_seqlens: Optional[torch.Tensor],
464475
conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor:
465476
return torch.empty_like(x)
466477

467-
@torch.library.register_fake("_C::selective_scan_fwd")
478+
@register_fake("_C::selective_scan_fwd")
468479
def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor,
469480
A: torch.Tensor, B: torch.Tensor,
470481
C: torch.Tensor, D_: Optional[torch.Tensor],
@@ -639,7 +650,7 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
639650

640651
if hasattr(torch.ops._C, "permute_cols"):
641652

642-
@torch.library.register_fake("_C::permute_cols")
653+
@register_fake("_C::permute_cols")
643654
def _permute_cols_fake(a: torch.Tensor,
644655
perm: torch.Tensor) -> torch.Tensor:
645656
return torch.empty_like(a)
@@ -837,7 +848,7 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
837848

838849
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
839850

840-
@torch.library.register_fake("_moe_C::marlin_gemm_moe")
851+
@register_fake("_moe_C::marlin_gemm_moe")
841852
def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
842853
sorted_ids: torch.Tensor,
843854
topk_weights: torch.Tensor,

0 commit comments

Comments
 (0)