Skip to content

isin() received an invalid combination of arguments  #31040

@pseudotensor

Description

@pseudotensor

System Info

4.41.0
python 3.10

Who can help?

@ArthurZucker @gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

This change: 7130a22

Here:

https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py#L486-L493

is leading to alot of different software to fail with the below error:

 isin() received an invalid combination of arguments - got (test_elements=int, elements=Tensor, ), but expected one of:
 * (Tensor elements, Tensor test_elements, *, bool assume_unique, bool invert, Tensor out)
 * (Number element, Tensor test_elements, *, bool assume_unique, bool invert, Tensor out)
 * (Tensor elements, Number test_element, *, bool assume_unique, bool invert, Tensor out)

A work-around patch is:

--- /home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/transformers/generation/utils.py      2024-05-26 01:04:39.151177467 -0700
+++ new.py      2024-05-26 01:02:53.993095157 -0700
@@ -468,12 +468,14 @@
             raise ValueError(
                 "Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or use a different device."
             )
+        pad_token_tensor = torch.tensor([pad_token_id], device=inputs.device) if pad_token_id is not None else None
+        eos_token_tensor = torch.tensor([eos_token_id], device=inputs.device) if eos_token_id is not None else None
 
         is_pad_token_in_inputs = (pad_token_id is not None) and (
-            torch.isin(elements=inputs, test_elements=pad_token_id).any()
+            torch.isin(elements=inputs, test_elements=pad_token_tensor).any()
         )
         is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
-            torch.isin(elements=eos_token_id, test_elements=pad_token_id).any()
+            torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
         )
         can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
         attention_mask_from_padding = inputs.ne(pad_token_id).long()

E.g. Coqui XTT (no longer maintained) fails like this without the above patch.

What is going on?

Expected behavior

No failure. I expect a Number (as it says is allowed) to be converted properly to a tensor on same device so no failure.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions