Skip to content

Commit

Permalink
remove unnecessary code from hessian computation
Browse files Browse the repository at this point in the history
  • Loading branch information
Ofir Gordon authored and Ofir Gordon committed Nov 22, 2023
1 parent a4054cc commit e1199ce
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ def compute(self) -> List[float]:
# Compute the approximation per node's output
score_approx_per_output = []
for grad in gradients:
grad = tf.reshape(grad, [grad.shape[0], -1])
score_approx_per_output.append(tf.reduce_mean(tf.reduce_sum(tf.pow(grad, 2.0))))
score_approx_per_output.append(tf.reduce_sum(tf.pow(grad, 2.0)))

# Free gradients
del grad
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ def compute(self) -> List[float]:
requires_grad=True,
device=device))
break
hess_v = torch.reshape(hess_v, [hess_v.shape[0], -1])
hessian_trace_approx = torch.mean(torch.sum(torch.pow(hess_v, 2.0)))
hessian_trace_approx = torch.sum(torch.pow(hess_v, 2.0))

# If the change to the mean Hessian approximation is insignificant we stop the calculation
if j > MIN_HESSIAN_ITER:
Expand Down

0 comments on commit e1199ce

Please sign in to comment.