Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions vllm/compilation/collective_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,12 @@ def __init__(self, config: VllmConfig):

self.dump_patterns(config, self.patterns)

def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
# only do replace for specific shapes
# This pass is applied on top of the sequence parallelism pass.
# It inherits the same applicability condition as `SequenceParallelismPass`.
# See `SequenceParallelismPass.is_applicable` for more details.
def is_applicable(self, shape: Optional[int]) -> bool:
if self.splitting_ops is None or self.splitting_ops == []:
return True
tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0

Expand Down
2 changes: 1 addition & 1 deletion vllm/compilation/inductor_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def hash_dict(dict_: dict[Any, Any]):
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()

def is_applicable_for_shape(self, shape: Optional[int]):
def is_applicable(self, shape: Optional[int]):
return True


Expand Down
2 changes: 1 addition & 1 deletion vllm/compilation/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __call__(self, graph: fx.Graph):

shape = get_pass_context().runtime_shape
for pass_ in self.passes:
if pass_.is_applicable_for_shape(shape):
if pass_.is_applicable(shape):
pass_(graph)
VllmInductorPass.dump_prefix += 1

Expand Down
17 changes: 16 additions & 1 deletion vllm/compilation/sequence_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,22 @@ def __init__(self, config: VllmConfig):
self.device).register(self.patterns)
self.dump_patterns(config, self.patterns)

def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
# When sequence parallelism is enabled, the residual tensor from RMSNorm
# needs to be split along the sequence dimension. However, this dimension
# is symbolic during piecewise compilation, and splitting symbolic shapes
# is not supported.
#
# This pass is therefore only applied when the sequence dimension is
# concrete:
# 1. In full-graph compilation mode (no splitting ops are used).
# For this case we always pad num_tokens to be a multiple of
# tensor_parallel_size, so there's no need to check shape % tp_size == 0.
# 2. For specific shape provided during compilation (e.g., from
# `compile_sizes`), which must be divisible by the tensor-parallel
# size.
def is_applicable(self, shape: Optional[int]) -> bool:
if self.splitting_ops is None or self.splitting_ops == []:
return True
tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0

Expand Down
1 change: 1 addition & 0 deletions vllm/compilation/vllm_inductor_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class VllmInductorPass(InductorPass):

def __init__(self, config: VllmConfig):
self.pass_config = config.compilation_config.pass_config
self.splitting_ops = config.compilation_config.splitting_ops
self.model_dtype = config.model_config.dtype if config.model_config \
else None
self.device = config.device_config.device if config.device_config \
Expand Down