@@ -666,6 +666,11 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,
666666 int input_dimension = step_info.input_dimension ();
667667 int state_dimension = step_info.state_dimension ();
668668
669+ const auto & size_info = op_data.size_info ;
670+ if (size_info.batch_size > 1 && size_info.time_steps == 1 ) {
671+ num_batches = size_info.batch_size ;
672+ }
673+
669674 // Check offset validity to avoid memory overflow
670675 TFLITE_DCHECK_LE (step_info.InputOffset () + num_batches * input_dimension,
671676 tflite::micro::GetTensorShape (input).FlatSize ());
@@ -805,16 +810,22 @@ TfLiteStatus EvalLstm(const OpDataLSTM& op_data,
805810 }
806811 } else {
807812 // batch first, unable to size the input data. single batch inference
808- for (int b = 0 ; b < size_info.batch_size ; b++) {
809- for (int t = 0 ; t < size_info.time_steps ; t++) {
810- lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
811- step_info, op_data, kernel_content, buffers);
812- // prepare for the next time step
813- step_info.UpdateTime ();
813+ if (size_info.batch_size > 1 && size_info.time_steps == 1 ) {
814+ // Ramesh
815+ lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
816+ step_info, op_data, kernel_content, buffers);
817+ } else {
818+ for (int b = 0 ; b < size_info.batch_size ; b++) {
819+ for (int t = 0 ; t < size_info.time_steps ; t++) {
820+ lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
821+ step_info, op_data, kernel_content, buffers);
822+ // prepare for the next time step
823+ step_info.UpdateTime ();
824+ }
825+ // prepare for the next batch
826+ step_info.UpdateBatch ();
827+ step_info.ResetTime ();
814828 }
815- // prepare for the next batch
816- step_info.UpdateBatch ();
817- step_info.ResetTime ();
818829 }
819830 }
820831 return kTfLiteOk ;
0 commit comments