Skip to content

Commit dd12fdc

Browse files
Add support for flatbuffers 2.0
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent cc0dae4 commit dd12fdc

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

tests/test_tflite_postprocess.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@
2222

2323
# pylint: disable=missing-docstring
2424

25+
def endvector(builder, length):
26+
try:
27+
builder.EndVector(length)
28+
except TypeError:
29+
# flatbuffers 2.0 changes the API
30+
builder.EndVector()
2531

2632
class TFLiteDetectionPostProcessTests(Tf2OnnxBackendTestBase):
2733

@@ -90,7 +96,7 @@ def make_postprocess_model(self, max_detections=10, detections_per_class=100, ma
9096
# op_codes
9197
Model.ModelStartOperatorCodesVector(builder, 1)
9298
builder.PrependUOffsetTRelative(op_code)
93-
op_codes = builder.EndVector(1)
99+
op_codes = endvector(builder, 1)
94100

95101
# Make tensors
96102
# [names, shape, type tensors]
@@ -118,19 +124,19 @@ def make_postprocess_model(self, max_detections=10, detections_per_class=100, ma
118124
SubGraph.SubGraphStartTensorsVector(builder, len(ts))
119125
for tensor in reversed(ts):
120126
builder.PrependUOffsetTRelative(tensor)
121-
tensors = builder.EndVector(len(ts))
127+
tensors = endvector(builder, len(ts))
122128

123129
# inputs
124130
SubGraph.SubGraphStartInputsVector(builder, 3)
125131
for inp in reversed([0, 1, 2]):
126132
builder.PrependInt32(inp)
127-
inputs = builder.EndVector(3)
133+
inputs = endvector(builder, 3)
128134

129135
# outputs
130136
SubGraph.SubGraphStartOutputsVector(builder, 4)
131137
for out in reversed([3, 4, 5, 6]):
132138
builder.PrependInt32(out)
133-
outputs = builder.EndVector(4)
139+
outputs = endvector(builder, 4)
134140

135141
flexbuffer = \
136142
b'y_scale\x00nms_score_threshold\x00max_detections\x00x_scale\x00w_scale\x00nms_iou_threshold' \
@@ -164,7 +170,7 @@ def make_postprocess_model(self, max_detections=10, detections_per_class=100, ma
164170
# operators
165171
SubGraph.SubGraphStartOperatorsVector(builder, 1)
166172
builder.PrependUOffsetTRelative(operator)
167-
operators = builder.EndVector(1)
173+
operators = endvector(builder, 1)
168174

169175
# subgraph
170176
graph_name = builder.CreateString("TFLite graph")
@@ -179,20 +185,20 @@ def make_postprocess_model(self, max_detections=10, detections_per_class=100, ma
179185
# subgraphs
180186
Model.ModelStartSubgraphsVector(builder, 1)
181187
builder.PrependUOffsetTRelative(subgraph)
182-
subgraphs = builder.EndVector(1)
188+
subgraphs = endvector(builder, 1)
183189

184190
description = builder.CreateString("Model for tflite testing")
185191

186192
Buffer.BufferStartDataVector(builder, 0)
187-
data = builder.EndVector(0)
193+
data = endvector(builder, 0)
188194

189195
Buffer.BufferStart(builder)
190196
Buffer.BufferAddData(builder, data)
191197
buffer = Buffer.BufferEnd(builder)
192198

193199
Model.ModelStartBuffersVector(builder, 1)
194200
builder.PrependUOffsetTRelative(buffer)
195-
buffers = builder.EndVector(1)
201+
buffers = endvector(builder, 1)
196202

197203
# model
198204
Model.ModelStart(builder)

0 commit comments

Comments
 (0)