Skip to content

Commit d8ceaa8

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Add additional gradient-based attribution methods to LLM Attribution (#1337)
Summary: Add `LayerGradientXActivation` and `LayerGradientShap` to the supported gradient-based LLM attribution methods. Pull Request resolved: #1337 Test Plan: `pytest tests/attr -k TestLLMGradAttr` with new test cases via parameterized library Reviewed By: cyrjano Differential Revision: D62221000 Pulled By: craymichael fbshipit-source-id: fb5f170e13a62355357d46d3ef7a2464e8eb80ab
1 parent d89243b commit d8ceaa8

File tree

4 files changed

+121
-56
lines changed

4 files changed

+121
-56
lines changed

captum/attr/_core/deep_lift.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def __init__(
110110
Default: 1e-10
111111
"""
112112
GradientAttribution.__init__(self, model)
113-
self.model = model
113+
self.model: nn.Module = model
114114
self.eps = eps
115115
self.forward_handles: List[RemovableHandle] = []
116116
self.backward_handles: List[RemovableHandle] = []
@@ -324,7 +324,8 @@ def attribute( # type: ignore
324324
warnings.warn(
325325
"""Setting forward, backward hooks and attributes on non-linear
326326
activations. The hooks and attributes will be removed
327-
after the attribution is finished"""
327+
after the attribution is finished""",
328+
stacklevel=2,
328329
)
329330
# pyre-fixme[6]: For 1st argument expected `Tuple[Tensor, ...]` but got
330331
# `TensorOrTupleOfTensorsGeneric`.

captum/attr/_core/layer/layer_deep_lift.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,9 @@ def chunk_output_fn(out: TensorOrTupleOfTensorsGeneric) -> Sequence:
351351
grad_kwargs=grad_kwargs,
352352
)
353353

354-
attr_inputs = tuple(map(lambda attr: attr[0], attrs))
355-
attr_baselines = tuple(map(lambda attr: attr[1], attrs))
356-
gradients = tuple(map(lambda grad: grad[0], gradients))
354+
attr_inputs = tuple(attr[0] for attr in attrs)
355+
attr_baselines = tuple(attr[1] for attr in attrs)
356+
gradients = tuple(grad[0] for grad in gradients)
357357

358358
if custom_attribution_func is None:
359359
if self.multiplies_by_inputs:

captum/attr/_core/llm_attr.py

