Skip to content

🐛 [Bug] Encountered bug when using Torch-TensorRT  #1123

Closed
@HireezShanPeng

Description

@HireezShanPeng

Bug Description

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper__index_select)

RuntimeError                              Traceback (most recent call last)
Input In [11], in <cell line: 21>()
     18 new_level = torch_tensorrt.logging.Level.Error
     19 torch_tensorrt.logging.set_reportable_log_level(new_level)
---> 21 trt_model = torch_tensorrt.compile(traced_mlm_model,
     22     inputs= [torch_tensorrt.Input(shape=[batch_size, 47, 24], dtype=torch.int32),  # input_ids
     23              torch_tensorrt.Input(shape=[batch_size, 47, 24], dtype=torch.int32),  # token_type_ids
     24              torch_tensorrt.Input(shape=[batch_size, 47, 24], dtype=torch.int32)], # attention_mask
     25     enabled_precisions= {torch.float32}, # Run with 32-bit precision
     26     workspace_size=1000000000,
     27     truncate_long_and_double=True
     28 )

File /opt/conda/lib/python3.8/site-packages/torch_tensorrt/_compile.py:115, in compile(module, ir, inputs, enabled_precisions, **kwargs)
    110         logging.log(
    111             logging.Level.Info,
    112             "Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript"
    113         )
    114         ts_mod = torch.jit.script(module)
--> 115     return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
    116 elif target_ir == _IRType.fx:
    117     raise RuntimeError("fx is currently not supported")

File /opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py:113, in compile(module, inputs, device, disable_tf32, sparse_weights, enabled_precisions, refit, debug, capability, num_min_timing_iters, num_avg_timing_iters, workspace_size, calibrator, truncate_long_and_double, require_full_compilation, min_block_size, torch_executed_ops, torch_executed_modules)
     87     raise ValueError(
     88         "require_full_compilation is enabled however the list of modules and ops to run in torch is not empty. Found: torch_executed_ops: "
     89         + torch_executed_ops + ", torch_executed_modules: " + torch_executed_modules)
     91 spec = {
     92     "inputs": inputs,
     93     "device": device,
   (...)
    110     }
    111 }
--> 113 compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
    114 compiled_module = torch.jit._recursive.wrap_cpp_module(compiled_cpp_mod)
    115 return compiled_module
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py(2148): embedding
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/sparse.py(158): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1117): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1129): _call_impl
/opt/conda/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py(235): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1117): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1129): _call_impl
/opt/conda/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py(1010): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1117): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1129): _call_impl
/opt/pytorch/torch_tensorrt/notebooks/segment_model/segment.py(481): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1117): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1129): _call_impl
/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py(958): trace_module
/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py(741): trace
/tmp/ipykernel_2946/2700711854.py(16): <cell line: 16>
/opt/conda/lib/python3.8/site-packages/IPython/core/interactiveshell.py(3397): run_code
/opt/conda/lib/python3.8/site-packages/IPython/core/interactiveshell.py(3337): run_ast_nodes
/opt/conda/lib/python3.8/site-packages/IPython/core/interactiveshell.py(3134): run_cell_async
/opt/conda/lib/python3.8/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
/opt/conda/lib/python3.8/site-packages/IPython/core/interactiveshell.py(2935): _run_cell
/opt/conda/lib/python3.8/site-packages/IPython/core/interactiveshell.py(2880): run_cell
/opt/conda/lib/python3.8/site-packages/ipykernel/zmqshell.py(528): run_cell
/opt/conda/lib/python3.8/site-packages/ipykernel/ipkernel.py(383): do_execute
/opt/conda/lib/python3.8/site-packages/ipykernel/kernelbase.py(724): execute_request
/opt/conda/lib/python3.8/site-packages/ipykernel/kernelbase.py(400): dispatch_shell
/opt/conda/lib/python3.8/site-packages/ipykernel/kernelbase.py(493): process_one
/opt/conda/lib/python3.8/site-packages/ipykernel/kernelbase.py(504): dispatch_queue
/opt/conda/lib/python3.8/asyncio/events.py(81): _run
/opt/conda/lib/python3.8/asyncio/base_events.py(1859): _run_once
/opt/conda/lib/python3.8/asyncio/base_events.py(570): run_forever
/opt/conda/lib/python3.8/site-packages/tornado/platform/asyncio.py(199): start
/opt/conda/lib/python3.8/site-packages/ipykernel/kernelapp.py(712): start
/opt/conda/lib/python3.8/site-packages/traitlets/config/application.py(846): launch_instance
/opt/conda/lib/python3.8/site-packages/ipykernel_launcher.py(17): <module>
/opt/conda/lib/python3.8/runpy.py(87): _run_code
/opt/conda/lib/python3.8/runpy.py(194): _run_module_as_main
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper__index_select)

