Skip to content
Merged
12 changes: 11 additions & 1 deletion tensorflow/lite/micro/kernels/prelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,18 @@ TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int8_t>(output));
return kTfLiteOk;
} break;
case kTfLiteInt16: {
reference_ops::BroadcastPrelu4DSlow(
params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(alpha),
tflite::micro::GetTensorData<int8_t>(alpha),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
return kTfLiteOk;
} break;
default:
MicroPrintf("Only float32 and uint8_t are supported currently, got %d.",
MicroPrintf("Input type '%s' is not supported.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/lite/micro/kernels/prelu_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
CalculatePreluParams(input, alpha, output, params));

if (output->type == kTfLiteInt16) {
// Make sure alpha type is Int8 when Output is Int16
TF_LITE_ENSURE(context, alpha->type == kTfLiteInt8);
}

micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(alpha);
micro_context->DeallocateTempTfLiteTensor(output);
Expand Down
86 changes: 59 additions & 27 deletions tensorflow/lite/micro/kernels/prelu_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,20 @@ namespace tflite {
namespace testing {
namespace {

template <typename T>
void ValidatePreluGoldens(TfLiteTensor* tensors, int tensors_size,
const T* golden, const int output_length,
T* output_data) {
const float kQuantizedTolerance = 2 * (1. / 256);

void ExecutePReluTest(const int tensors_count, TfLiteTensor* tensors) {
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);

const TFLMRegistration registration = tflite::Register_PRELU();
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
outputs_array,
/*builtin_data=*/nullptr);
micro::KernelRunner runner(registration, tensors, tensors_count, inputs_array,
outputs_array, /*builtin_data=*/nullptr);

TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());

for (int i = 0; i < output_length; ++i) {
TF_LITE_MICRO_EXPECT_NEAR(golden[i], output_data[i], 1e-5f);
}
}

void TestPreluFloat(int* input_dims_data, const float* input_data,
Expand All @@ -62,19 +56,22 @@ void TestPreluFloat(int* input_dims_data, const float* input_data,
CreateTensor(output_data, output_dims),
};

ValidatePreluGoldens(tensors, tensors_size, expected_output_data,
output_dims_count, output_data);
ExecutePReluTest(tensors_size, tensors);

for (int i = 0; i < output_dims_count; i++) {
TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]);
}
}

template <typename T>
template <typename T, typename Slope>
void TestPreluQuantized(int* input_dims_data, const float* input_data,
T* input_quantized, const float input_scale,
const int input_zero_point, int* alpha_dims_data,
const float* alpha_data, T* alpha_quantized,
const float* alpha_data, Slope* alpha_quantized,
const float alpha_scale, const int alpha_zero_point,
const float* golden, T* golden_quantized,
const float output_scale, const int output_zero_point,
int* output_dims_data, T* output_data) {
const float* golden, const float output_scale,
const int output_zero_point, int* output_dims_data,
T* output_quantized, float* output_data) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* alpha_dims = IntArrayFromInts(alpha_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
Expand All @@ -87,15 +84,18 @@ void TestPreluQuantized(int* input_dims_data, const float* input_data,
input_scale, input_zero_point),
CreateQuantizedTensor(alpha_data, alpha_quantized, alpha_dims,
alpha_scale, alpha_zero_point),
CreateQuantizedTensor(output_data, output_dims, output_scale,
CreateQuantizedTensor(output_quantized, output_dims, output_scale,
output_zero_point),
};

Quantize(golden, golden_quantized, output_dims_count, output_scale,
output_zero_point);
ExecutePReluTest(tensors_size, tensors);

Dequantize(output_quantized, output_dims_count, output_scale,
output_zero_point, output_data);

ValidatePreluGoldens(tensors, tensors_size, golden_quantized,
output_dims_count, output_data);
for (int i = 0; i < output_dims_count; i++) {
TF_LITE_MICRO_EXPECT_NEAR(golden[i], output_data[i], kQuantizedTolerance);
}
}
} // namespace
} // namespace testing
Expand Down Expand Up @@ -147,13 +147,45 @@ TF_LITE_MICRO_TEST(QuantizedInt8PreluActivationsOpTest) {
const int dims_count = 12;
int8_t input_quantized[dims_count];
int8_t alpha_quantized[3];
int8_t golden_quantized[dims_count];
float scale = 2.0 / 255.0;
int zero_point = 0;
int8_t output_data[dims_count];
tflite::testing::TestPreluQuantized(
int8_t output_data_q[dims_count];
float output_data_f[dims_count];
tflite::testing::TestPreluQuantized<int8_t, int8_t>(
input_shape, input_values, input_quantized, scale, zero_point,
alpha_shape, alpha_values, alpha_quantized, scale, zero_point, golden,
golden_quantized, scale, zero_point, output_shape, output_data);
scale, zero_point, output_shape, output_data_q, output_data_f);
}

TF_LITE_MICRO_TEST(QuantizedInt16PreluActivationsOpTest) {
int input_shape[] = {3, 2, 2, 3};
const float input_values[] = {
0.0f, 0.0f, 0.0f, // Row 1, Column 1
0.5f, 0.5f, 0.5f, // Row 1, Column 2
-1.0f, -1.0f, -1.0f, // Row 2, Column 1
-0.25f, -0.25f, -0.25f, // Row 1, Column 2
};
int alpha_shape[] = {3, 1, 1, 3};
const float alpha_values[] = {0.0f, 0.5f, -0.5f};
int output_shape[] = {3, 2, 2, 3};
const float golden[] = {
0.0f, 0.0f, 0.0f, // Row 1, Column 1
0.5f, 0.5f, 0.5f, // Row 1, Column 2
0.0f, -0.5f, 0.5f, // Row 2, Column 1
0.0f, -0.125f, 0.125f, // Row 1, Column 2
};
const int dims_count = 12;
int16_t input_quantized[dims_count];
int8_t alpha_quantized[3];
float scale_input_output = 2.0 / 65535.0;
float scale_alpha = 2.0 / 255.0;
int zero_point = 0;
int16_t output_data_q[dims_count];
float output_data_f[dims_count];
tflite::testing::TestPreluQuantized<int16_t, int8_t>(
input_shape, input_values, input_quantized, scale_input_output,
zero_point, alpha_shape, alpha_values, alpha_quantized, scale_alpha,
zero_point, golden, scale_input_output, zero_point, output_shape,
output_data_q, output_data_f);
}
TF_LITE_MICRO_TESTS_END
Loading