Skip to content

Commit 4919c04

Browse files
committed
feat(compression): enhance view.py to display decoded LUT indices
Add functionality to decode and display LUT indices for compressed tensor buffers in the model viewer. When viewing compressed models, buffers containing compressed data now show a "_lut_indices" field with the decoded indices displayed in a readable multi-line format. Also improve the numpy array pretty printer to display all values without ellipsis truncation, making large arrays more readable.
1 parent a415d3e commit 4919c04

File tree

1 file changed

+44
-1
lines changed
  • tensorflow/lite/micro/compression

1 file changed

+44
-1
lines changed

tensorflow/lite/micro/compression/view.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,28 @@ def unpack_compression_metadata(buffer):
350350
return {"subgraphs": result}
351351

352352

353+
def find_lut_info_for_buffer(buffer_index, model, compression_data):
354+
"""Find LUT metadata for a given buffer index.
355+
356+
Returns a dict with tensor_index, subgraph_index, and index_bitwidth if the
357+
buffer contains compressed indices, otherwise returns None.
358+
"""
359+
if compression_data is None:
360+
return None
361+
362+
for subgraph_idx, subgraph in enumerate(compression_data.metadata.subgraphs):
363+
for lut_tensor in subgraph.lutTensors:
364+
# Get the tensor to find which buffer contains the compressed indices
365+
tensor = model.subgraphs[subgraph_idx].tensors[lut_tensor.tensor]
366+
if tensor.buffer == buffer_index:
367+
return {
368+
"tensor_index": lut_tensor.tensor,
369+
"subgraph_index": subgraph_idx,
370+
"index_bitwidth": lut_tensor.indexBitwidth,
371+
}
372+
return None
373+
374+
353375
def unpack_buffers(model, compression_data):
354376
buffers = []
355377
for index, buffer in enumerate(model.buffers):
@@ -362,6 +384,25 @@ def unpack_buffers(model, compression_data):
362384
native["_compression_metadata"] = True
363385

364386
native["data"] = buffer.data
387+
388+
# Check if this buffer contains compressed indices
389+
lut_info = find_lut_info_for_buffer(index, model, compression_data)
390+
if lut_info and buffer.data is not None:
391+
# Decode the indices from the buffer
392+
bstring = bitarray.bitarray()
393+
bstring.frombytes(bytes(buffer.data))
394+
bitwidth = lut_info["index_bitwidth"]
395+
chunks = [bstring[i:i+bitwidth] for i in range(0, len(bstring) - bitwidth + 1, bitwidth)]
396+
indices = [bitarray.util.ba2int(chunk) for chunk in chunks]
397+
398+
# Convert indices to numpy array to match data field formatting
399+
indices_array = np.array(indices, dtype=np.uint8)
400+
401+
native["_lut_indices"] = {
402+
"tensor": lut_info["tensor_index"],
403+
"bitwidth": bitwidth,
404+
"indices": indices_array,
405+
}
365406

366407
buffers.append(native)
367408

@@ -404,7 +445,9 @@ def create_dictionary(flatbuffer: memoryview) -> dict:
404445

405446
@prettyprinter.register_pretty(np.ndarray)
406447
def pretty_numpy_array(array, ctx):
407-
string = np.array2string(array)
448+
# Format array without ellipsis, similar to how buffer data is displayed
449+
string = np.array2string(array, threshold=np.inf, max_line_width=78,
450+
separator=' ', suppress_small=True)
408451
lines = string.splitlines()
409452

410453
if len(lines) == 1:

0 commit comments

Comments
 (0)