22
33import warnings
44
5+ from abc import ABC
6+
57from copy import copy
68
79from textwrap import shorten
810
9- from typing import Any , Callable , cast , Dict , List , Optional , Tuple , Union
11+ from typing import Any , Callable , cast , Dict , List , Optional , Tuple , Type , Union
1012
1113import matplotlib .colors as mcolors
1214
@@ -319,7 +321,104 @@ def _convert_ids_to_pretty_tokens_fallback(
319321 return pretty_tokens
320322
321323
322- class LLMAttribution (Attribution ):
324+ class BaseLLMAttribution (Attribution , ABC ):
325+ """Base class for LLM Attribution methods"""
326+
327+ SUPPORTED_INPUTS : Tuple [Type [InterpretableInput ], ...]
328+ SUPPORTED_METHODS : Tuple [Type [Attribution ], ...]
329+
330+ model : nn .Module
331+ tokenizer : TokenizerLike
332+ device : torch .device
333+
334+ def __init__ (
335+ self ,
336+ attr_method : Attribution ,
337+ tokenizer : TokenizerLike ,
338+ ) -> None :
339+ assert isinstance (
340+ attr_method , self .SUPPORTED_METHODS
341+ ), f"{ self .__class__ .__name__ } does not support { type (attr_method )} "
342+
343+ super ().__init__ (attr_method .forward_func )
344+
345+ # alias, we really need a model and don't support wrapper functions
346+ # coz we need call model.forward, model.generate, etc.
347+ self .model : nn .Module = cast (nn .Module , self .forward_func )
348+
349+ self .tokenizer : TokenizerLike = tokenizer
350+ self .device : torch .device = (
351+ cast (torch .device , self .model .device )
352+ if hasattr (self .model , "device" )
353+ else next (self .model .parameters ()).device
354+ )
355+
356+ def _get_target_tokens (
357+ self ,
358+ inp : InterpretableInput ,
359+ target : Union [str , torch .Tensor , None ] = None ,
360+ skip_tokens : Union [List [int ], List [str ], None ] = None ,
361+ gen_args : Optional [Dict [str , Any ]] = None ,
362+ ) -> Tensor :
363+ assert isinstance (
364+ inp , self .SUPPORTED_INPUTS
365+ ), f"LLMAttribution does not support input type { type (inp )} "
366+
367+ if target is None :
368+ # generate when None
369+ assert hasattr (self .model , "generate" ) and callable (self .model .generate ), (
370+ "The model does not have recognizable generate function."
371+ "Target must be given for attribution"
372+ )
373+
374+ if not gen_args :
375+ gen_args = DEFAULT_GEN_ARGS
376+
377+ model_inp = self ._format_model_input (inp .to_model_input ())
378+ output_tokens = self .model .generate (model_inp , ** gen_args )
379+ target_tokens = output_tokens [0 ][model_inp .size (1 ) :]
380+ else :
381+ assert gen_args is None , "gen_args must be None when target is given"
382+ # Encode skip tokens
383+ if skip_tokens :
384+ if isinstance (skip_tokens [0 ], str ):
385+ skip_tokens = cast (List [str ], skip_tokens )
386+ skip_tokens = self .tokenizer .convert_tokens_to_ids (skip_tokens )
387+ else :
388+ skip_tokens = []
389+ skip_tokens = cast (List [int ], skip_tokens )
390+
391+ if isinstance (target , str ):
392+ encoded = self .tokenizer .encode (target )
393+ target_tokens = torch .tensor (
394+ [token for token in encoded if token not in skip_tokens ]
395+ )
396+ elif isinstance (target , torch .Tensor ):
397+ target_tokens = target [
398+ ~ torch .isin (target , torch .tensor (skip_tokens , device = target .device ))
399+ ]
400+ else :
401+ raise TypeError (
402+ "target must either be str or Tensor, but the type of target is "
403+ "{}" .format (type (target ))
404+ )
405+ return target_tokens
406+
407+ def _format_model_input (self , model_input : Union [str , Tensor ]) -> Tensor :
408+ """
409+ Convert str to tokenized tensor
410+ to make LLMAttribution work with model inputs of both
411+ raw text and text token tensors
412+ """
413+ # return tensor(1, n_tokens)
414+ if isinstance (model_input , str ):
415+ return self .tokenizer .encode (model_input , return_tensors = "pt" ).to (
416+ self .device
417+ )
418+ return model_input .to (self .device )
419+
420+
421+ class LLMAttribution (BaseLLMAttribution ):
323422 """
324423 Attribution class for large language models. It wraps a perturbation-based
325424 attribution algorthm to produce commonly interested attribution
@@ -365,11 +464,7 @@ class created with the llm model that follows huggingface style
365464 Default: "log_prob"
366465 """
367466
368- assert isinstance (
369- attr_method , self .SUPPORTED_METHODS
370- ), f"LLMAttribution does not support { type (attr_method )} "
371-
372- super ().__init__ (attr_method .forward_func )
467+ super ().__init__ (attr_method , tokenizer )
373468
374469 # shallow copy is enough to avoid modifying original instance
375470 self .attr_method : PerturbationAttribution = copy (attr_method )
@@ -379,17 +474,6 @@ class created with the llm model that follows huggingface style
379474
380475 self .attr_method .forward_func = self ._forward_func
381476
382- # alias, we really need a model and don't support wrapper functions
383- # coz we need call model.forward, model.generate, etc.
384- self .model : nn .Module = cast (nn .Module , self .forward_func )
385-
386- self .tokenizer : TokenizerLike = tokenizer
387- self .device : torch .device = (
388- cast (torch .device , self .model .device )
389- if hasattr (self .model , "device" )
390- else next (self .model .parameters ()).device
391- )
392-
393477 assert attr_target in (
394478 "log_prob" ,
395479 "prob" ,
@@ -488,19 +572,6 @@ def _forward_func(
488572
489573 return target_probs if self .attr_target != "log_prob" else target_log_probs
490574
491- def _format_model_input (self , model_input : Union [str , Tensor ]) -> Tensor :
492- """
493- Convert str to tokenized tensor
494- to make LLMAttribution work with model inputs of both
495- raw text and text token tensors
496- """
497- # return tensor(1, n_tokens)
498- if isinstance (model_input , str ):
499- return self .tokenizer .encode (model_input , return_tensors = "pt" ).to (
500- self .device
501- )
502- return model_input .to (self .device )
503-
504575 def attribute (
505576 self ,
506577 inp : InterpretableInput ,
@@ -527,7 +598,7 @@ def attribute(
527598 of integers of the token ids.
528599 Default: None
529600 num_trials (int, optional): number of trials to run. Return is the average
530- attribibutions over all the trials.
601+ attributions over all the trials.
531602 Defaults: 1.
532603 gen_args (dict, optional): arguments for generating the target. Only used if
533604 target is not given. When None, the default arguments are used,
@@ -542,49 +613,12 @@ def attribute(
542613 attr (LLMAttributionResult): Attribution result. token_attr will be None
543614 if attr method is Lime or KernelShap.
544615 """
545-
546- assert isinstance (
547- inp , self .SUPPORTED_INPUTS
548- ), f"LLMAttribution does not support input type { type (inp )} "
549-
550- if target is None :
551- # generate when None
552- assert hasattr (self .model , "generate" ) and callable (self .model .generate ), (
553- "The model does not have recognizable generate function."
554- "Target must be given for attribution"
555- )
556-
557- if not gen_args :
558- gen_args = DEFAULT_GEN_ARGS
559-
560- model_inp = self ._format_model_input (inp .to_model_input ())
561- output_tokens = self .model .generate (model_inp , ** gen_args )
562- target_tokens = output_tokens [0 ][model_inp .size (1 ) :]
563- else :
564- assert gen_args is None , "gen_args must be None when target is given"
565- # Encode skip tokens
566- if skip_tokens :
567- if isinstance (skip_tokens [0 ], str ):
568- skip_tokens = cast (List [str ], skip_tokens )
569- skip_tokens = self .tokenizer .convert_tokens_to_ids (skip_tokens )
570- else :
571- skip_tokens = []
572- skip_tokens = cast (List [int ], skip_tokens )
573-
574- if isinstance (target , str ):
575- encoded = self .tokenizer .encode (target )
576- target_tokens = torch .tensor (
577- [token for token in encoded if token not in skip_tokens ]
578- )
579- elif isinstance (target , torch .Tensor ):
580- target_tokens = target [
581- ~ torch .isin (target , torch .tensor (skip_tokens , device = target .device ))
582- ]
583- else :
584- raise TypeError (
585- "target must either be str or Tensor, but the type of target is "
586- "{}" .format (type (target ))
587- )
616+ target_tokens = self ._get_target_tokens (
617+ inp ,
618+ target ,
619+ skip_tokens = skip_tokens ,
620+ gen_args = gen_args ,
621+ )
588622
589623 attr = torch .zeros (
590624 [
@@ -638,7 +672,7 @@ def attribute_future(self) -> Callable[[], LLMAttributionResult]:
638672 )
639673
640674
641- class LLMGradientAttribution (Attribution ):
675+ class LLMGradientAttribution (BaseLLMAttribution ):
642676 """
643677 Attribution class for large language models. It wraps a gradient-based
644678 attribution algorthm to produce commonly interested attribution
@@ -670,27 +704,12 @@ class created with the llm model that follows huggingface style
670704 interface convention
671705 tokenizer (Tokenizer): tokenizer of the llm model used in the attr_method
672706 """
673- assert isinstance (
674- attr_method , self .SUPPORTED_METHODS
675- ), f"LLMGradientAttribution does not support { type (attr_method )} "
676-
677- super ().__init__ (attr_method .forward_func )
678-
679- # alias, we really need a model and don't support wrapper functions
680- # coz we need call model.forward, model.generate, etc.
681- self .model : nn .Module = cast (nn .Module , self .forward_func )
707+ super ().__init__ (attr_method , tokenizer )
682708
683709 # shallow copy is enough to avoid modifying original instance
684710 self .attr_method : GradientAttribution = copy (attr_method )
685711 self .attr_method .forward_func = GradientForwardFunc (self )
686712
687- self .tokenizer : TokenizerLike = tokenizer
688- self .device : torch .device = (
689- cast (torch .device , self .model .device )
690- if hasattr (self .model , "device" )
691- else next (self .model .parameters ()).device
692- )
693-
694713 def _format_model_input (self , model_input : Union [Tensor , str ]) -> Tensor :
695714 """
696715 Convert str to tokenized tensor
@@ -734,50 +753,12 @@ def attribute(
734753
735754 attr (LLMAttributionResult): attribution result
736755 """
737-
738- assert isinstance (
739- inp , self .SUPPORTED_INPUTS
740- ), f"LLMGradAttribution does not support input type { type (inp )} "
741-
742- if target is None :
743- # generate when None
744- assert hasattr (self .model , "generate" ) and callable (self .model .generate ), (
745- "The model does not have recognizable generate function."
746- "Target must be given for attribution"
747- )
748-
749- if not gen_args :
750- gen_args = DEFAULT_GEN_ARGS
751-
752- with torch .no_grad ():
753- model_inp = self ._format_model_input (inp .to_model_input ())
754- output_tokens = self .model .generate (model_inp , ** gen_args )
755- target_tokens = output_tokens [0 ][model_inp .size (1 ) :]
756- else :
757- assert gen_args is None , "gen_args must be None when target is given"
758- # Encode skip tokens
759- if skip_tokens :
760- if isinstance (skip_tokens [0 ], str ):
761- skip_tokens = cast (List [str ], skip_tokens )
762- skip_tokens = self .tokenizer .convert_tokens_to_ids (skip_tokens )
763- else :
764- skip_tokens = []
765- skip_tokens = cast (List [int ], skip_tokens )
766-
767- if isinstance (target , str ):
768- encoded = self .tokenizer .encode (target )
769- target_tokens = torch .tensor (
770- [token for token in encoded if token not in skip_tokens ]
771- )
772- elif isinstance (target , torch .Tensor ):
773- target_tokens = target [
774- ~ torch .isin (target , torch .tensor (skip_tokens , device = target .device ))
775- ]
776- else :
777- raise TypeError (
778- "target must either be str or Tensor, but the type of target is "
779- "{}" .format (type (target ))
780- )
756+ target_tokens = self ._get_target_tokens (
757+ inp ,
758+ target ,
759+ skip_tokens = skip_tokens ,
760+ gen_args = gen_args ,
761+ )
781762
782763 attr_inp = inp .to_tensor ().to (self .device )
783764
0 commit comments