-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[Large Tensor] Fixed RNN op #17632
[Large Tensor] Fixed RNN op #17632
Changes from all commits
7ec6480
94278ab
4fa7bd5
f865f65
671ee0a
eb09cf1
999328a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,7 +63,7 @@ struct RNNParam : public dmlc::Parameter<RNNParam> { | |
bool bidirectional, state_outputs; | ||
int mode; | ||
float p; | ||
int seq_length_, batch_size_, input_size_; | ||
index_t seq_length_, batch_size_, input_size_; | ||
|
||
bool use_sequence_length; | ||
dmlc::optional<int> projection_size; | ||
|
@@ -122,8 +122,8 @@ struct RNNParam : public dmlc::Parameter<RNNParam> { | |
} | ||
}; | ||
|
||
inline int GetRnnParamSize(int num_layer, | ||
int input_size, | ||
inline index_t GetRnnParamSize(int num_layer, | ||
index_t input_size, | ||
int state_size, | ||
int direction, | ||
int mode, | ||
|
@@ -140,14 +140,14 @@ inline int GetRnnParamSize(int num_layer, | |
size *= 3; | ||
break; | ||
} | ||
int size1 = (input_size + state_size + 2) * size; // first layer size | ||
int size2 = (state_size * direction + state_size + 2) * size; // other layers size | ||
index_t size1 = (input_size + state_size + 2) * size; // first layer size | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lets prefer size_t for sizes. Or do you think these values can be negative too ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed - testing changes now. |
||
index_t size2 = (state_size * direction + state_size + 2) * size; // other layers size | ||
if (projection_size.has_value()) { | ||
int proj_size = projection_size.value(); | ||
index_t proj_size = projection_size.value(); | ||
size1 = (input_size + proj_size + 2) * size; | ||
size2 = (proj_size * direction + proj_size + 2) * size; | ||
} | ||
int param_size = size1 + (num_layer - 1) * size2; | ||
index_t param_size = size1 + (num_layer - 1) * size2; | ||
if (projection_size.has_value()) { | ||
param_size += projection_size.value() * state_size * num_layer * direction; | ||
} | ||
|
@@ -182,8 +182,8 @@ inline int GetRnnBiasSize(int num_layer, | |
* - output -> h[t](, c[t] additionally with Lstm) time by time(sz: NxH(x2)) | ||
* - intermediate y[1...T] as next layer's inputs(sz: TxNxHxD) | ||
*/ | ||
inline size_t GetRNNWorkspaceSize(int seq_length, | ||
int batch_size, | ||
inline size_t GetRNNWorkspaceSize(index_t seq_length, | ||
index_t batch_size, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can batch_size be -ve ? @apeforest what do you think ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
int hidden_size, | ||
int projection_size, | ||
int direction, | ||
|
@@ -214,8 +214,8 @@ inline size_t GetRNNWorkspaceSize(int seq_length, | |
|
||
inline size_t GetRNNReserveSpaceSize(int num_layer, | ||
int direction, | ||
int seq_length, | ||
int batch_size, | ||
index_t seq_length, | ||
index_t batch_size, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. size_t? if its not a breaking chnage |
||
int hidden_size, | ||
int mode) { | ||
size_t size = 0; | ||
|
@@ -279,9 +279,9 @@ void RNNForwardTraining(DType* ws, | |
bool state_outputs, | ||
const int num_layers, | ||
const int direction, | ||
const int seq_length, | ||
const int batch_size, | ||
const int input_size, | ||
const index_t seq_length, | ||
const index_t batch_size, | ||
const index_t input_size, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did u check that "seq_length, batch_size, input_size" are index_t in functions LstmForwardTraining, GruForwardTraining, VanillaRNNForwardTraining ? If so can you let me know here, else you may need to update them too There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Excellent point, updating now.
Comment on lines
+282
to
+284
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. size_t? if its not a breaking chnage |
||
const int state_size, | ||
DType* x_ptr, | ||
DType* hx_ptr, | ||
|
@@ -321,9 +321,9 @@ void RNNForwardInference(DType* ws, | |
bool state_outputs, | ||
const int num_layers, | ||
const int direction, | ||
const int seq_length, | ||
const int batch_size, | ||
const int input_size, | ||
const index_t seq_length, | ||
const index_t batch_size, | ||
const index_t input_size, | ||
Comment on lines
+324
to
+326
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. size_t? if its not a breaking chnage |
||
const int state_size, | ||
const int projection_size, | ||
DType* x_ptr, | ||
|
@@ -363,9 +363,9 @@ void RNNBackward(DType* ws, | |
DType* rs, | ||
const int num_layers, | ||
const int direction, | ||
const int seq_length, | ||
const int batch_size, | ||
const int input_size, | ||
const index_t seq_length, | ||
const index_t batch_size, | ||
const index_t input_size, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did u check that "seq_length, batch_size, input_size" are index_t in functions LstmBackwardTraining, GruBackwardTraining, VanillaRNNBackwardTraining ? If so can you let me know here, else you may need to update them too There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Excellent point, updating now.
Comment on lines
+366
to
+368
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. size_t? if its not a breaking chnage There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can keep the signed index_t since all the functions being called using signed and the omp loop requires a signed index as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree with keeping the signed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree @apeforest! I believe the omp loop’s required signed index was the root cause of the segfault when I made the size_t changes. |
||
const int state_size, | ||
DType* x_ptr, | ||
DType* hx_ptr, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure you don't have to change size1, size2, proj_size and param_size:
https://github.com/apache/incubator-mxnet/pull/17632/files#diff-6dfdca409e69cc495f286170fe1e553eR143
https://github.com/apache/incubator-mxnet/pull/17632/files#diff-6dfdca409e69cc495f286170fe1e553eR152
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, fixing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
size_t ? make sure API signature doesn't chage. If thats the case then keep it index_t