Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 78 additions & 27 deletions rsl_rl/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,10 @@ def update(self): # noqa: C901
# original batch size
# we assume policy group is always there and needs augmentation
original_batch_size = obs_batch.batch_size[0]

original_actions_batch_size = actions_batch.shape[0]
if self.policy.is_recurrent:
original_batch_size = obs_batch.batch_size[1]
original_actions_batch_size = actions_batch.shape[1]
# check if we should normalize advantages per mini batch
if self.normalize_advantage_per_mini_batch:
with torch.no_grad():
Expand All @@ -234,14 +237,36 @@ def update(self): # noqa: C901
)
# compute number of augmentations per sample
# we assume policy group is always there and needs augmentation
num_aug = int(obs_batch.batch_size[0] / original_batch_size)
# repeat the rest of the batch
# -- actor
old_actions_log_prob_batch = old_actions_log_prob_batch.repeat(num_aug, 1)
# -- critic
target_values_batch = target_values_batch.repeat(num_aug, 1)
advantages_batch = advantages_batch.repeat(num_aug, 1)
returns_batch = returns_batch.repeat(num_aug, 1)
if self.policy.is_recurrent:
num_aug = int(obs_batch.batch_size[1] / original_batch_size)
# repeat the rest of the batch
# -- actor
old_actions_log_prob_batch = old_actions_log_prob_batch.repeat(1, num_aug, 1)
# -- critic
target_values_batch = target_values_batch.repeat(1, num_aug, 1)
advantages_batch = advantages_batch.repeat(1, num_aug, 1)
returns_batch = returns_batch.repeat(1, num_aug, 1)
for idx, hid_state in enumerate(hid_states_batch):
if isinstance(hid_state, list):
# 对 list 中的每个 tensor 进行 repeat
hid_states_batch = list(hid_states_batch) # 转换为可变的 list
hid_states_batch[idx] = [tensor.repeat(1, num_aug, 1) for tensor in hid_state]
elif isinstance(hid_state, torch.Tensor):
# 对 tensor 进行 repeat
hid_states_batch = list(hid_states_batch) # 转换为可变的 list
hid_states_batch[idx] = hid_state.repeat(1, num_aug, 1)
else:
raise ValueError(f"Unsupported hidden state type: {type(hid_state)}")
masks_batch = masks_batch.repeat(1, num_aug)
else:
num_aug = int(obs_batch.batch_size[0] / original_batch_size)
# repeat the rest of the batch
# -- actor
old_actions_log_prob_batch = old_actions_log_prob_batch.repeat(num_aug, 1)
# -- critic
target_values_batch = target_values_batch.repeat(num_aug, 1)
advantages_batch = advantages_batch.repeat(num_aug, 1)
returns_batch = returns_batch.repeat(num_aug, 1)

# Recompute actions log prob and entropy for current batch of transitions
# Note: we need to do this because we updated the policy with the new parameters
Expand All @@ -252,9 +277,14 @@ def update(self): # noqa: C901
value_batch = self.policy.evaluate(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1])
# -- entropy
# we only keep the entropy of the first augmentation (the original one)
mu_batch = self.policy.action_mean[:original_batch_size]
sigma_batch = self.policy.action_std[:original_batch_size]
entropy_batch = self.policy.entropy[:original_batch_size]
if self.policy.is_recurrent:
mu_batch = self.policy.action_mean[:, :original_actions_batch_size]
sigma_batch = self.policy.action_std[:, :original_actions_batch_size]
entropy_batch = self.policy.entropy[:, :original_actions_batch_size]
else:
mu_batch = self.policy.action_mean[:original_batch_size]
sigma_batch = self.policy.action_std[:original_batch_size]
entropy_batch = self.policy.entropy[:original_batch_size]

# KL
if self.desired_kl is not None and self.schedule == "adaptive":
Expand Down Expand Up @@ -324,23 +354,44 @@ def update(self): # noqa: C901
# compute number of augmentations per sample
num_aug = int(obs_batch.shape[0] / original_batch_size)

# actions predicted by the actor for symmetrically-augmented observations
mean_actions_batch = self.policy.act_inference(obs_batch.detach().clone())
if self.policy.is_recurrent:
if isinstance(hid_states_batch[0], list):
tmp_hid_states_batch = [hid_state.detach().clone() for hid_state in hid_states_batch[0]]
elif isinstance(hid_states_batch[0], torch.Tensor):
tmp_hid_states_batch = hid_states_batch[0].detach().clone()
mean_actions_batch = self.policy.act_inference(
obs_batch.detach().clone(),
masks=masks_batch.detach().clone(),
hidden_states=tmp_hid_states_batch,
)
action_mean_orig = mean_actions_batch[:, :original_actions_batch_size]
_, actions_mean_symm_batch = data_augmentation_func(
obs=None, actions=action_mean_orig, env=self.symmetry["_env"]
)

# compute the symmetrically augmented actions
# note: we are assuming the first augmentation is the original one.
# We do not use the action_batch from earlier since that action was sampled from the distribution.
# However, the symmetry loss is computed using the mean of the distribution.
action_mean_orig = mean_actions_batch[:original_batch_size]
_, actions_mean_symm_batch = data_augmentation_func(
obs=None, actions=action_mean_orig, env=self.symmetry["_env"]
)
# compute the loss (we skip the first augmentation as it is the original one)
mse_loss = torch.nn.MSELoss()
symmetry_loss = mse_loss(
mean_actions_batch[:, original_actions_batch_size:], actions_mean_symm_batch.detach()[:, original_actions_batch_size:]
)
else:
# actions predicted by the actor for symmetrically-augmented observations
mean_actions_batch = self.policy.act_inference(obs_batch.detach().clone())

# compute the symmetrically augmented actions
# note: we are assuming the first augmentation is the original one.
# We do not use the action_batch from earlier since that action was sampled from the distribution.
# However, the symmetry loss is computed using the mean of the distribution.
action_mean_orig = mean_actions_batch[:original_batch_size]
_, actions_mean_symm_batch = data_augmentation_func(
obs=None, actions=action_mean_orig, env=self.symmetry["_env"]
)

# compute the loss (we skip the first augmentation as it is the original one)
mse_loss = torch.nn.MSELoss()
symmetry_loss = mse_loss(
mean_actions_batch[original_batch_size:], actions_mean_symm_batch.detach()[original_batch_size:]
)
# compute the loss (we skip the first augmentation as it is the original one)
mse_loss = torch.nn.MSELoss()
symmetry_loss = mse_loss(
mean_actions_batch[original_batch_size:], actions_mean_symm_batch.detach()[original_batch_size:]
)
# add the loss to the total loss
if self.symmetry["use_mirror_loss"]:
loss += self.symmetry["mirror_loss_coeff"] * symmetry_loss
Expand Down
4 changes: 2 additions & 2 deletions rsl_rl/modules/actor_critic_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,10 @@ def act(self, obs, masks=None, hidden_states=None):
self.update_distribution(out_mem)
return self.distribution.sample()

def act_inference(self, obs):
def act_inference(self, obs, masks=None, hidden_states=None):
obs = self.get_actor_obs(obs)
obs = self.actor_obs_normalizer(obs)
out_mem = self.memory_a(obs).squeeze(0)
out_mem = self.memory_a(obs, masks, hidden_states).squeeze(0)
if self.state_dependent_std:
return self.actor(out_mem)[..., 0, :]
else:
Expand Down