Skip to content
38 changes: 28 additions & 10 deletions tensorflow/lite/micro/kernels/decode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,23 @@ TfLiteStatus SetOutputTensorData(TfLiteContext* context, const TfLiteNode* node,
return kTfLiteOk;
}

DecodeState* GetDecodeStateFromCustomRegistration(const TfLiteContext* context,
uint8_t type) {
const MicroContext* mc = GetMicroContext(context);
auto registrations = mc->GetCustomDecodeRegistrations();
if (registrations == nullptr) {
return nullptr;
}

for (auto& reg : *registrations) {
if (reg.type == type && reg.func != nullptr) {
return reg.func(context, mc->GetAlternateProfiler());
}
}

return nullptr;
}

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const size_t num_inputs = NumInputs(node);
const size_t num_outputs = NumOutputs(node);
Expand Down Expand Up @@ -113,21 +130,22 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
dsp = DecodeState::CreateDecodeStateHuffman(
context, micro_context->GetAlternateProfiler());
break;
case DecodeState::kDcmTypeCustom:
MicroPrintf("Custom decode type not yet supported");
break;
default:
MicroPrintf("unsupported decode type %u",
DecodeState::Type(*ancillary));
uint32_t type = DecodeState::Type(*ancillary);
if (type >= DecodeState::kDcmTypeCustomFirst &&
type <= DecodeState::kDcmTypeCustomLast) {
dsp = GetDecodeStateFromCustomRegistration(context, type);
} else {
MicroPrintf("unsupported decode type %u", type);
}
break;
}

status = SetOutputTensorData(context, node, i / 2, output);
if (status != kTfLiteOk) {
break;
}

