Skip to content

Commit 87a0a9e

Browse files
csauperfacebook-github-bot
authored andcommitted
Fix assorted unbound variables [4/n] (#1365)
Summary: Pull Request resolved: #1365 Fix unbound variables that flake8 is complaining about Reviewed By: cyrjano Differential Revision: D64261231 fbshipit-source-id: ecc85a72c8d309fe900e4f559566a7bd0ea95ef9
1 parent 07470af commit 87a0a9e

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

tests/attr/test_class_summarizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ def test_classes(self) -> None:
5757
for batch_size in [None, 1, 4]:
5858
for sizes, classes in zip(sizes_to_test, list_of_classes):
5959

60-
def create_batch_labels(batch_idx):
60+
def create_batch_labels(
61+
batch_idx, batch_size=batch_size, classes=classes
62+
):
6163
if batch_size is None:
6264
# batch_size = 1
6365
return classes[batch_idx]

tests/attr/test_input_layer_wrapper.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
class InputLayerMeta(type):
4747
def __new__(metacls, name: str, bases: Tuple, attrs: Dict):
48+
global layer_methods_to_test_with_equiv
4849
for (
4950
layer_method,
5051
equiv_method,
@@ -56,7 +57,7 @@ def __new__(metacls, name: str, bases: Tuple, attrs: Dict):
5657
+ f"_{equiv_method.__name__}_{multi_layer}"
5758
)
5859
attrs[test_name] = (
59-
lambda self: self.layer_method_with_input_layer_patches(
60+
lambda self, layer_method=layer_method, equiv_method=equiv_method, multi_layer=multi_layer: self.layer_method_with_input_layer_patches( # noqa: E501
6061
layer_method, equiv_method, multi_layer
6162
)
6263
)
@@ -107,8 +108,14 @@ def layer_method_with_input_layer_patches(
107108

108109
real_attributions = equivalent_method.attribute(*args_to_use, target=0)
109110

110-
if not isinstance(a1, tuple):
111+
if isinstance(a1, list):
112+
a1 = tuple(a1)
113+
elif not isinstance(a1, tuple):
111114
a1 = (a1,)
115+
116+
if isinstance(a2, list):
117+
a2 = tuple(a2)
118+
elif not isinstance(a2, tuple):
112119
a2 = (a2,)
113120

114121
if not isinstance(real_attributions, tuple):

0 commit comments

Comments
 (0)