diff --git a/tf2onnx/tflite_utils.py b/tf2onnx/tflite_utils.py index 82cd2f5d8..8c4909e15 100644 --- a/tf2onnx/tflite_utils.py +++ b/tf2onnx/tflite_utils.py @@ -171,13 +171,14 @@ def parse_tflite_graph(tflite_g, opcodes_map, model, input_prefix=''): tensor_names[i] = name name_to_tensor[name] = tensor - if tensor.ShapeIsNone(): - output_shapes[name] = None - elif tensor.ShapeSignatureIsNone(): + if not tensor.ShapeSignatureIsNone(): + output_shapes[name] = tensor.ShapeSignatureAsNumpy().tolist() + elif not tensor.ShapeIsNone() and len(tensor.ShapeAsNumpy().tolist()) > 0: # The shape signature uses -1 to signify unknown dims. Old models don't have this and use Shape instead. + # Annoyingly, an empty shape can actually mean the rank is unknown. output_shapes[name] = tensor.ShapeAsNumpy().tolist() else: - output_shapes[name] = tensor.ShapeSignatureAsNumpy().tolist() + output_shapes[name] = None buf = model.Buffers(tensor.Buffer()) dtypes[name] = map_tflite_dtype_to_onnx(tensor.Type()) if not buf.DataIsNone() and tensor.Buffer() > 0: