Skip to content

Commit c122105

Browse files
committed
Optimization in LSTM for batch > 1 cases on HiFi.
1 parent 3ea59f5 commit c122105

File tree

2 files changed

+22
-11
lines changed

2 files changed

+22
-11
lines changed

tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ void LstmStepManager::UpdateBatch() {
473473
// Multi-batch for time_major input
474474
RuntimeShape LstmStepManager::InputShape() const {
475475
int batch_size = 1;
476-
if (size_info_.time_major) {
476+
if (size_info_.time_major || ((size_info_.batch_size > 1 && size_info_.time_steps == 1))) {
477477
batch_size = size_info_.batch_size;
478478
}
479479
const int dims[2] = {batch_size, size_info_.input_dimension};
@@ -485,7 +485,7 @@ RuntimeShape LstmStepManager::InputShape() const {
485485
// Multi-batch for time_major input
486486
RuntimeShape LstmStepManager::StateShape() const {
487487
int batch_size = 1;
488-
if (size_info_.time_major) {
488+
if (size_info_.time_major || (size_info_.batch_size > 1 && size_info_.time_steps == 1)) {
489489
batch_size = size_info_.batch_size;
490490
}
491491
const int dims[2] = {batch_size, size_info_.state_dimension};

tensorflow/lite/micro/kernels/xtensa/lstm_eval.h

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,11 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,
666666
int input_dimension = step_info.input_dimension();
667667
int state_dimension = step_info.state_dimension();
668668

669+
const auto& size_info = op_data.size_info;
670+
if(size_info.batch_size > 1 && size_info.time_steps == 1) {
671+
num_batches = size_info.batch_size;
672+
}
673+
669674
// Check offset validity to avoid memory overflow
670675
TFLITE_DCHECK_LE(step_info.InputOffset() + num_batches * input_dimension,
671676
tflite::micro::GetTensorShape(input).FlatSize());
@@ -805,16 +810,22 @@ TfLiteStatus EvalLstm(const OpDataLSTM& op_data,
805810
}
806811
} else {
807812
// batch first, unable to size the input data. single batch inference
808-
for (int b = 0; b < size_info.batch_size; b++) {
809-
for (int t = 0; t < size_info.time_steps; t++) {
810-
lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
811-
step_info, op_data, kernel_content, buffers);
812-
// prepare for the next time step
813-
step_info.UpdateTime();
813+
if(size_info.batch_size > 1 && size_info.time_steps == 1) {
814+
// Ramesh
815+
lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
816+
step_info, op_data, kernel_content, buffers);
817+
} else {
818+
for (int b = 0; b < size_info.batch_size; b++) {
819+
for (int t = 0; t < size_info.time_steps; t++) {
820+
lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
821+
step_info, op_data, kernel_content, buffers);
822+
// prepare for the next time step
823+
step_info.UpdateTime();
824+
}
825+
// prepare for the next batch
826+
step_info.UpdateBatch();
827+
step_info.ResetTime();
814828
}
815-
// prepare for the next batch
816-
step_info.UpdateBatch();
817-
step_info.ResetTime();
818829
}
819830
}
820831
return kTfLiteOk;

0 commit comments

Comments
 (0)