Skip to content

Commit a384d85

Browse files
committed
add table-transformer-detection example
Signed-off-by: yuwenzho <[email protected]>
1 parent 4dc6805 commit a384d85

File tree

7 files changed

+165
-64
lines changed

7 files changed

+165
-64
lines changed

examples/.config/model_params_onnxrt.json

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -868,10 +868,17 @@
868868
"main_script": "main.py",
869869
"batch_size": 1
870870
},
871-
"table_transformer": {
871+
"table_transformer_structure_recognition": {
872872
"model_src_dir": "object_detection/table_transformer/quantization/ptq_static",
873-
"dataset_location": "/tf_dataset/dataset/PubTables-1M-Structure",
874-
"input_model": "/tf_dataset2/models/onnx/table-transformer/model.onnx",
873+
"dataset_location": "/tf_dataset/dataset/PubTables-1M",
874+
"input_model": "/tf_dataset2/models/onnx/table-transformer/pubtables1m_structure_detr_r18.onnx",
875+
"main_script": "table-transformer/src/main.py",
876+
"batch_size": 1
877+
},
878+
"table_transformer_detection": {
879+
"model_src_dir": "object_detection/table_transformer/quantization/ptq_static",
880+
"dataset_location": "/tf_dataset/dataset/PubTables-1M",
881+
"input_model": "/tf_dataset2/models/onnx/table-transformer/pubtables1m_detection_detr_r18.onnx",
875882
"main_script": "table-transformer/src/main.py",
876883
"batch_size": 1
877884
},

examples/README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1402,7 +1402,13 @@ Intel® Neural Compressor validated examples with multiple compression technique
14021402
<td><a href="./onnxrt/object_detection/ssd_mobilenet_v2/quantization/ptq_static">qlinearops</a> / <a href="./onnxrt/object_detection/ssd_mobilenet_v2/quantization/ptq_static">qdq</a></td>
14031403
</tr>
14041404
<tr>
1405-
<td>Table Transformer</td>
1405+
<td>Table Transformer Structure Recognition</td>
1406+
<td>Object Detection</td>
1407+
<td>Post-Training Static Quantization</td>
1408+
<td><a href="./onnxrt/object_detection/table_transformer/quantization/ptq_static">qlinearops</a></td>
1409+
</tr>
1410+
<tr>
1411+
<td>Table Transformer Detection</td>
14061412
<td>Object Detection</td>
14071413
<td>Post-Training Static Quantization</td>
14081414
<td><a href="./onnxrt/object_detection/table_transformer/quantization/ptq_static">qlinearops</a></td>

examples/onnxrt/object_detection/table_transformer/quantization/ptq_static/README.md

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Step-by-Step
22
============
33

4-
This example show how to export, quantize and evaluate the DETR R18 model for table structure recognition task based on PubTables-1M dataset.
4+
This example show how to export, quantize and evaluate 2 [DETR](https://huggingface.co/docs/transformers/model_doc/detr) R18 models on [PubTables-1M](https://huggingface.co/datasets/bsmock/pubtables-1m) dataset, one for table detection and one for table structure recognition, dubbed Table Transformers.
55

66
# Prerequisite
77

@@ -16,16 +16,20 @@ bash prepare.sh
1616
1717
## 2. Prepare Dataset
1818

19-
Download dataset according to this [doc](https://github.com/microsoft/table-transformer/tree/main#training-and-evaluation-data).
19+
Download PubTables-1M dataset according to this [doc](https://github.com/microsoft/table-transformer/tree/main#training-and-evaluation-data).
20+
After downloading and extracting, PubTables-1M dataset folder should contain `PubTables-1M-Structure` and `PubTables-1M-Detection` folders.
2021

2122
## 3. Prepare Model
2223

23-
```shell
24-
wget https://huggingface.co/bsmock/tatr-pubtables1m-v1.0/resolve/main/pubtables1m_structure_detr_r18.pth
24+
Prepare DETR R18 model for table structure recognition.
2525

26-
bash export.sh --input_model=/path/to/pubtables1m_structure_detr_r18.pth \
27-
--output_model=/path/to/export \ # model path as *.onnx
28-
--dataset_location=/path/to/dataset_folder # dataset_folder should contains 'words' sub-folder
26+
```
27+
python prepare_model.py --input_model=structure_detr --output_model=pubtables1m_structure_detr_r18.onnx --dataset_location=/path/to/pubtables-1m
28+
```
29+
30+
Prepare DETR R18 model for table detection.
31+
```
32+
python prepare_model.py --input_model=detection_detr --output_model=pubtables1m_detection_detr_r18.onnx --dataset_location=/path/to/pubtables-1m
2933
```
3034

3135
# Run
@@ -37,13 +41,13 @@ Static quantization with QOperator format:
3741
```bash
3842
bash run_tuning.sh --input_model=path/to/model \ # model path as *.onnx
3943
--output_model=path/to/save \ # model path as *.onnx
40-
--dataset_location=/path/to/dataset_folder # dataset_folder should contains 'words' sub-folder
44+
--dataset_location=/path/to/pubtables-1m # dataset_folder should contains `PubTables-1M-Structure` and/or `PubTables-1M-Detection` folders
4145
```
4246

4347
## 2. Benchmark
4448

4549
```bash
4650
bash run_benchmark.sh --input_model=path/to/model \ # model path as *.onnx
47-
--dataset_location=/path/to/dataset_folder # dataset_folder should contains 'words' sub-folder
51+
--dataset_location=/path/to/pubtables-1m # dataset_folder should contains `PubTables-1M-Structure` and/or `PubTables-1M-Detection` folders
4852
--mode=performance # or accuracy
4953
```

examples/onnxrt/object_detection/table_transformer/quantization/ptq_static/export.sh

Lines changed: 0 additions & 43 deletions
This file was deleted.
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import argparse
2+
import os
3+
import subprocess
4+
import sys
5+
from urllib import request
6+
7+
MODEL_URLS = {"structure_detr": "https://huggingface.co/bsmock/tatr-pubtables1m-v1.0/resolve/main/pubtables1m_structure_detr_r18.pth",
8+
"detection_detr": "https://huggingface.co/bsmock/tatr-pubtables1m-v1.0/resolve/main/pubtables1m_detection_detr_r18.pth"}
9+
MAX_TIMES_RETRY_DOWNLOAD = 5
10+
11+
12+
def parse_arguments():
13+
parser = argparse.ArgumentParser()
14+
parser.add_argument("--input_model",
15+
type=str,
16+
required=False,
17+
choices=["structure_detr", "detection_detr"],
18+
default="structure_detr")
19+
parser.add_argument("--output_model", type=str, required=True)
20+
parser.add_argument("--dataset_location", type=str, required=True)
21+
return parser.parse_args()
22+
23+
24+
def progressbar(cur, total=100):
25+
percent = '{:.2%}'.format(cur / total)
26+
sys.stdout.write("\r[%-100s] %s" % ('#' * int(cur), percent))
27+
sys.stdout.flush()
28+
29+
30+
def schedule(blocknum, blocksize, totalsize):
31+
if totalsize == 0:
32+
percent = 0
33+
else:
34+
percent = min(1.0, blocknum * blocksize / totalsize) * 100
35+
progressbar(percent)
36+
37+
38+
def download_model(url, retry_times=5):
39+
model_name = url.split("/")[-1]
40+
if os.path.isfile(model_name):
41+
print(f"{model_name} exists, skip download")
42+
return True
43+
44+
print("download model...")
45+
retries = 0
46+
while retries < retry_times:
47+
try:
48+
request.urlretrieve(url, model_name, schedule)
49+
break
50+
except KeyboardInterrupt:
51+
return False
52+
except:
53+
retries += 1
54+
print(f"Download failed{', Retry downloading...' if retries < retry_times else '!'}")
55+
return retries < retry_times
56+
57+
58+
def export_model(input_model, output_model, dataset_location):
59+
print("\nexport model...")
60+
61+
if not os.path.exists("./table-transformer"):
62+
subprocess.run("bash prepare.sh", shell=True)
63+
64+
model_load_path = os.path.abspath(MODEL_URLS[input_model].split("/")[-1])
65+
output_model = os.path.join(os.path.dirname(model_load_path), output_model)
66+
if input_model == "detection_detr":
67+
data_root_dir = os.path.join(dataset_location, "PubTables-1M-Detection")
68+
data_type = "detection"
69+
config_file = "detection_config.json"
70+
elif input_model == "structure_detr":
71+
data_root_dir = os.path.join(dataset_location, "PubTables-1M-Structure")
72+
data_type = "structure"
73+
config_file = "structure_config.json"
74+
table_words_dir = os.path.join(data_root_dir, "words")
75+
76+
os.chdir("table-transformer/src")
77+
78+
command = f"python main.py \
79+
--model_load_path {model_load_path} \
80+
--output_model {output_model} \
81+
--data_root_dir {data_root_dir} \
82+
--table_words_dir {table_words_dir} \
83+
--mode export \
84+
--data_type {data_type} \
85+
--device cpu \
86+
--config_file {config_file}"
87+
88+
subprocess.run(command, shell=True)
89+
assert os.path.exists(output_model), f"Export failed! {output_model} doesn't exist!"
90+
91+
92+
def prepare_model(input_model, output_model, dataset_location):
93+
is_download_successful = download_model(MODEL_URLS[args.input_model], MAX_TIMES_RETRY_DOWNLOAD)
94+
if is_download_successful:
95+
export_model(input_model, output_model, dataset_location)
96+
97+
98+
if __name__ == "__main__":
99+
args = parse_arguments()
100+
prepare_model(args.input_model, args.output_model, args.dataset_location)

examples/onnxrt/object_detection/table_transformer/quantization/ptq_static/run_benchmark.sh

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,28 @@ function run_benchmark {
3535
bash prepare.sh
3636
fi
3737

38+
if [[ "${input_model}" =~ "structure" ]]; then
39+
task_data_dir="PubTables-1M-Structure"
40+
data_type="structure"
41+
config_file="structure_config.json"
42+
fi
43+
if [[ "${input_model}" =~ "detection" ]]; then
44+
task_data_dir="PubTables-1M-Detection"
45+
data_type="detection"
46+
config_file="detection_config.json"
47+
fi
48+
49+
input_model=$(realpath "$input_model")
50+
3851
cd table-transformer/src
3952
python main.py \
4053
--input_onnx_model ${input_model} \
41-
--data_root_dir ${dataset_location} \
42-
--table_words_dir ${dataset_location}/words \
54+
--data_root_dir "${dataset_location}/${task_data_dir}" \
55+
--table_words_dir "${dataset_location}/${task_data_dir}/words" \
4356
--mode ${mode} \
44-
--data_type structure \
57+
--data_type ${data_type} \
4558
--device cpu \
46-
--config_file structure_config.json
59+
--config_file ${config_file}
4760
}
4861

4962
main "$@"

examples/onnxrt/object_detection/table_transformer/quantization/ptq_static/run_tuning.sh

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,30 @@ function run_tuning {
3535
bash prepare.sh
3636
fi
3737

38+
if [[ "${input_model}" =~ "structure" ]]; then
39+
task_data_dir="PubTables-1M-Structure"
40+
data_type="structure"
41+
config_file="structure_config.json"
42+
fi
43+
if [[ "${input_model}" =~ "detection" ]]; then
44+
task_data_dir="PubTables-1M-Detection"
45+
data_type="detection"
46+
config_file="detection_config.json"
47+
fi
48+
49+
input_model=$(realpath "$input_model")
50+
output_model=$(realpath "$output_model")
51+
3852
cd table-transformer/src
3953
python main.py \
4054
--input_onnx_model ${input_model} \
4155
--output_model ${output_model} \
42-
--data_root_dir ${dataset_location} \
43-
--table_words_dir ${dataset_location}/words \
56+
--data_root_dir "${dataset_location}/${task_data_dir}" \
57+
--table_words_dir "${dataset_location}/${task_data_dir}/words" \
4458
--mode quantize \
45-
--data_type structure \
59+
--data_type ${data_type} \
4660
--device cpu \
47-
--config_file structure_config.json
61+
--config_file ${config_file}
4862
}
4963

5064
main "$@"

0 commit comments

Comments
 (0)