@@ -80,22 +80,22 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
8080 } else {
8181 // Target device is current device
8282 target_device += std::to_string (curr_device.id );
83+ }
84+
85+ // For each input, ensure its current device is the desired target device
86+ for (size_t i = 0 ; i < inputs.size (); i++) {
87+ at::Tensor* in = &inputs[i];
88+ std::string current_tensor_device = in->device ().str ();
8389
84- // For each input, ensure its current device is the desired target device
85- for (size_t i = 0 ; i < inputs.size (); i++) {
86- at::Tensor* in = &inputs[i];
87- std::string current_tensor_device = in->device ().str ();
88-
89- // If current device string does not match target device, display warning and move tensor accordingly
90- if (current_tensor_device != target_device) {
91- LOG_WARNING (
92- " Input " << i << " of engine " << compiled_engine->name << " was found to be on " << current_tensor_device
93- << " but should be on " << target_device << " . This tensor is being moved by the runtime but "
94- << " for performance considerations, ensure your inputs are all on GPU "
95- << " and open an issue here (https://github.com/pytorch/TensorRT/issues) if this "
96- << " warning persists." );
97- *in = in->to (torch::Device (target_device));
98- }
90+ // If current device string does not match target device, display warning and move tensor accordingly
91+ if (current_tensor_device != target_device) {
92+ LOG_WARNING (
93+ " Input " << i << " of engine " << compiled_engine->name << " was found to be on " << current_tensor_device
94+ << " but should be on " << target_device << " . This tensor is being moved by the runtime but "
95+ << " for performance considerations, ensure your inputs are all on GPU "
96+ << " and open an issue here (https://github.com/pytorch/TensorRT/issues) if this "
97+ << " warning persists." );
98+ *in = in->to (torch::Device (target_device));
9999 }
100100 }
101101
0 commit comments