Skip to content

Commit

Permalink
* small mods
Browse files Browse the repository at this point in the history
  • Loading branch information
jingweiz committed Aug 21, 2017
1 parent f5ee762 commit b410023
Showing 1 changed file with 10 additions and 26 deletions.
36 changes: 10 additions & 26 deletions core/agents/acer_single_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def _1st_order_trpo(self, detached_policy_loss_vb, detached_policy_vb, detached_
kk_dot_vb = torch.mm(k_vb, torch.t(k_vb))
z_star_vb = g_vb - ((kg_dot_vb - self.master.clip_1st_order_trpo) / kk_dot_vb).clamp(min=0) * k_vb

del kl_div_vb, k_vb, g_vb, kg_dot_vb, kk_dot_vb
return z_star_vb

def _backward(self, on_policy=True):
Expand All @@ -204,7 +205,7 @@ def _backward(self, on_policy=True):

# compute loss
policy_loss_vb = Variable(torch.zeros(1, 1))
value_loss_vb = Variable(torch.zeros(1, 1))
value_loss_vb = Variable(torch.zeros(1, 1), requires_grad=True)
if self.master.enable_1st_order_trpo: z_star_vb = []
for i in reversed(range(rollout_steps)):
# 1. importance sampling weights: /rho = /pi(|s_i) / /mu(|s_i)
Expand All @@ -230,37 +231,20 @@ def _backward(self, on_policy=True):
# policy update d_/theta = d_/theta + /partical/theta / /partical/theta * z*
z_star_vb.append(self._1st_order_trpo(detached_policy_loss_vb, detached_policy_vb[i], self.rollout.detached_avg_policy_vb[i]))

# single_step_policy_loss = -(rho.gather(1, actions[i]).clamp(max=args.trace_max) * log_prob * A.detach()).mean(0) # Average over batch
# # Off-policy bias correction
# if off_policy:
# # g = g + /sum_a [1 - c / /rho_a]_+ /pi(a|s_i; /theta) * /delta_theta * log(/pi(a|s_i; /theta)) * (Q(s_i, a; theta) - V(s_i; theta)
# bias_weight = (1 - args.trace_max / rho).clamp(min=0) * policies[i]
# single_step_policy_loss -= (bias_weight * policies[i].log() * (Qs[i].detach() - Vs[i].expand_as(Qs[i]).detach())).sum(1).mean(0)
# if args.trust_region:
# # Policy update d_/theta = d_/theta + /partical/theta / /partical/theta * z*
# policy_loss += _trust_region_loss(model, policies[i], average_policies[i], single_step_policy_loss, args.trust_region_threshold)
# else:
# # Policy update d_/theta = d_/theta + partical_/theta / /partical_/theta * g
# policy_loss += single_step_policy_loss
# # Entropy regularisation d_/theta = d_/theta + /beta * /delta H(/pi(s_i; /theta))
# policy_loss -= args.entropy_weight * -(policies[i].log() * policies[i]).sum(1).mean(0) # Sum over probabilities, average over batch

# now we have all the losses ready, we backprop
self.model.zero_grad()
# backprop the policy loss
if self.master.enable_1st_order_trpo:
# 1. backprop the policy loss
# NOTE: here need to use the undetached policy_vb, cos we need to backprop to the whole model
backward(variables=self.rollout.policy_vb, grad_variables=z_star_vb, retain_graph=True)
# backprop the value loss

print("=======================>")
# 2. backprop the value loss
value_loss_vb.backward()
else:
# here we can backprop both losses at once
loss_vb = value_loss_vb + policy_loss_vb
loss_vb.backward()
torch.nn.utils.clip_grad_norm(self.model.parameters(), self.master.clip_grad)

# compute loss
# loss_vb = Variable(torch.zeros(1))
# # TODO:
# loss_vb.backward()
# torch.nn.utils.clip_grad_norm(self.model.parameters(), self.master.clip_grad)
#
self._ensure_global_grads()
self.master.optimizer.step()
self.train_step += 1
Expand Down

0 comments on commit b410023

Please sign in to comment.