-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Closed
Description
System Info
4.41.0
python 3.10
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
This change: 7130a22
Here:
is leading to alot of different software to fail with the below error:
isin() received an invalid combination of arguments - got (test_elements=int, elements=Tensor, ), but expected one of:
* (Tensor elements, Tensor test_elements, *, bool assume_unique, bool invert, Tensor out)
* (Number element, Tensor test_elements, *, bool assume_unique, bool invert, Tensor out)
* (Tensor elements, Number test_element, *, bool assume_unique, bool invert, Tensor out)
A work-around patch is:
--- /home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/transformers/generation/utils.py 2024-05-26 01:04:39.151177467 -0700
+++ new.py 2024-05-26 01:02:53.993095157 -0700
@@ -468,12 +468,14 @@
raise ValueError(
"Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or use a different device."
)
+ pad_token_tensor = torch.tensor([pad_token_id], device=inputs.device) if pad_token_id is not None else None
+ eos_token_tensor = torch.tensor([eos_token_id], device=inputs.device) if eos_token_id is not None else None
is_pad_token_in_inputs = (pad_token_id is not None) and (
- torch.isin(elements=inputs, test_elements=pad_token_id).any()
+ torch.isin(elements=inputs, test_elements=pad_token_tensor).any()
)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
- torch.isin(elements=eos_token_id, test_elements=pad_token_id).any()
+ torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
)
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
attention_mask_from_padding = inputs.ne(pad_token_id).long()
E.g. Coqui XTT (no longer maintained) fails like this without the above patch.
What is going on?
Expected behavior
No failure. I expect a Number (as it says is allowed) to be converted properly to a tensor on same device so no failure.
Metadata
Metadata
Assignees
Labels
No labels