Skip to content

Commit b92762a

Browse files
committed
[TRTLLM-6577][feat] Support nano_v2_vlm in pytorch backend
* update notes that nano v2 vlm cannot support kvcache reuse. * update codes according to reviewers' comments. Signed-off-by: Wanli Jiang <[email protected]>
1 parent 57377b5 commit b92762a

File tree

2 files changed

+28
-26
lines changed

2 files changed

+28
-26
lines changed

docs/source/legacy/reference/multimodal-feature-support-matrix.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
| LLaVA-NeXT | Yes | Yes | Yes | Yes |
99
| Llama 4 | Yes | Yes | No | No |
1010
| Mistral-Small-3.1 | Yes | Yes | No | No |
11-
| Nano-v2-VLM | Yes | Yes | Yes | No |
11+
| Nano-v2-VLM | Yes | Yes | No | No |
1212
| Phi-4-multimodal | Yes | Yes | No | No |
1313
| Qwen2-VL | Yes | Yes | Yes | Yes |
1414
| Qwen2.5-VL | Yes | Yes | Yes | Yes |

tensorrt_llm/_torch/models/modeling_nanov2vlm.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ def _is_disagg() -> bool:
3333

3434
# TODO: update the reference config path once Nano v2 VLM is released.
3535
IMAGE_TOKEN_ID = 131072
36+
IMG_CONTEXT_TOKEN = "<image>"
37+
IMG_START_TOKEN = "<img>"
38+
IMG_END_TOKEN = "</img>"
3639

3740

3841
class SquaredReLU(nn.Module):
@@ -41,8 +44,7 @@ def forward(self, x):
4144
return torch.pow(torch.nn.functional.relu(x), 2)
4245

4346

44-
class NanoV2VLVisionEncoder(transformers.PreTrainedModel,
45-
transformers.generation.GenerationMixin):
47+
class NanoV2VLVisionEncoder(transformers.PreTrainedModel):
4648

4749
def __init__(self,
4850
model_config: ModelConfig[transformers.PretrainedConfig]):
@@ -61,20 +63,21 @@ def __init__(self,
6163
self.llm_hidden_size = config.llm_config.hidden_size
6264
self.mlp1 = nn.Sequential(
6365
nn.RMSNorm(self.vit_hidden_size * int(1 / self.downsample_ratio)**2,
64-
eps=config.llm_config.rms_norm_eps),
66+
eps=config.llm_config.rms_norm_eps,
67+
dtype=config.torch_dtype),
6568
nn.Linear(self.vit_hidden_size * int(1 / self.downsample_ratio)**2,
6669
self.vision_projection_hidden_size,
67-
bias=False), SquaredReLU(),
70+
bias=False,
71+
dtype=config.torch_dtype), SquaredReLU(),
6872
nn.Linear(self.vision_projection_hidden_size,
6973
self.llm_hidden_size,
70-
bias=False))
71-
self.mlp1 = self.mlp1.to(config.torch_dtype)
74+
bias=False,
75+
dtype=config.torch_dtype))
7276

7377
# Construct the vision encoder.
7478
vision_model_config = copy.deepcopy(model_config)
7579
vision_model_config.pretrained_config = vision_model_config.pretrained_config.vision_config
7680
self.vision_model = RADIOVisionModel(vision_model_config)
77-
self.vision_model.to(config.torch_dtype)
7881

7982
def load_weights(self, weights):
8083
# Load mlp1 weights.
@@ -111,7 +114,6 @@ def pixel_shuffle(self, x, scale_factor=0.5):
111114

