Skip to content

Commit

Permalink
remove const_cast which makes it's not thread safe. (#1463)
Browse files Browse the repository at this point in the history
  • Loading branch information
HectorSVC authored Jul 23, 2019
1 parent 6be93f1 commit 31838fc
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,19 +207,20 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
}

IAllocatorUniquePtr<T> x_reversed_data;
T* x_data = const_cast<T*>(X->template Data<T>());
const T* x_data = X->template Data<T>();
if (reverse_) {
// reverse input data
x_reversed_data = GetScratchBuffer<T>(seq_length * batch_size * input_size);
ReverseBySequence(gsl::narrow_cast<int32_t>(seq_length),
gsl::narrow_cast<int32_t>(batch_size),
gsl::narrow_cast<int32_t>(input_size),
reinterpret_cast<CudaT*>(x_data),
reinterpret_cast<const CudaT*>(x_data),
reinterpret_cast<CudaT*>(x_reversed_data.get()),
seq_length * batch_size * input_size);
x_data = x_reversed_data.get();
}

const T* x_data_input = reverse_ ? x_reversed_data.get() : x_data;

const T* hx_data = (initial_h == nullptr) ? nullptr : initial_h->template Data<T>();
const T* cx_data = (initial_c == nullptr) ? nullptr : initial_c->template Data<T>();
T* y_h_data = (Y_h == nullptr) ? nullptr : Y_h->template MutableData<T>();
Expand Down Expand Up @@ -248,7 +249,7 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
rnn_desc_,
gsl::narrow_cast<int>(seq_length),
x_desc.data(),
x_data,
x_data_input,
hx_desc,
hx_data,
cx_desc,
Expand All @@ -272,7 +273,7 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(CudnnHandle(),
rnn_desc_,
x_desc,
x_data,
x_data_input,
hx_desc,
hx_data,
cx_desc,
Expand Down

0 comments on commit 31838fc

Please sign in to comment.