if (dsp != nullptr) {
status = SetOutputTensorData(context, node, i / 2, output);
if (status != kTfLiteOk) {
break;
}
status = dsp->Setup(*input, *ancillary, *output);
if (status != kTfLiteOk) {
break;
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/lite/micro/kernels/decode_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ class DecodeState {
static constexpr uint8_t kDcmTypeLUT = 0;
static constexpr uint8_t kDcmTypeHuffman = 1;
static constexpr uint8_t kDcmTypePrune = 2;
static constexpr uint8_t kDcmTypeCustom = 127;
static constexpr uint8_t kDcmTypeCustomFirst = 128;
static constexpr uint8_t kDcmTypeCustomLast = 255;

static constexpr size_t kDcmSizeInBytes = 16;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ TF_LITE_MICRO_TEST(DecodeHuffmanTable16BitsInt16Fail) {
tflite::testing::TestDecode<encodes.size() + ancillaries.size(),
outputs.size()>(
encodes, ancillaries, outputs, expected, tflite::Register_DECODE(),
nullptr, kTfLiteError);
nullptr, nullptr, kTfLiteError);
}

TF_LITE_MICRO_TEST(DecodeHuffmanTable32BitsInt8) {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/lite/micro/kernels/decode_state_prune_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ TF_LITE_MICRO_TEST(DecodePruneQuantizedInvalidZeroPointInt16) {
tflite::testing::TestDecode<kEncodes.size() + kAncillaries.size(),
kOutputs.size()>(
kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE(),
nullptr, kTfLiteError);
nullptr, nullptr, kTfLiteError);
}

TF_LITE_MICRO_TESTS_END
129 changes: 129 additions & 0 deletions tensorflow/lite/micro/kernels/decode_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,76 @@ constexpr int kEncodedShapeLUT[] = {1, sizeof(kEncodedLUT)};
constexpr int8_t kExpectLUT0[] = {1, 2, 3, 4, 4, 3, 2, 1};
constexpr int16_t kExpectLUT1[] = {5, 6, 7, 8, 8, 7, 6, 5};

//
// Custom DECODE test data
//
constexpr int kDecodeTypeCustom = 200;

constexpr int8_t kAncillaryDataCustom[] = {0x42};

constexpr uint8_t kDcmCustom[tflite::DecodeState::kDcmSizeInBytes] = {
kDecodeTypeCustom, // type: custom
1, // DCM version: 1
};

// Align the tensor data the same as a Buffer in the TfLite schema
alignas(16) const uint8_t kEncodedCustom[] = {0x42, 0x43, 0x40, 0x46,
0x4A, 0x52, 0x62, 0x02};

// Tensor shapes as TfLiteIntArray
constexpr int kOutputShapeCustom[] = {1, 8};
constexpr int kEncodedShapeCustom[] = {1, sizeof(kEncodedCustom)};

constexpr int8_t kExpectCustom[] = {0x00, 0x01, 0x02, 0x04,
0x08, 0x10, 0x20, 0x40};

class DecodeStateCustom : public tflite::DecodeState {
public:
DecodeStateCustom() = delete;

DecodeStateCustom(const TfLiteContext* context,
tflite::MicroProfilerInterface* profiler)
: DecodeState(context, profiler) {}

virtual TfLiteStatus Setup(const TfLiteTensor& input,
const TfLiteTensor& ancillary,
const TfLiteTensor& output) override {
return kTfLiteOk;
}

virtual TfLiteStatus Decode(const TfLiteEvalTensor& input,
const TfLiteEvalTensor& ancillary,
const TfLiteEvalTensor& output) override {
const uint8_t* inp = tflite::micro::GetTensorData<uint8_t>(&input);
TF_LITE_ENSURE(const_cast<TfLiteContext*>(context_), inp != nullptr);
uint8_t* outp = tflite::micro::GetTensorData<uint8_t>(
const_cast<TfLiteEvalTensor*>(&output));
TF_LITE_ENSURE(const_cast<TfLiteContext*>(context_), outp != nullptr);
const uint8_t* vp = tflite::micro::GetTensorData<uint8_t>(&ancillary);
TF_LITE_ENSURE(const_cast<TfLiteContext*>(context_), vp != nullptr);
vp += kDcmSizeInBytes;

// simple XOR de-obfuscation
std::transform(inp, inp + input.dims->data[0], outp,
[vp](uint8_t i) { return i ^ *vp; });

return kTfLiteOk;
}

static DecodeState* CreateDecodeStateCustom(
const TfLiteContext* context, tflite::MicroProfilerInterface* profiler) {
alignas(4) static uint8_t buffer[sizeof(DecodeStateCustom)];
DecodeState* instance = new (buffer) DecodeStateCustom(context, profiler);
return instance;
}

protected:
virtual ~DecodeStateCustom() = default;

private:
TF_LITE_REMOVE_VIRTUAL_DELETE
};

} // namespace

TF_LITE_MICRO_TESTS_BEGIN
Expand Down Expand Up @@ -246,4 +316,63 @@ TF_LITE_MICRO_TEST(DecodeWithAltDecompressionMemory) {
encodes, ancillaries, outputs, expected, tflite::Register_DECODE(), &amr);
}

TF_LITE_MICRO_TEST(DecodeWithCustomRegistration) {
// Align the tensor data the same as a Buffer in the TfLite schema
alignas(16) int8_t output_data[std::size(kExpectCustom)] = {};
alignas(16) const AncillaryData<int8_t, std::size(kAncillaryDataCustom)>
kAncillaryData = {{kDcmCustom}, {kAncillaryDataCustom}};

constexpr int kAncillaryShapeCustom[] = {1, sizeof(kAncillaryData)};

const TfLiteIntArray* const encoded_dims =
tflite::testing::IntArrayFromInts(kEncodedShapeCustom);
static const TensorInDatum tid_encode = {
kEncodedCustom,
*encoded_dims,
};
static constexpr std::initializer_list<const TensorInDatum*> encodes = {
&tid_encode,
};

const TfLiteIntArray* const ancillary_dims =
tflite::testing::IntArrayFromInts(kAncillaryShapeCustom);
static const TensorInDatum tid_ancillary = {
&kAncillaryData,
*ancillary_dims,
};
static constexpr std::initializer_list<const TensorInDatum*> ancillaries = {
&tid_ancillary};

const TfLiteIntArray* const output_dims =
tflite::testing::IntArrayFromInts(kOutputShapeCustom);
constexpr int kOutputZeroPointsData[] = {0};
const TfLiteIntArray* const kOutputZeroPoints =
tflite::testing::IntArrayFromInts(kOutputZeroPointsData);
const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size};
static const TensorOutDatum tod = {
output_data, *output_dims, kTfLiteInt8, kOutputScales, *kOutputZeroPoints,
0, {},
};
static constexpr std::initializer_list<const TensorOutDatum*> outputs = {
&tod};

