-
Notifications
You must be signed in to change notification settings - Fork 372
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
Attempting torch.compile (backend = torch_tensorrt) the google/paligemma2-3b-pt-224
model, I encountered below message for both torch_tensorrt 2.6 and 2.7 dev.
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node _to_copy [aten._to_copy.default] (Inputs: (arg0_1: (1, 3, 224, 224)@torch.float16) | Outputs: (_to_copy: (1, 3, 224, 224)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node arg1_1 (kind: arg1_1, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: arg1_1 [shape=[1152, 3, 14, 14], dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node arg1_1 [arg1_1] (Inputs: () | Outputs: (arg1_1: (1152, 3, 14, 14)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node arg2_1 (kind: arg2_1, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: arg2_1 [shape=[1152], dtype=DataType.HALF]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node arg2_1 [arg2_1] (Inputs: () | Outputs: (arg2_1: (1152,)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node convolution (kind: aten.convolution.default, args: ('_to_copy <Node>', 'arg1_1 <Node>', 'arg2_1 <Node>', ['14 <int>', '14 <int>'], ['0 <int>'], ['1 <int>', '1 <int>'], 'False <bool>', ['0 <int>'], '1 <int>'))
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.convolution.default: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.convolution.default
DEBUG:torch_tensorrt [TensorRT Conversion Context]:Kernel weights are not set yet. Kernel weights must be set using setInput(1, kernel_tensor) API call.
ERROR:torch_tensorrt [TensorRT Conversion Context]:IConvolutionLayer::setPaddingNd: Error Code 3: API Usage Error (Parameter check failed, condition: (padding.nbDims == 2 || padding.nbDims == 3) && allDimsGtEq(padding, 0) && allDimsLtEq(padding, kMAX_PADDING). )
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node convolution [aten.convolution.default] (Inputs: (_to_copy: (1, 3, 224, 224)@torch.float16, arg1_1: (1152, 3, 14, 14)@torch.float16, arg2_1: (1152,)@torch.float16, [14, 14], [0], [1, 1], False, [0], 1) | Outputs: (convolution: (1, 1152, 16, 16)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node reshape_default (kind: aten.reshape.default, args: ('convolution <Node>', ['1 <int>', '1152 <int>', '256 <int>']))
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.reshape.default: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.reshape.default
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node reshape_default [aten.reshape.default] (Inputs: (convolution: (1, 1152, 16, 16)@torch.float16, [1, 1152, 256]) | Outputs: (reshape_default: (1, 1152, 256)@torch.float16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node permute (kind: aten.permute.default, args: ('reshape_default <Node>', ['0 <int>', '2 <int>', '1 <int>']))
To Reproduce
Steps to reproduce the behavior:
import torch
import torch_tensorrt
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
from transformers.image_utils import load_image
DEVICE = "cuda:0"
model_id = "google/paligemma2-3b-pt-224"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image = load_image(url)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.float16).eval()
model.to(DEVICE).to(torch.float16)
# model.forward = model.forward.to(torch.float16).eval()
processor = PaliGemmaProcessor.from_pretrained(model_id)
prompt = ""
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.float16).to(DEVICE) # to(DEVICE) # .to(torch.float16).to(DEVICE)
input_len = model_inputs["input_ids"].shape[-1]
# model.config.token_healing = False
with torch.inference_mode():
pyt_generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
pyt_generation_out = pyt_generation[0][input_len:]
pyt_decoded = processor.decode(pyt_generation_out, skip_special_tokens=True)
print("=============================")
print("pyt_generation whole text:")
print(pyt_generation)
print("=============================")
print("=============================")
print("PyTorch generated text:")
print(pyt_decoded)
print("=============================")
with torch_tensorrt.logging.debug():
torch._dynamo.mark_dynamic(model_inputs["input_ids"], 1, min=2, max=1023)
model.forward = torch.compile(
model.forward,
backend="tensorrt",
dynamic=None,
options={
"enabled_precisions": {torch.float16},
"disable_tf32": True,
"min_block_size": 1,
# "use_explicit_typing": True,
# "use_fp32_acc": True,
"debug": True,
# "use_aot_joint_export":False,
},
)
with torch.inference_mode():
trt_generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
trt_generation_out = trt_generation[0][input_len:]
trt_decoded = processor.decode(trt_generation_out, skip_special_tokens=True)
print(trt_generation)
print("TensorRT generated text:")
print(trt_decoded)
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working