Skip to content

🐛 [Bug] SymInt placeholder nodes not fully removed in remove_sym_nodes pass, causing issues with TensorRT lowering #3981

@lwlsaysnuaa

Description

@lwlsaysnuaa

Bug Description

When compiling a model using torch.compile(..., backend="tensorrt") with dynamic batch dimension marked via torch._dynamo.mark_dynamic, the remove_sym_nodes pass in Torch-TensorRT does not fully remove torch.SymInt placeholder nodes.

there is still a torch.SymInt placeholder node (%s0) left in the FX graph. This causes problems for downstream lowering / codegen, since these SymInt placeholders are expected to be eliminated and replaced with tensor-based size queries (e.g., via torch.ops.aten.sym_size).

I believe the pass should instead replace those SymInt placeholders with aten.sym_size calls on the corresponding Tensor placeholders (e.g., the first dimension of the input tensor), and then remove the SymInt placeholders entirely.

To Reproduce

Reproduction Code:

import torch
import torch.nn as nn
import torch_tensorrt
import logging
logging.basicConfig(level=logging.DEBUG)

class ExpandReshapeModel(nn.Module):
    def __init__(self, embed_dim: int):
        super().__init__()
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.embed_dim = embed_dim
        self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3)

    def forward(self, x: torch.Tensor):
        batch_size = x.shape[0]
        cls_token = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        x = self.qkv_proj(x)
        reshaped_qkv = x.reshape(
            batch_size,
            x.size(1),
            3,
            12,
            -1
        )
        return reshaped_qkv

model = ExpandReshapeModel(embed_dim=768).cuda().eval()
model = torch.compile(model, backend="tensorrt")

x = torch.randn(4, 196, 768).cuda()
torch._dynamo.mark_dynamic(x, index=0, min=2, max=32)

out = model(x)
print(out.shape)

FX Graph Before / After remove_sym_nodes
From the debug log of torch_tensorrt.dynamo.lowering.passes.remove_sym_nodes:

Current behavior (after the pass, but still with SymInt placeholder):

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_sym_nodes:Removed SymInt placeholders:
graph():
    %s0 : torch.SymInt [num_users=2] = placeholder[target=s0]
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %l_self_parameters_cls_token_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_parameters_cls_token_]
    %l_self_modules_qkv_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_qkv_proj_parameters_weight_]
    %l_self_modules_qkv_proj_parameters_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_qkv_proj_parameters_bias_]
    %clone_default_3 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_modules_qkv_proj_parameters_bias_,), kwargs = {})
    %clone_default_2 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_modules_qkv_proj_parameters_weight_,), kwargs = {})
    %clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_parameters_cls_token_,), kwargs = {})
    %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})
    %cls_token : [num_users=1] = call_method[target=expand](args = (%clone_default_1, %s0, -1, -1), kwargs = {})
    %x : [num_users=1] = call_function[target=torch.cat](args = ([%cls_token, %clone_default],), kwargs = {dim: 1})
    %x_1 : [num_users=1] = call_function[target=torch._C._nn.linear](args = (%x, %clone_default_2, %clone_default_3), kwargs = {})
    %reshaped_qkv : [num_users=1] = call_method[target=reshape](args = (%x_1, %s0, 197, 3, 12, -1), kwargs = {})
    return (reshaped_qkv,)

As shown, %s0 : torch.SymInt = placeholder[target=s0] is still present and used in expand and reshape.
Expected / Desired Transformation
The problematic torch.SymInt placeholder should be removed and replaced by a dynamic size extracted from the relevant tensor placeholder (here, the batch dimension of %l_x_), via aten.sym_size.

For example, a correct/desired transformed FX graph would look like:

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_sym_nodes:Removed SymInt placeholders:
graph():
    %l_x_ : torch.Tensor [num_users=2] = placeholder[target=L_x_]
    %sym_size : [num_users=2] = call_function[target=torch.ops.aten.sym_size](args = (%l_x_, 0), kwargs = {})
    %l_self_parameters_cls_token_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_parameters_cls_token_]
    %l_self_modules_qkv_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_qkv_proj_parameters_weight_]
    %l_self_modules_qkv_proj_parameters_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_qkv_proj_parameters_bias_]
    %clone_default_3 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_modules_qkv_proj_parameters_bias_,), kwargs = {})
    %clone_default_2 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_modules_qkv_proj_parameters_weight_,), kwargs = {})
    %clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_parameters_cls_token_,), kwargs = {})
    %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})
    %cls_token : [num_users=1] = call_method[target=expand](args = (%clone_default_1, %sym_size, -1, -1), kwargs = {})
    %x : [num_users=1] = call_function[target=torch.cat](args = ([%cls_token, %clone_default],), kwargs = {dim: 1})
    %x_1 : [num_users=1] = call_function[target=torch._C._nn.linear](args = (%x, %clone_default_2, %clone_default_3), kwargs = {})
    %reshaped_qkv : [num_users=1] = call_method[target=reshape](args = (%x_1, %sym_size, 197, 3, 12, -1), kwargs = {})
    return (reshaped_qkv,)

Here, %sym_size = torch.ops.aten.sym_size(%l_x_, 0) replaces the SymInt placeholder %s0, and %s0 is completely removed.

I am happy to provide additional logs, the full FX graph before/after, or try a patch if you can point to the relevant parts of the code.

Expected behavior

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_sym_nodes:Removed SymInt placeholders:
graph():
    %l_x_ : torch.Tensor [num_users=2] = placeholder[target=L_x_]
    %sym_size : [num_users=2] = call_function[target=torch.ops.aten.sym_size](args = (%l_x_, 0), kwargs = {})
    %l_self_parameters_cls_token_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_parameters_cls_token_]
    %l_self_modules_qkv_proj_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_qkv_proj_parameters_weight_]
    %l_self_modules_qkv_proj_parameters_bias_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_qkv_proj_parameters_bias_]
    %clone_default_3 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_modules_qkv_proj_parameters_bias_,), kwargs = {})
    %clone_default_2 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_modules_qkv_proj_parameters_weight_,), kwargs = {})
    %clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_self_parameters_cls_token_,), kwargs = {})
    %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})
    %cls_token : [num_users=1] = call_method[target=expand](args = (%clone_default_1, %sym_size, -1, -1), kwargs = {})
    %x : [num_users=1] = call_function[target=torch.cat](args = ([%cls_token, %clone_default],), kwargs = {dim: 1})
    %x_1 : [num_users=1] = call_function[target=torch._C._nn.linear](args = (%x, %clone_default_2, %clone_default_3), kwargs = {})
    %reshaped_qkv : [num_users=1] = call_method[target=reshape](args = (%x_1, %sym_size, 197, 3, 12, -1), kwargs = {})
    return (reshaped_qkv,)

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 2.6):
  • PyTorch Version (e.g. 2.6):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions