11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+ from typing import Optional
34
45import pytest
56import torch
@@ -34,15 +35,15 @@ class Relu3(ReLUSquaredActivation):
3435 [
3536 # Default values based on compile level
3637 # - All by default (no Inductor compilation)
37- ("" , 0 , False , [True ] * 4 , True ),
38- ("" , 1 , True , [True ] * 4 , True ),
39- ("" , 2 , False , [True ] * 4 , True ),
38+ (None , 0 , False , [True ] * 4 , True ),
39+ (None , 1 , True , [True ] * 4 , True ),
40+ (None , 2 , False , [True ] * 4 , True ),
4041 # - None by default (with Inductor)
41- ("" , 3 , True , [False ] * 4 , False ),
42- ("" , 4 , True , [False ] * 4 , False ),
42+ (None , 3 , True , [False ] * 4 , False ),
43+ (None , 4 , True , [False ] * 4 , False ),
4344 # - All by default (without Inductor)
44- ("" , 3 , False , [True ] * 4 , True ),
45- ("" , 4 , False , [True ] * 4 , True ),
45+ (None , 3 , False , [True ] * 4 , True ),
46+ (None , 4 , False , [True ] * 4 , True ),
4647 # Explicitly enabling/disabling
4748 #
4849 # Default: all
@@ -54,7 +55,7 @@ class Relu3(ReLUSquaredActivation):
5455 # All but SiluAndMul
5556 ("all,-silu_and_mul" , 2 , True , [1 , 0 , 1 , 1 ], True ),
5657 # All but ReLU3 (even if ReLU2 is on)
57- ("-relu3,relu2" , 3 , False , [1 , 1 , 1 , 0 ], True ),
58+ ("-relu3,+ relu2" , 3 , False , [1 , 1 , 1 , 0 ], True ),
5859 # RMSNorm and SiluAndMul
5960 ("none,-relu3,+rms_norm,+silu_and_mul" , 4 , False , [1 , 1 , 0 , 0 ], False ),
6061 # All but RMSNorm
@@ -67,12 +68,13 @@ class Relu3(ReLUSquaredActivation):
6768 # All but RMSNorm
6869 ("all,-rms_norm" , 4 , True , [0 , 1 , 1 , 1 ], True ),
6970 ])
70- def test_enabled_ops (env : str , torch_level : int , use_inductor : bool ,
71+ def test_enabled_ops (env : Optional [ str ] , torch_level : int , use_inductor : bool ,
7172 ops_enabled : list [int ], default_on : bool ):
73+ custom_ops = env .split (',' ) if env else []
7274 vllm_config = VllmConfig (
7375 compilation_config = CompilationConfig (use_inductor = bool (use_inductor ),
7476 level = torch_level ,
75- custom_ops = env . split ( "," ) ))
77+ custom_ops = custom_ops ))
7678 with set_current_vllm_config (vllm_config ):
7779 assert CustomOp .default_on () == default_on
7880
0 commit comments