Skip to content

Commit

Permalink
fixed end of batch size mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
unknown committed Nov 19, 2020
1 parent fa5a944 commit 4e8ad0c
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion pl_bolts/models/rl/reinforce_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def loss(self, states, actions, scaled_rewards) -> torch.Tensor:

# policy loss
log_prob = log_softmax(logits, dim=1)
log_prob_actions = scaled_rewards * log_prob[range(self.batch_size), actions]
log_prob_actions = scaled_rewards * log_prob[range(len(log_prob)), actions]
loss = -log_prob_actions.mean()

return loss
Expand Down

0 comments on commit 4e8ad0c

Please sign in to comment.