-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this design is too heavy. I don't like putting a core functionality like loss.backward()
into a callback. It makes it too hard to see what's going on in the trainer.
Instead, can we use the normal TrainerCallback
, and give it some extra methods, like pre_backward()
and post_backward()
? Can you solve your problem with that?
I can't do adversarial training without |
You could leave one |
@dirkgr I have made the revisions you suggested :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other than the changelog entry, this is great!
CHANGELOG.md
Outdated
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 | |||
|
|||
## Unreleased | |||
|
|||
### Added | |||
|
|||
- Added `BackwardCallback`, a training callback which allows for control over backpropagation and gradient manipulation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment isn't accurate anymore, is it?
if not backward_called: | ||
trainer._scaler.scale(loss).backward() # type: ignore | ||
return True | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it an error if this gets called with backward_called == True
? Should we throw an exception in that case?
tests/training/trainer_test.py
Outdated
if not backward_called: | ||
loss.backward() | ||
for param in trainer.model.parameters(): | ||
param.grad *= 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that really the best way to do that?
param.grad *= 0.0 | |
param.zero_() |
I don't know for sure, but I would guess that zero_()
is faster.
on_backward
trainer callback
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! Will make a great update for tomorrow's meeting, too!
* added BackwardCallback * finished tests * fixed linting issue * revised design per Dirk's suggestion * added OnBackwardException, changed loss to batch_ouputs, etc. Co-authored-by: Arjun Subramonian <[email protected]>
Additions proposed in this pull request:
on_backward
training callback which allows for control over backpropagation and gradient manipulation.