-
Couldn't load subscription status.
- Fork 408
Description
Description
In the PPO.update() method in rsl_rl/algorithms/ppo.py, there appears to be an issue with how the batch size is calculated when using symmetric data augmentation with RNN models.
Problem
rsl_rl/rsl_rl/algorithms/ppo.py
Line 218 in 8363520
| original_batch_size = obs_batch.batch_size[0] |
rsl_rl/rsl_rl/algorithms/ppo.py
Line 237 in 8363520
| num_aug = int(obs_batch.batch_size[0] / original_batch_size) |
When using RNN models, the obs_batch returned from recurrent_mini_batch_generator is a TensorDict with dimensions [time_steps, batch_size, obs_dim]. The batch_size attribute of this TensorDict is [time_steps, batch_size], so:
obs_batch.batch_size[0]returns the number of time steps, not the batch sizeobs_batch.batch_size[1]returns the actual batch size
However, the current implementation always uses obs_batch.batch_size[0], which means for RNN models it's using the time step count instead of the batch size, leading to incorrect calculations of num_aug (number of augmentations).