Skip to content

Commit d1e994b

Browse files
authored
Update TEQ train dataloader (#1554)
1 parent 941fed3 commit d1e994b

File tree

1 file changed

+3
-1
lines changed
  • neural_compressor/adaptor/torch_utils

1 file changed

+3
-1
lines changed

neural_compressor/adaptor/torch_utils/teq.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,9 @@ def train(
256256

257257
while global_steps <= train_steps:
258258
for inputs in dataloader:
259-
if isinstance(inputs, dict):
259+
if isinstance(inputs, torch.Tensor):
260+
input_id = inputs
261+
elif isinstance(inputs, dict):
260262
input_id = inputs["input_ids"]
261263
else:
262264
input_id = inputs[0]

0 commit comments

Comments
 (0)