|
5 | 5 | import math |
6 | 6 | import typing |
7 | 7 | import warnings |
| 8 | +from collections.abc import Iterator |
8 | 9 | from typing import Any, Callable, cast, List, Optional, Tuple, Union |
9 | 10 |
|
10 | 11 | import torch |
@@ -243,6 +244,7 @@ def __init__( |
243 | 244 | ), "Must provide transform from original input space to interpretable space" |
244 | 245 |
|
245 | 246 | @log_usage() |
| 247 | + @torch.no_grad() |
246 | 248 | def attribute( |
247 | 249 | self, |
248 | 250 | inputs: TensorOrTupleOfTensorsGeneric, |
@@ -422,125 +424,136 @@ def attribute( |
422 | 424 | >>> # model. |
423 | 425 | >>> attr_coefs = lime_attr.attribute(input, target=1, kernel_width=1.1) |
424 | 426 | """ |
425 | | - with torch.no_grad(): |
426 | | - inp_tensor = ( |
427 | | - cast(Tensor, inputs) if isinstance(inputs, Tensor) else inputs[0] |
| 427 | + inp_tensor = cast(Tensor, inputs) if isinstance(inputs, Tensor) else inputs[0] |
| 428 | + device = inp_tensor.device |
| 429 | + |
| 430 | + interpretable_inps = [] |
| 431 | + similarities = [] |
| 432 | + outputs = [] |
| 433 | + |
| 434 | + curr_model_inputs = [] |
| 435 | + expanded_additional_args = None |
| 436 | + expanded_target = None |
| 437 | + gen_perturb_func = self._get_perturb_generator_func(inputs, **kwargs) |
| 438 | + |
| 439 | + if show_progress: |
| 440 | + attr_progress = progress( |
| 441 | + total=math.ceil(n_samples / perturbations_per_eval), |
| 442 | + desc=f"{self.get_name()} attribution", |
428 | 443 | ) |
429 | | - device = inp_tensor.device |
430 | | - |
431 | | - interpretable_inps = [] |
432 | | - similarities = [] |
433 | | - outputs = [] |
434 | | - |
435 | | - curr_model_inputs = [] |
436 | | - expanded_additional_args = None |
437 | | - expanded_target = None |
438 | | - perturb_generator = None |
439 | | - if inspect.isgeneratorfunction(self.perturb_func): |
440 | | - perturb_generator = self.perturb_func(inputs, **kwargs) |
441 | | - |
442 | | - if show_progress: |
443 | | - attr_progress = progress( |
444 | | - total=math.ceil(n_samples / perturbations_per_eval), |
445 | | - desc=f"{self.get_name()} attribution", |
| 444 | + attr_progress.update(0) |
| 445 | + |
| 446 | + batch_count = 0 |
| 447 | + for _ in range(n_samples): |
| 448 | + try: |
| 449 | + interpretable_inp, curr_model_input = gen_perturb_func() |
| 450 | + except StopIteration: |
| 451 | + warnings.warn( |
| 452 | + "Generator completed prior to given n_samples iterations!", |
| 453 | + stacklevel=1, |
446 | 454 | ) |
447 | | - attr_progress.update(0) |
448 | | - |
449 | | - batch_count = 0 |
450 | | - for _ in range(n_samples): |
451 | | - if perturb_generator: |
452 | | - try: |
453 | | - curr_sample = next(perturb_generator) |
454 | | - except StopIteration: |
455 | | - warnings.warn( |
456 | | - "Generator completed prior to given n_samples iterations!" |
457 | | - ) |
458 | | - break |
459 | | - else: |
460 | | - curr_sample = self.perturb_func(inputs, **kwargs) |
461 | | - batch_count += 1 |
462 | | - if self.perturb_interpretable_space: |
463 | | - interpretable_inps.append(curr_sample) |
464 | | - curr_model_inputs.append( |
465 | | - self.from_interp_rep_transform( # type: ignore |
466 | | - curr_sample, inputs, **kwargs |
467 | | - ) |
468 | | - ) |
469 | | - else: |
470 | | - curr_model_inputs.append(curr_sample) |
471 | | - interpretable_inps.append( |
472 | | - self.to_interp_rep_transform( # type: ignore |
473 | | - curr_sample, inputs, **kwargs |
474 | | - ) |
475 | | - ) |
476 | | - curr_sim = self.similarity_func( |
477 | | - inputs, curr_model_inputs[-1], interpretable_inps[-1], **kwargs |
478 | | - ) |
479 | | - similarities.append( |
480 | | - curr_sim.flatten() |
481 | | - if isinstance(curr_sim, Tensor) |
482 | | - else torch.tensor([curr_sim], device=device) |
483 | | - ) |
484 | | - |
485 | | - if len(curr_model_inputs) == perturbations_per_eval: |
486 | | - if expanded_additional_args is None: |
487 | | - expanded_additional_args = _expand_additional_forward_args( |
488 | | - additional_forward_args, len(curr_model_inputs) |
489 | | - ) |
490 | | - if expanded_target is None: |
491 | | - expanded_target = _expand_target(target, len(curr_model_inputs)) |
492 | | - |
493 | | - model_out = self._evaluate_batch( |
494 | | - curr_model_inputs, |
495 | | - expanded_target, |
496 | | - expanded_additional_args, |
497 | | - device, |
498 | | - ) |
499 | | - |
500 | | - if show_progress: |
501 | | - attr_progress.update() |
| 455 | + break |
| 456 | + batch_count += 1 |
| 457 | + interpretable_inps.append(interpretable_inp) |
| 458 | + curr_model_inputs.append(curr_model_input) |
502 | 459 |
|
503 | | - outputs.append(model_out) |
| 460 | + curr_sim = self.similarity_func( |
| 461 | + inputs, curr_model_input, interpretable_inp, **kwargs |
| 462 | + ) |
| 463 | + similarities.append( |
| 464 | + curr_sim.flatten() |
| 465 | + if isinstance(curr_sim, Tensor) |
| 466 | + else torch.tensor([curr_sim], device=device) |
| 467 | + ) |
504 | 468 |
|
505 | | - curr_model_inputs = [] |
| 469 | + if len(curr_model_inputs) == perturbations_per_eval: |
| 470 | + if expanded_additional_args is None: |
| 471 | + expanded_additional_args = _expand_additional_forward_args( |
| 472 | + additional_forward_args, len(curr_model_inputs) |
| 473 | + ) |
| 474 | + if expanded_target is None: |
| 475 | + expanded_target = _expand_target(target, len(curr_model_inputs)) |
506 | 476 |
|
507 | | - if len(curr_model_inputs) > 0: |
508 | | - expanded_additional_args = _expand_additional_forward_args( |
509 | | - additional_forward_args, len(curr_model_inputs) |
510 | | - ) |
511 | | - expanded_target = _expand_target(target, len(curr_model_inputs)) |
512 | 477 | model_out = self._evaluate_batch( |
513 | 478 | curr_model_inputs, |
514 | 479 | expanded_target, |
515 | 480 | expanded_additional_args, |
516 | 481 | device, |
517 | 482 | ) |
| 483 | + |
518 | 484 | if show_progress: |
519 | 485 | attr_progress.update() |
| 486 | + |
520 | 487 | outputs.append(model_out) |
521 | 488 |
|
522 | | - if show_progress: |
523 | | - attr_progress.close() |
524 | | - |
525 | | - # Argument 1 to "cat" has incompatible type |
526 | | - # "list[Tensor | tuple[Tensor, ...]]"; |
527 | | - # expected "tuple[Tensor, ...] | list[Tensor]" [arg-type] |
528 | | - combined_interp_inps = torch.cat(interpretable_inps).float() # type: ignore |
529 | | - combined_outputs = ( |
530 | | - torch.cat(outputs) |
531 | | - if len(outputs[0].shape) > 0 |
532 | | - else torch.stack(outputs) |
533 | | - ).float() |
534 | | - combined_sim = ( |
535 | | - torch.cat(similarities) |
536 | | - if len(similarities[0].shape) > 0 |
537 | | - else torch.stack(similarities) |
538 | | - ).float() |
539 | | - dataset = TensorDataset( |
540 | | - combined_interp_inps, combined_outputs, combined_sim |
| 489 | + curr_model_inputs = [] |
| 490 | + |
| 491 | + if len(curr_model_inputs) > 0: |
| 492 | + expanded_additional_args = _expand_additional_forward_args( |
| 493 | + additional_forward_args, len(curr_model_inputs) |
| 494 | + ) |
| 495 | + expanded_target = _expand_target(target, len(curr_model_inputs)) |
| 496 | + model_out = self._evaluate_batch( |
| 497 | + curr_model_inputs, |
| 498 | + expanded_target, |
| 499 | + expanded_additional_args, |
| 500 | + device, |
541 | 501 | ) |
542 | | - self.interpretable_model.fit(DataLoader(dataset, batch_size=batch_count)) |
543 | | - return self.interpretable_model.representation() |
| 502 | + if show_progress: |
| 503 | + attr_progress.update() |
| 504 | + outputs.append(model_out) |
| 505 | + |
| 506 | + if show_progress: |
| 507 | + attr_progress.close() |
| 508 | + |
| 509 | + # Argument 1 to "cat" has incompatible type |
| 510 | + # "list[Tensor | tuple[Tensor, ...]]"; |
| 511 | + # expected "tuple[Tensor, ...] | list[Tensor]" [arg-type] |
| 512 | + combined_interp_inps = torch.cat(interpretable_inps).float() # type: ignore |
| 513 | + combined_outputs = ( |
| 514 | + torch.cat(outputs) if len(outputs[0].shape) > 0 else torch.stack(outputs) |
| 515 | + ).float() |
| 516 | + combined_sim = ( |
| 517 | + torch.cat(similarities) |
| 518 | + if len(similarities[0].shape) > 0 |
| 519 | + else torch.stack(similarities) |
| 520 | + ).float() |
| 521 | + dataset = TensorDataset(combined_interp_inps, combined_outputs, combined_sim) |
| 522 | + self.interpretable_model.fit(DataLoader(dataset, batch_size=batch_count)) |
| 523 | + return self.interpretable_model.representation() |
| 524 | + |
| 525 | + def _get_perturb_generator_func( |
| 526 | + self, inputs: TensorOrTupleOfTensorsGeneric, **kwargs: Any |
| 527 | + ) -> Callable[ |
| 528 | + [], Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric] |
| 529 | + ]: |
| 530 | + perturb_generator: Optional[Iterator[TensorOrTupleOfTensorsGeneric]] |
| 531 | + perturb_generator = None |
| 532 | + if inspect.isgeneratorfunction(self.perturb_func): |
| 533 | + perturb_generator = self.perturb_func(inputs, **kwargs) |
| 534 | + |
| 535 | + def generate_perturbation() -> ( |
| 536 | + Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric] |
| 537 | + ): |
| 538 | + if perturb_generator: |
| 539 | + curr_sample = next(perturb_generator) |
| 540 | + else: |
| 541 | + curr_sample = self.perturb_func(inputs, **kwargs) |
| 542 | + |
| 543 | + if self.perturb_interpretable_space: |
| 544 | + interpretable_inp = curr_sample |
| 545 | + curr_model_input = self.from_interp_rep_transform( # type: ignore |
| 546 | + curr_sample, inputs, **kwargs |
| 547 | + ) |
| 548 | + else: |
| 549 | + curr_model_input = curr_sample |
| 550 | + interpretable_inp = self.to_interp_rep_transform( # type: ignore |
| 551 | + curr_sample, inputs, **kwargs |
| 552 | + ) |
| 553 | + |
| 554 | + return interpretable_inp, curr_model_input |
| 555 | + |
| 556 | + return generate_perturbation |
544 | 557 |
|
545 | 558 | # pyre-fixme[24] Generic type `Callable` expects 2 type parameters. |
546 | 559 | def attribute_future(self) -> Callable: |
|
0 commit comments