Skip to content

Commit 8a8c9a4

Browse files
committed
update patch
Signed-off-by: yuwenzho <[email protected]>
1 parent aaba929 commit 8a8c9a4

File tree

1 file changed

+6
-7
lines changed
  • examples/onnxrt/object_detection/table_transformer/quantization/ptq_static

1 file changed

+6
-7
lines changed

examples/onnxrt/object_detection/table_transformer/quantization/ptq_static/patch

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ index e3a0565..d66b318 100644
159159
+ return pubmed_stats['coco_eval_bbox'][0]
160160
\ No newline at end of file
161161
diff --git a/src/main.py b/src/main.py
162-
index 74cd13c..83d6aeb 100644
162+
index 74cd13c..c30377d 100644
163163
--- a/src/main.py
164164
+++ b/src/main.py
165165
@@ -41,6 +41,7 @@ def get_args():
@@ -215,7 +215,7 @@ index 74cd13c..83d6aeb 100644
215215

216216
dataset_test = PDFTablesDataset(os.path.join(args.data_root_dir,
217217
"test"),
218-
@@ -169,6 +180,29 @@ def get_data(args):
218+
@@ -169,6 +180,28 @@ def get_data(args):
219219
num_workers=args.num_workers)
220220
return data_loader_test, dataset_test
221221

@@ -240,12 +240,11 @@ index 74cd13c..83d6aeb 100644
240240
+ collate_fn=utils.collate_fn,
241241
+ num_workers=args.num_workers)
242242
+ return OXDataloader(data_loader_test, args.batch_size), dataset_test
243-
+
244243
+
245244
elif args.mode == "grits" or args.mode == "grits-all":
246245
dataset_test = PDFTablesDataset(os.path.join(args.data_root_dir,
247246
"test"),
248-
@@ -337,6 +371,20 @@ def train(args, model, criterion, postprocessors, device):
247+
@@ -337,6 +370,20 @@ def train(args, model, criterion, postprocessors, device):
249248

250249
print('Total training time: ', datetime.now() - start_time)
251250

@@ -266,7 +265,7 @@ index 74cd13c..83d6aeb 100644
266265

267266
def main():
268267
cmd_args = get_args().__dict__
269-
@@ -350,7 +398,7 @@ def main():
268+
@@ -350,7 +397,7 @@ def main():
270269
print('-' * 100)
271270

272271
# Check for debug mode
@@ -275,7 +274,7 @@ index 74cd13c..83d6aeb 100644
275274
print("Running evaluation/inference in DEBUG mode, processing will take longer. Saving output to: {}.".format(args.debug_save_dir))
276275
os.makedirs(args.debug_save_dir, exist_ok=True)
277276

278-
@@ -366,10 +414,35 @@ def main():
277+
@@ -366,10 +413,35 @@ def main():
279278

280279
if args.mode == "train":
281280
train(args, model, criterion, postprocessors, device)
@@ -313,4 +312,4 @@ index 74cd13c..83d6aeb 100644
313312
+ fit(args.input_onnx_model, config, b_dataloader=data_loader_test)
314313

315314
if __name__ == "__main__":
316-
main()
315+
main()

0 commit comments

Comments
 (0)