55import itertools
66import math
77import warnings
8- from typing import Callable , cast , Iterable , Optional , Sequence , Tuple , Union
8+ from typing import Callable , cast , Iterable , List , Optional , Sequence , Tuple , Union
99
1010import torch
1111from captum ._utils .common import (
2020 _is_tuple ,
2121 _run_forward ,
2222)
23+ from captum ._utils .exceptions import ShapleyValueFutureError
2324from captum ._utils .progress import progress
2425from captum ._utils .typing import BaselineType , TargetType , TensorOrTupleOfTensorsGeneric
2526from captum .attr ._utils .attribution import PerturbationAttribution
2930 _tensorize_baseline ,
3031)
3132from captum .log import log_usage
32- from torch import dtype , Tensor
33+ from torch import dtype , Size , Tensor
34+ from torch .futures import collect_all , Future
3335
3436
3537def _all_perm_generator (num_features : int , num_samples : int ) -> Iterable [Sequence [int ]]:
@@ -394,7 +396,6 @@ def attribute(
394396 )
395397 if show_progress :
396398 attr_progress .update ()
397-
398399 if agg_output_mode :
399400 eval_diff = modified_eval - prev_results
400401 prev_results = modified_eval
@@ -438,7 +439,6 @@ def attribute(
438439
439440 # (*output_shape, *input_feature_shape)
440441 total_attrib [j ] += cur_attr
441-
442442 if show_progress :
443443 attr_progress .close ()
444444
@@ -452,14 +452,298 @@ def attribute(
452452 # `Tuple[Tensor, ...]`.
453453 return formatted_attr
454454
455- # pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
456- def attribute_future (self ) -> Callable :
455+ def attribute_future (
456+ self ,
457+ inputs : TensorOrTupleOfTensorsGeneric ,
458+ baselines : BaselineType = None ,
459+ target : TargetType = None ,
460+ additional_forward_args : Optional [Tuple [object , ...]] = None ,
461+ feature_mask : Union [None , TensorOrTupleOfTensorsGeneric ] = None ,
462+ n_samples : int = 25 ,
463+ perturbations_per_eval : int = 1 ,
464+ show_progress : bool = False ,
465+ ) -> Future [TensorOrTupleOfTensorsGeneric ]:
457466 r"""
458467 This method is not implemented for ShapleyValueSampling.
459468 """
460- raise NotImplementedError (
461- "attribute_future is not implemented for ShapleyValueSampling"
469+ is_inputs_tuple = _is_tuple (inputs )
470+ inputs_tuple , baselines = _format_input_baseline (inputs , baselines )
471+ additional_forward_args = _format_additional_forward_args (
472+ additional_forward_args
462473 )
474+ formatted_feature_mask = _format_feature_mask (feature_mask , inputs_tuple )
475+ reshaped_feature_mask = _shape_feature_mask (
476+ formatted_feature_mask , inputs_tuple
477+ )
478+
479+ assert (
480+ isinstance (perturbations_per_eval , int ) and perturbations_per_eval >= 1
481+ ), "Ablations per evaluation must be at least 1."
482+
483+ with torch .no_grad ():
484+ baselines = _tensorize_baseline (inputs_tuple , baselines )
485+ num_examples = inputs_tuple [0 ].shape [0 ]
486+
487+ total_features = _get_max_feature_index (reshaped_feature_mask ) + 1
488+
489+ if show_progress :
490+ attr_progress = progress (
491+ desc = f"{ self .get_name ()} attribution" ,
492+ total = self ._get_n_evaluations (
493+ total_features , n_samples , perturbations_per_eval
494+ )
495+ + 1 , # add 1 for the initial eval
496+ )
497+ attr_progress .update (0 )
498+
499+ initial_eval = self ._strict_run_forward_future (
500+ self .forward_func , baselines , target , additional_forward_args
501+ )
502+
503+ if show_progress :
504+ attr_progress .update ()
505+
506+ prev_result_tuple = initial_eval .then (
507+ lambda initial_eval = initial_eval : self ._initial_eval_to_prev_results_tuple (
508+ initial_eval ,
509+ num_examples ,
510+ perturbations_per_eval ,
511+ reshaped_feature_mask ,
512+ inputs_tuple ,
513+ )
514+ )
515+
516+ iter_count = 0
517+ # Iterate for number of samples, generate a permutation of the features
518+ # and evalute the incremental increase for each feature.
519+ for feature_permutation in self .permutation_generator (
520+ total_features , n_samples
521+ ):
522+ prev_result_tuple = prev_result_tuple .then (
523+ lambda prev_result_tuple = prev_result_tuple : self ._set_prev_results_to_initial_eval (
524+ prev_result_tuple
525+ )
526+ )
527+
528+ iter_count += 1
529+ for (
530+ current_inputs ,
531+ current_add_args ,
532+ current_target ,
533+ current_masks ,
534+ ) in self ._perturbation_generator (
535+ inputs_tuple ,
536+ additional_forward_args ,
537+ target ,
538+ baselines ,
539+ reshaped_feature_mask ,
540+ feature_permutation ,
541+ perturbations_per_eval ,
542+ ):
543+ if sum (torch .sum (mask ).item () for mask in current_masks ) == 0 :
544+ warnings .warn (
545+ "Feature mask is missing some integers between 0 and "
546+ "num_features, for optimal performance, make sure each"
547+ " consecutive integer corresponds to a feature." ,
548+ stacklevel = 1 ,
549+ )
550+ # modified_eval dimensions: 1D tensor with length
551+ # equal to #num_examples * #features in batch
552+ modified_eval = self ._strict_run_forward_future (
553+ self .forward_func ,
554+ current_inputs ,
555+ current_target ,
556+ current_add_args ,
557+ )
558+ if show_progress :
559+ attr_progress .update ()
560+
561+ assert isinstance (modified_eval , torch .Future ), (
562+ "when using futures method, modified_eval should have "
563+ f"Future type rather than { type (modified_eval )} "
564+ )
565+ eval_futs = collect_all ([prev_result_tuple , modified_eval ])
566+ prev_result_tuple = eval_futs .then (
567+ lambda eval_futs = eval_futs , num_examples = num_examples , inputs_tuple = inputs_tuple , current_masks = current_masks : self ._eval_fut_to_prev_results_tuple (
568+ eval_futs , num_examples , inputs_tuple , current_masks
569+ )
570+ )
571+
572+ if show_progress :
573+ attr_progress .close ()
574+
575+ # Divide total attributions by number of random permutations and return
576+ # formatted attributions.
577+ formatted_attr = prev_result_tuple .then (
578+ lambda prev_result_tuple = prev_result_tuple : self ._prev_result_tuple_to_formatted_attr (
579+ prev_result_tuple , iter_count , is_inputs_tuple
580+ )
581+ )
582+ # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
583+ # `Tuple[Tensor, ...]`.
584+ return formatted_attr
585+
586+ def _initial_eval_to_prev_results_tuple (
587+ self ,
588+ initial_eval : Future [Tensor ],
589+ num_examples : int ,
590+ perturbations_per_eval : int ,
591+ reshaped_feature_mask : TensorOrTupleOfTensorsGeneric ,
592+ inputs_tuple : Tuple [Tensor , ...],
593+ ) -> Tuple [Tensor , Tensor , Size , List [Tensor ], bool ]:
594+ """Since the initial eval is a Future, it is easier to bundle the prev_result, agg_output_mode, output_shape, and total_attrib together
595+ as Shapley Value Feature Attributions are being calculated"""
596+ try :
597+ initial_eval_processed = initial_eval .value ()
598+ prev_result = initial_eval_processed
599+ if not isinstance (initial_eval_processed , Tensor ):
600+ raise AssertionError (
601+ "initial_eval_to_processed_initial_eval_fut: "
602+ "initial_eval should be a Tensor"
603+ )
604+ agg_output_mode = _find_output_mode_and_verify (
605+ initial_eval_processed ,
606+ num_examples ,
607+ perturbations_per_eval ,
608+ reshaped_feature_mask ,
609+ allow_multi_outputs = True ,
610+ )
611+ output_shape = initial_eval_processed .shape
612+ total_attrib : List [Tensor ] = [
613+ torch .zeros (
614+ tuple (output_shape ) + tuple (input .shape [1 :]),
615+ dtype = torch .float ,
616+ device = inputs_tuple [0 ].device ,
617+ )
618+ for input in inputs_tuple
619+ ]
620+ result = (
621+ initial_eval_processed ,
622+ prev_result ,
623+ output_shape ,
624+ total_attrib ,
625+ agg_output_mode ,
626+ )
627+ except ShapleyValueFutureError as e :
628+ raise ShapleyValueFutureError (
629+ "_initial_eval_to_prev_results_tuple func failed"
630+ ) from e
631+ return result
632+
633+ def _set_prev_results_to_initial_eval (
634+ self ,
635+ processed_initial_eval : Future [Tuple [Tensor , Tensor , Size , List [Tensor ], bool ]],
636+ ) -> Tuple [Tensor , Tensor , Size , List [Tensor ], bool ]:
637+ """At the beginning of each feature permutation, the prev_results is reset to the initial eval, and this method helps set that up"""
638+ (initial_eval , prev_results , output_shape , total_attrib , agg_output_mode ) = (
639+ processed_initial_eval .value ()
640+ )
641+ prev_results = initial_eval
642+ return (initial_eval , prev_results , output_shape , total_attrib , agg_output_mode )
643+
644+ def _eval_fut_to_prev_results_tuple (
645+ self ,
646+ eval_futs : Future [
647+ List [
648+ Union [
649+ Future [Tuple [Tensor , Tensor , Size , List [Tensor ], bool ]],
650+ Future [Tensor ],
651+ ]
652+ ]
653+ ],
654+ num_examples : int ,
655+ inputs_tuple : Tuple [Tensor , ...],
656+ current_masks : Tuple [Tensor , ...],
657+ ) -> Tuple [Tensor , Tensor , Size , List [Tensor ], bool ]:
658+ """Helper method responsible for calculating eval differences between the modified eval and prev_results Tensor and storing them in total_attrib. Returns prev_results_tuple with modified total_attrib and prev_results"""
659+ prev_results_tuple = eval_futs .value ()[0 ].value ()
660+ modified_eval = eval_futs .value ()[1 ].value ()
661+ if not isinstance (modified_eval , Tensor ) or not isinstance (
662+ prev_results_tuple , tuple
663+ ):
664+ raise ShapleyValueFutureError (
665+ "_eval_fut_to_prev_results_tuple func failed due to type mismatch"
666+ )
667+ (
668+ initial_eval ,
669+ prev_results ,
670+ output_shape ,
671+ total_attrib ,
672+ agg_output_mode ,
673+ ) = prev_results_tuple
674+ if agg_output_mode :
675+ eval_diff = modified_eval - prev_results
676+ prev_results = modified_eval
677+ else :
678+ # when perturb_per_eval > 1, every num_examples stands for
679+ # one perturb. Since the perturbs are from a consecutive
680+ # perumuation, each diff of a perturb is its eval minus
681+ # the eval of the previous perturb
682+
683+ all_eval = torch .cat ((prev_results , modified_eval ), dim = 0 )
684+ eval_diff = all_eval [num_examples :] - all_eval [:- num_examples ]
685+ prev_results = all_eval [- num_examples :]
686+
687+ for j in range (len (total_attrib )):
688+ # format eval_diff to shape
689+ # (n_perturb, *output_shape, 1,.. 1)
690+ # where n_perturb may not be perturb_per_eval
691+ # Append n_input_feature dim of 1 to make the tensor
692+ # have the same dim as the mask tensor.
693+ formatted_eval_diff = eval_diff .reshape (
694+ (- 1 ,) + tuple (output_shape ) + (len (inputs_tuple [j ].shape ) - 1 ) * (1 ,)
695+ )
696+
697+ # mask in shape (n_perturb, *mask_shape_broadcastable_to_input)
698+ # reshape to
699+ # (
700+ # n_perturb,
701+ # *broadcastable_to_output_shape
702+ # *broadcastable_to_input_feature_shape
703+ # )
704+ cur_mask = current_masks [j ]
705+ cur_mask = cur_mask .reshape (
706+ tuple (cur_mask .shape [:2 ])
707+ + (len (output_shape ) - 1 ) * (1 ,)
708+ + tuple (cur_mask .shape [2 :])
709+ )
710+
711+ # aggregate n_perturb
712+ cur_attr = (formatted_eval_diff * cur_mask .float ()).sum (dim = 0 )
713+ # (*output_shape, *input_feature_shape)
714+ total_attrib [j ] += cur_attr
715+
716+ result = (
717+ initial_eval ,
718+ prev_results ,
719+ output_shape ,
720+ total_attrib ,
721+ agg_output_mode ,
722+ )
723+ return result
724+
725+ def _prev_result_tuple_to_formatted_attr (
726+ self ,
727+ prev_result_tuple : Future [
728+ Tuple [Tensor , Tensor , Tuple [int ], List [Tensor ], bool ]
729+ ],
730+ iter_count : int ,
731+ is_inputs_tuple : bool ,
732+ ) -> Union [Tensor , Tuple [Tensor , ...]]:
733+ """Helper method to format total_attrib, which is a list of tensors, into formatted attributions, which are either a single tensor or a tuple of tensors"""
734+
735+ (
736+ _ ,
737+ _ ,
738+ _ ,
739+ total_attrib ,
740+ _ ,
741+ ) = prev_result_tuple .value ()
742+ attrib = tuple (
743+ tensor_attrib_total / iter_count for tensor_attrib_total in total_attrib
744+ )
745+ formatted_attr = _format_output (is_inputs_tuple , attrib )
746+ return formatted_attr
463747
464748 def _perturbation_generator (
465749 self ,
@@ -574,6 +858,37 @@ def _strict_run_forward(self, *args, **kwargs) -> Tensor:
574858 # ref: https://github.com/pytorch/pytorch/pull/21215
575859 return torch .tensor ([forward_output ], dtype = cast (dtype , output_type ))
576860
861+ # pyre-fixme[2]: Parameter must be annotated.
862+ def _strict_run_forward_future (self , * args , ** kwargs ) -> Future [Tensor ]:
863+ """
864+ A temp wrapper for global _run_forward util to force forward output
865+ type assertion & conversion, but takes into account the Future tensor type
866+ """
867+
868+ def process_strict_run_forward (fut : Future [Tensor ]) -> Tensor :
869+ output = fut .value ()
870+ if isinstance (output , Tensor ):
871+ # format scalar to shape (1) so we can always assume non-empty output_shape
872+ if not output .shape :
873+ output = output .reshape (1 )
874+ return output
875+ output_type = type (output )
876+ assert output_type is int or output_type is float , (
877+ "the return of forward_func must be a Future of tensor, int, or float,"
878+ f" received: { output_type } "
879+ )
880+ output = torch .tensor ([output ], dtype = cast (dtype , output_type ))
881+ return output
882+
883+ forward_output = _run_forward (* args , ** kwargs )
884+ assert isinstance (forward_output , torch .Future ), (
885+ "The return type of forward_func must be a Future"
886+ f" received: { type (forward_output )} "
887+ )
888+
889+ return_output = forward_output .then (process_strict_run_forward )
890+ return return_output
891+
577892
578893class ShapleyValues (ShapleyValueSampling ):
579894 """
0 commit comments