Skip to content

Commit 550cee2

Browse files
authored
Add table-transformer-detection ONNXRT example (#1314)
Signed-off-by: yuwenzho <[email protected]>
1 parent 2344905 commit 550cee2

File tree

8 files changed

+183
-75
lines changed

8 files changed

+183
-75
lines changed

examples/.config/model_params_onnxrt.json

+11-4
Original file line numberDiff line numberDiff line change
@@ -868,11 +868,18 @@
868868
"main_script": "main.py",
869869
"batch_size": 1
870870
},
871-
"table_transformer": {
871+
"table_transformer_structure_recognition": {
872872
"model_src_dir": "object_detection/table_transformer/quantization/ptq_static",
873-
"dataset_location": "/tf_dataset/dataset/PubTables-1M-Structure",
874-
"input_model": "/tf_dataset2/models/onnx/table-transformer/model.onnx",
875-
"main_script": "table-transformer/src/main.py",
873+
"dataset_location": "/tf_dataset/dataset/PubTables-1M",
874+
"input_model": "/tf_dataset2/models/onnx/table-transformer/pubtables1m_structure_detr_r18.onnx",
875+
"main_script": "patch",
876+
"batch_size": 1
877+
},
878+
"table_transformer_detection": {
879+
"model_src_dir": "object_detection/table_transformer/quantization/ptq_static",
880+
"dataset_location": "/tf_dataset/dataset/PubTables-1M",
881+
"input_model": "/tf_dataset2/models/onnx/table-transformer/pubtables1m_detection_detr_r18.onnx",
882+
"main_script": "patch",
876883
"batch_size": 1
877884
},
878885
"hf_codebert": {

examples/README.md

+7-1
Original file line numberDiff line numberDiff line change
@@ -1402,7 +1402,13 @@ Intel® Neural Compressor validated examples with multiple compression technique
14021402
<td><a href="./onnxrt/object_detection/ssd_mobilenet_v2/quantization/ptq_static">qlinearops</a> / <a href="./onnxrt/object_detection/ssd_mobilenet_v2/quantization/ptq_static">qdq</a></td>
14031403
</tr>
14041404
<tr>
1405-
<td>Table Transformer</td>
1405+
<td>Table Transformer Structure Recognition</td>
1406+
<td>Object Detection</td>
1407+
<td>Post-Training Static Quantization</td>
1408+
<td><a href="./onnxrt/object_detection/table_transformer/quantization/ptq_static">qlinearops</a></td>
1409+
</tr>
1410+
<tr>
1411+
<td>Table Transformer Detection</td>
14061412
<td>Object Detection</td>
14071413
<td>Post-Training Static Quantization</td>
14081414
<td><a href="./onnxrt/object_detection/table_transformer/quantization/ptq_static">qlinearops</a></td>
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Step-by-Step
22
============
33

4-
This example show how to export, quantize and evaluate the DETR R18 model for table structure recognition task based on PubTables-1M dataset.
4+
This example show how to export, quantize and evaluate 2 [DETR](https://huggingface.co/docs/transformers/model_doc/detr) R18 models on [PubTables-1M](https://huggingface.co/datasets/bsmock/pubtables-1m) dataset, one for table detection and one for table structure recognition, dubbed Table Transformers.
55

66
# Prerequisite
77

@@ -16,16 +16,20 @@ bash prepare.sh
1616
1717
## 2. Prepare Dataset
1818

19-
Download dataset according to this [doc](https://github.com/microsoft/table-transformer/tree/main#training-and-evaluation-data).
19+
Download PubTables-1M dataset according to this [doc](https://github.com/microsoft/table-transformer/tree/main#training-and-evaluation-data).
20+
After downloading and extracting, PubTables-1M dataset folder should contain `PubTables-1M-Structure` and `PubTables-1M-Detection` folders.
2021

2122
## 3. Prepare Model
2223

23-
```shell
24-
wget https://huggingface.co/bsmock/tatr-pubtables1m-v1.0/resolve/main/pubtables1m_structure_detr_r18.pth
24+
Prepare DETR R18 model for table structure recognition.
2525

26-
bash export.sh --input_model=/path/to/pubtables1m_structure_detr_r18.pth \
27-
--output_model=/path/to/export \ # model path as *.onnx
28-
--dataset_location=/path/to/dataset_folder # dataset_folder should contains 'words' sub-folder
26+
```
27+
python prepare_model.py --input_model=structure_detr --output_model=pubtables1m_structure_detr_r18.onnx --dataset_location=/path/to/pubtables-1m
28+
```
29+
30+
Prepare DETR R18 model for table detection.
31+
```
32+
python prepare_model.py --input_model=detection_detr --output_model=pubtables1m_detection_detr_r18.onnx --dataset_location=/path/to/pubtables-1m
2933
```
3034

3135
# Run
@@ -35,15 +39,15 @@ bash export.sh --input_model=/path/to/pubtables1m_structure_detr_r18.pth \
3539
Static quantization with QOperator format:
3640

3741
```bash
38-
bash run_tuning.sh --input_model=path/to/model \ # model path as *.onnx
42+
bash run_quant.sh --input_model=path/to/model \ # model path as *.onnx
3943
--output_model=path/to/save \ # model path as *.onnx
40-
--dataset_location=/path/to/dataset_folder # dataset_folder should contains 'words' sub-folder
44+
--dataset_location=/path/to/pubtables-1m # dataset_folder should contains `PubTables-1M-Structure` and/or `PubTables-1M-Detection` folders
4145
```
4246

4347
## 2. Benchmark
4448

4549
```bash
4650
bash run_benchmark.sh --input_model=path/to/model \ # model path as *.onnx
47-
--dataset_location=/path/to/dataset_folder # dataset_folder should contains 'words' sub-folder
51+
--dataset_location=/path/to/pubtables-1m # dataset_folder should contains `PubTables-1M-Structure` and/or `PubTables-1M-Detection` folders
4852
--mode=performance # or accuracy
4953
```

examples/onnxrt/object_detection/table_transformer/quantization/ptq_static/export.sh

-43
This file was deleted.

examples/onnxrt/object_detection/table_transformer/quantization/ptq_static/patch

+16-9
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ index 73ae39e..2049449 100644
110110
i = torch.arange(w, device=x.device)
111111
j = torch.arange(h, device=x.device)
112112
diff --git a/src/eval.py b/src/eval.py
113-
index e3a0565..5514db5 100644
113+
index e3a0565..d66b318 100644
114114
--- a/src/eval.py
115115
+++ b/src/eval.py
116116
@@ -4,6 +4,7 @@ Copyright (C) 2021 Microsoft Corporation
@@ -152,8 +152,14 @@ index e3a0565..5514db5 100644
152152

153153
if args.debug:
154154
for target, pred_logits, pred_boxes in zip(targets, outputs['pred_logits'], outputs['pred_boxes']):
155+
@@ -696,3 +703,4 @@ def eval_coco(args, model, criterion, postprocessors, data_loader_test, dataset_
156+
print("COCO metrics summary: AP50: {:.3f}, AP75: {:.3f}, AP: {:.3f}, AR: {:.3f}".format(
157+
pubmed_stats['coco_eval_bbox'][1], pubmed_stats['coco_eval_bbox'][2],
158+
pubmed_stats['coco_eval_bbox'][0], pubmed_stats['coco_eval_bbox'][8]))
159+
+ return pubmed_stats['coco_eval_bbox'][0]
160+
\ No newline at end of file
155161
diff --git a/src/main.py b/src/main.py
156-
index 74cd13c..1e5e5e9 100644
162+
index 74cd13c..c30377d 100644
157163
--- a/src/main.py
158164
+++ b/src/main.py
159165
@@ -41,6 +41,7 @@ def get_args():
@@ -209,7 +215,7 @@ index 74cd13c..1e5e5e9 100644
209215

210216
dataset_test = PDFTablesDataset(os.path.join(args.data_root_dir,
211217
"test"),
212-
@@ -169,6 +180,29 @@ def get_data(args):
218+
@@ -169,6 +180,28 @@ def get_data(args):
213219
num_workers=args.num_workers)
214220
return data_loader_test, dataset_test
215221

@@ -234,12 +240,11 @@ index 74cd13c..1e5e5e9 100644
234240
+ collate_fn=utils.collate_fn,
235241
+ num_workers=args.num_workers)
236242
+ return OXDataloader(data_loader_test, args.batch_size), dataset_test
237-
+
238243
+
239244
elif args.mode == "grits" or args.mode == "grits-all":
240245
dataset_test = PDFTablesDataset(os.path.join(args.data_root_dir,
241246
"test"),
242-
@@ -337,6 +371,20 @@ def train(args, model, criterion, postprocessors, device):
247+
@@ -337,6 +370,20 @@ def train(args, model, criterion, postprocessors, device):
243248

244249
print('Total training time: ', datetime.now() - start_time)
245250

@@ -260,7 +265,7 @@ index 74cd13c..1e5e5e9 100644
260265

261266
def main():
262267
cmd_args = get_args().__dict__
263-
@@ -350,7 +398,7 @@ def main():
268+
@@ -350,7 +397,7 @@ def main():
264269
print('-' * 100)
265270

266271
# Check for debug mode
@@ -269,7 +274,7 @@ index 74cd13c..1e5e5e9 100644
269274
print("Running evaluation/inference in DEBUG mode, processing will take longer. Saving output to: {}.".format(args.debug_save_dir))
270275
os.makedirs(args.debug_save_dir, exist_ok=True)
271276

272-
@@ -366,10 +414,33 @@ def main():
277+
@@ -366,10 +413,35 @@ def main():
273278

274279
if args.mode == "train":
275280
train(args, model, criterion, postprocessors, device)
@@ -278,7 +283,9 @@ index 74cd13c..1e5e5e9 100644
278283
data_loader_test, dataset_test = get_data(args)
279284
- eval_coco(args, model, criterion, postprocessors, data_loader_test, dataset_test, device)
280285
-
281-
+ eval_coco(args, args.input_onnx_model, criterion, postprocessors, data_loader_test, dataset_test, device)
286+
+ ap_result = eval_coco(args, args.input_onnx_model, criterion, postprocessors, data_loader_test, dataset_test, device)
287+
+ print("Batch size = %d" % args.batch_size)
288+
+ print("Accuracy: %.5f" % ap_result)
282289
+ elif args.mode == "export":
283290
+ data_loader_test, dataset_test = get_data(args)
284291
+ export(args, model, data_loader_test, device)
@@ -303,6 +310,6 @@ index 74cd13c..1e5e5e9 100644
303310
+ from neural_compressor.config import BenchmarkConfig
304311
+ config = BenchmarkConfig(warmup=10, iteration=100, cores_per_instance=4, num_of_instance=1)
305312
+ fit(args.input_onnx_model, config, b_dataloader=data_loader_test)
306-
313+
307314
if __name__ == "__main__":
308315
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import argparse
2+
import os
3+
import subprocess
4+
import sys
5+
from urllib import request
6+
7+
MODEL_URLS = {"structure_detr": "https://huggingface.co/bsmock/tatr-pubtables1m-v1.0/resolve/main/pubtables1m_structure_detr_r18.pth",
8+
"detection_detr": "https://huggingface.co/bsmock/tatr-pubtables1m-v1.0/resolve/main/pubtables1m_detection_detr_r18.pth"}
9+
MAX_TIMES_RETRY_DOWNLOAD = 5
10+
11+
12+
def parse_arguments():
13+
parser = argparse.ArgumentParser()
14+
parser.add_argument("--input_model",
15+
type=str,
16+
required=False,
17+
choices=["structure_detr", "detection_detr"],
18+
default="structure_detr")
19+
parser.add_argument("--output_model", type=str, required=True)
20+
parser.add_argument("--dataset_location", type=str, required=True)
21+
return parser.parse_args()
22+
23+
24+
def progressbar(cur, total=100):
25+
percent = '{:.2%}'.format(cur / total)
26+
sys.stdout.write("\r[%-100s] %s" % ('#' * int(cur), percent))
27+
sys.stdout.flush()
28+
29+
30+
def schedule(blocknum, blocksize, totalsize):
31+
if totalsize == 0:
32+
percent = 0
33+
else:
34+
percent = min(1.0, blocknum * blocksize / totalsize) * 100
35+
progressbar(percent)
36+
37+
38+
def download_model(url, retry_times=5):
39+
model_name = url.split("/")[-1]
40+
if os.path.isfile(model_name):
41+
print(f"{model_name} exists, skip download")
42+
return True
43+
44+
print("download model...")
45+
retries = 0
46+
while retries < retry_times:
47+
try:
48+
request.urlretrieve(url, model_name, schedule)
49+
break
50+
except KeyboardInterrupt:
51+
return False
52+
except:
53+
retries += 1
54+
print(f"Download failed{', Retry downloading...' if retries < retry_times else '!'}")
55+
return retries < retry_times
56+
57+
58+
def export_model(input_model, output_model, dataset_location):
59+
print("\nexport model...")
60+
61+
if not os.path.exists("./table-transformer"):
62+
subprocess.run("bash prepare.sh", shell=True)
63+
64+
model_load_path = os.path.abspath(MODEL_URLS[input_model].split("/")[-1])
65+
output_model = os.path.join(os.path.dirname(model_load_path), output_model)
66+
if input_model == "detection_detr":
67+
data_root_dir = os.path.join(dataset_location, "PubTables-1M-Detection")
68+
data_type = "detection"
69+
config_file = "detection_config.json"
70+
elif input_model == "structure_detr":
71+
data_root_dir = os.path.join(dataset_location, "PubTables-1M-Structure")
72+
data_type = "structure"
73+
config_file = "structure_config.json"
74+
table_words_dir = os.path.join(data_root_dir, "words")
75+
76+
os.chdir("table-transformer/src")
77+
78+
command = f"python main.py \
79+
--model_load_path {model_load_path} \
80+
--output_model {output_model} \
81+
--data_root_dir {data_root_dir} \
82+
--table_words_dir {table_words_dir} \
83+
--mode export \
84+
--data_type {data_type} \
85+
--device cpu \
86+
--config_file {config_file}"
87+
88+
subprocess.run(command, shell=True)
89+
assert os.path.exists(output_model), f"Export failed! {output_model} doesn't exist!"
90+
91+
92+
def prepare_model(input_model, output_model, dataset_location):
93+
is_download_successful = download_model(MODEL_URLS[args.input_model], MAX_TIMES_RETRY_DOWNLOAD)
94+
if is_download_successful:
95+
export_model(input_model, output_model, dataset_location)
96+
97+
98+
if __name__ == "__main__":
99+
args = parse_arguments()
100+
prepare_model(args.input_model, args.output_model, args.dataset_location)

examples/onnxrt/object_detection/table_transformer/quantization/ptq_static/run_benchmark.sh

+17-4
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,28 @@ function run_benchmark {
3535
bash prepare.sh
3636
fi
3737

38+
if [[ "${input_model}" =~ "structure" ]]; then
39+
task_data_dir="PubTables-1M-Structure"
40+
data_type="structure"
41+
config_file="structure_config.json"
42+
fi
43+
if [[ "${input_model}" =~ "detection" ]]; then
44+
task_data_dir="PubTables-1M-Detection"
45+
data_type="detection"
46+
config_file="detection_config.json"
47+
fi
48+
49+
input_model=$(realpath "$input_model")
50+
3851
cd table-transformer/src
3952
python main.py \
4053
--input_onnx_model ${input_model} \
41-
--data_root_dir ${dataset_location} \
42-
--table_words_dir ${dataset_location}/words \
54+
--data_root_dir "${dataset_location}/${task_data_dir}" \
55+
--table_words_dir "${dataset_location}/${task_data_dir}/words" \
4356
--mode ${mode} \
44-
--data_type structure \
57+
--data_type ${data_type} \
4558
--device cpu \
46-
--config_file structure_config.json
59+
--config_file ${config_file}
4760
}
4861

4962
main "$@"

examples/onnxrt/object_detection/table_transformer/quantization/ptq_static/run_tuning.sh renamed to examples/onnxrt/object_detection/table_transformer/quantization/ptq_static/run_quant.sh

+18-4
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,30 @@ function run_tuning {
3535
bash prepare.sh
3636
fi
3737

38+
if [[ "${input_model}" =~ "structure" ]]; then
39+
task_data_dir="PubTables-1M-Structure"
40+
data_type="structure"
41+
config_file="structure_config.json"
42+
fi
43+
if [[ "${input_model}" =~ "detection" ]]; then
44+
task_data_dir="PubTables-1M-Detection"
45+
data_type="detection"
46+
config_file="detection_config.json"
47+
fi
48+
49+
input_model=$(realpath "$input_model")
50+
output_model=$(realpath "$output_model")
51+
3852
cd table-transformer/src
3953
python main.py \
4054
--input_onnx_model ${input_model} \
4155
--output_model ${output_model} \
42-
--data_root_dir ${dataset_location} \
43-
--table_words_dir ${dataset_location}/words \
56+
--data_root_dir "${dataset_location}/${task_data_dir}" \
57+
--table_words_dir "${dataset_location}/${task_data_dir}/words" \
4458
--mode quantize \
45-
--data_type structure \
59+
--data_type ${data_type} \
4660
--device cpu \
47-
--config_file structure_config.json
61+
--config_file ${config_file}
4862
}
4963

5064
main "$@"

0 commit comments

Comments
 (0)