@@ -184,7 +184,7 @@ def _init_test_generalized_rcnn_transform(self):
184
184
transform = GeneralizedRCNNTransform (min_size , max_size , image_mean , image_std )
185
185
return transform
186
186
187
- def _init_test_rpn (self ):
187
+ def _init_test_rpn (self , score_threshold = 0.0 ):
188
188
anchor_sizes = ((32 ,), (64 ,), (128 ,), (256 ,), (512 ,))
189
189
aspect_ratios = ((0.5 , 1.0 , 2.0 ),) * len (anchor_sizes )
190
190
rpn_anchor_generator = AnchorGenerator (anchor_sizes , aspect_ratios )
@@ -197,7 +197,7 @@ def _init_test_rpn(self):
197
197
rpn_pre_nms_top_n = dict (training = 2000 , testing = 1000 )
198
198
rpn_post_nms_top_n = dict (training = 2000 , testing = 1000 )
199
199
rpn_nms_thresh = 0.7
200
- rpn_score_thresh = 0.0
200
+ rpn_score_thresh = score_threshold
201
201
202
202
rpn = RegionProposalNetwork (
203
203
rpn_anchor_generator , rpn_head ,
@@ -260,7 +260,7 @@ def test_rpn(self):
260
260
class RPNModule (torch .nn .Module ):
261
261
def __init__ (self_module ):
262
262
super (RPNModule , self_module ).__init__ ()
263
- self_module .rpn = self ._init_test_rpn ()
263
+ self_module .rpn = self ._init_test_rpn (0.5 )
264
264
265
265
def forward (self_module , images , features ):
266
266
images = ImageList (images , [i .shape [- 2 :] for i in images ])
0 commit comments