Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions captum/attr/_utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ class VisualizationDataRecord:
"true_class",
"attr_class",
"attr_score",
"raw_input",
"raw_input_ids",
"convergence_score",
]

Expand All @@ -449,7 +449,7 @@ def __init__(
true_class,
attr_class,
attr_score,
raw_input,
raw_input_ids,
convergence_score,
) -> None:
self.word_attributions = word_attributions
Expand All @@ -458,7 +458,7 @@ def __init__(
self.true_class = true_class
self.attr_class = attr_class
self.attr_score = attr_score
self.raw_input = raw_input
self.raw_input_ids = raw_input_ids
self.convergence_score = convergence_score


Expand Down Expand Up @@ -541,7 +541,7 @@ def visualize_text(
format_classname(datarecord.attr_class),
format_classname("{0:.2f}".format(datarecord.attr_score)),
format_word_importances(
datarecord.raw_input, datarecord.word_attributions
datarecord.raw_input_ids, datarecord.word_attributions
),
"<tr>",
]
Expand Down
4 changes: 0 additions & 4 deletions captum/concept/_core/tcav.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,10 +733,6 @@ def _tcav_sub_computation(
sign_count_score = (
torch.sum(tcav_score > 0.0, dim=1).float() / tcav_score.shape[1]
)
# n_experiments x n_concepts
sign_count_score = (
torch.sum(tcav_score > 0.0, dim=1).float() / tcav_score.shape[1]
)

magnitude_score = torch.sum(
torch.abs(tcav_score * (tcav_score > 0.0).float()), dim=1
Expand Down
22 changes: 10 additions & 12 deletions tests/utils/test_av.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,10 +389,9 @@ def test_generate_activation(self) -> None:
low, high = 0, 16
mymodel = BasicLinearReLULinear(num_features)
mydata = RangeDataset(low, high, num_features)
layers: List[str] = []
for name, _module in mymodel.named_modules():
layers.append(name)
layers: List[str] = list(filter(None, layers))
layers: List[str] = [
value[0] for value in mymodel.named_modules() if value[0]
]

# First AV generation on last 2 layers
inputs = torch.stack((mydata[1], mydata[8], mydata[14]))
Expand Down Expand Up @@ -422,10 +421,9 @@ def test_generate_dataset_activations(self) -> None:
batch_size = high // 2
mymodel = BasicLinearReLULinear(num_features)
mydata = RangeDataset(low, high, num_features)
layers: List[str] = []
for name, _module in mymodel.named_modules():
layers.append(name)
layers: List[str] = list(filter(None, layers))
layers: List[str] = [
value[0] for value in mymodel.named_modules() if value[0]
]

# First AV generation on last 2 layers
AV.generate_dataset_activations(
Expand Down Expand Up @@ -466,10 +464,10 @@ def test_equal_activation(self) -> None:
low, high = 0, 16
mymodel = BasicLinearReLULinear(num_features)
mydata = RangeDataset(low, high, num_features)
layers: List[str] = []
for name, _module in mymodel.named_modules():
layers.append(name)
layers: List[str] = list(filter(None, layers))
layers: List[str] = [
value[0] for value in mymodel.named_modules() if value[0]
]

# First AV generation on last 2 layers
test_input = mydata[1].unsqueeze(0)
act = AV.generate_activation(
Expand Down