1
1
import contextlib
2
2
import functools
3
- from typing import List , Optional , Tuple , Union
3
+ from typing import TYPE_CHECKING , List , Optional , Tuple , Union
4
4
5
5
import torch
6
+ import torch .library
6
7
7
8
import vllm .envs as envs
8
9
from vllm ._core_ext import ScalarType
25
26
import vllm ._moe_C # noqa: F401
26
27
supports_moe_ops = True
27
28
29
+ if TYPE_CHECKING :
30
+
31
+ def register_fake (fn ):
32
+ return lambda name : fn
33
+ else :
34
+ try :
35
+ from torch .library import register_fake
36
+ except ImportError :
37
+ from torch .library import impl_abstract as register_fake
38
+
28
39
29
40
def hint_on_error (fn ):
30
41
@@ -266,7 +277,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
266
277
267
278
if hasattr (torch .ops ._C , "gptq_gemm" ):
268
279
269
- @torch . library . register_fake ("_C::gptq_gemm" )
280
+ @register_fake ("_C::gptq_gemm" )
270
281
def _gptq_gemm_fake (a : torch .Tensor , b_q_weight : torch .Tensor ,
271
282
b_gptq_qzeros : torch .Tensor ,
272
283
b_gptq_scales : torch .Tensor , b_g_idx : torch .Tensor ,
@@ -301,15 +312,15 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
301
312
302
313
if hasattr (torch .ops ._C , "gptq_marlin_24_gemm" ):
303
314
304
- @torch . library . register_fake ("_C::gptq_marlin_24_gemm" )
315
+ @register_fake ("_C::gptq_marlin_24_gemm" )
305
316
def _gptq_marlin_24_gemm_fake (a : torch .Tensor , b_q_weight : torch .Tensor ,
306
317
b_meta : torch .Tensor , b_scales : torch .Tensor ,
307
318
workspace : torch .Tensor ,
308
319
b_q_type : ScalarType , size_m : int ,
309
320
size_n : int , size_k : int ) -> torch .Tensor :
310
321
return torch .empty ((size_m , size_n ), device = a .device , dtype = a .dtype )
311
322
312
- @torch . library . register_fake ("_C::gptq_marlin_gemm" )
323
+ @register_fake ("_C::gptq_marlin_gemm" )
313
324
def _gptq_marlin_gemm_fake (a : torch .Tensor ,
314
325
b_q_weight : torch .Tensor ,
315
326
b_scales : torch .Tensor ,
@@ -326,12 +337,12 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor,
326
337
use_fp32_reduce : bool = False ) -> torch .Tensor :
327
338
return torch .empty ((size_m , size_n ), device = a .device , dtype = a .dtype )
328
339
329
- @torch . library . register_fake ("_C::ggml_dequantize" )
340
+ @register_fake ("_C::ggml_dequantize" )
330
341
def _ggml_dequantize_fake (W : torch .Tensor , quant_type : int , m : int ,
331
342
n : int ) -> torch .Tensor :
332
343
return torch .empty ((m , n ), dtype = torch .float16 , device = W .device )
333
344
334
- @torch . library . register_fake ("_C::ggml_mul_mat_vec_a8" )
345
+ @register_fake ("_C::ggml_mul_mat_vec_a8" )
335
346
def _ggml_mul_mat_vec_a8_fake (
336
347
W : torch .Tensor ,
337
348
X : torch .Tensor ,
@@ -340,7 +351,7 @@ def _ggml_mul_mat_vec_a8_fake(
340
351
) -> torch .Tensor :
341
352
return torch .empty ((1 , row ), dtype = torch .float16 , device = W .device )
342
353
343
- @torch . library . register_fake ("_C::ggml_mul_mat_a8" )
354
+ @register_fake ("_C::ggml_mul_mat_a8" )
344
355
def _ggml_mul_mat_a8_fake (
345
356
W : torch .Tensor ,
346
357
X : torch .Tensor ,
@@ -350,7 +361,7 @@ def _ggml_mul_mat_a8_fake(
350
361
batch = X .size (0 )
351
362
return torch .empty ((batch , row ), dtype = torch .float16 , device = W .device )
352
363
353
- @torch . library . register_fake ("_C::marlin_qqq_gemm" )
364
+ @register_fake ("_C::marlin_qqq_gemm" )
354
365
def _marlin_qqq_gemm_fake (a : torch .Tensor , b_q_weight : torch .Tensor ,
355
366
s_tok : torch .Tensor , s_ch : torch .Tensor ,
356
367
s_group : torch .Tensor , workspace : torch .Tensor ,
@@ -360,7 +371,7 @@ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
360
371
dtype = torch .float16 ,
361
372
device = a .device )
362
373
363
- @torch . library . register_fake ("_C::marlin_gemm" )
374
+ @register_fake ("_C::marlin_gemm" )
364
375
def _marlin_gemm_fake (a : torch .Tensor , b_q_weight : torch .Tensor ,
365
376
b_scales : torch .Tensor , workspace : torch .Tensor ,
366
377
size_m : int , size_n : int ,
@@ -369,7 +380,7 @@ def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
369
380
dtype = torch .float16 ,
370
381
device = a .device )
371
382
372
- @torch . library . register_fake ("_C::awq_dequantize" )
383
+ @register_fake ("_C::awq_dequantize" )
373
384
def _awq_dequantize_fake (qweight : torch .Tensor , scales : torch .Tensor ,
374
385
zeros : torch .Tensor , split_k_iters : int , thx : int ,
375
386
thy : int ) -> torch .Tensor :
@@ -380,7 +391,7 @@ def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
380
391
dtype = scales .dtype ,
381
392
device = scales .device )
382
393
383
- @torch . library . register_fake ("_C::awq_gemm" )
394
+ @register_fake ("_C::awq_gemm" )
384
395
def _awq_gemm_fake (input : torch .Tensor , qweight : torch .Tensor ,
385
396
qzeros : torch .Tensor , scales : torch .Tensor ,
386
397
split_k_iters : int ) -> torch .Tensor :
@@ -389,7 +400,7 @@ def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
389
400
dtype = input .dtype ,
390
401
device = input .device ).sum (0 )
391
402
392
- @torch . library . register_fake ("_C::aqlm_gemm" )
403
+ @register_fake ("_C::aqlm_gemm" )
393
404
def _aqlm_gemm_fake (input : torch .Tensor , codes : torch .Tensor ,
394
405
codebooks : torch .Tensor , scales : torch .Tensor ,
395
406
codebook_partition_sizes : List [int ],
@@ -405,7 +416,7 @@ def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
405
416
output_sizes .append (- 1 )
406
417
return flat_output .reshape (tuple (output_sizes ))
407
418
408
- @torch . library . register_fake ("_C::aqlm_dequant" )
419
+ @register_fake ("_C::aqlm_dequant" )
409
420
def _aqlm_dequant_fake (
410
421
codes : torch .Tensor , codebooks : torch .Tensor ,
411
422
codebook_partition_sizes : List [int ]) -> torch .Tensor :
@@ -415,14 +426,14 @@ def _aqlm_dequant_fake(
415
426
dtype = codebooks .dtype ,
416
427
device = codebooks .device )
417
428
418
- @torch . library . register_fake ("_C::fp8_marlin_gemm" )
429
+ @register_fake ("_C::fp8_marlin_gemm" )
419
430
def _fp8_marlin_gemm_fake (a : torch .Tensor , b_q_weight : torch .Tensor ,
420
431
b_scales : torch .Tensor , workspace : torch .Tensor ,
421
432
num_bits : int , size_m : int , size_n : int ,
422
433
size_k : int ) -> torch .Tensor :
423
434
return torch .empty ((size_m , size_n ), dtype = a .dtype , device = a .device )
424
435
425
- @torch . library . register_fake ("_C::machete_gemm" )
436
+ @register_fake ("_C::machete_gemm" )
426
437
def machete_gemm_fake (
427
438
a : torch .Tensor ,
428
439
# Should be the tensor returned by machete_prepack_B
@@ -440,13 +451,13 @@ def machete_gemm_fake(
440
451
n = b_q .size (1 )
441
452
return torch .empty ((m , n ), device = a .device , dtype = a .dtype )
442
453
443
- @torch . library . register_fake ("_C::machete_prepack_B" )
454
+ @register_fake ("_C::machete_prepack_B" )
444
455
def machete_prepack_B_fake (b_q_weight : torch .Tensor ,
445
456
b_type : ScalarType ) -> torch .Tensor :
446
457
return torch .empty_like (b_q_weight ,
447
458
memory_format = torch .contiguous_format )
448
459
449
- @torch . library . register_fake ("_C::causal_conv1d_fwd" )
460
+ @register_fake ("_C::causal_conv1d_fwd" )
450
461
def causal_conv1d_fwd_fake (x : torch .Tensor , weight : torch .Tensor ,
451
462
bias_ : Optional [torch .Tensor ],
452
463
conv_states : Optional [torch .Tensor ],
@@ -456,15 +467,15 @@ def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
456
467
silu_activation : bool ) -> torch .Tensor :
457
468
return torch .empty_like (x )
458
469
459
- @torch . library . register_fake ("_C::causal_conv1d_update" )
470
+ @register_fake ("_C::causal_conv1d_update" )
460
471
def causal_conv1d_update_fake (
461
472
x : torch .Tensor , conv_state : torch .Tensor , weight : torch .Tensor ,
462
473
bias_ : Optional [torch .Tensor ], silu_activation : bool ,
463
474
cache_seqlens : Optional [torch .Tensor ],
464
475
conv_state_indices : Optional [torch .Tensor ]) -> torch .Tensor :
465
476
return torch .empty_like (x )
466
477
467
- @torch . library . register_fake ("_C::selective_scan_fwd" )
478
+ @register_fake ("_C::selective_scan_fwd" )
468
479
def selective_scan_fwd_fake (u : torch .Tensor , delta : torch .Tensor ,
469
480
A : torch .Tensor , B : torch .Tensor ,
470
481
C : torch .Tensor , D_ : Optional [torch .Tensor ],
@@ -639,7 +650,7 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
639
650
640
651
if hasattr (torch .ops ._C , "permute_cols" ):
641
652
642
- @torch . library . register_fake ("_C::permute_cols" )
653
+ @register_fake ("_C::permute_cols" )
643
654
def _permute_cols_fake (a : torch .Tensor ,
644
655
perm : torch .Tensor ) -> torch .Tensor :
645
656
return torch .empty_like (a )
@@ -837,7 +848,7 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
837
848
838
849
if supports_moe_ops and hasattr (torch .ops ._moe_C , "marlin_gemm_moe" ):
839
850
840
- @torch . library . register_fake ("_moe_C::marlin_gemm_moe" )
851
+ @register_fake ("_moe_C::marlin_gemm_moe" )
841
852
def marlin_gemm_moe_fake (a : torch .Tensor , b_q_weights : torch .Tensor ,
842
853
sorted_ids : torch .Tensor ,
843
854
topk_weights : torch .Tensor ,
0 commit comments