Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
clean up forward hooks on exception (#4778)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Nov 10, 2020
1 parent fcc3a70 commit dc3a4f6
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed a bug where forward hooks were not cleaned up with saliency interpreters if there
was an exception.
- Fixed the computation of saliency maps in the Interpret code when using mismatched indexing.
Previously, we would compute gradients from the top of the transformer, after aggregation from
wordpieces to tokens, which gives results that are not very informative. Now, we compute gradients
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,11 @@ def _integrate_gradients(self, instance: Instance) -> Dict[str, numpy.ndarray]:
# Hook for modifying embedding value
handles = self._register_hooks(alpha, embeddings_list, token_offsets)

grads = self.predictor.get_gradients([instance])[0]
for handle in handles:
handle.remove()
try:
grads = self.predictor.get_gradients([instance])[0]
finally:
for handle in handles:
handle.remove()

# Running sum of gradients
if ig_grads == {}:
Expand Down
8 changes: 5 additions & 3 deletions allennlp/interpret/saliency_interpreters/simple_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ def saliency_interpret_from_json(self, inputs: JsonDict) -> JsonDict:

# Hook used for saving embeddings
handles = self._register_hooks(embeddings_list, token_offsets)
grads = self.predictor.get_gradients([instance])[0]
for handle in handles:
handle.remove()
try:
grads = self.predictor.get_gradients([instance])[0]
finally:
for handle in handles:
handle.remove()

# Gradients come back in the reverse order that they were sent into the network
embeddings_list.reverse()
Expand Down
6 changes: 4 additions & 2 deletions allennlp/interpret/saliency_interpreters/smooth_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,10 @@ def _smooth_grads(self, instance: Instance) -> Dict[str, numpy.ndarray]:
total_gradients: Dict[str, Any] = {}
for _ in range(self.num_samples):
handle = self._register_forward_hook(self.stdev)
grads = self.predictor.get_gradients([instance])[0]
handle.remove()
try:
grads = self.predictor.get_gradients([instance])[0]
finally:
handle.remove()

# Sum gradients
if total_gradients == {}:
Expand Down

0 comments on commit dc3a4f6

Please sign in to comment.