Skip to content

Commit b91e65a

Browse files
Correct order of overflowing_tokens for slow tokenizer (#13179)
* correct order of overflowing_tokens for slow tokenizer (issue fix #13148) * python 3.9 requires sentencepiece version 0.1.94 or above * slicing of ids fixed in truncated_sequence() * Update setup.py * Correct order of overflowing tokens for pair of sentences * code reformatted * Update tokenization_utils_base.py * reformatting file * test to check single_input added * missing function restored * test to check pair_input overflowing tokens order * test to check pair_input overflowing tokens order * test to check pair_input overflowing tokens order * added an error message for pair of seq and longest_first strategy * test for pair_input modified * variable name corrected * fixed a typo in error message * requested changes implemented * required test added * Corrected the message to match test message * added error message for Luke Tokenizer * lost test recovered * docstring for truncate_sequences and prepare_for_model updated * docstring for luke tokenizer updated * updated ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING * aligned text and fixed puncuatations * improved style and quality of code * fixed error_msg in truncate_sequences * replaced encode_plus method with regular call method * clean up * rephrased the docstring
1 parent c9184a2 commit b91e65a

File tree

3 files changed

+119
-65
lines changed

3 files changed

+119
-65
lines changed

src/transformers/models/luke/tokenization_luke.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@
8585
8686
`What are attention masks? <../glossary.html#attention-mask>`__
8787
return_overflowing_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
88-
Whether or not to return overflowing token sequences.
88+
Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch
89+
of pairs) is provided with :obj:`truncation_strategy = longest_first` or :obj:`True`, an error is
90+
raised instead of returning overflowing tokens.
8991
return_special_tokens_mask (:obj:`bool`, `optional`, defaults to :obj:`False`):
9092
Whether or not to return special tokens mask information.
9193
return_offsets_mapping (:obj:`bool`, `optional`, defaults to :obj:`False`):
@@ -1037,8 +1039,9 @@ def prepare_for_model(
10371039
Prepares a sequence of input id, entity id and entity span, or a pair of sequences of inputs ids, entity ids,
10381040
entity spans so that it can be used by the model. It adds special tokens, truncates sequences if overflowing
10391041
while taking into account the special tokens and manages a moving window (with user defined stride) for
1040-
overflowing tokens
1041-
1042+
overflowing tokens. Please Note, for `pair_ids` different than `None` and `truncation_strategy = longest_first`
1043+
or `True`, it is not possible to return overflowing tokens. Such a combination of arguments will raise an
1044+
error.
10421045
10431046
Args:
10441047
ids (:obj:`List[int]`):
@@ -1078,6 +1081,16 @@ def prepare_for_model(
10781081
"results in an undefined behavior. Please set add_special_tokens to True or "
10791082
"set return_token_type_ids to None."
10801083
)
1084+
if (
1085+
return_overflowing_tokens
1086+
and truncation_strategy == TruncationStrategy.LONGEST_FIRST
1087+
and pair_ids is not None
1088+
):
1089+
raise ValueError(
1090+
"Not possible to return overflowing tokens for pair of sequences with the "
1091+
"`longest_first`. Please select another truncation strategy than `longest_first`, "
1092+
"for instance `only_second` or `only_first`."
1093+
)
10811094

10821095
# Load from model defaults
10831096
if return_token_type_ids is None:

src/transformers/tokenization_utils_base.py

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,7 +1324,9 @@ def all_special_ids(self) -> List[int]:
13241324
13251325
`What are attention masks? <../glossary.html#attention-mask>`__
13261326
return_overflowing_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
1327-
Whether or not to return overflowing token sequences.
1327+
Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch
1328+
of pairs) is provided with :obj:`truncation_strategy = longest_first` or :obj:`True`, an error is
1329+
raised instead of returning overflowing tokens.
13281330
return_special_tokens_mask (:obj:`bool`, `optional`, defaults to :obj:`False`):
13291331
Whether or not to return special tokens mask information.
13301332
return_offsets_mapping (:obj:`bool`, `optional`, defaults to :obj:`False`):
@@ -2838,7 +2840,9 @@ def prepare_for_model(
28382840
"""
28392841
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
28402842
adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
2841-
manages a moving window (with user defined stride) for overflowing tokens
2843+
manages a moving window (with user defined stride) for overflowing tokens. Please Note, for `pair_ids`
2844+
different than `None` and `truncation_strategy = longest_first` or `True`, it is not possible to return
2845+
overflowing tokens. Such a combination of arguments will raise an error.
28422846
28432847
Args:
28442848
ids (:obj:`List[int]`):
@@ -2870,6 +2874,17 @@ def prepare_for_model(
28702874
"set return_token_type_ids to None."
28712875
)
28722876

2877+
if (
2878+
return_overflowing_tokens
2879+
and truncation_strategy == TruncationStrategy.LONGEST_FIRST
2880+
and pair_ids is not None
2881+
):
2882+
raise ValueError(
2883+
"Not possible to return overflowing tokens for pair of sequences with the "
2884+
"`longest_first`. Please select another truncation strategy than `longest_first`, "
2885+
"for instance `only_second` or `only_first`."
2886+
)
2887+
28732888
# Load from model defaults
28742889
if return_token_type_ids is None:
28752890
return_token_type_ids = "token_type_ids" in self.model_input_names
@@ -2977,7 +2992,8 @@ def truncate_sequences(
29772992
29782993
Returns:
29792994
:obj:`Tuple[List[int], List[int], List[int]]`: The truncated ``ids``, the truncated ``pair_ids`` and the
2980-
list of overflowing tokens.
2995+
list of overflowing tokens. Note: The `longest_first` strategy returns empty list of overflowing_tokens if
2996+
a pair of sequences (or a batch of pairs) is provided.
29812997
"""
29822998
if num_tokens_to_remove <= 0:
29832999
return ids, pair_ids, []
@@ -2986,34 +3002,36 @@ def truncate_sequences(
29863002
truncation_strategy = TruncationStrategy(truncation_strategy)
29873003

29883004
overflowing_tokens = []
2989-
if truncation_strategy == TruncationStrategy.LONGEST_FIRST:
2990-
for _ in range(num_tokens_to_remove):
2991-
if pair_ids is None or len(ids) > len(pair_ids):
2992-
if not overflowing_tokens:
2993-
window_len = min(len(ids), stride + 1)
2994-
else:
2995-
window_len = 1
2996-
overflowing_tokens.extend(ids[-window_len:])
2997-
ids = ids[:-1]
2998-
else:
2999-
if not overflowing_tokens:
3000-
window_len = min(len(pair_ids), stride + 1)
3001-
else:
3002-
window_len = 1
3003-
overflowing_tokens.extend(pair_ids[-window_len:])
3004-
pair_ids = pair_ids[:-1]
3005-
elif truncation_strategy == TruncationStrategy.ONLY_FIRST:
3005+
if truncation_strategy == TruncationStrategy.ONLY_FIRST or (
3006+
truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is None
3007+
):
30063008
if len(ids) > num_tokens_to_remove:
30073009
window_len = min(len(ids), stride + num_tokens_to_remove)
30083010
overflowing_tokens = ids[-window_len:]
30093011
ids = ids[:-num_tokens_to_remove]
30103012
else:
3011-
logger.error(
3012-
f"We need to remove {num_tokens_to_remove} to truncate the input"
3013+
error_msg = (
3014+
f"We need to remove {num_tokens_to_remove} to truncate the input "
30133015
f"but the first sequence has a length {len(ids)}. "
3014-
f"Please select another truncation strategy than {truncation_strategy}, "
3015-
f"for instance 'longest_first' or 'only_second'."
30163016
)
3017+
if truncation_strategy == TruncationStrategy.ONLY_FIRST:
3018+
error_msg = (
3019+
error_msg + "Please select another truncation strategy than "
3020+
f"{truncation_strategy}, for instance 'longest_first' or 'only_second'."
3021+
)
3022+
logger.error(error_msg)
3023+
elif truncation_strategy == TruncationStrategy.LONGEST_FIRST:
3024+
logger.warning(
3025+
f"Be aware, overflowing tokens are not returned for the setting you have chosen,"
3026+
f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' "
3027+
f"truncation strategy. So the returned list will always be empty even if some "
3028+
f"tokens have been removed."
3029+
)
3030+
for _ in range(num_tokens_to_remove):
3031+
if pair_ids is None or len(ids) > len(pair_ids):
3032+
ids = ids[:-1]
3033+
else:
3034+
pair_ids = pair_ids[:-1]
30173035
elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
30183036
if len(pair_ids) > num_tokens_to_remove:
30193037
window_len = min(len(pair_ids), stride + num_tokens_to_remove)

tests/test_tokenization_common.py

Lines changed: 61 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,7 @@ def test_maximum_encoding_length_single_input(self):
941941
self.assertEqual(truncated_sequence, sequence[:-2])
942942

943943
self.assertEqual(len(overflowing_tokens), 2 + stride)
944+
self.assertEqual(overflowing_tokens, sequence[-(2 + stride) :])
944945

945946
def test_maximum_encoding_length_pair_input(self):
946947
tokenizers = self.get_tokenizers(do_lower_case=False, model_max_length=100)
@@ -1053,18 +1054,18 @@ def test_maximum_encoding_length_pair_input(self):
10531054
overflow_first_sequence if len(seq0_tokens) > len(seq1_tokens) else overflow_second_sequence
10541055
)
10551056

1056-
information = tokenizer.encode_plus(
1057-
seq_0,
1058-
seq_1,
1059-
max_length=len(sequence) - 2,
1060-
add_special_tokens=False,
1061-
stride=stride,
1062-
truncation="longest_first",
1063-
return_overflowing_tokens=True,
1064-
# add_prefix_space=False,
1065-
)
10661057
# Overflowing tokens are handled quite differently in slow and fast tokenizers
10671058
if isinstance(tokenizer, PreTrainedTokenizerFast):
1059+
information = tokenizer(
1060+
seq_0,
1061+
seq_1,
1062+
max_length=len(sequence) - 2,
1063+
add_special_tokens=False,
1064+
stride=stride,
1065+
truncation="longest_first",
1066+
return_overflowing_tokens=True,
1067+
# add_prefix_space=False,
1068+
)
10681069
truncated_sequence = information["input_ids"][0]
10691070
overflowing_tokens = information["input_ids"][1]
10701071
self.assertEqual(len(information["input_ids"]), 2)
@@ -1075,28 +1076,39 @@ def test_maximum_encoding_length_pair_input(self):
10751076
self.assertEqual(len(overflowing_tokens), 2 + stride + len(smallest))
10761077
self.assertEqual(overflowing_tokens, overflow_longest_sequence)
10771078
else:
1078-
truncated_sequence = information["input_ids"]
1079-
overflowing_tokens = information["overflowing_tokens"]
1080-
1081-
self.assertEqual(len(truncated_sequence), len(sequence) - 2)
1082-
self.assertEqual(truncated_sequence, truncated_longest_sequence)
1079+
# No overflowing tokens when using 'longest' in python tokenizers
1080+
with self.assertRaises(ValueError) as context:
1081+
information = tokenizer(
1082+
seq_0,
1083+
seq_1,
1084+
max_length=len(sequence) - 2,
1085+
add_special_tokens=False,
1086+
stride=stride,
1087+
truncation="longest_first",
1088+
return_overflowing_tokens=True,
1089+
# add_prefix_space=False,
1090+
)
10831091

1084-
self.assertEqual(
1085-
len(overflowing_tokens), 2 + stride
1086-
) # No overflowing tokens when using 'longest' in python tokenizers
1092+
self.assertTrue(
1093+
context.exception.args[0].startswith(
1094+
"Not possible to return overflowing tokens for pair of sequences with the "
1095+
"`longest_first`. Please select another truncation strategy than `longest_first`, "
1096+
"for instance `only_second` or `only_first`."
1097+
)
1098+
)
10871099

1088-
information = tokenizer.encode_plus(
1089-
seq_0,
1090-
seq_1,
1091-
max_length=len(sequence) - 2,
1092-
add_special_tokens=False,
1093-
stride=stride,
1094-
truncation=True,
1095-
return_overflowing_tokens=True,
1096-
# add_prefix_space=False,
1097-
)
10981100
# Overflowing tokens are handled quite differently in slow and fast tokenizers
10991101
if isinstance(tokenizer, PreTrainedTokenizerFast):
1102+
information = tokenizer(
1103+
seq_0,
1104+
seq_1,
1105+
max_length=len(sequence) - 2,
1106+
add_special_tokens=False,
1107+
stride=stride,
1108+
truncation=True,
1109+
return_overflowing_tokens=True,
1110+
# add_prefix_space=False,
1111+
)
11001112
truncated_sequence = information["input_ids"][0]
11011113
overflowing_tokens = information["input_ids"][1]
11021114
self.assertEqual(len(information["input_ids"]), 2)
@@ -1107,17 +1119,28 @@ def test_maximum_encoding_length_pair_input(self):
11071119
self.assertEqual(len(overflowing_tokens), 2 + stride + len(smallest))
11081120
self.assertEqual(overflowing_tokens, overflow_longest_sequence)
11091121
else:
1110-
truncated_sequence = information["input_ids"]
1111-
overflowing_tokens = information["overflowing_tokens"]
1112-
1113-
self.assertEqual(len(truncated_sequence), len(sequence) - 2)
1114-
self.assertEqual(truncated_sequence, truncated_longest_sequence)
1122+
# No overflowing tokens when using 'longest' in python tokenizers
1123+
with self.assertRaises(ValueError) as context:
1124+
information = tokenizer(
1125+
seq_0,
1126+
seq_1,
1127+
max_length=len(sequence) - 2,
1128+
add_special_tokens=False,
1129+
stride=stride,
1130+
truncation=True,
1131+
return_overflowing_tokens=True,
1132+
# add_prefix_space=False,
1133+
)
11151134

1116-
self.assertEqual(
1117-
len(overflowing_tokens), 2 + stride
1118-
) # No overflowing tokens when using 'longest' in python tokenizers
1135+
self.assertTrue(
1136+
context.exception.args[0].startswith(
1137+
"Not possible to return overflowing tokens for pair of sequences with the "
1138+
"`longest_first`. Please select another truncation strategy than `longest_first`, "
1139+
"for instance `only_second` or `only_first`."
1140+
)
1141+
)
11191142

1120-
information_first_truncated = tokenizer.encode_plus(
1143+
information_first_truncated = tokenizer(
11211144
seq_0,
11221145
seq_1,
11231146
max_length=len(sequence) - 2,
@@ -1148,7 +1171,7 @@ def test_maximum_encoding_length_pair_input(self):
11481171
self.assertEqual(len(overflowing_tokens), 2 + stride)
11491172
self.assertEqual(overflowing_tokens, seq0_tokens[-(2 + stride) :])
11501173

1151-
information_second_truncated = tokenizer.encode_plus(
1174+
information_second_truncated = tokenizer(
11521175
seq_0,
11531176
seq_1,
11541177
max_length=len(sequence) - 2,

0 commit comments

Comments
 (0)