@@ -661,15 +661,14 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,
661661 kernel_content.GetInternalTensor (tflite::kLstmInputTensor );
662662 TfLiteEvalTensor* recurrent = kernel_content.HiddenStateTensor ();
663663
664- int time_major = step_info.time_major ();
665- int num_batches = time_major == 0 ? 1 : step_info.batch_size ();
666- int input_dimension = step_info.input_dimension ();
667- int state_dimension = step_info.state_dimension ();
668-
669664 const auto & size_info = op_data.size_info ;
670- if (size_info.batch_size > 1 && size_info.time_steps == 1 ) {
671- num_batches = size_info.batch_size ;
672- }
665+ const int time_major = step_info.time_major ();
666+ const int batch_size = size_info.batch_size ;
667+ const int time_steps = size_info.time_steps ;
668+ const int num_batches = time_major == 0 ? (time_steps == 1 ? batch_size : 1 )
669+ : step_info.batch_size ();
670+ const int input_dimension = step_info.input_dimension ();
671+ const int state_dimension = step_info.state_dimension ();
673672
674673 // Check offset validity to avoid memory overflow
675674 TFLITE_DCHECK_LE (step_info.InputOffset () + num_batches * input_dimension,
@@ -808,24 +807,21 @@ TfLiteStatus EvalLstm(const OpDataLSTM& op_data,
808807 // prepare for the next time step
809808 step_info.UpdateTime ();
810809 }
810+ } else if (size_info.batch_size > 1 && size_info.time_steps == 1 ) {
811+ // Ramesh
812+ lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
813+ step_info, op_data, kernel_content, buffers);
811814 } else {
812- // batch first, unable to size the input data. single batch inference
813- if (size_info.batch_size > 1 && size_info.time_steps == 1 ) {
814- // Ramesh
815- lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
816- step_info, op_data, kernel_content, buffers);
817- } else {
818- for (int b = 0 ; b < size_info.batch_size ; b++) {
819- for (int t = 0 ; t < size_info.time_steps ; t++) {
820- lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
821- step_info, op_data, kernel_content, buffers);
822- // prepare for the next time step
823- step_info.UpdateTime ();
824- }
825- // prepare for the next batch
826- step_info.UpdateBatch ();
827- step_info.ResetTime ();
815+ for (int b = 0 ; b < size_info.batch_size ; b++) {
816+ for (int t = 0 ; t < size_info.time_steps ; t++) {
817+ lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
818+ step_info, op_data, kernel_content, buffers);
819+ // prepare for the next time step
820+ step_info.UpdateTime ();
828821 }
822+ // prepare for the next batch
823+ step_info.UpdateBatch ();
824+ step_info.ResetTime ();
829825 }
830826 }
831827 return kTfLiteOk ;
0 commit comments