Skip to content

Commit

Permalink
* 1st order trpo to only return z, backward is done outside on the wh…
Browse files Browse the repository at this point in the history
…ole rollout instead of individually
  • Loading branch information
jingweiz committed Aug 21, 2017
1 parent ff10e13 commit f5ee762
Showing 1 changed file with 32 additions and 24 deletions.
56 changes: 32 additions & 24 deletions core/agents/acer_single_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,23 +172,20 @@ def _get_QretT_vb(self, on_policy=True):

return QretT_vb

def _1st_order_trpo(self, detached_policy_loss_vb, policy_vb, detached_policy_vb, detached_avg_policy_vb):
def _1st_order_trpo(self, detached_policy_loss_vb, detached_policy_vb, detached_avg_policy_vb):
# KL divergence k = \delta_{\phi_{\theta}} DKL[ \pi(|\phi_{\theta_a}) || \pi{|\phi_{\theta}}]
kl_div_vb = F.kl_div(detached_policy_vb.log(), detached_avg_policy_vb, size_average=False)
# NOTE: k & g are wll w.r.t. the network output, which is policy_vb
# NOTE: gradient from this part does not need to be propagated back into the model
# NOTE: that's why we are only using detached policies here
# NOTE: k & g are wll w.r.t. the network output, which is detached_policy_vb
# NOTE: gradient from this part will not flow back into the model
# NOTE: that's why we are only using detached policy variables here
k_vb = grad(outputs=kl_div_vb, inputs=detached_policy_vb, retain_graph=False, only_inputs=True)[0]
g_vb = grad(outputs=detached_policy_loss_vb, inputs=detached_policy_vb, retain_graph=False, only_inputs=True)[0]

kg_dot_vb = torch.mm(k_vb, torch.t(g_vb))
kk_dot_vb = torch.mm(k_vb, torch.t(k_vb))
z_star_vb = g_vb - ((kg_dot_vb - self.master.clip_1st_order_trpo) / kk_dot_vb).clamp(min=0) * k_vb

# NOTE: we still need to backprop the value loss afterwards, so the graph needs to be retained here
self.model.zero_grad()
backward(variables=policy_vb, grad_variables=z_star_vb, retain_graph=True)
# NOTE: must not call zero_grad before backprop value loss
return z_star_vb

def _backward(self, on_policy=True):
# preparation
Expand All @@ -208,6 +205,7 @@ def _backward(self, on_policy=True):
# compute loss
policy_loss_vb = Variable(torch.zeros(1, 1))
value_loss_vb = Variable(torch.zeros(1, 1))
if self.master.enable_1st_order_trpo: z_star_vb = []
for i in reversed(range(rollout_steps)):
# 1. importance sampling weights: /rho = /pi(|s_i) / /mu(|s_i)
if on_policy: # 1 for on-policy
Expand All @@ -230,22 +228,32 @@ def _backward(self, on_policy=True):

if self.master.enable_1st_order_trpo:
# policy update d_/theta = d_/theta + /partical/theta / /partical/theta * z*
policy_loss_vb += self._1st_order_trpo(detached_policy_loss_vb, self.rollout.policy_vb[i], detached_policy_vb[i], self.rollout.detached_avg_policy_vb[i])

single_step_policy_loss = -(rho.gather(1, actions[i]).clamp(max=args.trace_max) * log_prob * A.detach()).mean(0) # Average over batch
# Off-policy bias correction
if off_policy:
# g = g + /sum_a [1 - c / /rho_a]_+ /pi(a|s_i; /theta) * /delta_theta * log(/pi(a|s_i; /theta)) * (Q(s_i, a; theta) - V(s_i; theta)
bias_weight = (1 - args.trace_max / rho).clamp(min=0) * policies[i]
single_step_policy_loss -= (bias_weight * policies[i].log() * (Qs[i].detach() - Vs[i].expand_as(Qs[i]).detach())).sum(1).mean(0)
if args.trust_region:
# Policy update d_/theta = d_/theta + /partical/theta / /partical/theta * z*
policy_loss += _trust_region_loss(model, policies[i], average_policies[i], single_step_policy_loss, args.trust_region_threshold)
else:
# Policy update d_/theta = d_/theta + partical_/theta / /partical_/theta * g
policy_loss += single_step_policy_loss
# Entropy regularisation d_/theta = d_/theta + /beta * /delta H(/pi(s_i; /theta))
policy_loss -= args.entropy_weight * -(policies[i].log() * policies[i]).sum(1).mean(0) # Sum over probabilities, average over batch
z_star_vb.append(self._1st_order_trpo(detached_policy_loss_vb, detached_policy_vb[i], self.rollout.detached_avg_policy_vb[i]))

# single_step_policy_loss = -(rho.gather(1, actions[i]).clamp(max=args.trace_max) * log_prob * A.detach()).mean(0) # Average over batch
# # Off-policy bias correction
# if off_policy:
# # g = g + /sum_a [1 - c / /rho_a]_+ /pi(a|s_i; /theta) * /delta_theta * log(/pi(a|s_i; /theta)) * (Q(s_i, a; theta) - V(s_i; theta)
# bias_weight = (1 - args.trace_max / rho).clamp(min=0) * policies[i]
# single_step_policy_loss -= (bias_weight * policies[i].log() * (Qs[i].detach() - Vs[i].expand_as(Qs[i]).detach())).sum(1).mean(0)
# if args.trust_region:
# # Policy update d_/theta = d_/theta + /partical/theta / /partical/theta * z*
# policy_loss += _trust_region_loss(model, policies[i], average_policies[i], single_step_policy_loss, args.trust_region_threshold)
# else:
# # Policy update d_/theta = d_/theta + partical_/theta / /partical_/theta * g
# policy_loss += single_step_policy_loss
# # Entropy regularisation d_/theta = d_/theta + /beta * /delta H(/pi(s_i; /theta))
# policy_loss -= args.entropy_weight * -(policies[i].log() * policies[i]).sum(1).mean(0) # Sum over probabilities, average over batch

# now we have all the losses ready, we backprop
self.model.zero_grad()
# backprop the policy loss
if self.master.enable_1st_order_trpo:
# NOTE: here need to use the undetached policy_vb, cos we need to backprop to the whole model
backward(variables=self.rollout.policy_vb, grad_variables=z_star_vb, retain_graph=True)
# backprop the value loss

print("=======================>")

# compute loss
# loss_vb = Variable(torch.zeros(1))
Expand Down

0 comments on commit f5ee762

Please sign in to comment.