Skip to content

Commit 23739cb

Browse files
committed
fix: Transfer calibration data to gpu when it is not a batch
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent a032c3a commit 23739cb

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

py/trtorch/ptq.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def get_batch(self, names):
3434
# Treat the first element as input and others as targets.
3535
if isinstance(batch, list):
3636
batch = batch[0].to(self.device)
37+
else:
38+
batch = batch.to(self.device)
3739
return [batch.data_ptr()]
3840

3941

0 commit comments

Comments
 (0)