Lines changed: 56 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from captum._utils.typing import TokenizerLike
1111
from captum.attr._core.feature_ablation import FeatureAblation
1212
from captum.attr._core.kernel_shap import KernelShap
13+
from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap
14+
from captum.attr._core.layer.layer_gradient_x_activation import LayerGradientXActivation
1315
from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients
1416
from captum.attr._core.lime import Lime
1517
from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling
@@ -452,7 +454,11 @@ class LLMGradientAttribution(Attribution):
452454
and returns LLMAttributionResult
453455
"""
454456

455-
SUPPORTED_METHODS = (LayerIntegratedGradients,)
457+
SUPPORTED_METHODS = (
458+
LayerGradientShap,
459+
LayerGradientXActivation,
460+
LayerIntegratedGradients,
461+
)
456462
SUPPORTED_INPUTS = (TextTokenInput,)
457463

458464
def __init__(
@@ -473,53 +479,21 @@ class created with the llm model that follows huggingface style
473479

474480
super().__init__(attr_method.forward_func)
475481

476-
# shallow copy is enough to avoid modifying original instance
477-
self.attr_method: GradientAttribution = copy(attr_method)
478-
self.attr_method.forward_func = self._forward_func
479-
480482
# alias, we really need a model and don't support wrapper functions
481483
# coz we need call model.forward, model.generate, etc.
482484
self.model: nn.Module = cast(nn.Module, self.forward_func)
483485

486+
# shallow copy is enough to avoid modifying original instance
487+
self.attr_method: GradientAttribution = copy(attr_method)
488+
self.attr_method.forward_func = GradientForwardFunc(self)
489+
484490
self.tokenizer: TokenizerLike = tokenizer
485491
self.device: torch.device = (
486492
cast(torch.device, self.model.device)
487493
if hasattr(self.model, "device")
488494
else next(self.model.parameters()).device
489495
)
490496

491-
def _forward_func(
492-
self,
493-
perturbed_tensor: Tensor,
494-
inp: InterpretableInput,
495-
target_tokens: Tensor, # 1D tensor of target token ids
496-
cur_target_idx: int, # current target index
497-
) -> Tensor:
498-
perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor))
499-
500-
if cur_target_idx:
501-
# the input batch size can be expanded by attr method
502-
output_token_tensor = (
503-
target_tokens[:cur_target_idx]
504-
.unsqueeze(0)
505-
.expand(perturbed_input.size(0), -1)
506-
.to(self.device)
507-
)
508-
new_input_tensor = torch.cat([perturbed_input, output_token_tensor], dim=1)
509-
else:
510-
new_input_tensor = perturbed_input
511-
512-
output_logits = self.model(new_input_tensor)
513-
514-
new_token_logits = output_logits.logits[:, -1]
515-
log_probs = torch.nn.functional.log_softmax(new_token_logits, dim=1)
516-
517-
target_token = target_tokens[cur_target_idx]
518-
token_log_probs = log_probs[..., target_token]
519-
520-
# the attribution target is limited to the log probability
521-
return token_log_probs
522-
523497
def _format_model_input(self, model_input: Tensor) -> Tensor:
524498
"""
525499
Convert str to tokenized tensor
@@ -643,3 +617,48 @@ def attribute_future(self) -> Callable:
643617
raise NotImplementedError(
644618
"attribute_future is not implemented for LLMGradientAttribution"
645619
)
620+
621+
622+
class GradientForwardFunc(nn.Module):
623+
"""
624+
A wrapper class for the forward function of a model in LLMGradientAttribution
625+
"""
626+
627+
def __init__(self, attr: LLMGradientAttribution) -> None:
628+
super().__init__()
629+
self.attr = attr
630+
self.model: nn.Module = attr.model
631+
632+
def forward(
633+
self,
634+
perturbed_tensor: Tensor,
635+
inp: InterpretableInput,
636+
target_tokens: Tensor, # 1D tensor of target token ids
637+
cur_target_idx: int, # current target index
638+
) -> Tensor:
639+
perturbed_input = self.attr._format_model_input(
640+
inp.to_model_input(perturbed_tensor)
641+
)
642+
643+
if cur_target_idx:
644+
# the input batch size can be expanded by attr method
645+
output_token_tensor = (
646+
target_tokens[:cur_target_idx]
647+
.unsqueeze(0)
648+
.expand(perturbed_input.size(0), -1)
649+
.to(self.attr.device)
650+
)
651+
new_input_tensor = torch.cat([perturbed_input, output_token_tensor], dim=1)
652+
else:
653+
new_input_tensor = perturbed_input
654+
655+
output_logits = self.model(new_input_tensor)
656+
657+
new_token_logits = output_logits.logits[:, -1]
658+
log_probs = torch.nn.functional.log_softmax(new_token_logits, dim=1)
659+
660+
target_token = target_tokens[cur_target_idx]
661+
token_log_probs = log_probs[..., target_token]
662+
663+
# the attribution target is limited to the log probability
664+
return token_log_probs

tests/attr/test_llm_attr.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,19 @@
33
# pyre-strict
44

55
import copy
6-
from typing import Any, cast, Dict, List, NamedTuple, Optional, Type, Union
6+
from typing import Any, cast, Dict, List, NamedTuple, Optional, Tuple, Type, Union
77

88
import torch
9-
from captum._utils.models.linear_model import ( # @manual=//pytorch/captum/captum/_utils/models/linear_model:linear_model # noqa: E501
10-
SkLearnLasso,
11-
)
9+
from captum._utils.models.linear_model import SkLearnLasso
1210
from captum.attr._core.feature_ablation import FeatureAblation
1311
from captum.attr._core.kernel_shap import KernelShap
12+
from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap
13+
from captum.attr._core.layer.layer_gradient_x_activation import LayerGradientXActivation
1414
from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients
1515
from captum.attr._core.lime import Lime
1616
from captum.attr._core.llm_attr import LLMAttribution, LLMGradientAttribution
1717
from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling
18-
from captum.attr._utils.attribution import PerturbationAttribution
18+
from captum.attr._utils.attribution import GradientAttribution, PerturbationAttribution
1919
from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput
2020
from parameterized import parameterized, parameterized_class
2121
from tests.helpers import BaseTest
@@ -379,15 +379,30 @@ def test_futures_not_implemented(self) -> None:
379379
class TestLLMGradAttr(BaseTest):
380380
device: str
381381

