diff --git a/captum/_utils/models/linear_model/train.py b/captum/_utils/models/linear_model/train.py index 2ba70ea32f..37f1507c94 100644 --- a/captum/_utils/models/linear_model/train.py +++ b/captum/_utils/models/linear_model/train.py @@ -1,7 +1,9 @@ # pyre-strict import time import warnings -from typing import Any, Callable, Dict, List, Optional +from functools import reduce +from types import ModuleType +from typing import Any, Callable, cast, Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -17,6 +19,82 @@ def l2_loss(x1, x2, weights=None) -> torch.Tensor: return torch.sum((weights / weights.norm(p=1)) * ((x1 - x2) ** 2)) / 2.0 +class ConvergenceTracker: + def __init__(self, patience: int, threshold: float) -> None: + self.min_avg_loss: Optional[torch.Tensor] = None + self.convergence_counter: int = 0 + self.converged: bool = False + + self.threshold = threshold + self.patience = patience + + def update(self, average_loss: torch.Tensor) -> bool: + if self.min_avg_loss is not None: + # if we haven't improved by at least `threshold` + if average_loss > self.min_avg_loss or torch.isclose( + cast(torch.Tensor, self.min_avg_loss), average_loss, atol=self.threshold + ): + self.convergence_counter += 1 + if self.convergence_counter >= self.patience: + self.converged = True + return True + else: + self.convergence_counter = 0 + if self.min_avg_loss is None or self.min_avg_loss >= average_loss: + self.min_avg_loss = average_loss.clone() + return False + + +class LossWindow: + def __init__(self, window_size: int) -> None: + self.loss_window: List[torch.Tensor] = [] + self.window_size = window_size + + def append(self, loss: torch.Tensor) -> None: + if len(self.loss_window) >= self.window_size: + self.loss_window = self.loss_window[-self.window_size :] + self.loss_window.append(loss) + + def average(self) -> torch.Tensor: + return torch.mean(torch.stack(self.loss_window)) + + +def _init_linear_model(model: LinearModel, init_scheme: Optional[str] = None) -> None: + assert model.linear is not None + if init_scheme is not None: + assert init_scheme in ["xavier", "zeros"] + + with torch.no_grad(): + if init_scheme == "xavier": + # pyre-fixme[16]: `Optional` has no attribute `weight`. + torch.nn.init.xavier_uniform_(model.linear.weight) + else: + model.linear.weight.zero_() + + # pyre-fixme[16]: `Optional` has no attribute `bias`. + if model.linear.bias is not None: + model.linear.bias.zero_() + + +def _get_point( + datapoint: Tuple[torch.Tensor, ...], + device: Optional[str] = None, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + if len(datapoint) == 2: + x, y = datapoint + w = None + else: + x, y, w = datapoint + + if device is not None: + x = x.to(device) + y = y.to(device) + if w is not None: + w = w.to(device) + + return x, y, w + + def sgd_train_linear_model( model: LinearModel, dataloader: DataLoader, @@ -102,31 +180,16 @@ def sgd_train_linear_model( This will return the final training loss (averaged with `running_loss_window`) """ - loss_window: List[torch.Tensor] = [] - min_avg_loss = None - convergence_counter = 0 - converged = False - - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def get_point(datapoint): - if len(datapoint) == 2: - x, y = datapoint - w = None - else: - x, y, w = datapoint - - if device is not None: - x = x.to(device) - y = y.to(device) - if w is not None: - w = w.to(device) - - return x, y, w + converge_tracker = ConvergenceTracker(patience, threshold) # get a point and construct the model data_iter = iter(dataloader) - x, y, w = get_point(next(data_iter)) + x, y, w = _get_point(next(data_iter), device) + + if running_loss_window is None: + running_loss_window = x.shape[0] * len(dataloader) + + loss_window = LossWindow(running_loss_window) model._construct_model_params( in_features=x.shape[1], @@ -135,21 +198,8 @@ def get_point(datapoint): ) model.train() - assert model.linear is not None - - if init_scheme is not None: - assert init_scheme in ["xavier", "zeros"] - - with torch.no_grad(): - if init_scheme == "xavier": - # pyre-fixme[16]: `Optional` has no attribute `weight`. - torch.nn.init.xavier_uniform_(model.linear.weight) - else: - model.linear.weight.zero_() - - # pyre-fixme[16]: `Optional` has no attribute `bias`. - if model.linear.bias is not None: - model.linear.bias.zero_() + # Initialize linear model weights if applicable + _init_linear_model(model, init_scheme) with torch.enable_grad(): optim = torch.optim.SGD(model.parameters(), lr=initial_lr) @@ -163,9 +213,6 @@ def get_point(datapoint): i = 0 while epoch < max_epoch: while True: # for x, y, w in dataloader - if running_loss_window is None: - running_loss_window = x.shape[0] * len(dataloader) - y = y.view(x.shape[0], -1) if w is not None: w = w.view(x.shape[0], -1) @@ -176,33 +223,20 @@ def get_point(datapoint): loss = loss_fn(y, out, w) if reg_term is not None: - reg = torch.norm(model.linear.weight, p=reg_term) + # pyre-fixme[16]: `Optional` has no attribute `weight`. + reg = torch.norm(model.linear.weight, p=reg_term) # type: ignore loss += reg.sum() * alpha - if len(loss_window) >= running_loss_window: - loss_window = loss_window[1:] loss_window.append(loss.clone().detach()) - assert len(loss_window) <= running_loss_window - - average_loss = torch.mean(torch.stack(loss_window)) - if min_avg_loss is not None: - # if we haven't improved by at least `threshold` - if average_loss > min_avg_loss or torch.isclose( - min_avg_loss, average_loss, atol=threshold - ): - convergence_counter += 1 - if convergence_counter >= patience: - converged = True - break - else: - convergence_counter = 0 - if min_avg_loss is None or min_avg_loss >= average_loss: - min_avg_loss = average_loss.clone() + average_loss = loss_window.average() + if converge_tracker.update(average_loss): + break # converged if debug: print( - f"lr={optim.param_groups[0]['lr']}, Loss={loss}," - + "Aloss={average_loss}, min_avg_loss={min_avg_loss}" + f"lr={optim.param_groups[0]['lr']}, Loss={loss}, " + f"Aloss={average_loss}, " + f"min_avg_loss={converge_tracker.min_avg_loss}" ) loss.backward() @@ -215,19 +249,19 @@ def get_point(datapoint): temp = next(data_iter, None) if temp is None: break - x, y, w = get_point(temp) + x, y, w = _get_point(temp, device) - if converged: + if converge_tracker.converged: break epoch += 1 data_iter = iter(dataloader) - x, y, w = get_point(next(data_iter)) + x, y, w = _get_point(next(data_iter), device) t2 = time.time() return { "train_time": t2 - t1, - "train_loss": torch.mean(torch.stack(loss_window)).item(), + "train_loss": loss_window.average().item(), "train_iter": i, "train_epoch": epoch, } @@ -250,14 +284,38 @@ def forward(self, x): return (x - self.mean) / (self.std + self.eps) +def _import_sklearn() -> ModuleType: + try: + import sklearn + import sklearn.linear_model + import sklearn.svm + except ImportError: + raise ValueError("sklearn is not available. Please install sklearn >= 0.23") + + if not sklearn.__version__ >= "0.23.0": + warnings.warn( + "Must have sklearn version 0.23.0 or higher to use " + "sample_weight in Lasso regression.", + stacklevel=1, + ) + return sklearn + + +def _import_numpy() -> ModuleType: + try: + import numpy + except ImportError: + raise ValueError("numpy is not available. Please install numpy.") + return numpy + + def sklearn_train_linear_model( model: LinearModel, dataloader: DataLoader, construct_kwargs: Dict[str, Any], sklearn_trainer: str = "Lasso", norm_input: bool = False, - # pyre-fixme[2]: Parameter must be annotated. - **fit_kwargs, + **fit_kwargs: Any, ) -> Dict[str, float]: r""" Alternative method to train with sklearn. This does introduce some slight @@ -286,25 +344,9 @@ def sklearn_train_linear_model( fit_kwargs Other arguments to send to `sklearn_trainer`'s `.fit` method """ - from functools import reduce - - try: - import numpy as np - except ImportError: - raise ValueError("numpy is not available. Please install numpy.") - - try: - import sklearn - import sklearn.linear_model - import sklearn.svm - except ImportError: - raise ValueError("sklearn is not available. Please install sklearn >= 0.23") - - if not sklearn.__version__ >= "0.23.0": - warnings.warn( - "Must have sklearn version 0.23.0 or higher to use " - "sample_weight in Lasso regression." - ) + # Lazy imports + np = _import_numpy() + sklearn = _import_sklearn() num_batches = 0 xs, ys, ws = [], [], [] @@ -336,8 +378,8 @@ def sklearn_train_linear_model( t1 = time.time() # pyre-fixme[29]: `str` is not a function. - sklearn_model = reduce( - lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".") + sklearn_model = reduce( # type: ignore + lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".") # type: ignore # noqa: E501 )(**construct_kwargs) try: sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs) @@ -346,7 +388,8 @@ def sklearn_train_linear_model( warnings.warn( "Sample weight is not supported for the provided linear model!" " Trained model without weighting inputs. For Lasso, please" - " upgrade sklearn to a version >= 0.23.0." + " upgrade sklearn to a version >= 0.23.0.", + stacklevel=1, ) t2 = time.time() diff --git a/captum/attr/_core/lime.py b/captum/attr/_core/lime.py index dc8447d1b3..7911cce735 100644 --- a/captum/attr/_core/lime.py +++ b/captum/attr/_core/lime.py @@ -5,6 +5,7 @@ import math import typing import warnings +from collections.abc import Iterator from typing import Any, Callable, cast, List, Optional, Tuple, Union import torch @@ -243,6 +244,7 @@ def __init__( ), "Must provide transform from original input space to interpretable space" @log_usage() + @torch.no_grad() def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, @@ -422,125 +424,136 @@ def attribute( >>> # model. >>> attr_coefs = lime_attr.attribute(input, target=1, kernel_width=1.1) """ - with torch.no_grad(): - inp_tensor = ( - cast(Tensor, inputs) if isinstance(inputs, Tensor) else inputs[0] + inp_tensor = cast(Tensor, inputs) if isinstance(inputs, Tensor) else inputs[0] + device = inp_tensor.device + + interpretable_inps = [] + similarities = [] + outputs = [] + + curr_model_inputs = [] + expanded_additional_args = None + expanded_target = None + gen_perturb_func = self._get_perturb_generator_func(inputs, **kwargs) + + if show_progress: + attr_progress = progress( + total=math.ceil(n_samples / perturbations_per_eval), + desc=f"{self.get_name()} attribution", ) - device = inp_tensor.device - - interpretable_inps = [] - similarities = [] - outputs = [] - - curr_model_inputs = [] - expanded_additional_args = None - expanded_target = None - perturb_generator = None - if inspect.isgeneratorfunction(self.perturb_func): - perturb_generator = self.perturb_func(inputs, **kwargs) - - if show_progress: - attr_progress = progress( - total=math.ceil(n_samples / perturbations_per_eval), - desc=f"{self.get_name()} attribution", + attr_progress.update(0) + + batch_count = 0 + for _ in range(n_samples): + try: + interpretable_inp, curr_model_input = gen_perturb_func() + except StopIteration: + warnings.warn( + "Generator completed prior to given n_samples iterations!", + stacklevel=1, ) - attr_progress.update(0) - - batch_count = 0 - for _ in range(n_samples): - if perturb_generator: - try: - curr_sample = next(perturb_generator) - except StopIteration: - warnings.warn( - "Generator completed prior to given n_samples iterations!" - ) - break - else: - curr_sample = self.perturb_func(inputs, **kwargs) - batch_count += 1 - if self.perturb_interpretable_space: - interpretable_inps.append(curr_sample) - curr_model_inputs.append( - self.from_interp_rep_transform( # type: ignore - curr_sample, inputs, **kwargs - ) - ) - else: - curr_model_inputs.append(curr_sample) - interpretable_inps.append( - self.to_interp_rep_transform( # type: ignore - curr_sample, inputs, **kwargs - ) - ) - curr_sim = self.similarity_func( - inputs, curr_model_inputs[-1], interpretable_inps[-1], **kwargs - ) - similarities.append( - curr_sim.flatten() - if isinstance(curr_sim, Tensor) - else torch.tensor([curr_sim], device=device) - ) - - if len(curr_model_inputs) == perturbations_per_eval: - if expanded_additional_args is None: - expanded_additional_args = _expand_additional_forward_args( - additional_forward_args, len(curr_model_inputs) - ) - if expanded_target is None: - expanded_target = _expand_target(target, len(curr_model_inputs)) - - model_out = self._evaluate_batch( - curr_model_inputs, - expanded_target, - expanded_additional_args, - device, - ) - - if show_progress: - attr_progress.update() + break + batch_count += 1 + interpretable_inps.append(interpretable_inp) + curr_model_inputs.append(curr_model_input) - outputs.append(model_out) + curr_sim = self.similarity_func( + inputs, curr_model_input, interpretable_inp, **kwargs + ) + similarities.append( + curr_sim.flatten() + if isinstance(curr_sim, Tensor) + else torch.tensor([curr_sim], device=device) + ) - curr_model_inputs = [] + if len(curr_model_inputs) == perturbations_per_eval: + if expanded_additional_args is None: + expanded_additional_args = _expand_additional_forward_args( + additional_forward_args, len(curr_model_inputs) + ) + if expanded_target is None: + expanded_target = _expand_target(target, len(curr_model_inputs)) - if len(curr_model_inputs) > 0: - expanded_additional_args = _expand_additional_forward_args( - additional_forward_args, len(curr_model_inputs) - ) - expanded_target = _expand_target(target, len(curr_model_inputs)) model_out = self._evaluate_batch( curr_model_inputs, expanded_target, expanded_additional_args, device, ) + if show_progress: attr_progress.update() + outputs.append(model_out) - if show_progress: - attr_progress.close() - - # Argument 1 to "cat" has incompatible type - # "list[Tensor | tuple[Tensor, ...]]"; - # expected "tuple[Tensor, ...] | list[Tensor]" [arg-type] - combined_interp_inps = torch.cat(interpretable_inps).float() # type: ignore - combined_outputs = ( - torch.cat(outputs) - if len(outputs[0].shape) > 0 - else torch.stack(outputs) - ).float() - combined_sim = ( - torch.cat(similarities) - if len(similarities[0].shape) > 0 - else torch.stack(similarities) - ).float() - dataset = TensorDataset( - combined_interp_inps, combined_outputs, combined_sim + curr_model_inputs = [] + + if len(curr_model_inputs) > 0: + expanded_additional_args = _expand_additional_forward_args( + additional_forward_args, len(curr_model_inputs) + ) + expanded_target = _expand_target(target, len(curr_model_inputs)) + model_out = self._evaluate_batch( + curr_model_inputs, + expanded_target, + expanded_additional_args, + device, ) - self.interpretable_model.fit(DataLoader(dataset, batch_size=batch_count)) - return self.interpretable_model.representation() + if show_progress: + attr_progress.update() + outputs.append(model_out) + + if show_progress: + attr_progress.close() + + # Argument 1 to "cat" has incompatible type + # "list[Tensor | tuple[Tensor, ...]]"; + # expected "tuple[Tensor, ...] | list[Tensor]" [arg-type] + combined_interp_inps = torch.cat(interpretable_inps).float() # type: ignore + combined_outputs = ( + torch.cat(outputs) if len(outputs[0].shape) > 0 else torch.stack(outputs) + ).float() + combined_sim = ( + torch.cat(similarities) + if len(similarities[0].shape) > 0 + else torch.stack(similarities) + ).float() + dataset = TensorDataset(combined_interp_inps, combined_outputs, combined_sim) + self.interpretable_model.fit(DataLoader(dataset, batch_size=batch_count)) + return self.interpretable_model.representation() + + def _get_perturb_generator_func( + self, inputs: TensorOrTupleOfTensorsGeneric, **kwargs: Any + ) -> Callable[ + [], Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric] + ]: + perturb_generator: Optional[Iterator[TensorOrTupleOfTensorsGeneric]] + perturb_generator = None + if inspect.isgeneratorfunction(self.perturb_func): + perturb_generator = self.perturb_func(inputs, **kwargs) + + def generate_perturbation() -> ( + Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric] + ): + if perturb_generator: + curr_sample = next(perturb_generator) + else: + curr_sample = self.perturb_func(inputs, **kwargs) + + if self.perturb_interpretable_space: + interpretable_inp = curr_sample + curr_model_input = self.from_interp_rep_transform( # type: ignore + curr_sample, inputs, **kwargs + ) + else: + curr_model_input = curr_sample + interpretable_inp = self.to_interp_rep_transform( # type: ignore + curr_sample, inputs, **kwargs + ) + + return interpretable_inp, curr_model_input + + return generate_perturbation # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. def attribute_future(self) -> Callable: diff --git a/captum/metrics/_core/infidelity.py b/captum/metrics/_core/infidelity.py index 001fd18050..8963f168fc 100644 --- a/captum/metrics/_core/infidelity.py +++ b/captum/metrics/_core/infidelity.py @@ -408,88 +408,175 @@ def infidelity( >>> # Computes infidelity score for saliency maps >>> infid = infidelity(net, perturb_fn, input, attribution) """ + # perform argument formattings + inputs = _format_tensor_into_tuples(inputs) # type: ignore + if baselines is not None: + baselines = _format_baseline(baselines, cast(Tuple[Tensor, ...], inputs)) + additional_forward_args = _format_additional_forward_args(additional_forward_args) + attributions = _format_tensor_into_tuples(attributions) # type: ignore - def _generate_perturbations( - current_n_perturb_samples: int, - ) -> Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]: - r""" - The perturbations are generated for each example - `current_n_perturb_samples` times. + # Make sure that inputs and corresponding attributions have matching sizes. + assert len(inputs) == len(attributions), ( + """The number of tensors in the inputs and + attributions must match. Found number of tensors in the inputs is: {} and in the + attributions: {}""" + ).format(len(inputs), len(attributions)) + for inp, attr in zip(inputs, attributions): + assert inp.shape == attr.shape, ( + """Inputs and attributions must have + matching shapes. One of the input tensor's shape is {} and the + attribution tensor's shape is: {}""" + # pyre-fixme[16]: Module `attr` has no attribute `shape`. + ).format(inp.shape, attr.shape) - For performance reasons we are not calling `perturb_func` on each example but - on a batch that contains `current_n_perturb_samples` - repeated instances per example. - """ + bsz = inputs[0].size(0) - # pyre-fixme[3]: Return type must be annotated. - def call_perturb_func(): - r""" """ - baselines_pert = None - inputs_pert: Union[Tensor, Tuple[Tensor, ...]] - if len(inputs_expanded) == 1: - inputs_pert = inputs_expanded[0] - if baselines_expanded is not None: - # pyre-fixme[24]: Generic type `tuple` expects at least 1 type - # parameter. - baselines_pert = cast(Tuple, baselines_expanded)[0] - else: - inputs_pert = inputs_expanded - baselines_pert = baselines_expanded - return ( - perturb_func(inputs_pert, baselines_pert) - if baselines_pert is not None - else perturb_func(inputs_pert) - ) + _next_infidelity_tensors = _make_next_infidelity_tensors_func( + forward_func, + bsz, + perturb_func, + inputs, + baselines, + attributions, + additional_forward_args, + target, + normalize, + ) + + with torch.no_grad(): + # if not normalize, directly return aggrgated MSE ((a-b)^2,) + # else return aggregated MSE's polynomial expansion tensors (a^2, ab, b^2) + agg_tensors = _divide_and_aggregate_metrics( + cast(Tuple[Tensor, ...], inputs), + n_perturb_samples, + _next_infidelity_tensors, + agg_func=_sum_infidelity_tensors, + max_examples_per_batch=max_examples_per_batch, + ) + + if normalize: + beta_num = agg_tensors[1] + beta_denorm = agg_tensors[0] - inputs_expanded = tuple( - torch.repeat_interleave(input, current_n_perturb_samples, dim=0) - for input in inputs + beta = safe_div(beta_num, beta_denorm) + + infidelity_values = ( + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + beta**2 * agg_tensors[0] + - 2 * beta * agg_tensors[1] + + agg_tensors[2] ) + else: + infidelity_values = agg_tensors[0] - baselines_expanded = baselines - if baselines is not None: - baselines_expanded = tuple( - ( - baseline.repeat_interleave(current_n_perturb_samples, dim=0) - if isinstance(baseline, torch.Tensor) - and baseline.shape[0] == input.shape[0] - and baseline.shape[0] > 1 - else baseline - ) + infidelity_values /= n_perturb_samples + + return infidelity_values + + +def _generate_perturbations( + current_n_perturb_samples: int, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. + perturb_func: Callable, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType, +) -> Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]: + r""" + The perturbations are generated for each example + `current_n_perturb_samples` times. + + For performance reasons we are not calling `perturb_func` on each example but + on a batch that contains `current_n_perturb_samples` + repeated instances per example. + """ + + # pyre-fixme[3]: Return type must be annotated. + def call_perturb_func(): + r""" """ + baselines_pert = None + inputs_pert: Union[Tensor, Tuple[Tensor, ...]] + if len(inputs_expanded) == 1: + inputs_pert = inputs_expanded[0] + if baselines_expanded is not None: # pyre-fixme[24]: Generic type `tuple` expects at least 1 type # parameter. - for input, baseline in zip(inputs, cast(Tuple, baselines)) + baselines_pert = cast(Tuple, baselines_expanded)[0] + else: + inputs_pert = inputs_expanded + baselines_pert = baselines_expanded + return ( + perturb_func(inputs_pert, baselines_pert) + if baselines_pert is not None + else perturb_func(inputs_pert) + ) + + inputs_expanded = tuple( + torch.repeat_interleave(input, current_n_perturb_samples, dim=0) + for input in inputs + ) + + baselines_expanded = baselines + if baselines is not None: + baselines_expanded = tuple( + ( + baseline.repeat_interleave(current_n_perturb_samples, dim=0) + if isinstance(baseline, torch.Tensor) + and baseline.shape[0] == input.shape[0] + and baseline.shape[0] > 1 + else baseline ) + # pyre-fixme[24]: Generic type `tuple` expects at least 1 type + # parameter. + for input, baseline in zip(inputs, cast(Tuple, baselines)) + ) - return call_perturb_func() - - def _validate_inputs_and_perturbations( - inputs: Tuple[Tensor, ...], - inputs_perturbed: Tuple[Tensor, ...], - perturbations: Tuple[Tensor, ...], - ) -> None: - # asserts the sizes of the perturbations and inputs - assert len(perturbations) == len(inputs), ( - """The number of perturbed - inputs and corresponding perturbations must have the same number of - elements. Found number of inputs is: {} and perturbations: - {}""" - ).format(len(perturbations), len(inputs)) - - # asserts the shapes of the perturbations and perturbed inputs - for perturb, input_perturbed in zip(perturbations, inputs_perturbed): - assert perturb[0].shape == input_perturbed[0].shape, ( - """Perturbed input - and corresponding perturbation must have the same shape and - dimensionality. Found perturbation shape is: {} and the input shape - is: {}""" - ).format(perturb[0].shape, input_perturbed[0].shape) + return call_perturb_func() + + +def _validate_inputs_and_perturbations( + inputs: Tuple[Tensor, ...], + inputs_perturbed: Tuple[Tensor, ...], + perturbations: Tuple[Tensor, ...], +) -> None: + # asserts the sizes of the perturbations and inputs + assert len(perturbations) == len(inputs), ( + """The number of perturbed + inputs and corresponding perturbations must have the same number of + elements. Found number of inputs is: {} and perturbations: + {}""" + ).format(len(perturbations), len(inputs)) + + # asserts the shapes of the perturbations and perturbed inputs + for perturb, input_perturbed in zip(perturbations, inputs_perturbed): + assert perturb[0].shape == input_perturbed[0].shape, ( + """Perturbed input + and corresponding perturbation must have the same shape and + dimensionality. Found perturbation shape is: {} and the input shape + is: {}""" + ).format(perturb[0].shape, input_perturbed[0].shape) + + +def _make_next_infidelity_tensors_func( + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. + forward_func: Callable, + bsz: int, + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. + perturb_func: Callable, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType, + attributions: TensorOrTupleOfTensorsGeneric, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. + additional_forward_args: Any = None, + target: TargetType = None, + normalize: bool = False, +) -> Callable[[int], Union[Tuple[Tensor], Tuple[Tensor, Tensor, Tensor]]]: def _next_infidelity_tensors( current_n_perturb_samples: int, ) -> Union[Tuple[Tensor], Tuple[Tensor, Tensor, Tensor]]: perturbations, inputs_perturbed = _generate_perturbations( - current_n_perturb_samples + current_n_perturb_samples, perturb_func, inputs, baselines ) perturbations = _format_tensor_into_tuples(perturbations) @@ -564,60 +651,10 @@ def _next_infidelity_tensors( # returns (a-b)^2 if no need to normalize return ((attr_times_perturb_sums - perturbed_fwd_diffs).pow(2).sum(-1),) - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def _sum_infidelity_tensors(agg_tensors, tensors): - return tuple(agg_t + t for agg_t, t in zip(agg_tensors, tensors)) + return _next_infidelity_tensors - # perform argument formattings - inputs = _format_tensor_into_tuples(inputs) # type: ignore - if baselines is not None: - baselines = _format_baseline(baselines, cast(Tuple[Tensor, ...], inputs)) - additional_forward_args = _format_additional_forward_args(additional_forward_args) - attributions = _format_tensor_into_tuples(attributions) # type: ignore - # Make sure that inputs and corresponding attributions have matching sizes. - assert len(inputs) == len(attributions), ( - """The number of tensors in the inputs and - attributions must match. Found number of tensors in the inputs is: {} and in the - attributions: {}""" - ).format(len(inputs), len(attributions)) - for inp, attr in zip(inputs, attributions): - assert inp.shape == attr.shape, ( - """Inputs and attributions must have - matching shapes. One of the input tensor's shape is {} and the - attribution tensor's shape is: {}""" - # pyre-fixme[16]: Module `attr` has no attribute `shape`. - ).format(inp.shape, attr.shape) - - bsz = inputs[0].size(0) - with torch.no_grad(): - # if not normalize, directly return aggrgated MSE ((a-b)^2,) - # else return aggregated MSE's polynomial expansion tensors (a^2, ab, b^2) - agg_tensors = _divide_and_aggregate_metrics( - cast(Tuple[Tensor, ...], inputs), - n_perturb_samples, - _next_infidelity_tensors, - agg_func=_sum_infidelity_tensors, - max_examples_per_batch=max_examples_per_batch, - ) - - if normalize: - beta_num = agg_tensors[1] - beta_denorm = agg_tensors[0] - - beta = safe_div(beta_num, beta_denorm) - - infidelity_values = ( - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and - # `int`. - beta**2 * agg_tensors[0] - - 2 * beta * agg_tensors[1] - + agg_tensors[2] - ) - else: - infidelity_values = agg_tensors[0] - - infidelity_values /= n_perturb_samples - - return infidelity_values +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. +def _sum_infidelity_tensors(agg_tensors, tensors): + return tuple(agg_t + t for agg_t, t in zip(agg_tensors, tensors)) diff --git a/tests/attr/test_data_parallel.py b/tests/attr/test_data_parallel.py index bf89b9068a..2135e9e368 100644 --- a/tests/attr/test_data_parallel.py +++ b/tests/attr/test_data_parallel.py @@ -4,7 +4,7 @@ import copy import os from enum import Enum -from typing import Any, Callable, cast, Dict, Optional, Tuple, Type +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type import torch import torch.distributed as dist @@ -136,91 +136,22 @@ def data_parallel_test_assert(self) -> None: else: cuda_args[key] = args[key] - alt_device_ids = None cuda_model = copy.deepcopy(model).cuda() - # Initialize models based on DataParallelCompareMode - if mode is DataParallelCompareMode.cpu_cuda: - model_1, model_2 = model, cuda_model - args_1, args_2 = args, cuda_args - elif mode is DataParallelCompareMode.data_parallel_default: - model_1, model_2 = ( - cuda_model, - torch.nn.parallel.DataParallel(cuda_model), - ) - args_1, args_2 = cuda_args, cuda_args - elif mode is DataParallelCompareMode.data_parallel_alt_dev_ids: - alt_device_ids = [0] + [ - x for x in range(torch.cuda.device_count() - 1, 0, -1) - ] - model_1, model_2 = ( - cuda_model, - torch.nn.parallel.DataParallel( - cuda_model, device_ids=alt_device_ids - ), - ) - args_1, args_2 = cuda_args, cuda_args - elif mode is DataParallelCompareMode.dist_data_parallel: - - model_1, model_2 = ( - cuda_model, - torch.nn.parallel.DistributedDataParallel( - cuda_model, device_ids=[0], output_device=0 - ), - ) - args_1, args_2 = cuda_args, cuda_args - else: - raise AssertionError("DataParallel compare mode type is not valid.") - - attr_method_1: Attribution - attr_method_2: Attribution - if target_layer: - internal_algorithm = cast(Type[InternalAttribution], algorithm) - attr_method_1 = internal_algorithm( - model_1, get_target_layer(model_1, target_layer) - ) - # cuda_model is used to obtain target_layer since DataParallel - # adds additional wrapper. - # model_2 is always either the CUDA model itself or DataParallel - if alt_device_ids is None: - attr_method_2 = internal_algorithm( - model_2, get_target_layer(cuda_model, target_layer) - ) - else: - # LayerDeepLift and LayerDeepLiftShap do not take device ids - # as a parameter, since they must always have the DataParallel - # model object directly. - # Some neuron methods and GuidedGradCAM also require the - # model and cannot take a forward function. - if issubclass( - internal_algorithm, - ( - LayerDeepLift, - LayerDeepLiftShap, - LayerLRP, - NeuronDeepLift, - NeuronDeepLiftShap, - NeuronDeconvolution, - NeuronGuidedBackprop, - GuidedGradCam, - ), - ): - attr_method_2 = internal_algorithm( - model_2, - get_target_layer(cuda_model, target_layer), # type: ignore - ) - else: - attr_method_2 = internal_algorithm( - model_2.forward, - get_target_layer(cuda_model, target_layer), - device_ids=alt_device_ids, - ) - else: - attr_method_1 = algorithm(model_1) - attr_method_2 = algorithm(model_2) + # Set up test arguments based on DataParallelCompareMode + model_1, model_2, args_1, args_2, alt_device_ids = _get_dp_test_args( + cuda_model, model, cuda_args, args, mode + ) - if noise_tunnel: - attr_method_1 = NoiseTunnel(attr_method_1) - attr_method_2 = NoiseTunnel(attr_method_2) + # Construct attribution methods + attr_method_1, attr_method_2 = _get_dp_attr_methods( + algorithm, + target_layer, + model_1, + model_2, + cuda_model, + alt_device_ids, + noise_tunnel, + ) if attr_method_1.has_convergence_delta(): attributions_1, delta_1 = attr_method_1.attribute( return_convergence_delta=True, **args_1 @@ -266,6 +197,107 @@ def data_parallel_test_assert(self) -> None: return data_parallel_test_assert +def _get_dp_test_args( + cuda_model: Module, + model: Module, + cuda_args: Dict[str, Any], + args: Dict[str, Any], + mode: DataParallelCompareMode, +) -> Tuple[Module, Module, Dict[str, Any], Dict[str, Any], Optional[List[int]]]: + # Initialize models based on DataParallelCompareMode + alt_device_ids = None + if mode is DataParallelCompareMode.cpu_cuda: + model_1, model_2 = model, cuda_model + args_1, args_2 = args, cuda_args + elif mode is DataParallelCompareMode.data_parallel_default: + model_1, model_2 = ( + cuda_model, + torch.nn.parallel.DataParallel(cuda_model), + ) + args_1, args_2 = cuda_args, cuda_args + elif mode is DataParallelCompareMode.data_parallel_alt_dev_ids: + alt_device_ids = [0] + list(range(torch.cuda.device_count() - 1, 0, -1)) + model_1, model_2 = ( + cuda_model, + torch.nn.parallel.DataParallel(cuda_model, device_ids=alt_device_ids), + ) + args_1, args_2 = cuda_args, cuda_args + elif mode is DataParallelCompareMode.dist_data_parallel: + + model_1, model_2 = ( + cuda_model, + torch.nn.parallel.DistributedDataParallel( + cuda_model, device_ids=[0], output_device=0 + ), + ) + args_1, args_2 = cuda_args, cuda_args + else: + raise AssertionError("DataParallel compare mode type is not valid.") + + return model_1, model_2, args_1, args_2, alt_device_ids + + +def _get_dp_attr_methods( + algorithm: Type[Attribution], + target_layer: Optional[str], + model_1: Module, + model_2: Module, + cuda_model: Module, + alt_device_ids: Optional[List[int]], + noise_tunnel: bool, +) -> Tuple[Attribution, Attribution]: + attr_method_1: Attribution + attr_method_2: Attribution + if target_layer: + internal_algorithm = cast(Type[InternalAttribution], algorithm) + attr_method_1 = internal_algorithm( + model_1, get_target_layer(model_1, target_layer) + ) + # cuda_model is used to obtain target_layer since DataParallel + # adds additional wrapper. + # model_2 is always either the CUDA model itself or DataParallel + if alt_device_ids is None: + attr_method_2 = internal_algorithm( + model_2, get_target_layer(cuda_model, target_layer) + ) + else: + # LayerDeepLift and LayerDeepLiftShap do not take device ids + # as a parameter, since they must always have the DataParallel + # model object directly. + # Some neuron methods and GuidedGradCAM also require the + # model and cannot take a forward function. + if issubclass( + internal_algorithm, + ( + LayerDeepLift, + LayerDeepLiftShap, + LayerLRP, + NeuronDeepLift, + NeuronDeepLiftShap, + NeuronDeconvolution, + NeuronGuidedBackprop, + GuidedGradCam, + ), + ): + attr_method_2 = internal_algorithm( + model_2, + get_target_layer(cuda_model, target_layer), # type: ignore + ) + else: + attr_method_2 = internal_algorithm( + model_2.forward, + get_target_layer(cuda_model, target_layer), + device_ids=alt_device_ids, + ) + else: + attr_method_1 = algorithm(model_1) + attr_method_2 = algorithm(model_2) + if noise_tunnel: + attr_method_1 = NoiseTunnel(attr_method_1) + attr_method_2 = NoiseTunnel(attr_method_2) + return attr_method_1, attr_method_2 + + if torch.cuda.is_available() and torch.cuda.device_count() != 0: class DataParallelTest(BaseTest, metaclass=DataParallelMeta):