Skip to content

Commit 0092ea1

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Consolidate LLM attr logic
Summary: Add base class BaseLLMAttribution to consolidate repeat logic between perturbation/gradient-based LLM attr classes Differential Revision: D65008854
1 parent 2f82f65 commit 0092ea1

File tree

2 files changed

+120
-136
lines changed

2 files changed

+120
-136
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 117 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
import warnings
44

5+
from abc import ABC
6+
57
from copy import copy
68

79
from textwrap import shorten
810

9-
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
11+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union
1012

1113
import matplotlib.colors as mcolors
1214

@@ -319,7 +321,104 @@ def _convert_ids_to_pretty_tokens_fallback(
319321
return pretty_tokens
320322

321323

322-
class LLMAttribution(Attribution):
324+
class BaseLLMAttribution(Attribution, ABC):
325+
"""Base class for LLM Attribution methods"""
326+
327+
SUPPORTED_INPUTS: Tuple[Type[InterpretableInput], ...]
328+
SUPPORTED_METHODS: Tuple[Type[Attribution], ...]
329+
330+
model: nn.Module
331+
tokenizer: TokenizerLike
332+
device: torch.device
333+
334+
def __init__(
335+
self,
336+
attr_method: Attribution,
337+
tokenizer: TokenizerLike,
338+
) -> None:
339+
assert isinstance(
340+
attr_method, self.SUPPORTED_METHODS
341+
), f"{self.__class__.__name__} does not support {type(attr_method)}"
342+
343+
super().__init__(attr_method.forward_func)
344+
345+
# alias, we really need a model and don't support wrapper functions
346+
# coz we need call model.forward, model.generate, etc.
347+
self.model: nn.Module = cast(nn.Module, self.forward_func)
348+
349+
self.tokenizer: TokenizerLike = tokenizer
350+
self.device: torch.device = (
351+
cast(torch.device, self.model.device)
352+
if hasattr(self.model, "device")
353+
else next(self.model.parameters()).device
354+
)
355+
356+
def _get_target_tokens(
357+
self,
358+
inp: InterpretableInput,
359+
target: Union[str, torch.Tensor, None] = None,
360+
skip_tokens: Union[List[int], List[str], None] = None,
361+
gen_args: Optional[Dict[str, Any]] = None,
362+
) -> Tensor:
363+
assert isinstance(
364+
inp, self.SUPPORTED_INPUTS
365+
), f"LLMAttribution does not support input type {type(inp)}"
366+
367+
if target is None:
368+
# generate when None
369+
assert hasattr(self.model, "generate") and callable(self.model.generate), (
370+
"The model does not have recognizable generate function."
371+
"Target must be given for attribution"
372+
)
373+
374+
if not gen_args:
375+
gen_args = DEFAULT_GEN_ARGS
376+
377+
model_inp = self._format_model_input(inp.to_model_input())
378+
output_tokens = self.model.generate(model_inp, **gen_args)
379+
target_tokens = output_tokens[0][model_inp.size(1) :]
380+
else:
381+
assert gen_args is None, "gen_args must be None when target is given"
382+
# Encode skip tokens
383+
if skip_tokens:
384+
if isinstance(skip_tokens[0], str):
385+
skip_tokens = cast(List[str], skip_tokens)
386+
skip_tokens = self.tokenizer.convert_tokens_to_ids(skip_tokens)
387+
else:
388+
skip_tokens = []
389+
skip_tokens = cast(List[int], skip_tokens)
390+
391+
if isinstance(target, str):
392+
encoded = self.tokenizer.encode(target)
393+
target_tokens = torch.tensor(
394+
[token for token in encoded if token not in skip_tokens]
395+
)
396+
elif isinstance(target, torch.Tensor):
397+
target_tokens = target[
398+
~torch.isin(target, torch.tensor(skip_tokens, device=target.device))
399+
]
400+
else:
401+
raise TypeError(
402+
"target must either be str or Tensor, but the type of target is "
403+
"{}".format(type(target))
404+
)
405+
return target_tokens
406+
407+
def _format_model_input(self, model_input: Union[str, Tensor]) -> Tensor:
408+
"""
409+
Convert str to tokenized tensor
410+
to make LLMAttribution work with model inputs of both
411+
raw text and text token tensors
412+
"""
413+
# return tensor(1, n_tokens)
414+
if isinstance(model_input, str):
415+
return self.tokenizer.encode(model_input, return_tensors="pt").to(
416+
self.device
417+
)
418+
return model_input.to(self.device)
419+
420+
421+
class LLMAttribution(BaseLLMAttribution):
323422
"""
324423
Attribution class for large language models. It wraps a perturbation-based
325424
attribution algorthm to produce commonly interested attribution
@@ -365,11 +464,7 @@ class created with the llm model that follows huggingface style
365464
Default: "log_prob"
366465
"""
367466

368-
assert isinstance(
369-
attr_method, self.SUPPORTED_METHODS
370-
), f"LLMAttribution does not support {type(attr_method)}"
371-
372-
super().__init__(attr_method.forward_func)
467+
super().__init__(attr_method, tokenizer)
373468

374469
# shallow copy is enough to avoid modifying original instance
375470
self.attr_method: PerturbationAttribution = copy(attr_method)
@@ -379,17 +474,6 @@ class created with the llm model that follows huggingface style
379474

380475
self.attr_method.forward_func = self._forward_func
381476

382-
# alias, we really need a model and don't support wrapper functions
383-
# coz we need call model.forward, model.generate, etc.
384-
self.model: nn.Module = cast(nn.Module, self.forward_func)
385-
386-
self.tokenizer: TokenizerLike = tokenizer
387-
self.device: torch.device = (
388-
cast(torch.device, self.model.device)
389-
if hasattr(self.model, "device")
390-
else next(self.model.parameters()).device
391-
)
392-
393477
assert attr_target in (
394478
"log_prob",
395479
"prob",
@@ -488,19 +572,6 @@ def _forward_func(
488572

489573
return target_probs if self.attr_target != "log_prob" else target_log_probs
490574

491-
def _format_model_input(self, model_input: Union[str, Tensor]) -> Tensor:
492-
"""
493-
Convert str to tokenized tensor
494-
to make LLMAttribution work with model inputs of both
495-
raw text and text token tensors
496-
"""
497-
# return tensor(1, n_tokens)
498-
if isinstance(model_input, str):
499-
return self.tokenizer.encode(model_input, return_tensors="pt").to(
500-
self.device
501-
)
502-
return model_input.to(self.device)
503-
504575
def attribute(
505576
self,
506577
inp: InterpretableInput,
@@ -527,7 +598,7 @@ def attribute(
527598
of integers of the token ids.
528599
Default: None
529600
num_trials (int, optional): number of trials to run. Return is the average
530-
attribibutions over all the trials.
601+
attributions over all the trials.
531602
Defaults: 1.
532603
gen_args (dict, optional): arguments for generating the target. Only used if
533604
target is not given. When None, the default arguments are used,
@@ -542,49 +613,12 @@ def attribute(
542613
attr (LLMAttributionResult): Attribution result. token_attr will be None
543614
if attr method is Lime or KernelShap.
544615
"""
545-
546-
assert isinstance(
547-
inp, self.SUPPORTED_INPUTS
548-
), f"LLMAttribution does not support input type {type(inp)}"
549-
550-
if target is None:
551-
# generate when None
552-
assert hasattr(self.model, "generate") and callable(self.model.generate), (
553-
"The model does not have recognizable generate function."
554-
"Target must be given for attribution"
555-
)
556-
557-
if not gen_args:
558-
gen_args = DEFAULT_GEN_ARGS
559-
560-
model_inp = self._format_model_input(inp.to_model_input())
561-
output_tokens = self.model.generate(model_inp, **gen_args)
562-
target_tokens = output_tokens[0][model_inp.size(1) :]
563-
else:
564-
assert gen_args is None, "gen_args must be None when target is given"
565-
# Encode skip tokens
566-
if skip_tokens:
567-
if isinstance(skip_tokens[0], str):
568-
skip_tokens = cast(List[str], skip_tokens)
569-
skip_tokens = self.tokenizer.convert_tokens_to_ids(skip_tokens)
570-
else:
571-
skip_tokens = []
572-
skip_tokens = cast(List[int], skip_tokens)
573-
574-
if isinstance(target, str):
575-
encoded = self.tokenizer.encode(target)
576-
target_tokens = torch.tensor(
577-
[token for token in encoded if token not in skip_tokens]
578-
)
579-
elif isinstance(target, torch.Tensor):
580-
target_tokens = target[
581-
~torch.isin(target, torch.tensor(skip_tokens, device=target.device))
582-
]
583-
else:
584-
raise TypeError(
585-
"target must either be str or Tensor, but the type of target is "
586-
"{}".format(type(target))
587-
)
616+
target_tokens = self._get_target_tokens(
617+
inp,
618+
target,
619+
skip_tokens=skip_tokens,
620+
gen_args=gen_args,
621+
)
588622

589623
attr = torch.zeros(
590624
[
@@ -638,7 +672,7 @@ def attribute_future(self) -> Callable[[], LLMAttributionResult]:
638672
)
639673

640674

641-
class LLMGradientAttribution(Attribution):
675+
class LLMGradientAttribution(BaseLLMAttribution):
642676
"""
643677
Attribution class for large language models. It wraps a gradient-based
644678
attribution algorthm to produce commonly interested attribution
@@ -670,27 +704,12 @@ class created with the llm model that follows huggingface style
670704
interface convention
671705
tokenizer (Tokenizer): tokenizer of the llm model used in the attr_method
672706
"""
673-
assert isinstance(
674-
attr_method, self.SUPPORTED_METHODS
675-
), f"LLMGradientAttribution does not support {type(attr_method)}"
676-
677-
super().__init__(attr_method.forward_func)
678-
679-
# alias, we really need a model and don't support wrapper functions
680-
# coz we need call model.forward, model.generate, etc.
681-
self.model: nn.Module = cast(nn.Module, self.forward_func)
707+
super().__init__(attr_method, tokenizer)
682708

683709
# shallow copy is enough to avoid modifying original instance
684710
self.attr_method: GradientAttribution = copy(attr_method)
685711
self.attr_method.forward_func = GradientForwardFunc(self)
686712

687-
self.tokenizer: TokenizerLike = tokenizer
688-
self.device: torch.device = (
689-
cast(torch.device, self.model.device)
690-
if hasattr(self.model, "device")
691-
else next(self.model.parameters()).device
692-
)
693-
694713
def _format_model_input(self, model_input: Union[Tensor, str]) -> Tensor:
695714
"""
696715
Convert str to tokenized tensor
@@ -734,50 +753,12 @@ def attribute(
734753
735754
attr (LLMAttributionResult): attribution result
736755
"""
737-
738-
assert isinstance(
739-
inp, self.SUPPORTED_INPUTS
740-
), f"LLMGradAttribution does not support input type {type(inp)}"
741-
742-
if target is None:
743-
# generate when None
744-
assert hasattr(self.model, "generate") and callable(self.model.generate), (
745-
"The model does not have recognizable generate function."
746-
"Target must be given for attribution"
747-
)
748-
749-
if not gen_args:
750-
gen_args = DEFAULT_GEN_ARGS
751-
752-
with torch.no_grad():
753-
model_inp = self._format_model_input(inp.to_model_input())
754-
output_tokens = self.model.generate(model_inp, **gen_args)
755-
target_tokens = output_tokens[0][model_inp.size(1) :]
756-
else:
757-
assert gen_args is None, "gen_args must be None when target is given"
758-
# Encode skip tokens
759-
if skip_tokens:
760-
if isinstance(skip_tokens[0], str):
761-
skip_tokens = cast(List[str], skip_tokens)
762-
skip_tokens = self.tokenizer.convert_tokens_to_ids(skip_tokens)
763-
else:
764-
skip_tokens = []
765-
skip_tokens = cast(List[int], skip_tokens)
766-
767-
if isinstance(target, str):
768-
encoded = self.tokenizer.encode(target)
769-
target_tokens = torch.tensor(
770-
[token for token in encoded if token not in skip_tokens]
771-
)
772-
elif isinstance(target, torch.Tensor):
773-
target_tokens = target[
774-
~torch.isin(target, torch.tensor(skip_tokens, device=target.device))
775-
]
776-
else:
777-
raise TypeError(
778-
"target must either be str or Tensor, but the type of target is "
779-
"{}".format(type(target))
780-
)
756+
target_tokens = self._get_target_tokens(
757+
inp,
758+
target,
759+
skip_tokens=skip_tokens,
760+
gen_args=gen_args,
761+
)
781762

782763
attr_inp = inp.to_tensor().to(self.device)
783764

captum/attr/_utils/interpretable_input.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ class to create other types of customized input.
104104
is only allowed in certain attribution classes like LLMAttribution for now.)
105105
"""
106106

107+
n_itp_features: int
108+
values: List[str]
109+
107110
@abstractmethod
108111
def to_tensor(self) -> Tensor:
109112
"""

0 commit comments

Comments
 (0)