Skip to content

Commit 492ae0e

Browse files
craymichaelfacebook-github-bot
authored andcommitted
LLM offsets logic consolidate w/ checks and test case fix (#1422)
Summary: Pull Request resolved: #1422 Consolidate offsets logic with extra checks to one function. May be used to later group data in gradient LLM attribution. Test case fixed as a result of checks. Reviewed By: cyrjano Differential Revision: D65010820 fbshipit-source-id: a88cde9decf1c850dcd16dc2c5aacf5c4e8cd4f2
1 parent 638b920 commit 492ae0e

File tree

2 files changed

+47
-16
lines changed

2 files changed

+47
-16
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,42 @@ def _clean_up_pretty_token(token: str) -> str:
226226
return token.replace("\n", "\\n").strip()
227227

228228

229-
def _convert_ids_to_pretty_tokens(ids: Tensor, tokenizer: TokenizerLike) -> List[str]:
229+
def _encode_with_offsets(
230+
txt: str,
231+
tokenizer: TokenizerLike,
232+
add_special_tokens: bool = True,
233+
**kwargs: Any,
234+
) -> Tuple[List[int], List[Tuple[int, int]]]:
235+
enc = tokenizer(
236+
txt,
237+
return_offsets_mapping=True,
238+
add_special_tokens=add_special_tokens,
239+
**kwargs,
240+
)
241+
input_ids = cast(List[int], enc["input_ids"])
242+
offset_mapping = cast(List[Tuple[int, int]], enc["offset_mapping"])
243+
assert len(input_ids) == len(offset_mapping), (
244+
f"{len(input_ids)} != {len(offset_mapping)}: {txt} -> "
245+
f"{input_ids}, {offset_mapping}"
246+
)
247+
# For the case where offsets are not set properly (the end and start are
248+
# equal for all tokens - fall back on the start of the next span in the
249+
# offset mapping)
250+
offset_mapping_corrected = []
251+
for i, (start, end) in enumerate(offset_mapping):
252+
if start == end:
253+
if (i + 1) < len(offset_mapping):
254+
end = offset_mapping[i + 1][0]
255+
else:
256+
end = len(txt)
257+
offset_mapping_corrected.append((start, end))
258+
return input_ids, offset_mapping_corrected
259+
260+
261+
def _convert_ids_to_pretty_tokens(
262+
ids: Tensor,
263+
tokenizer: TokenizerLike,
264+
) -> List[str]:
230265
"""
231266
Convert ids to tokens without ugly unicode characters (e.g., Ġ). See:
232267
https://github.com/huggingface/transformers/issues/4786 and
@@ -241,32 +276,26 @@ def _convert_ids_to_pretty_tokens(ids: Tensor, tokenizer: TokenizerLike) -> List
241276
> used spaces in its process
242277
"""
243278
txt = tokenizer.decode(ids)
279+
input_ids: Optional[List[int]] = None
244280
# Don't add special tokens (they're either already there, or we don't want them)
245-
enc = tokenizer(txt, return_offsets_mapping=True, add_special_tokens=False)
246-
input_ids = cast(List[int], enc["input_ids"])
247-
offset_mapping = cast(List[Tuple[int, int]], enc["offset_mapping"])
281+
input_ids, offset_mapping = _encode_with_offsets(
282+
txt, tokenizer, add_special_tokens=False
283+
)
248284

249285
pretty_tokens = []
250286
end_prev = -1
251287
idx = 0
252-
for i, (input_id, offset) in enumerate(zip(input_ids, offset_mapping)):
288+
for i, offset in enumerate(offset_mapping):
253289
start, end = offset
254-
if start == end:
255-
# For the case where offsets are not set properly (the end and start are
256-
# equal for all tokens - fall back on the start of the next span in the
257-
# offset mapping)
258-
if (i + 1) < len(input_ids):
259-
end = offset_mapping[i + 1][0]
260-
else:
261-
end = len(txt)
262-
if input_id != ids[idx]:
290+
if input_ids[i] != ids[idx]:
263291
# When the re-encoded string doesn't match the original encoding we skip
264292
# this token and hope for the best, falling back on a naive method. This
265293
# can happen when a tokenizer might add a token that corresponds to
266294
# a space only when add_special_tokens=False.
267295
warnings.warn(
268-
f"(i={i}) input_id {input_id} != ids[idx] {ids[idx]} (corresponding "
269-
f"to text: {repr(txt[start:end])}). Skipping this token.",
296+
f"(i={i}, idx={idx}) input_ids[i] {input_ids[i]} != ids[idx] "
297+
f"{ids[idx]} (corresponding to text: {repr(txt[start:end])}). "
298+
"Skipping this token.",
270299
stacklevel=2,
271300
)
272301
continue

tests/attr/test_llm_attr.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ def __call__(
127127

128128
if return_offsets_mapping:
129129
offset_mapping = []
130+
if add_special_tokens:
131+
offset_mapping.append((0, 0))
130132
idx = 0
131133
for token in text.split(" "):
132134
offset_mapping.append((idx - (0 if idx == 0 else 1), idx + len(token)))

0 commit comments

Comments
 (0)