-
Notifications
You must be signed in to change notification settings - Fork 514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/autodiff/checkpoint #1239
Feat/autodiff/checkpoint #1239
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find the tests quite confusing, they normaly reflect how
the code is going to be used, but it seems like the "scenario" is defined in many methods and assumed to be static.
What I think should improve the code:
- The "normal" forward pass that generates the Checkpoint struct with maybe some assertions on the
Computed
andRecompute
tensors. - Simulate the backward pass that calculates the "retro forwards" and asserts the results.
The more the forward pass is similar to normal tensor operations, the better.
fn make_ids() -> [NodeID; 7] { | ||
[ | ||
NodeID::new(), | ||
NodeID::new(), | ||
NodeID::new(), | ||
NodeID::new(), | ||
NodeID::new(), | ||
NodeID::new(), | ||
NodeID::new(), | ||
] | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That method isn't helpful, we can inline each NodeId
where they are needed.
} | ||
|
||
/// Make the leaves for a div tree | ||
fn make_leaves<B: Backend>(device: &B::Device, ids: [NodeID; 4]) -> (InnerStates, NodeTree) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same, the NodeIds should be inlined.
#[cfg(test)] | ||
mod tests { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you are already in a test module this is unecessary.
You should be happier with the new state of the production code. Tests are still convoluted but they were never really meant to outlive the development phase. I will migrate them to clean ones once I can use the autodiff api |
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
WIP for #936
Changes
Most of the logic for the autodiff checkpointing strategy.
At the moment, the code is all independant from the rest of Burn, I suspect it won't pass the CI because everything is unused; thus why I put my PR as a draft.
It has already undergone several refactorings, so I'm pretty satisfied with the cleanness.
Next I will of course plug it in all autodiff operations.
Testing
Heavily tested in the tests.rs file