@@ -159,7 +159,7 @@ index e3a0565..d66b318 100644
159
159
+ return pubmed_stats['coco_eval_bbox'][0]
160
160
\ No newline at end of file
161
161
diff --git a/src/main.py b/src/main.py
162
- index 74cd13c..83d6aeb 100644
162
+ index 74cd13c..c30377d 100644
163
163
--- a/src/main.py
164
164
+++ b/src/main.py
165
165
@@ -41,6 +41,7 @@ def get_args():
@@ -215,7 +215,7 @@ index 74cd13c..83d6aeb 100644
215
215
216
216
dataset_test = PDFTablesDataset(os.path.join(args.data_root_dir,
217
217
"test"),
218
- @@ -169,6 +180,29 @@ def get_data(args):
218
+ @@ -169,6 +180,28 @@ def get_data(args):
219
219
num_workers=args.num_workers)
220
220
return data_loader_test, dataset_test
221
221
@@ -240,12 +240,11 @@ index 74cd13c..83d6aeb 100644
240
240
+ collate_fn=utils.collate_fn,
241
241
+ num_workers=args.num_workers)
242
242
+ return OXDataloader(data_loader_test, args.batch_size), dataset_test
243
- +
244
243
+
245
244
elif args.mode == "grits" or args.mode == "grits-all":
246
245
dataset_test = PDFTablesDataset(os.path.join(args.data_root_dir,
247
246
"test"),
248
- @@ -337,6 +371 ,20 @@ def train(args, model, criterion, postprocessors, device):
247
+ @@ -337,6 +370 ,20 @@ def train(args, model, criterion, postprocessors, device):
249
248
250
249
print('Total training time: ', datetime.now() - start_time)
251
250
@@ -266,7 +265,7 @@ index 74cd13c..83d6aeb 100644
266
265
267
266
def main():
268
267
cmd_args = get_args().__dict__
269
- @@ -350,7 +398 ,7 @@ def main():
268
+ @@ -350,7 +397 ,7 @@ def main():
270
269
print('-' * 100)
271
270
272
271
# Check for debug mode
@@ -275,7 +274,7 @@ index 74cd13c..83d6aeb 100644
275
274
print("Running evaluation/inference in DEBUG mode, processing will take longer. Saving output to: {}.".format(args.debug_save_dir))
276
275
os.makedirs(args.debug_save_dir, exist_ok=True)
277
276
278
- @@ -366,10 +414 ,35 @@ def main():
277
+ @@ -366,10 +413 ,35 @@ def main():
279
278
280
279
if args.mode == "train":
281
280
train(args, model, criterion, postprocessors, device)
@@ -313,4 +312,4 @@ index 74cd13c..83d6aeb 100644
313
312
+ fit(args.input_onnx_model, config, b_dataloader=data_loader_test)
314
313
315
314
if __name__ == "__main__":
316
- main()
315
+ main()
0 commit comments