diff --git a/abcdrl/ddqn.py b/abcdrl/ddqn.py index 0ff2950..3a46407 100644 --- a/abcdrl/ddqn.py +++ b/abcdrl/ddqn.py @@ -130,9 +130,9 @@ def predict(self, obs: torch.Tensor) -> torch.Tensor: def learn(self, data: ReplayBuffer.Samples) -> dict[str, Any]: with torch.no_grad(): - target_max, target_argmax = self.model_t.value(data.next_observations).max(dim=1) + _, target_argmax = self.model.value(data.next_observations).max(dim=1) # double dqn - target_max = self.model.value(data.next_observations).gather(1, target_argmax.unsqueeze(1)).squeeze() + target_max = self.model_t.value(data.next_observations).gather(1, target_argmax.unsqueeze(1)).squeeze() td_target = data.rewards.flatten() + self.kwargs["gamma"] * target_max * (1 - data.dones.flatten()) old_val = self.model.value(data.observations).gather(1, data.actions).squeeze()