diff --git a/captum/insights/api.py b/captum/insights/api.py
index 36c0da13db..7125334187 100644
--- a/captum/insights/api.py
+++ b/captum/insights/api.py
@@ -89,7 +89,7 @@ class FilterConfig(NamedTuple):
arg: config.value # type: ignore
for arg, config in ATTRIBUTION_METHOD_CONFIG[
IntegratedGradients.get_name()
- ].items()
+ ].params.items()
}
prediction: str = "all"
classes: List[str] = []
@@ -221,6 +221,12 @@ def _calculate_attribution(
attribution_cls = ATTRIBUTION_NAMES_TO_METHODS[self._config.attribution_method]
attribution_method = attribution_cls(net)
args = self._config.attribution_arguments
+ param_config = ATTRIBUTION_METHOD_CONFIG[self._config.attribution_method]
+ if param_config.post_process:
+ for k, v in args.items():
+ if k in param_config.post_process:
+ args[k] = param_config.post_process[k](v)
+
# TODO support multiple baselines
baseline = baselines[0] if baselines and len(baselines) > 0 else None
label = (
@@ -329,7 +335,9 @@ def _serve_colab(self, blocking=False, debug=False, port=None):
def _get_labels_from_scores(
self, scores: Tensor, indices: Tensor
) -> List[OutputScore]:
- pred_scores = []
+ pred_scores: List[OutputScore] = []
+ if indices.nelement() < 2:
+ return pred_scores
for i in range(len(indices)):
score = scores[i]
pred_scores.append(
@@ -542,6 +550,8 @@ def get_insights_config(self):
return {
"classes": self.classes,
"methods": list(ATTRIBUTION_NAMES_TO_METHODS.keys()),
- "method_arguments": namedtuple_to_dict(ATTRIBUTION_METHOD_CONFIG),
+ "method_arguments": namedtuple_to_dict(
+ {k: v.params for (k, v) in ATTRIBUTION_METHOD_CONFIG.items()}
+ ),
"selected_method": self._config.attribution_method,
}
diff --git a/captum/insights/config.py b/captum/insights/config.py
index 897835e8f6..d1abccb5a7 100644
--- a/captum/insights/config.py
+++ b/captum/insights/config.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-from typing import Dict, List, NamedTuple, Optional, Tuple
+from typing import Dict, List, NamedTuple, Optional, Tuple, Callable, Any, Union
from captum.attr import (
Deconvolution,
@@ -9,6 +9,7 @@
InputXGradient,
IntegratedGradients,
Saliency,
+ Occlusion,
)
from captum.attr._utils.approximation_methods import SUPPORTED_METHODS
@@ -25,6 +26,13 @@ class StrEnumConfig(NamedTuple):
type: str = "enum"
+class StrConfig(NamedTuple):
+ value: str
+ type: str = "string"
+
+
+Config = Union[NumberConfig, StrEnumConfig, StrConfig]
+
SUPPORTED_ATTRIBUTION_METHODS = [
Deconvolution,
DeepLift,
@@ -33,20 +41,50 @@ class StrEnumConfig(NamedTuple):
IntegratedGradients,
Saliency,
FeatureAblation,
+ Occlusion,
]
+
+class ConfigParameters(NamedTuple):
+ params: Dict[str, Config]
+ help_info: Optional[str] = None # TODO fill out help for each method
+ post_process: Optional[Dict[str, Callable[[Any], Any]]] = None
+
+
ATTRIBUTION_NAMES_TO_METHODS = {
# mypy bug - treating it as a type instead of a class
cls.get_name(): cls # type: ignore
for cls in SUPPORTED_ATTRIBUTION_METHODS
}
-ATTRIBUTION_METHOD_CONFIG: Dict[str, Dict[str, tuple]] = {
- IntegratedGradients.get_name(): {
- "n_steps": NumberConfig(value=25, limit=(2, None)),
- "method": StrEnumConfig(limit=SUPPORTED_METHODS, value="gausslegendre"),
- },
- FeatureAblation.get_name(): {
- "perturbations_per_eval": NumberConfig(value=1, limit=(1, 100)),
- },
+
+def _str_to_tuple(s):
+ if isinstance(s, tuple):
+ return s
+ return tuple([int(i) for i in s.split()])
+
+
+ATTRIBUTION_METHOD_CONFIG: Dict[str, ConfigParameters] = {
+ IntegratedGradients.get_name(): ConfigParameters(
+ params={
+ "n_steps": NumberConfig(value=25, limit=(2, None)),
+ "method": StrEnumConfig(limit=SUPPORTED_METHODS, value="gausslegendre"),
+ },
+ post_process={"n_steps": int},
+ ),
+ FeatureAblation.get_name(): ConfigParameters(
+ params={"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100))},
+ ),
+ Occlusion.get_name(): ConfigParameters(
+ params={
+ "sliding_window_shapes": StrConfig(value=""),
+ "strides": StrConfig(value=""),
+ "perturbations_per_eval": NumberConfig(value=1, limit=(1, 100)),
+ },
+ post_process={
+ "sliding_window_shapes": _str_to_tuple,
+ "strides": _str_to_tuple,
+ "perturbations_per_eval": int,
+ },
+ ),
}
diff --git a/captum/insights/frontend/src/App.js b/captum/insights/frontend/src/App.js
index 99a271b0af..461965ddf8 100644
--- a/captum/insights/frontend/src/App.js
+++ b/captum/insights/frontend/src/App.js
@@ -8,6 +8,7 @@ import "./App.css";
const ConfigType = Object.freeze({
Number: "number",
Enum: "enum",
+ String: "string",
});
const Plot = createPlotlyComponent(Plotly);
@@ -153,62 +154,71 @@ class FilterContainer extends React.Component {
}
}
-class ClassFilter extends React.Component {
- render() {
- return (
-