From 4e8ad0c5eff5e5a9cc539405624f236a9e7b32c8 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 19 Nov 2020 17:06:02 +0530 Subject: [PATCH] fixed end of batch size mismatch --- pl_bolts/models/rl/reinforce_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/models/rl/reinforce_model.py b/pl_bolts/models/rl/reinforce_model.py index 4a99116178..e329c128e1 100644 --- a/pl_bolts/models/rl/reinforce_model.py +++ b/pl_bolts/models/rl/reinforce_model.py @@ -215,7 +215,7 @@ def loss(self, states, actions, scaled_rewards) -> torch.Tensor: # policy loss log_prob = log_softmax(logits, dim=1) - log_prob_actions = scaled_rewards * log_prob[range(self.batch_size), actions] + log_prob_actions = scaled_rewards * log_prob[range(len(log_prob)), actions] loss = -log_prob_actions.mean() return loss