382-
def test_llm_attr(self) -> None:
382+
@parameterized.expand(
383+
[
384+
(LayerIntegratedGradients, None),
385+
(LayerGradientXActivation, None),
386+
(LayerGradientShap, (torch.tensor([[1, 0, 1, 0]]),)),
387+
]
388+
)
389+
def test_llm_attr(
390+
self, AttrClass: Type[GradientAttribution], baselines: Optional[Tuple[Tensor]]
391+
) -> None:
383392
llm = DummyLLM()
384393
llm.to(self.device)
385394
tokenizer = DummyTokenizer()
386-
attr = LayerIntegratedGradients(llm, llm.emb)
395+
attr = AttrClass(llm, llm.emb) # type: ignore[call-arg]
387396
llm_attr = LLMGradientAttribution(attr, tokenizer)
388397

398+
attr_kws: Dict[str, Any] = {}
399+
if baselines is not None:
400+
attr_kws["baselines"] = tuple(
401+
baseline.to(self.device) for baseline in baselines
402+
)
403+
389404
inp = TextTokenInput("a b c", tokenizer)
390-
res = llm_attr.attribute(inp, "m n o p q")
405+
res = llm_attr.attribute(inp, "m n o p q", **attr_kws)
391406

392407
# 5 output tokens, 4 input tokens including sos
393408
self.assertEqual(res.seq_attr.shape, (4,))
@@ -402,15 +417,30 @@ def test_llm_attr(self) -> None:
402417
assert res.token_attr is not None # make pyre/mypy happy
403418
self.assertEqual(token_attr.device.type, self.device) # type: ignore
404419

405-
def test_llm_attr_without_target(self) -> None:
420+
@parameterized.expand(
421+
[
422+
(LayerIntegratedGradients, None),
423+
(LayerGradientXActivation, None),
424+
(LayerGradientShap, (torch.tensor([[1, 0, 1, 0]]),)),
425+
]
426+
)
427+
def test_llm_attr_without_target(
428+
self, AttrClass: Type[GradientAttribution], baselines: Optional[Tuple[Tensor]]
429+
) -> None:
406430
llm = DummyLLM()
407431
llm.to(self.device)
408432
tokenizer = DummyTokenizer()
409-
attr = LayerIntegratedGradients(llm, llm.emb)
433+
attr = AttrClass(llm, llm.emb) # type: ignore[call-arg]
410434
llm_attr = LLMGradientAttribution(attr, tokenizer)
411435

436+
attr_kws: Dict[str, Any] = {}
437+
if baselines is not None:
438+
attr_kws["baselines"] = tuple(
439+
baseline.to(self.device) for baseline in baselines
440+
)
441+
412442
inp = TextTokenInput("a b c", tokenizer)
413-
res = llm_attr.attribute(inp, gen_args={"mock_response": "x y z"})
443+
res = llm_attr.attribute(inp, gen_args={"mock_response": "x y z"}, **attr_kws)
414444

415445
self.assertEqual(res.seq_attr.shape, (4,))
416446
assert res.token_attr is not None # make pyre/mypy happy
@@ -424,15 +454,30 @@ def test_llm_attr_without_target(self) -> None:
424454
assert res.token_attr is not None # make pyre/mypy happy
425455
self.assertEqual(token_attr.device.type, self.device) # type: ignore
426456

427-
def test_llm_attr_with_skip_tokens(self) -> None:
457+
@parameterized.expand(
458+
[
459+
(LayerIntegratedGradients, None),
460+
(LayerGradientXActivation, None),
461+
(LayerGradientShap, (torch.tensor([[1, 0, 1]]),)),
462+
]
463+
)
464+
def test_llm_attr_with_skip_tokens(
465+
self, AttrClass: Type[GradientAttribution], baselines: Optional[Tuple[Tensor]]
466+
) -> None:
428467
llm = DummyLLM()
429468
llm.to(self.device)
430469
tokenizer = DummyTokenizer()
431-
attr = LayerIntegratedGradients(llm, llm.emb)
470+
attr = AttrClass(llm, llm.emb) # type: ignore[call-arg]
432471
llm_attr = LLMGradientAttribution(attr, tokenizer)
433472

473+
attr_kws: Dict[str, Any] = {}
474+
if baselines is not None:
475+
attr_kws["baselines"] = tuple(
476+
baseline.to(self.device) for baseline in baselines
477+
)
478+
434479
inp = TextTokenInput("a b c", tokenizer, skip_tokens=[0])
435-
res = llm_attr.attribute(inp, "m n o p q")
480+
res = llm_attr.attribute(inp, "m n o p q", **attr_kws)
436481

437482
# 5 output tokens, 4 input tokens including sos
438483
self.assertEqual(res.seq_attr.shape, (3,))

0 commit comments

Comments
 (0)