Skip to content

Commit c67a0c7

Browse files
Reduce memory consumption for BasNet tests (#2325)
1 parent d04fbcc commit c67a0c7

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

keras_cv/models/segmentation/basnet/basnet_test.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class BASNetTest(TestCase):
3232
def test_basnet_construction(self):
3333
backbone = ResNet18Backbone()
3434
model = BASNet(
35-
input_shape=[288, 288, 3], backbone=backbone, num_classes=1
35+
input_shape=[64, 64, 3], backbone=backbone, num_classes=1
3636
)
3737
model.compile(
3838
optimizer="adam",
@@ -44,17 +44,17 @@ def test_basnet_construction(self):
4444
def test_basnet_call(self):
4545
backbone = ResNet18Backbone()
4646
model = BASNet(
47-
input_shape=[288, 288, 3], backbone=backbone, num_classes=1
47+
input_shape=[64, 64, 3], backbone=backbone, num_classes=1
4848
)
49-
images = np.random.uniform(size=(2, 288, 288, 3))
49+
images = np.random.uniform(size=(2, 64, 64, 3))
5050
_ = model(images)
5151
_ = model.predict(images)
5252

5353
@pytest.mark.large
5454
@pytest.mark.filterwarnings("ignore::UserWarning")
5555
def test_weights_change(self):
56-
input_size = [288, 288, 3]
57-
target_size = [288, 288, 1]
56+
input_size = [64, 64, 3]
57+
target_size = [64, 64, 1]
5858

5959
images = np.ones([1] + input_size)
6060
labels = np.random.uniform(size=[1] + target_size)
@@ -64,7 +64,7 @@ def test_weights_change(self):
6464

6565
backbone = ResNet18Backbone()
6666
model = BASNet(
67-
input_shape=[288, 288, 3], backbone=backbone, num_classes=1
67+
input_shape=[64, 64, 3], backbone=backbone, num_classes=1
6868
)
6969
model_metrics = ["accuracy"]
7070
if keras_3():
@@ -77,7 +77,7 @@ def test_weights_change(self):
7777
)
7878

7979
original_weights = model.refinement_head.get_weights()
80-
model.fit(ds, epochs=1)
80+
model.fit(ds, epochs=1, batch_size=1)
8181
updated_weights = model.refinement_head.get_weights()
8282

8383
for w1, w2 in zip(original_weights, updated_weights):
@@ -98,11 +98,11 @@ def test_with_model_preset_forward_pass(self):
9898

9999
@pytest.mark.large
100100
def test_saved_model(self):
101-
target_size = [288, 288, 3]
101+
target_size = [64, 64, 3]
102102

103103
backbone = ResNet18Backbone()
104104
model = BASNet(
105-
input_shape=[288, 288, 3], backbone=backbone, num_classes=1
105+
input_shape=[64, 64, 3], backbone=backbone, num_classes=1
106106
)
107107

108108
input_batch = np.ones(shape=[2] + target_size)

0 commit comments

Comments
 (0)