11# pyre-strict
22from copy import copy
33
4- from typing import Any , Callable , cast , Dict , List , Optional , Union
4+ from typing import Any , Callable , cast , Dict , List , Optional , Tuple , Union
55
66import matplotlib .pyplot as plt
77import numpy as np
88
99import torch
10+ from captum ._utils .typing import TokenizerLike
1011from captum .attr ._core .feature_ablation import FeatureAblation
1112from captum .attr ._core .kernel_shap import KernelShap
1213from captum .attr ._core .layer .layer_integrated_gradients import LayerIntegratedGradients
1314from captum .attr ._core .lime import Lime
1415from captum .attr ._core .shapley_value import ShapleyValues , ShapleyValueSampling
15- from captum .attr ._utils .attribution import Attribution
16+ from captum .attr ._utils .attribution import (
17+ Attribution ,
18+ GradientAttribution ,
19+ PerturbationAttribution ,
20+ )
1621from captum .attr ._utils .interpretable_input import (
1722 InterpretableInput ,
1823 TextTemplateInput ,
@@ -44,11 +49,12 @@ def __init__(
4449 self .output_tokens = output_tokens
4550
4651 @property
47- def seq_attr_dict (self ) -> Dict [str , Any ]:
52+ def seq_attr_dict (self ) -> Dict [str , float ]:
4853 return {k : v for v , k in zip (self .seq_attr .cpu ().tolist (), self .input_tokens )}
4954
50- # pyre-fixme[3]: Return type must be annotated.
51- def plot_token_attr (self , show : bool = False ):
55+ def plot_token_attr (
56+ self , show : bool = False
57+ ) -> Union [None , Tuple [plt .Figure , plt .Axes ]]:
5258 """
5359 Generate a matplotlib plot for visualising the attribution
5460 of the output tokens.
@@ -58,7 +64,11 @@ def plot_token_attr(self, show: bool = False):
5864 Default: False
5965 """
6066
61- # pyre-fixme[16]: `Optional` has no attribute `cpu`.
67+ if self .token_attr is None :
68+ raise ValueError (
69+ "token_attr is None (no token-level attribution was performed), please "
70+ "use plot_seq_attr instead for the sequence-level attribution plot"
71+ )
6272 token_attr = self .token_attr .cpu () # type: ignore
6373
6474 # maximum absolute attribution value
@@ -83,7 +93,7 @@ def plot_token_attr(self, show: bool = False):
8393 )
8494
8595 # Create colorbar
86- cbar = ax . figure .colorbar (im , ax = ax ) # type: ignore
96+ cbar = fig .colorbar (im , ax = ax ) # type: ignore
8797 cbar .ax .set_ylabel ("Token Attribuiton" , rotation = - 90 , va = "bottom" )
8898
8999 # Show all ticks and label them with the respective list entries.
@@ -113,11 +123,13 @@ def plot_token_attr(self, show: bool = False):
113123
114124 if show :
115125 plt .show ()
126+ return None # mypy wants this
116127 else :
117128 return fig , ax
118129
119- # pyre-fixme[3]: Return type must be annotated.
120- def plot_seq_attr (self , show : bool = False ):
130+ def plot_seq_attr (
131+ self , show : bool = False
132+ ) -> Union [None , Tuple [plt .Figure , plt .Axes ]]:
121133 """
122134 Generate a matplotlib plot for visualising the attribution
123135 of the output sequence.
@@ -150,6 +162,7 @@ def plot_seq_attr(self, show: bool = False):
150162
151163 if show :
152164 plt .show ()
165+ return None # mypy wants this
153166 else :
154167 return fig , ax
155168
@@ -181,9 +194,8 @@ class LLMAttribution(Attribution):
181194
182195 def __init__ (
183196 self ,
184- attr_method : Attribution ,
185- # pyre-fixme[2]: Parameter must be annotated.
186- tokenizer ,
197+ attr_method : PerturbationAttribution ,
198+ tokenizer : TokenizerLike ,
187199 attr_target : str = "log_prob" , # TODO: support callable attr_target
188200 ) -> None :
189201 """
@@ -208,24 +220,19 @@ class created with the llm model that follows huggingface style
208220 super ().__init__ (attr_method .forward_func )
209221
210222 # shallow copy is enough to avoid modifying original instance
211- # pyre-fixme[4]: Attribute must be annotated.
212- self .attr_method = copy (attr_method )
213- # pyre-fixme[4]: Attribute must be annotated.
214- self .include_per_token_attr = isinstance (
223+ self .attr_method : PerturbationAttribution = copy (attr_method )
224+ self .include_per_token_attr : bool = isinstance (
215225 attr_method , self .SUPPORTED_PER_TOKEN_ATTR_METHODS
216226 )
217227
218228 self .attr_method .forward_func = self ._forward_func
219229
220230 # alias, we really need a model and don't support wrapper functions
221231 # coz we need call model.forward, model.generate, etc.
222- # pyre-fixme[4]: Attribute must be annotated.
223- self .model = cast (nn .Module , self .forward_func )
232+ self .model : nn .Module = cast (nn .Module , self .forward_func )
224233
225- # pyre-fixme[4]: Attribute must be annotated.
226- self .tokenizer = tokenizer
227- # pyre-fixme[4]: Attribute must be annotated.
228- self .device = (
234+ self .tokenizer : TokenizerLike = tokenizer
235+ self .device : torch .device = (
229236 cast (torch .device , self .model .device )
230237 if hasattr (self .model , "device" )
231238 else next (self .model .parameters ()).device
@@ -239,15 +246,12 @@ class created with the llm model that follows huggingface style
239246
240247 def _forward_func (
241248 self ,
242- # pyre-fixme[2]: Parameter must be annotated.
243- perturbed_tensor ,
244- # pyre-fixme[2]: Parameter must be annotated.
245- inp ,
246- # pyre-fixme[2]: Parameter must be annotated.
247- target_tokens ,
249+ perturbed_tensor : Union [None , Tensor ],
250+ inp : InterpretableInput ,
251+ target_tokens : Tensor ,
248252 use_cached_outputs : bool = False ,
249- _inspect_forward = None ,
250- ) -> Union [ int , Tensor ] :
253+ _inspect_forward : Optional [ Callable [[ str , str , List [ float ]], None ]] = None ,
254+ ) -> Tensor :
251255 perturbed_input = self ._format_model_input (inp .to_model_input (perturbed_tensor ))
252256 init_model_inp = perturbed_input
253257
@@ -279,7 +283,9 @@ def _forward_func(
279283 (model_inp , torch .tensor ([[target_token ]]).to (self .device )), dim = 1
280284 )
281285
282- total_log_prob = sum (log_prob_list )
286+ # pyre-ignore[9] pyre/mypy thinks sum returns int here, but it will return
287+ # Tensor
288+ total_log_prob : Tensor = sum (log_prob_list ) # type: ignore
283289 # 1st element is the total prob, rest are the target tokens
284290 # add a leading dim for batch even we only support single instance for now
285291 if self .include_per_token_attr :
@@ -288,8 +294,6 @@ def _forward_func(
288294 ).unsqueeze (0 )
289295 else :
290296 target_log_probs = total_log_prob # type: ignore
291- # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[int,
292- # Tensor]`.
293297 target_probs = torch .exp (target_log_probs )
294298
295299 if _inspect_forward :
@@ -301,35 +305,31 @@ def _forward_func(
301305
302306 return target_probs if self .attr_target != "log_prob" else target_log_probs
303307
304- # pyre-fixme[3]: Return type must be annotated.
305- def _format_model_input (self , model_input : Union [str , Tensor ]):
308+ def _format_model_input (self , model_input : Union [str , Tensor ]) -> Tensor :
306309 """
307310 Convert str to tokenized tensor
308311 to make LLMAttribution work with model inputs of both
309312 raw text and text token tensors
310313 """
311314 # return tensor(1, n_tokens)
312315 if isinstance (model_input , str ):
313- return self .tokenizer .encode (model_input , return_tensors = "pt" ).to (
314- self .device
315- )
316+ # pyre-ignore[9] pyre/mypy thinks return type may be List, but it will be
317+ # Tensor
318+ return self .tokenizer .encode ( # type: ignore
319+ model_input , return_tensors = "pt"
320+ ).to (self .device )
316321 return model_input .to (self .device )
317322
318323 def attribute (
319324 self ,
320325 inp : InterpretableInput ,
321326 target : Union [str , torch .Tensor , None ] = None ,
322327 num_trials : int = 1 ,
323- # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
324- # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting
325- # errors.
326- gen_args : Optional [Dict ] = None ,
328+ gen_args : Optional [Dict [str , Any ]] = None ,
327329 use_cached_outputs : bool = True ,
328330 # internal callback hook can be used for logging
329- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
330- _inspect_forward : Optional [Callable ] = None ,
331- # pyre-fixme[2]: Parameter must be annotated.
332- ** kwargs ,
331+ _inspect_forward : Optional [Callable [[str , str , List [float ]], None ]] = None ,
332+ ** kwargs : Any ,
333333 ) -> LLMAttributionResult :
334334 """
335335 Args:
@@ -380,10 +380,14 @@ def attribute(
380380 target_tokens = torch .tensor (target_tokens )
381381 elif type (target ) is torch .Tensor :
382382 target_tokens = target
383+ else :
384+ raise TypeError (
385+ "target must either be str or Tensor, but the type of target is "
386+ "{}" .format (type (target ))
387+ )
383388
384389 attr = torch .zeros (
385390 [
386- # pyre-fixme[61]: `target_tokens` is undefined, or not always defined.
387391 1 + len (target_tokens ) if self .include_per_token_attr else 1 ,
388392 inp .n_itp_features ,
389393 ],
@@ -398,8 +402,6 @@ def attribute(
398402 attr_input ,
399403 additional_forward_args = (
400404 inp ,
401- # pyre-fixme[61]: `target_tokens` is undefined, or not always
402- # defined.
403405 target_tokens ,
404406 use_cached_outputs ,
405407 _inspect_forward ,
@@ -424,7 +426,6 @@ def attribute(
424426 attr [1 :] if self .include_per_token_attr else None
425427 ), # shape(n_output_token, n_input_features)
426428 inp .values ,
427- # pyre-fixme[61]: `target_tokens` is undefined, or not always defined.
428429 self .tokenizer .convert_ids_to_tokens (target_tokens ),
429430 )
430431
@@ -454,14 +455,11 @@ class LLMGradientAttribution(Attribution):
454455 SUPPORTED_METHODS = (LayerIntegratedGradients ,)
455456 SUPPORTED_INPUTS = (TextTokenInput ,)
456457
457- # pyre-fixme[3]: Return type must be annotated.
458458 def __init__ (
459459 self ,
460- # pyre-fixme[2]: Parameter must be annotated.
461- attr_method ,
462- # pyre-fixme[2]: Parameter must be annotated.
463- tokenizer ,
464- ):
460+ attr_method : GradientAttribution ,
461+ tokenizer : TokenizerLike ,
462+ ) -> None :
465463 """
466464 Args:
467465 attr_method (Attribution): instance of a supported perturbation attribution
@@ -476,19 +474,15 @@ class created with the llm model that follows huggingface style
476474 super ().__init__ (attr_method .forward_func )
477475
478476 # shallow copy is enough to avoid modifying original instance
479- # pyre-fixme[4]: Attribute must be annotated.
480- self .attr_method = copy (attr_method )
477+ self .attr_method : GradientAttribution = copy (attr_method )
481478 self .attr_method .forward_func = self ._forward_func
482479
483480 # alias, we really need a model and don't support wrapper functions
484481 # coz we need call model.forward, model.generate, etc.
485- # pyre-fixme[4]: Attribute must be annotated.
486- self .model = cast (nn .Module , self .forward_func )
482+ self .model : nn .Module = cast (nn .Module , self .forward_func )
487483
488- # pyre-fixme[4]: Attribute must be annotated.
489- self .tokenizer = tokenizer
490- # pyre-fixme[4]: Attribute must be annotated.
491- self .device = (
484+ self .tokenizer : TokenizerLike = tokenizer
485+ self .device : torch .device = (
492486 cast (torch .device , self .model .device )
493487 if hasattr (self .model , "device" )
494488 else next (self .model .parameters ()).device
@@ -526,9 +520,7 @@ def _forward_func(
526520 # the attribution target is limited to the log probability
527521 return token_log_probs
528522
529- # pyre-fixme[3]: Return type must be annotated.
530- # pyre-fixme[2]: Parameter must be annotated.
531- def _format_model_input (self , model_input ):
523+ def _format_model_input (self , model_input : Tensor ) -> Tensor :
532524 """
533525 Convert str to tokenized tensor
534526 """
@@ -538,12 +530,8 @@ def attribute(
538530 self ,
539531 inp : InterpretableInput ,
540532 target : Union [str , torch .Tensor , None ] = None ,
541- # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
542- # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting
543- # errors.
544- gen_args : Optional [Dict ] = None ,
545- # pyre-fixme[2]: Parameter must be annotated.
546- ** kwargs ,
533+ gen_args : Optional [Dict [str , Any ]] = None ,
534+ ** kwargs : Any ,
547535 ) -> LLMAttributionResult :
548536 """
549537 Args:
@@ -590,19 +578,21 @@ def attribute(
590578 target_tokens = torch .tensor (target_tokens )
591579 elif type (target ) is torch .Tensor :
592580 target_tokens = target
581+ else :
582+ raise TypeError (
583+ "target must either be str or Tensor, but the type of target is "
584+ "{}" .format (type (target ))
585+ )
593586
594587 attr_inp = inp .to_tensor ().to (self .device )
595588
596589 attr_list = []
597- # pyre-fixme[61]: `target_tokens` is undefined, or not always defined.
598590 for cur_target_idx , _ in enumerate (target_tokens ):
599591 # attr in shape(batch_size, input+output_len, emb_dim)
600592 attr = self .attr_method .attribute (
601593 attr_inp ,
602594 additional_forward_args = (
603595 inp ,
604- # pyre-fixme[61]: `target_tokens` is undefined, or not always
605- # defined.
606596 target_tokens ,
607597 cur_target_idx ,
608598 ),
@@ -629,7 +619,7 @@ def attribute(
629619 # it attributes to all the elements of the output of the specified layer
630620 # so we need special handling for the inp type which don't care all the elements
631621 if isinstance (inp , TextTokenInput ) and inp .itp_mask is not None :
632- itp_mask = inp .itp_mask .to (self .device )
622+ itp_mask = inp .itp_mask .to (attr .device )
633623 itp_mask = itp_mask .expand_as (attr )
634624 attr = attr [itp_mask ].view (attr .size (0 ), - 1 )
635625
@@ -642,7 +632,6 @@ def attribute(
642632 seq_attr ,
643633 attr , # shape(n_output_token, n_input_features)
644634 inp .values ,
645- # pyre-fixme[61]: `target_tokens` is undefined, or not always defined.
646635 self .tokenizer .convert_ids_to_tokens (target_tokens ),
647636 )
648637
0 commit comments