diff --git a/references/detection/train.py b/references/detection/train.py index d3b394b8bd0..229278eb9b4 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -158,10 +158,7 @@ def main(args): device = torch.device(args.device) if args.use_deterministic_algorithms: - torch.backends.cudnn.benchmark = False torch.use_deterministic_algorithms(True) - else: - torch.backends.cudnn.benchmark = True # Data loading code print("Loading data") @@ -253,8 +250,6 @@ def main(args): scaler.load_state_dict(checkpoint["scaler"]) if args.test_only: - # We disable the cudnn benchmarking because it can noticeably affect the accuracy - torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True evaluate(model, data_loader_test, device=device) return