diff --git a/test.py b/test.py index cd82c6c..e426ed3 100644 --- a/test.py +++ b/test.py @@ -75,7 +75,7 @@ def main(args): ac_i2t_top10_best = ac_top10_i2t ac_t2i_top1_best = ac_top1_t2i ac_t2i_top10_best = ac_top10_t2i - dst_best = os.path.join(args.checkpoint_dir, 'model_best', str(epoch)) + '.pth.tar' + dst_best = os.path.join(args.model_path, 'model_best', str(epoch)) + '.pth.tar' shutil.copyfile(model_file, dst_best) logging.info('epoch:{}'.format(epoch))