112115
def extract_feature(self, pixel_values):
113116
vit_embeds = self.vision_model(pixel_values)
114-
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
115117
# Down-sampling and projection.
116118
h = w = int(vit_embeds.shape[1]**0.5)
117119
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
@@ -131,11 +133,11 @@ def forward(self, multimodal_params: List[MultimodalParams]):
131133
],
132134
dim=0)
133135
# -> [num_patches, channel, height, width]
134-
batched_num_patches = torch.cat([
136+
patch_list = [
135137
multimodal_param.multimodal_data["num_patches"]
136138
for multimodal_param in multimodal_params
137-
],
138-
dim=0).tolist()
139+
]
140+
batched_num_patches = torch.cat(patch_list, dim=0).tolist()
139141
# -> list of[num_patches1, num_patches2, ...]
140142
batched_image_embeds = self.extract_feature(batched_pixel_values)
141143
# -> [num_patches, num_image_token, hidden_size]
@@ -176,9 +178,10 @@ def __init__(self,
176178

177179
self.processor = transformers.AutoImageProcessor.from_pretrained(
178180
model_path, trust_remote_code=True, use_fast=self.use_fast)
179-
self.img_context_token = "<image>"
180-
self.img_start_token = "<img>"
181-
self.img_end_token = "</img>"
181+
182+
self.img_context_token = IMG_CONTEXT_TOKEN
183+
self.img_start_token = IMG_START_TOKEN
184+
self.img_end_token = IMG_END_TOKEN
182185
self.dtype = model_config.torch_dtype
183186

184187
def get_vocab_size(self):
@@ -194,7 +197,7 @@ def get_num_tokens_per_image(
194197
**kwargs,
195198
):
196199

197-
def get_internvl_target_ratios(
200+
def _get_internvl_target_ratios(
198201
min_num: int,
199202
max_num: int,
200203
) -> list[tuple[int, int]]:
@@ -205,8 +208,8 @@ def get_internvl_target_ratios(
205208
if min_num <= i * j <= max_num}
206209
return sorted(target_ratios, key=lambda x: x[0] * x[1])
207210

208-
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width,
209-
height, image_size):
211+
def _find_closest_aspect_ratio(aspect_ratio, target_ratios, width,
212+
height, image_size):
210213
best_factor = float('-inf')
211214
best_ratio = (1, 1)
212215
area = width * height
@@ -221,7 +224,7 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width,
221224
best_ratio = ratio
222225
return best_ratio
223226

224-
def calculate_targets(
227+
def _calculate_targets(
225228
orig_width: int,
226229
orig_height: int,
227230
target_ratios: list[tuple[int, int]],
@@ -230,7 +233,7 @@ def calculate_targets(
230233
aspect_ratio = orig_width / orig_height
231234

232235
# find the closest aspect ratio to the target
233-
target_aspect_ratio = find_closest_aspect_ratio(
236+
target_aspect_ratio = _find_closest_aspect_ratio(
234237
aspect_ratio,
235238
target_ratios,
236239
width=orig_width,
@@ -243,10 +246,10 @@ def calculate_targets(
243246

244247
image_height = image.height
245248
image_width = image.width
246-
target_ratios = get_internvl_target_ratios(1,
247-
self.processor.max_num_tiles)
248-
blocks = calculate_targets(image_width, image_height, target_ratios,
249-
self.image_size)
249+
target_ratios = _get_internvl_target_ratios(
250+
1, self.processor.max_num_tiles)
251+
blocks = _calculate_targets(image_width, image_height, target_ratios,
252+
self.image_size)
250253
if self.processor.use_thumbnail and blocks != 1:
251254
blocks += 1
252255
num_image_tokens = self.num_image_token * blocks
@@ -309,7 +312,7 @@ def __call__(
309312
model_type="NemotronH_Nano_VL_V2",
310313
placeholder_metadata=MultimodalPlaceholderMetadata(
311314
placeholder_map={
312-
"image": "<image>",
315+
"image": IMG_CONTEXT_TOKEN,
313316
},
314317
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
315318
placeholders_separator="",
@@ -332,7 +335,6 @@ def __init__(self, model_config: ModelConfig):
332335

333336
if not _is_disagg():
334337
self.vision_encoder = NanoV2VLVisionEncoder(model_config).eval()
335-
self.vision_encoder.to(config.torch_dtype)
336338

337339
llm_model_config = copy.deepcopy(model_config)
338340
llm_model_config.pretrained_config = llm_model_config.pretrained_config.llm_config

0 commit comments

Comments
 (0)