const std::initializer_list<const void*> expected = {kExpectCustom};

const std::initializer_list<tflite::MicroContext::CustomDecodeRegistration>
cdr = {
{
kDecodeTypeCustom,
0, // reserved
0, // reserved
0, // reserved
DecodeStateCustom::CreateDecodeStateCustom,
},
};

tflite::testing::TestDecode<encodes.size() + ancillaries.size(),
outputs.size()>(
encodes, ancillaries, outputs, expected, tflite::Register_DECODE(),
nullptr, &cdr);
}

TF_LITE_MICRO_TESTS_END
9 changes: 8 additions & 1 deletion tensorflow/lite/micro/kernels/decode_test_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ TfLiteStatus ExecuteDecodeTest(
TfLiteTensor* tensors, const TFLMRegistration& registration,
const std::initializer_list<const void*>& expected,
const std::initializer_list<MicroContext::AlternateMemoryRegion>* amr =
nullptr,
const std::initializer_list<MicroContext::CustomDecodeRegistration>* cdr =
nullptr) {
int kInputArrayData[kNumInputs + 1] = {kNumInputs};
for (size_t i = 0; i < kNumInputs; i++) {
Expand All @@ -104,6 +106,9 @@ TfLiteStatus ExecuteDecodeTest(
if (amr != nullptr) {
runner.GetFakeMicroContext()->SetDecompressionMemory(*amr);
}
if (cdr != nullptr) {
runner.GetFakeMicroContext()->SetCustomDecodeRegistrations(*cdr);
}

if (runner.InitAndPrepare() != kTfLiteOk || runner.Invoke() != kTfLiteOk) {
return kTfLiteError;
Expand Down Expand Up @@ -149,6 +154,8 @@ void TestDecode(
const TFLMRegistration& registration,
const std::initializer_list<MicroContext::AlternateMemoryRegion>* amr =
nullptr,
const std::initializer_list<MicroContext::CustomDecodeRegistration>* cdr =
nullptr,
const TfLiteStatus expected_status = kTfLiteOk) {
TfLiteTensor tensors[kNumInputs + kNumOutputs] = {};

Expand Down Expand Up @@ -182,7 +189,7 @@ void TestDecode(
}

TfLiteStatus s = ExecuteDecodeTest<kNumInputs, kNumOutputs>(
tensors, registration, expected, amr);
tensors, registration, expected, amr, cdr);
TF_LITE_MICRO_EXPECT_EQ(s, expected_status);
}

Expand Down
2 changes: 2 additions & 0 deletions tensorflow/lite/micro/kernels/kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/micro/micro_common.h"
#include "tensorflow/lite/micro/micro_context.h"
#include "tensorflow/lite/micro/micro_graph.h"

#ifdef USE_TFLM_COMPRESSION

Expand Down
32 changes: 31 additions & 1 deletion tensorflow/lite/micro/micro_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ namespace tflite {
// TODO(b/149795762): kTfLiteAbort cannot be part of the tflite TfLiteStatus.
const TfLiteStatus kTfLiteAbort = static_cast<TfLiteStatus>(15);

class DecodeState; // can't use decode_state.h due to circular include

// MicroContext is eventually going to become the API between TFLM and the
// kernels, replacing all the functions in TfLiteContext. The end state is code
// kernels to have code like:
Expand Down Expand Up @@ -136,7 +138,7 @@ class MicroContext {
};

// Set the alternate decompression memory regions.
// Can only be called during the MicroInterpreter kInit state.
// Can only be called during the kInit state.
virtual TfLiteStatus SetDecompressionMemory(
const std::initializer_list<AlternateMemoryRegion>& regions);

Expand Down Expand Up @@ -169,12 +171,40 @@ class MicroContext {
return nullptr;
}

struct CustomDecodeRegistration {
uint8_t type; // custom decode type
uint8_t reserved1; // reserved
uint8_t reserved2; // reserved
uint8_t reserved3; // reserved
tflite::DecodeState* (*func)(const TfLiteContext*, MicroProfilerInterface*);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of func, factory or create_state would be better to describes what the function will do.

};

// Set the custom DECODE operator registrations.
// Can only be called during the kInit state.
virtual TfLiteStatus SetCustomDecodeRegistrations(
const std::initializer_list<CustomDecodeRegistration>& registrations) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SetCustomDecodeRegistrations could pose a lifetime safety issue. Storing a pointer to a std::initializer_list is dangerous because the underlying array is often a temporary object that gets destroyed immediately after the function returns. Please update the API to accept a pointer to a persistent array and a size count (e.g., const CustomDecodeRegistration* registrations, size_t size) so the caller is explicitly responsible for ensuring the data's lifetime matches that of the interpreter. (also its comment should mention about the lifetime expectation)

With this change, this kind of help function would be helpful .

template <size_t N>
TfLiteStatus SetCustomDecodeRegistrations(const CustomDecodeRegistration (&regs)[N]) {
  return SetCustomDecodeRegistrations(regs, N);
}

We might want to make a similar change to decompress_regions_ as well but it could be a separate PR.

if (custom_decode_registrations_ != nullptr) {
return kTfLiteError;
}
custom_decode_registrations_ = &registrations;
return kTfLiteOk;
}

// Get the custom decompression registrations.
virtual const std::initializer_list<CustomDecodeRegistration>*
GetCustomDecodeRegistrations() const {
return custom_decode_registrations_;
}

private:
const std::initializer_list<AlternateMemoryRegion>* decompress_regions_ =
nullptr;
// array of size_t elements with length equal to decompress_regions_.size()
size_t* decompress_regions_allocations_ = nullptr;

const std::initializer_list<CustomDecodeRegistration>*
custom_decode_registrations_ = nullptr;

TF_LITE_REMOVE_VIRTUAL_DELETE
};

Expand Down
13 changes: 12 additions & 1 deletion tensorflow/lite/micro/micro_interpreter_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class MicroInterpreterContext : public MicroContext {
#endif // USE_TFLM_COMPRESSION

// Set the alternate decompression memory regions.
// Can only be called during the MicroInterpreter kInit state.
// Can only be called during the kInit state.
TfLiteStatus SetDecompressionMemory(
const std::initializer_list<AlternateMemoryRegion>& regions) override;

Expand All @@ -159,6 +159,17 @@ class MicroInterpreterContext : public MicroContext {
// decompression subsystem.
MicroProfilerInterface* GetAlternateProfiler() const override;

// Set the custom DECODE operator registrations.
// Can only be called during the kInit state.
virtual TfLiteStatus SetCustomDecodeRegistrations(
const std::initializer_list<CustomDecodeRegistration>& registrations)
override {
if (state_ != InterpreterState::kInit) {
return kTfLiteError;
}
return MicroContext::SetCustomDecodeRegistrations(registrations);
}

private:
MicroAllocator& allocator_;
MicroInterpreterGraph& graph_;
Expand Down
Loading