Skip to content

Commit

Permalink
* w/ & w/p trpo runnable
Browse files Browse the repository at this point in the history
  • Loading branch information
jingweiz committed Aug 21, 2017
1 parent b410023 commit 7ec9b42
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions core/agents/acer_single_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def _1st_order_trpo(self, detached_policy_loss_vb, detached_policy_vb, detached_
kk_dot_vb = torch.mm(k_vb, torch.t(k_vb))
z_star_vb = g_vb - ((kg_dot_vb - self.master.clip_1st_order_trpo) / kk_dot_vb).clamp(min=0) * k_vb

del kl_div_vb, k_vb, g_vb, kg_dot_vb, kk_dot_vb
return z_star_vb

def _backward(self, on_policy=True):
Expand All @@ -194,24 +193,30 @@ def _backward(self, on_policy=True):
if self.master.enable_continuous:
pass
else:
detached_policy_vb = [Variable(self.rollout.policy_vb[i].data, requires_grad=True) for i in range(rollout_steps)] # [rollout_steps x batch_size x action_dim]
action_batch_vb = Variable(torch.from_numpy(np.array(self.rollout.action)).view(rollout_steps, -1, 1).long()) # [rollout_steps x batch_size x 1]
action_batch_vb = Variable(torch.from_numpy(np.array(self.rollout.action)).view(rollout_steps, -1, 1).long()) # [rollout_steps x batch_size x 1]
if self.master.use_cuda:
action_batch_vb = action_batch_vb.cuda()
# NOTE: here we use the detached policies, cos when using 1st order trpo,
# NOTE: the policy losses are not directly backproped into the model
# NOTE: but only backproped up to the output of the network
# NOTE: and to make the code consistent, we also decouple the backprop
# NOTE: into two parts when not using trpo policy update
detached_policy_vb = [Variable(self.rollout.policy_vb[i].data, requires_grad=True) for i in range(rollout_steps)] # [rollout_steps x batch_size x action_dim]
detached_policy_log_vb = [torch.log(detached_policy_vb[i]) for i in range(rollout_steps)]
# detached_entropy_vb = [- (detached_policy_log_vb[i] * detached_policy_vb[i]).sum(1) for i in range(rollout_steps)] # TODO: check if should keepdim
detached_policy_log_vb = [detached_policy_log_vb[i].gather(1, action_batch_vb[i]) for i in range(rollout_steps) ]
if self.master.enable_1st_order_trpo:
z_star_vb = []
else:
policy_grad_vb = []
QretT_vb = self._get_QretT_vb(on_policy)

# compute loss
policy_loss_vb = Variable(torch.zeros(1, 1))
value_loss_vb = Variable(torch.zeros(1, 1), requires_grad=True)
if self.master.enable_1st_order_trpo: z_star_vb = []
for i in reversed(range(rollout_steps)):
# 1. importance sampling weights: /rho = /pi(|s_i) / /mu(|s_i)
if on_policy: # 1 for on-policy
rho_vb = Variable(torch.ones(1, self.master.action_dim))
rho_vb[0,0] = 50#0.5#Variable(torch.ones(1, self.master.action_dim))
else:
pass

Expand All @@ -227,22 +232,26 @@ def _backward(self, on_policy=True):
bias_correction_coefficient_vb = (1 - self.master.clip_trace / rho_vb).clamp(min=0) * detached_policy_vb[i]
detached_policy_loss_vb -= (bias_correction_coefficient_vb * detached_policy_vb[i].log() * (self.rollout.q0_vb[i].detach() - self.rollout.value0_vb[i].detach())).sum(1, keepdim=True).mean(0)

# 1.1 backprop policy loss up to the network output
if self.master.enable_1st_order_trpo:
# policy update d_/theta = d_/theta + /partical/theta / /partical/theta * z*
z_star_vb.append(self._1st_order_trpo(detached_policy_loss_vb, detached_policy_vb[i], self.rollout.detached_avg_policy_vb[i]))
else:
policy_grad_vb.append(grad(outputs=detached_policy_loss_vb, inputs=detached_policy_vb[i], retain_graph=False, only_inputs=True)[0])

# now we have all the losses ready, we backprop
self.model.zero_grad()
# 1.2 backprop the policy loss from the network output to the whole model
if self.master.enable_1st_order_trpo:
# 1. backprop the policy loss
# NOTE: here need to use the undetached policy_vb, cos we need to backprop to the whole model
backward(variables=self.rollout.policy_vb, grad_variables=z_star_vb, retain_graph=True)
# 2. backprop the value loss
value_loss_vb.backward()
else:
# here we can backprop both losses at once
loss_vb = value_loss_vb + policy_loss_vb
loss_vb.backward()
# NOTE: here we can backprop both losses at once, but to make consistent
# NOTE: and avoid the need to keep track of another set of undetached policy loss
# NOTE: we also decouple the backprop of the policy loss into two stages
# 1.2 backprop from the network output to the whole model
backward(variables=self.rollout.policy_vb, grad_variables=policy_grad_vb, retain_graph=True)
# 2. backprop the value loss
value_loss_vb.backward()
torch.nn.utils.clip_grad_norm(self.model.parameters(), self.master.clip_grad)

self._ensure_global_grads()
Expand Down

0 comments on commit 7ec9b42

Please sign in to comment.