Skip to content

Commit

Permalink
Merge pull request pytorch#386 from r-aristov/RL-loss-normalized
Browse files Browse the repository at this point in the history
Normalized loss in actor-critic and REINFORCE examples.
  • Loading branch information
msaroufim authored Mar 9, 2022
2 parents 4aee9d0 + 41ae197 commit 886b74e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
3 changes: 3 additions & 0 deletions reinforcement_learning/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def finish_episode():

# sum up all the values of policy_losses and value_losses
loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()

# normalize loss by number of rewards
loss /= rewards.numel()

# perform backprop
loss.backward()
Expand Down
1 change: 1 addition & 0 deletions reinforcement_learning/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def finish_episode():
policy_loss.append(-log_prob * R)
optimizer.zero_grad()
policy_loss = torch.cat(policy_loss).sum()
policy_loss /= rewards.numel()
policy_loss.backward()
optimizer.step()
del policy.rewards[:]
Expand Down

0 comments on commit 886b74e

Please sign in to comment.