Skip to content

Commit c98ec28

Browse files
committed
Addressed review comments.
1 parent c122105 commit c98ec28

File tree

1 file changed

+20
-24
lines changed

1 file changed

+20
-24
lines changed

tensorflow/lite/micro/kernels/xtensa/lstm_eval.h

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -661,15 +661,14 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,
661661
kernel_content.GetInternalTensor(tflite::kLstmInputTensor);
662662
TfLiteEvalTensor* recurrent = kernel_content.HiddenStateTensor();
663663

664-
int time_major = step_info.time_major();
665-
int num_batches = time_major == 0 ? 1 : step_info.batch_size();
666-
int input_dimension = step_info.input_dimension();
667-
int state_dimension = step_info.state_dimension();
668-
669664
const auto& size_info = op_data.size_info;
670-
if(size_info.batch_size > 1 && size_info.time_steps == 1) {
671-
num_batches = size_info.batch_size;
672-
}
665+
const int time_major = step_info.time_major();
666+
const int batch_size = size_info.batch_size;
667+
const int time_steps = size_info.time_steps;
668+
const int num_batches = time_major == 0 ? (time_steps == 1 ? batch_size : 1)
669+
: step_info.batch_size();
670+
const int input_dimension = step_info.input_dimension();
671+
const int state_dimension = step_info.state_dimension();
673672

674673
// Check offset validity to avoid memory overflow
675674
TFLITE_DCHECK_LE(step_info.InputOffset() + num_batches * input_dimension,
@@ -808,24 +807,21 @@ TfLiteStatus EvalLstm(const OpDataLSTM& op_data,
808807
// prepare for the next time step
809808
step_info.UpdateTime();
810809
}
810+
} else if(size_info.batch_size > 1 && size_info.time_steps == 1) {
811+
// Ramesh
812+
lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
813+
step_info, op_data, kernel_content, buffers);
811814
} else {
812-
// batch first, unable to size the input data. single batch inference
813-
if(size_info.batch_size > 1 && size_info.time_steps == 1) {
814-
// Ramesh
815-
lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
816-
step_info, op_data, kernel_content, buffers);
817-
} else {
818-
for (int b = 0; b < size_info.batch_size; b++) {
819-
for (int t = 0; t < size_info.time_steps; t++) {
820-
lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
821-
step_info, op_data, kernel_content, buffers);
822-
// prepare for the next time step
823-
step_info.UpdateTime();
824-
}
825-
// prepare for the next batch
826-
step_info.UpdateBatch();
827-
step_info.ResetTime();
815+
for (int b = 0; b < size_info.batch_size; b++) {
816+
for (int t = 0; t < size_info.time_steps; t++) {
817+
lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
818+
step_info, op_data, kernel_content, buffers);
819+
// prepare for the next time step
820+
step_info.UpdateTime();
828821
}
822+
// prepare for the next batch
823+
step_info.UpdateBatch();
824+
step_info.ResetTime();
829825
}
830826
}
831827
return kTfLiteOk;

0 commit comments

Comments
 (0)