Code

when I follow the steps in this tutuial to process my model, I encountered the bug above. I did not move any data or model to GPU, I wonder why cuda is found is compiling.
Here is my code:

import torch
import timeit
import numpy as np
import torch_tensorrt
import torch.backends.cudnn as cudnn

from transformers import BertConfig, BertTokenizerFast
from segment_model.segment import SegmentationFontEndBertModel

batch_size = 4
batched_input_ids = torch.tensor([[[1]*24 for _ in range(47)]]*batch_size)
batched_attention_masks = torch.tensor([[[0]*24 for _ in range(47)]]*batch_size)
batched_token_type_ids = torch.tensor([[[1]*24 for _ in range(47)]]*batch_size)
# print(batched_input_ids)
mlm_model_ts = SegmentationFontEndBertModel.from_pretrained("20220601_1860", torchscript=True)
traced_mlm_model = torch.jit.trace(mlm_model_ts, [batched_input_ids, batched_attention_masks, batched_token_type_ids], strict=False)

new_level = torch_tensorrt.logging.Level.Error
torch_tensorrt.logging.set_reportable_log_level(new_level)

trt_model = torch_tensorrt.compile(traced_mlm_model,
    inputs= [torch_tensorrt.Input(shape=[batch_size, 47, 24], dtype=torch.int32),  # input_ids
             torch_tensorrt.Input(shape=[batch_size, 47, 24], dtype=torch.int32),  # token_type_ids
             torch_tensorrt.Input(shape=[batch_size, 47, 24], dtype=torch.int32)], # attention_mask
    enabled_precisions= {torch.float32}, # Run with 32-bit precision
    workspace_size=1000000000,
    truncate_long_and_double=True
)

I run this code in PyTorch Docker Container which contains PyTorch and Torch-TensorRT. nvcr.io/nvidia/pytorch:22.05-py3,

here is my model structure:

