Skip to content

🐛 [Bug] SDPA decomposition causing TorchTRT to be 2x slower than ONNX on SD3.5 #3682

@cehongwang

Description

@cehongwang

In SD3.5 medium, we noticed that TorchTRT is a lot slower than ONNX-TRT. To replicate that, I used a 1-layer dummy model and found TorchTRT's latency is 2x of ONNX-TRT's latency.

Analyzing the engine layer profiling, I found that the for the same fused attention project layer, torch-trt takes 6x more time than ONNX-TRT.

Torch-TRT: 
{ "name" : "[MATRIX_MULTIPLY]_[aten_ops_addmm_default]_[transformer_blocks_0_attn_add_v_proj/addmm_12_mm]+[MATRIX_MULTIPLY]_[aten_ops_addmm_default]_[transformer_blocks_0_attn_add_k_proj/addmm_11_mm]+[MATRIX_MULTIPLY]_[aten_ops_addmm_default]_[transformer_blocks_0_attn_add_q_proj/addmm_10_mm]_myl2_15", "timeMs" : 179.394, "averageMs" : 0.627251, "medianMs" : 0.628736, "percentage" : 5.04585 }

ONNX:
{ "name" : "/transformer_blocks_0/attn/add_v_proj/MatMul+/transformer_blocks_0/attn/add_k_proj/MatMul+/transformer_blocks_0/attn/add_q_proj/MatMul_myl3_16", "timeMs" : 69.5249, "averageMs" : 0.118846, "medianMs" : 0.118784, "percentage" : 1.93445 }

The only noticeable difference between those two kernels is Torch-TRT is using myl2-15 while ONNX-TRT is using myl3-16. Not sure whether this makes a difference or not.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions