@@ -32,7 +32,7 @@ class BASNetTest(TestCase):
3232 def test_basnet_construction (self ):
3333 backbone = ResNet18Backbone ()
3434 model = BASNet (
35- input_shape = [288 , 288 , 3 ], backbone = backbone , num_classes = 1
35+ input_shape = [64 , 64 , 3 ], backbone = backbone , num_classes = 1
3636 )
3737 model .compile (
3838 optimizer = "adam" ,
@@ -44,17 +44,17 @@ def test_basnet_construction(self):
4444 def test_basnet_call (self ):
4545 backbone = ResNet18Backbone ()
4646 model = BASNet (
47- input_shape = [288 , 288 , 3 ], backbone = backbone , num_classes = 1
47+ input_shape = [64 , 64 , 3 ], backbone = backbone , num_classes = 1
4848 )
49- images = np .random .uniform (size = (2 , 288 , 288 , 3 ))
49+ images = np .random .uniform (size = (2 , 64 , 64 , 3 ))
5050 _ = model (images )
5151 _ = model .predict (images )
5252
5353 @pytest .mark .large
5454 @pytest .mark .filterwarnings ("ignore::UserWarning" )
5555 def test_weights_change (self ):
56- input_size = [288 , 288 , 3 ]
57- target_size = [288 , 288 , 1 ]
56+ input_size = [64 , 64 , 3 ]
57+ target_size = [64 , 64 , 1 ]
5858
5959 images = np .ones ([1 ] + input_size )
6060 labels = np .random .uniform (size = [1 ] + target_size )
@@ -64,7 +64,7 @@ def test_weights_change(self):
6464
6565 backbone = ResNet18Backbone ()
6666 model = BASNet (
67- input_shape = [288 , 288 , 3 ], backbone = backbone , num_classes = 1
67+ input_shape = [64 , 64 , 3 ], backbone = backbone , num_classes = 1
6868 )
6969 model_metrics = ["accuracy" ]
7070 if keras_3 ():
@@ -77,7 +77,7 @@ def test_weights_change(self):
7777 )
7878
7979 original_weights = model .refinement_head .get_weights ()
80- model .fit (ds , epochs = 1 )
80+ model .fit (ds , epochs = 1 , batch_size = 1 )
8181 updated_weights = model .refinement_head .get_weights ()
8282
8383 for w1 , w2 in zip (original_weights , updated_weights ):
@@ -98,11 +98,11 @@ def test_with_model_preset_forward_pass(self):
9898
9999 @pytest .mark .large
100100 def test_saved_model (self ):
101- target_size = [288 , 288 , 3 ]
101+ target_size = [64 , 64 , 3 ]
102102
103103 backbone = ResNet18Backbone ()
104104 model = BASNet (
105- input_shape = [288 , 288 , 3 ], backbone = backbone , num_classes = 1
105+ input_shape = [64 , 64 , 3 ], backbone = backbone , num_classes = 1
106106 )
107107
108108 input_batch = np .ones (shape = [2 ] + target_size )
0 commit comments