This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
* [v1.x] [Large Tensor] Backport of Fixed RNN op (#17632) * Changed relevant function args to index_t * Added nightly test for RNN * Added fix for LSTM, GRU, RNN-ReLU, RNN-tanh * Using const instead of literals * Added nightly test for RNN ReLU & tanh, LSTM, GRU * Type assertion to force evaluation of output NDArray * Incorporated latest round of comments * [v1.x] Backport of Fix LSTM and GRU layers gradient calculations (#18203) * Fix input gradient calculation for bidirectional LSTM For bidiractional LSTM with number of layers > 2 input gradient calculation was incorrect. Reason of wrong calculations was overwriting y derivative (dy) tensor by calculated x derivative (dx) tensor before right2left layer could use dy for own gradient calculations. Propsed fix uses additional space to avoid overwriting. * Fix gradient calculation for GRU For GRU with number of layers > 2 i2h_weight gradient for layers in the middle (all except last and first) was incorrect. Wrong caluculations were caused by assigning output pointer to input instead of calculating new input pointer. * Enable tests for GRU and LSTM gradients * Fix comments * Change loop iteration deduction * Add more test cases for fused rnn layers Co-authored-by: Connor Goggins <[email protected]>
- Loading branch information
1 parent
36bd144
commit 8986e3f
Showing
4 changed files
with
291 additions
and
280 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.