diff --git a/pl_bolts/models/rl/reinforce_model.py b/pl_bolts/models/rl/reinforce_model.py index 4a99116178..e329c128e1 100644 --- a/pl_bolts/models/rl/reinforce_model.py +++ b/pl_bolts/models/rl/reinforce_model.py @@ -215,7 +215,7 @@ def loss(self, states, actions, scaled_rewards) -> torch.Tensor: # policy loss log_prob = log_softmax(logits, dim=1) - log_prob_actions = scaled_rewards * log_prob[range(self.batch_size), actions] + log_prob_actions = scaled_rewards * log_prob[range(len(log_prob)), actions] loss = -log_prob_actions.mean() return loss