Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions captum/insights/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class FilterConfig(NamedTuple):
arg: config.value # type: ignore
for arg, config in ATTRIBUTION_METHOD_CONFIG[
IntegratedGradients.get_name()
].items()
].params.items()
}
prediction: str = "all"
classes: List[str] = []
Expand Down Expand Up @@ -221,6 +221,12 @@ def _calculate_attribution(
attribution_cls = ATTRIBUTION_NAMES_TO_METHODS[self._config.attribution_method]
attribution_method = attribution_cls(net)
args = self._config.attribution_arguments
param_config = ATTRIBUTION_METHOD_CONFIG[self._config.attribution_method]
if param_config.post_process:
for k, v in args.items():
if k in param_config.post_process:
args[k] = param_config.post_process[k](v)

# TODO support multiple baselines
baseline = baselines[0] if baselines and len(baselines) > 0 else None
label = (
Expand Down Expand Up @@ -329,7 +335,9 @@ def _serve_colab(self, blocking=False, debug=False, port=None):
def _get_labels_from_scores(
self, scores: Tensor, indices: Tensor
) -> List[OutputScore]:
pred_scores = []
pred_scores: List[OutputScore] = []
if indices.nelement() < 2:
return pred_scores
for i in range(len(indices)):
score = scores[i]
pred_scores.append(
Expand Down Expand Up @@ -542,6 +550,8 @@ def get_insights_config(self):
return {
"classes": self.classes,
"methods": list(ATTRIBUTION_NAMES_TO_METHODS.keys()),
"method_arguments": namedtuple_to_dict(ATTRIBUTION_METHOD_CONFIG),
"method_arguments": namedtuple_to_dict(
{k: v.params for (k, v) in ATTRIBUTION_METHOD_CONFIG.items()}
),
"selected_method": self._config.attribution_method,
}
56 changes: 47 additions & 9 deletions captum/insights/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
from typing import Dict, List, NamedTuple, Optional, Tuple
from typing import Dict, List, NamedTuple, Optional, Tuple, Callable, Any, Union

from captum.attr import (
Deconvolution,
Expand All @@ -9,6 +9,7 @@
InputXGradient,
IntegratedGradients,
Saliency,
Occlusion,
)
from captum.attr._utils.approximation_methods import SUPPORTED_METHODS

Expand All @@ -25,6 +26,13 @@ class StrEnumConfig(NamedTuple):
type: str = "enum"


class StrConfig(NamedTuple):
value: str
type: str = "string"


Config = Union[NumberConfig, StrEnumConfig, StrConfig]

SUPPORTED_ATTRIBUTION_METHODS = [
Deconvolution,
DeepLift,
Expand All @@ -33,20 +41,50 @@ class StrEnumConfig(NamedTuple):
IntegratedGradients,
Saliency,
FeatureAblation,
Occlusion,
]


class ConfigParameters(NamedTuple):
params: Dict[str, Config]
help_info: Optional[str] = None # TODO fill out help for each method
post_process: Optional[Dict[str, Callable[[Any], Any]]] = None


ATTRIBUTION_NAMES_TO_METHODS = {
# mypy bug - treating it as a type instead of a class
cls.get_name(): cls # type: ignore
for cls in SUPPORTED_ATTRIBUTION_METHODS
}

ATTRIBUTION_METHOD_CONFIG: Dict[str, Dict[str, tuple]] = {
IntegratedGradients.get_name(): {
"n_steps": NumberConfig(value=25, limit=(2, None)),
"method": StrEnumConfig(limit=SUPPORTED_METHODS, value="gausslegendre"),
},
FeatureAblation.get_name(): {
"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100)),
},

def _str_to_tuple(s):
if isinstance(s, tuple):
return s
return tuple([int(i) for i in s.split()])


ATTRIBUTION_METHOD_CONFIG: Dict[str, ConfigParameters] = {
IntegratedGradients.get_name(): ConfigParameters(
params={
"n_steps": NumberConfig(value=25, limit=(2, None)),
"method": StrEnumConfig(limit=SUPPORTED_METHODS, value="gausslegendre"),
},
post_process={"n_steps": int},
),
FeatureAblation.get_name(): ConfigParameters(
params={"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100))},
),
Occlusion.get_name(): ConfigParameters(
params={
"sliding_window_shapes": StrConfig(value=""),
"strides": StrConfig(value=""),
"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100)),
},
post_process={
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these post-processing steps necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Occlusion takes in a tuple or an int e.g. (3, 5, 5), so we need a way to convert a string representation into the appropriate format.

"sliding_window_shapes": _str_to_tuple,
"strides": _str_to_tuple,
"perturbations_per_eval": int,
},
),
}
120 changes: 69 additions & 51 deletions captum/insights/frontend/src/App.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import "./App.css";
const ConfigType = Object.freeze({
Number: "number",
Enum: "enum",
String: "string",
});

const Plot = createPlotlyComponent(Plotly);
Expand Down Expand Up @@ -153,62 +154,71 @@ class FilterContainer extends React.Component {
}
}

class ClassFilter extends React.Component {
render() {
return (
<ReactTags
tags={this.props.classes}
autofocus={false}
suggestions={this.props.suggestedClasses}
handleDelete={this.props.handleClassDelete}
handleAddition={this.props.handleClassAdd}
minQueryLength={0}
placeholder="add new class..."
function ClassFilter(props) {
return (
<ReactTags
tags={props.classes}
autofocus={false}
suggestions={props.suggestedClasses}
handleDelete={props.handleClassDelete}
handleAddition={props.handleClassAdd}
minQueryLength={0}
placeholder="add new class..."
/>
);
}

function NumberArgument(props) {
var min = props.limit[0];
var max = props.limit[1];
return (
<div>
{props.name}:
<input
className={cx([styles.input, styles["input--narrow"]])}
name={props.name}
type="number"
value={props.value}
min={min}
max={max}
onChange={props.handleInputChange}
/>
);
}
</div>
);
}

class NumberArgument extends React.Component {
render() {
var min = this.props.limit[0];
var max = this.props.limit[1];
return (
<div>
{this.props.name + ": "}
<input
className={cx([styles.input, styles["input--narrow"]])}
name={this.props.name}
type="number"
value={this.props.value}
min={min}
max={max}
onChange={this.props.handleInputChange}
/>
</div>
);
}
function EnumArgument(props) {
const options = props.limit.map((item, key) => (
<option value={item}>{item}</option>
));
return (
<div>
{props.name}:
<select
className={styles.select}
name={props.name}
value={props.value}
onChange={props.handleInputChange}
>
{options}
</select>
</div>
);
}

class EnumArgument extends React.Component {
render() {
const options = this.props.limit.map((item, key) => (
<option value={item}>{item}</option>
));
return (
<div>
{this.props.name + ": "}
<select
className={styles.select}
name={this.props.name}
value={this.props.value}
onChange={this.props.handleInputChange}
>
{options}
</select>
</div>
);
}
function StringArgument(props) {
return (
<div>
{props.name}:
<input
className={cx([styles.input, styles["input--narrow"]])}
name={props.name}
type="text"
value={props.value}
onChange={props.handleInputChange}
/>
</div>
);
}

class Filter extends React.Component {
Expand All @@ -232,6 +242,14 @@ class Filter extends React.Component {
handleInputChange={this.props.handleArgumentChange}
/>
);
case ConfigType.String:
return (
<StringArgument
name={name}
value={config.value}
handleInputChange={this.props.handleArgumentChange}
/>
);
}
};

Expand Down
19 changes: 8 additions & 11 deletions captum/insights/frontend/widget/src/Widget.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ class Widget extends React.Component {
config: {
classes: [],
methods: [],
method_arguments: {}
method_arguments: {},
},
loading: false,
callback: null
callback: null,
};
this.backbone = this.props.backbone;
}
Expand Down Expand Up @@ -47,14 +47,11 @@ class Widget extends React.Component {

_fetchInit = () => {
this.setState({
config: this.backbone.model.get("insights_config")
config: this.backbone.model.get("insights_config"),
});
};

fetchData = filterConfig => {
filterConfig.approximation_steps = parseInt(
filterConfig.approximation_steps
);
fetchData = (filterConfig) => {
this.setState({ loading: true }, () => {
this.backbone.model.save({ config: filterConfig, output: [] });
});
Expand All @@ -64,7 +61,7 @@ class Widget extends React.Component {
this.setState({ callback: callback }, () => {
this.backbone.model.save({
label_details: { labelIndex, instance },
attribution: {}
attribution: {},
});
});
};
Expand All @@ -90,16 +87,16 @@ var CaptumInsightsModel = widgets.DOMWidgetModel.extend({
_model_module: "jupyter-captum-insights",
_view_module: "jupyter-captum-insights",
_model_module_version: "0.1.0",
_view_module_version: "0.1.0"
})
_view_module_version: "0.1.0",
}),
});

var CaptumInsightsView = widgets.DOMWidgetView.extend({
initialize() {
const $app = document.createElement("div");
ReactDOM.render(<Widget backbone={this} />, $app);
this.el.append($app);
}
},
});

export { Widget as default, CaptumInsightsModel, CaptumInsightsView };