diff --git a/rsl_rl/algorithms/ppo.py b/rsl_rl/algorithms/ppo.py index 44efe782..1479c06a 100644 --- a/rsl_rl/algorithms/ppo.py +++ b/rsl_rl/algorithms/ppo.py @@ -83,11 +83,14 @@ def __init__( if isinstance(symmetry_cfg["data_augmentation_func"], str): symmetry_cfg["data_augmentation_func"] = string_to_callable(symmetry_cfg["data_augmentation_func"]) # Check valid configuration - if symmetry_cfg["use_data_augmentation"] and not callable(symmetry_cfg["data_augmentation_func"]): + if not callable(symmetry_cfg["data_augmentation_func"]): raise ValueError( - "Data augmentation enabled but the function is not callable:" - f" {symmetry_cfg['data_augmentation_func']}" + f"Symmetry configuration exists but the function is not callable: " + f"{symmetry_cfg['data_augmentation_func']}" ) + # Check if the policy is compatible with symmetry + if isinstance(policy, ActorCriticRecurrent): + raise ValueError("Symmetry augmentation is not supported for recurrent policies.") # Store symmetry configuration self.symmetry = symmetry_cfg else: