From ddddf33338924070fa51fa59072697ceae020501 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 19 Jan 2025 21:41:00 +0800 Subject: [PATCH 1/2] fix sym_tensor_indices Signed-off-by: youkaichao --- vllm/compilation/backends.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index d7f4dcb7a20f..29ebb45226df 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -624,9 +624,12 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: ] # index of tensors that have symbolic shapes (batch size) + # for weights and static buffers, they will have concrete shapes, + # and therefore x.numel() will be an integer. self.sym_tensor_indices = [ i for i, x in enumerate(fake_args) if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) + and not isinstance(x.numel(), int) ] # compiler managed cudagraph input buffers From d5c42f2b20a6454f55e8fcb0da8942d5deccbb88 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 20 Jan 2025 10:25:58 +0800 Subject: [PATCH 2/2] use torch check Signed-off-by: youkaichao --- vllm/compilation/backends.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 29ebb45226df..955c25f30051 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -624,12 +624,13 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: ] # index of tensors that have symbolic shapes (batch size) - # for weights and static buffers, they will have concrete shapes, - # and therefore x.numel() will be an integer. + # for weights and static buffers, they will have concrete shapes. + # symbolic shape only happens for input tensors. + from torch.fx.experimental.symbolic_shapes import is_symbolic self.sym_tensor_indices = [ i for i, x in enumerate(fake_args) - if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) - and not isinstance(x.numel(), int) + if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \ + any(is_symbolic(d) for d in x.size()) ] # compiler managed cudagraph input buffers