Skip to content

Commit 94141d6

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Reduce complexity of 'LimeBase.attribute' (#1370)
Summary: Reduce complexity of 'LimeBase.attribute' Differential Revision: D64372053
1 parent b9917aa commit 94141d6

File tree

1 file changed

+116
-103
lines changed

1 file changed

+116
-103
lines changed

captum/attr/_core/lime.py

Lines changed: 116 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import math
66
import typing
77
import warnings
8+
from collections.abc import Iterator
89
from typing import Any, Callable, cast, List, Optional, Tuple, Union
910

1011
import torch
@@ -243,6 +244,7 @@ def __init__(
243244
), "Must provide transform from original input space to interpretable space"
244245

245246
@log_usage()
247+
@torch.no_grad()
246248
def attribute(
247249
self,
248250
inputs: TensorOrTupleOfTensorsGeneric,
@@ -422,125 +424,136 @@ def attribute(
422424
>>> # model.
423425
>>> attr_coefs = lime_attr.attribute(input, target=1, kernel_width=1.1)
424426
"""
425-
with torch.no_grad():
426-
inp_tensor = (
427-
cast(Tensor, inputs) if isinstance(inputs, Tensor) else inputs[0]
427+
inp_tensor = cast(Tensor, inputs) if isinstance(inputs, Tensor) else inputs[0]
428+
device = inp_tensor.device
429+
430+
interpretable_inps = []
431+
similarities = []
432+
outputs = []
433+
434+
curr_model_inputs = []
435+
expanded_additional_args = None
436+
expanded_target = None
437+
gen_perturb_func = self._get_perturb_generator_func(inputs, **kwargs)
438+
439+
if show_progress:
440+
attr_progress = progress(
441+
total=math.ceil(n_samples / perturbations_per_eval),
442+
desc=f"{self.get_name()} attribution",
428443
)
429-
device = inp_tensor.device
430-
431-
interpretable_inps = []
432-
similarities = []
433-
outputs = []
434-
435-
curr_model_inputs = []
436-
expanded_additional_args = None
437-
expanded_target = None
438-
perturb_generator = None
439-
if inspect.isgeneratorfunction(self.perturb_func):
440-
perturb_generator = self.perturb_func(inputs, **kwargs)
441-
442-
if show_progress:
443-
attr_progress = progress(
444-
total=math.ceil(n_samples / perturbations_per_eval),
445-
desc=f"{self.get_name()} attribution",
444+
attr_progress.update(0)
445+
446+
batch_count = 0
447+
for _ in range(n_samples):
448+
try:
449+
interpretable_inp, curr_model_input = gen_perturb_func()
450+
except StopIteration:
451+
warnings.warn(
452+
"Generator completed prior to given n_samples iterations!",
453+
stacklevel=1,
446454
)
447-
attr_progress.update(0)
448-
449-
batch_count = 0
450-
for _ in range(n_samples):
451-
if perturb_generator:
452-
try:
453-
curr_sample = next(perturb_generator)
454-
except StopIteration:
455-
warnings.warn(
456-
"Generator completed prior to given n_samples iterations!"
457-
)
458-
break
459-
else:
460-
curr_sample = self.perturb_func(inputs, **kwargs)
461-
batch_count += 1
462-
if self.perturb_interpretable_space:
463-
interpretable_inps.append(curr_sample)
464-
curr_model_inputs.append(
465-
self.from_interp_rep_transform( # type: ignore
466-
curr_sample, inputs, **kwargs
467-
)
468-
)
469-
else:
470-
curr_model_inputs.append(curr_sample)
471-
interpretable_inps.append(
472-
self.to_interp_rep_transform( # type: ignore
473-
curr_sample, inputs, **kwargs
474-
)
475-
)
476-
curr_sim = self.similarity_func(
477-
inputs, curr_model_inputs[-1], interpretable_inps[-1], **kwargs
478-
)
479-
similarities.append(
480-
curr_sim.flatten()
481-
if isinstance(curr_sim, Tensor)
482-
else torch.tensor([curr_sim], device=device)
483-
)
484-
485-
if len(curr_model_inputs) == perturbations_per_eval:
486-
if expanded_additional_args is None:
487-
expanded_additional_args = _expand_additional_forward_args(
488-
additional_forward_args, len(curr_model_inputs)
489-
)
490-
if expanded_target is None:
491-
expanded_target = _expand_target(target, len(curr_model_inputs))
492-
493-
model_out = self._evaluate_batch(
494-
curr_model_inputs,
495-
expanded_target,
496-
expanded_additional_args,
497-
device,
498-
)
499-
500-
if show_progress:
501-
attr_progress.update()
455+
break
456+
batch_count += 1
457+
interpretable_inps.append(interpretable_inp)
458+
curr_model_inputs.append(curr_model_input)
502459

503-
outputs.append(model_out)
460+
curr_sim = self.similarity_func(
461+
inputs, curr_model_input, interpretable_inp, **kwargs
462+
)
463+
similarities.append(
464+
curr_sim.flatten()
465+
if isinstance(curr_sim, Tensor)
466+
else torch.tensor([curr_sim], device=device)
467+
)
504468

505-
curr_model_inputs = []
469+
if len(curr_model_inputs) == perturbations_per_eval:
470+
if expanded_additional_args is None:
471+
expanded_additional_args = _expand_additional_forward_args(
472+
additional_forward_args, len(curr_model_inputs)
473+
)
474+
if expanded_target is None:
475+
expanded_target = _expand_target(target, len(curr_model_inputs))
506476

507-
if len(curr_model_inputs) > 0:
508-
expanded_additional_args = _expand_additional_forward_args(
509-
additional_forward_args, len(curr_model_inputs)
510-
)
511-
expanded_target = _expand_target(target, len(curr_model_inputs))
512477
model_out = self._evaluate_batch(
513478
curr_model_inputs,
514479
expanded_target,
515480
expanded_additional_args,
516481
device,
517482
)
483+
518484
if show_progress:
519485
attr_progress.update()
486+
520487
outputs.append(model_out)
521488

522-
if show_progress:
523-
attr_progress.close()
524-
525-
# Argument 1 to "cat" has incompatible type
526-
# "list[Tensor | tuple[Tensor, ...]]";
527-
# expected "tuple[Tensor, ...] | list[Tensor]" [arg-type]
528-
combined_interp_inps = torch.cat(interpretable_inps).float() # type: ignore
529-
combined_outputs = (
530-
torch.cat(outputs)
531-
if len(outputs[0].shape) > 0
532-
else torch.stack(outputs)
533-
).float()
534-
combined_sim = (
535-
torch.cat(similarities)
536-
if len(similarities[0].shape) > 0
537-
else torch.stack(similarities)
538-
).float()
539-
dataset = TensorDataset(
540-
combined_interp_inps, combined_outputs, combined_sim
489+
curr_model_inputs = []
490+
491+
if len(curr_model_inputs) > 0:
492+
expanded_additional_args = _expand_additional_forward_args(
493+
additional_forward_args, len(curr_model_inputs)
494+
)
495+
expanded_target = _expand_target(target, len(curr_model_inputs))
496+
model_out = self._evaluate_batch(
497+
curr_model_inputs,
498+
expanded_target,
499+
expanded_additional_args,
500+
device,
541501
)
542-
self.interpretable_model.fit(DataLoader(dataset, batch_size=batch_count))
543-
return self.interpretable_model.representation()
502+
if show_progress:
503+
attr_progress.update()
504+
outputs.append(model_out)
505+
506+
if show_progress:
507+
attr_progress.close()
508+
509+
# Argument 1 to "cat" has incompatible type
510+
# "list[Tensor | tuple[Tensor, ...]]";
511+
# expected "tuple[Tensor, ...] | list[Tensor]" [arg-type]
512+
combined_interp_inps = torch.cat(interpretable_inps).float() # type: ignore
513+
combined_outputs = (
514+
torch.cat(outputs) if len(outputs[0].shape) > 0 else torch.stack(outputs)
515+
).float()
516+
combined_sim = (
517+
torch.cat(similarities)
518+
if len(similarities[0].shape) > 0
519+
else torch.stack(similarities)
520+
).float()
521+
dataset = TensorDataset(combined_interp_inps, combined_outputs, combined_sim)
522+
self.interpretable_model.fit(DataLoader(dataset, batch_size=batch_count))
523+
return self.interpretable_model.representation()
524+
525+
def _get_perturb_generator_func(
526+
self, inputs: TensorOrTupleOfTensorsGeneric, **kwargs: Any
527+
) -> Callable[
528+
[], Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]
529+
]:
530+
perturb_generator: Optional[Iterator[TensorOrTupleOfTensorsGeneric]]
531+
perturb_generator = None
532+
if inspect.isgeneratorfunction(self.perturb_func):
533+
perturb_generator = self.perturb_func(inputs, **kwargs)
534+
535+
def generate_perturbation() -> (
536+
Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]
537+
):
538+
if perturb_generator:
539+
curr_sample = next(perturb_generator)
540+
else:
541+
curr_sample = self.perturb_func(inputs, **kwargs)
542+
543+
if self.perturb_interpretable_space:
544+
interpretable_inp = curr_sample
545+
curr_model_input = self.from_interp_rep_transform( # type: ignore
546+
curr_sample, inputs, **kwargs
547+
)
548+
else:
549+
curr_model_input = curr_sample
550+
interpretable_inp = self.to_interp_rep_transform( # type: ignore
551+
curr_sample, inputs, **kwargs
552+
)
553+
554+
return interpretable_inp, curr_model_input
555+
556+
return generate_perturbation
544557

545558
# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
546559
def attribute_future(self) -> Callable:

0 commit comments

Comments
 (0)