Skip to content

Commit

Permalink
fix ddqn target(#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
sdpkjc authored Nov 20, 2022
1 parent 1ad3807 commit bcd3c43
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions abcdrl/ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ def predict(self, obs: torch.Tensor) -> torch.Tensor:

def learn(self, data: ReplayBuffer.Samples) -> dict[str, Any]:
with torch.no_grad():
target_max, target_argmax = self.model_t.value(data.next_observations).max(dim=1)
_, target_argmax = self.model.value(data.next_observations).max(dim=1)
# double dqn
target_max = self.model.value(data.next_observations).gather(1, target_argmax.unsqueeze(1)).squeeze()
target_max = self.model_t.value(data.next_observations).gather(1, target_argmax.unsqueeze(1)).squeeze()
td_target = data.rewards.flatten() + self.kwargs["gamma"] * target_max * (1 - data.dones.flatten())

old_val = self.model.value(data.observations).gather(1, data.actions).squeeze()
Expand Down

1 comment on commit bcd3c43

@vercel
Copy link

@vercel vercel bot commented on bcd3c43 Nov 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.