Skip to content

Commit 4038adc

Browse files
committed
[TRTLLM-6577][feat] Support nano_v2_vlm in pytorch backend
* Support MMMU eval. * Update codes to match with latest HF codes. Signed-off-by: Wanli Jiang <[email protected]>
1 parent f96e7ba commit 4038adc

File tree

10 files changed

+118
-137
lines changed

10 files changed

+118
-137
lines changed

cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_
4949
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_softcapping_tma_ws_sm90_cu_cubin[];
5050
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_alibi_tma_ws_sm90_cu_cubin[];
5151
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_alibi_tma_ws_sm90_cu_cubin[];
52-
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_256_output_bf16_tma_ws_sm90_cu_cubin[];
5352
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_tma_ws_sm90_cu_cubin[];
5453
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_tma_ws_sm90_cu_cubin[];
5554
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_tma_ws_sm90_cu_cubin[];
@@ -261,7 +260,7 @@ extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_192_tma_ws_sm90
261260
extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_256_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
262261
extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_256_softcapping_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
263262
extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
264-
extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
263+
extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
265264
extern void run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_32_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
266265
extern void run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_40_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
267266
extern void run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_48_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
@@ -282,7 +281,8 @@ extern void run_fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_104_alibi_tma_w
282281
extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_160_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
283282
extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_192_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
284283
extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_256_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
285-
extern void run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_80_sage_64_64_256_output_bf16_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
284+
extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
285+
extern void run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_softmax_output_bf16_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
286286
extern void run_fmha_v2_flash_attention_fp16_fp32_64_256_S_qkv_32_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
287287
extern void run_fmha_v2_flash_attention_fp16_fp32_64_256_S_qkv_40_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
288288
extern void run_fmha_v2_flash_attention_fp16_fp32_64_256_S_qkv_48_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
@@ -1354,7 +1354,6 @@ extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_tma_w
13541354
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_softcapping_tma_ws_sm90_cu_cubin_len;
13551355
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_alibi_tma_ws_sm90_cu_cubin_len;
13561356
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_alibi_tma_ws_sm90_cu_cubin_len;
1357-
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_256_output_bf16_tma_ws_sm90_cu_cubin_len;
13581357
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_tma_ws_sm90_cu_cubin_len;
13591358
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_tma_ws_sm90_cu_cubin_len;
13601359
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_tma_ws_sm90_cu_cubin_len;
@@ -1976,8 +1975,8 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
19761975
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 128, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_tma_ws_sm90_kernel", 196864, 384, 64, 2, 2, false, true, true, true, false, false, true, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_256_softcapping_tma_ws_sm90},
19771976
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90_kernel", 164096, 384, 64, 0, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90},
19781977
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_causal_tma_ws_sm90_kernel", 164096, 384, 64, 1, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_tma_ws_sm90},
1979-
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90_kernel", 164096, 384, 64, 0, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90},
1980-
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_causal_output_bf16_tma_ws_sm90_kernel", 164096, 384, 64, 1, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90},
1978+
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90_kernel", 164096, 384, 64, 0, 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90},
1979+
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_causal_softmax_tma_ws_sm90_kernel", 164096, 384, 64, 1, 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90},
19811980
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_256_S_qkv_32_causal_alibi_tma_ws_sm90_kernel", 82304, 384, 64, 1, 0, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_32_alibi_tma_ws_sm90},
19821981
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 256, 40, 40, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_256_S_qkv_40_causal_alibi_tma_ws_sm90_kernel", 164224, 384, 64, 1, 0, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_40_alibi_tma_ws_sm90},
19831982
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 256, 48, 48, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_256_S_qkv_48_causal_alibi_tma_ws_sm90_kernel", 164224, 384, 64, 1, 0, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_48_alibi_tma_ws_sm90},
@@ -2000,8 +1999,10 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
20001999
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 128, 160, 160, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_160_causal_alibi_tma_ws_sm90_kernel", 229632, 384, 64, 1, 2, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_160_alibi_tma_ws_sm90},
20012000
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 128, 192, 192, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_192_causal_alibi_tma_ws_sm90_kernel", 229632, 384, 64, 1, 2, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_192_alibi_tma_ws_sm90},
20022001
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 128, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_256_causal_alibi_tma_ws_sm90_kernel", 229632, 384, 64, 1, 2, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_paged_kv_256_alibi_tma_ws_sm90},
2003-
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 256, 80, 80, 64, 64, 256, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_256_S_qkv_80_sage_64_64_256_output_bf16_tma_ws_sm90_kernel", 196864, 384, 64, 0, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_256_S_qkv_80_sage_64_64_256_output_bf16_tma_ws_sm90},
2004-
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 256, 128, 128, 64, 64, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_256_output_bf16_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_256_output_bf16_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_256_output_bf16_tma_ws_sm90_kernel", 196864, 384, 64, 0, 0, false, true, true, true, false, false, false, false, nullptr},
2002+
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90_kernel", 164096, 384, 64, 0, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90},
2003+
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_causal_output_bf16_tma_ws_sm90_kernel", 164096, 384, 64, 1, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_output_bf16_tma_ws_sm90},
2004+
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_softmax_output_bf16_tma_ws_sm90_kernel", 164096, 384, 64, 0, 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_softmax_output_bf16_tma_ws_sm90},
2005+
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_causal_softmax_output_bf16_tma_ws_sm90_kernel", 164096, 384, 64, 1, 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_e4m3_64_128_S_q_k_v_192x128_softmax_output_bf16_tma_ws_sm90},
20052006
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_256_S_qkv_32_tma_ws_sm90_kernel", 73984, 384, 64, 0, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_256_S_qkv_32_tma_ws_sm90},
20062007
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_256_S_qkv_32_causal_tma_ws_sm90_kernel", 73984, 384, 64, 1, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_256_S_qkv_32_tma_ws_sm90},
20072008
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_256_S_qkv_32_sliding_or_chunked_causal_tma_ws_sm90_kernel", 73984, 384, 64, 2, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_fp16_fp32_64_256_S_qkv_32_tma_ws_sm90},

docs/source/legacy/reference/multimodal-feature-support-matrix.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
| LLaVA-NeXT | Yes | Yes | Yes | Yes |
99
| Llama 4 | Yes | Yes | No | No |
1010
| Mistral-Small-3.1 | Yes | Yes | No | No |
11-
| Nano-v2-VLM | Yes | Yes | No | No |
1211
| Phi-4-multimodal | Yes | Yes | No | No |
1312
| Qwen2-VL | Yes | Yes | Yes | Yes |
1413
| Qwen2.5-VL | Yes | Yes | Yes | Yes |

docs/source/models/supported-models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ Note: Support for other models may vary. Features marked "N/A" are not applicabl
5151
| LlavaNextForConditionalGeneration | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I |
5252
| Llama4ForConditionalGeneration | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I |
5353
| Mistral3ForConditionalGeneration | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | L + I |
54-
| NemotronH_Nano_VL_V2 | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I + V |
54+
| NemotronH_Nano_VL_V2 | Yes | Yes | Yes | Yes | Yes | No | Yes | No | L + I + V |
5555
| Phi4MMForCausalLM | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I + A |
5656
| Qwen2VLForConditionalGeneration | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I + V |
5757
| Qwen2_5_VLForConditionalGeneration | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I + V |

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ blake3
7272
soundfile
7373
triton==3.3.1; platform_machine == "x86_64"
7474
tiktoken
75-
timm
7675
blobfile
7776
openai-harmony==0.0.4
7877
nvidia-cutlass-dsl==4.1.0; python_version >= "3.12"

0 commit comments

Comments
 (0)