SegmentationFontEndBertModel(
  original_name=SegmentationFontEndBertModel
  (bert): BertModel(
    original_name=BertModel
    (embeddings): BertEmbeddings(
      original_name=BertEmbeddings
      (word_embeddings): Embedding(original_name=Embedding)
      (position_embeddings): Embedding(original_name=Embedding)
      (token_type_embeddings): Embedding(original_name=Embedding)
      (LayerNorm): LayerNorm(original_name=LayerNorm)
      (dropout): Dropout(original_name=Dropout)
    )
    (encoder): BertEncoder(
      original_name=BertEncoder
      (layer): ModuleList(
        original_name=ModuleList
        (0): BertLayer(
          original_name=BertLayer
          (attention): BertAttention(
            original_name=BertAttention
            (self): BertSelfAttention(
              original_name=BertSelfAttention
              (query): Linear(original_name=Linear)
              (key): Linear(original_name=Linear)
              (value): Linear(original_name=Linear)
              (dropout): Dropout(original_name=Dropout)
            )
            (output): BertSelfOutput(
              original_name=BertSelfOutput
              (dense): Linear(original_name=Linear)
              (LayerNorm): LayerNorm(original_name=LayerNorm)
              (dropout): Dropout(original_name=Dropout)
            )
          )
          (intermediate): BertIntermediate(
            original_name=BertIntermediate
            (dense): Linear(original_name=Linear)
            (intermediate_act_fn): GELUActivation(original_name=GELUActivation)
          )
          (output): BertOutput(
            original_name=BertOutput
            (dense): Linear(original_name=Linear)
            (LayerNorm): LayerNorm(original_name=LayerNorm)
            (dropout): Dropout(original_name=Dropout)
          )
        )
        (1): BertLayer(
          original_name=BertLayer
          (attention): BertAttention(
            original_name=BertAttention
            (self): BertSelfAttention(
              original_name=BertSelfAttention
              (query): Linear(original_name=Linear)
              (key): Linear(original_name=Linear)
              (value): Linear(original_name=Linear)
              (dropout): Dropout(original_name=Dropout)
            )
            (output): BertSelfOutput(
              original_name=BertSelfOutput
              (dense): Linear(original_name=Linear)
              (LayerNorm): LayerNorm(original_name=LayerNorm)
              (dropout): Dropout(original_name=Dropout)
            )
          )
          (intermediate): BertIntermediate(
            original_name=BertIntermediate
            (dense): Linear(original_name=Linear)
            (intermediate_act_fn): GELUActivation(original_name=GELUActivation)
          )
          (output): BertOutput(
            original_name=BertOutput
            (dense): Linear(original_name=Linear)
            (LayerNorm): LayerNorm(original_name=LayerNorm)
            (dropout): Dropout(original_name=Dropout)
          )
        )
        (2): BertLayer(
          original_name=BertLayer
          (attention): BertAttention(
            original_name=BertAttention
            (self): BertSelfAttention(
              original_name=BertSelfAttention
              (query): Linear(original_name=Linear)
              (key): Linear(original_name=Linear)
              (value): Linear(original_name=Linear)
              (dropout): Dropout(original_name=Dropout)
            )
            (output): BertSelfOutput(
              original_name=BertSelfOutput
              (dense): Linear(original_name=Linear)
              (LayerNorm): LayerNorm(original_name=LayerNorm)
              (dropout): Dropout(original_name=Dropout)
            )
          )
          (intermediate): BertIntermediate(
            original_name=BertIntermediate
            (dense): Linear(original_name=Linear)
            (intermediate_act_fn): GELUActivation(original_name=GELUActivation)
          )
          (output): BertOutput(
            original_name=BertOutput
            (dense): Linear(original_name=Linear)
            (LayerNorm): LayerNorm(original_name=LayerNorm)
            (dropout): Dropout(original_name=Dropout)
          )
        )
        (3): BertLayer(
          original_name=BertLayer
          (attention): BertAttention(
            original_name=BertAttention
            (self): BertSelfAttention(
              original_name=BertSelfAttention
              (query): Linear(original_name=Linear)
              (key): Linear(original_name=Linear)
              (value): Linear(original_name=Linear)
              (dropout): Dropout(original_name=Dropout)
            )
            (output): BertSelfOutput(
              original_name=BertSelfOutput
              (dense): Linear(original_name=Linear)
              (LayerNorm): LayerNorm(original_name=LayerNorm)
              (dropout): Dropout(original_name=Dropout)
            )
          )
          (intermediate): BertIntermediate(
            original_name=BertIntermediate
            (dense): Linear(original_name=Linear)
            (intermediate_act_fn): GELUActivation(original_name=GELUActivation)
          )
          (output): BertOutput(
            original_name=BertOutput
            (dense): Linear(original_name=Linear)
            (LayerNorm): LayerNorm(original_name=LayerNorm)
            (dropout): Dropout(original_name=Dropout)
          )
        )
      )
    )
  )
)
trt_model = torch_tensorrt.compile(traced_mlm_model,
    inputs= [torch_tensorrt.Input(shape=[batch_size, 47, 24], dtype=torch.int32),  # input_ids
             torch_tensorrt.Input(shape=[batch_size, 47, 24], dtype=torch.int32),  # token_type_ids
             torch_tensorrt.Input(shape=[batch_size, 47, 24], dtype=torch.int32)], # attention_mask
    enabled_precisions= {torch.float32}, # Run with 32-bit precision
    workspace_size=1000000000,
    truncate_long_and_double=True
)

Steps to reproduce the behavior:

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0):
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions