diff --git a/src/cudamatrix/cu-kernels-ansi.h b/src/cudamatrix/cu-kernels-ansi.h index 75ebcf79d74..d1463e6b9ca 100644 --- a/src/cudamatrix/cu-kernels-ansi.h +++ b/src/cudamatrix/cu-kernels-ansi.h @@ -6,6 +6,7 @@ // 2013 Xiaohui Zhang // 2013-2015 Guoguo Chen // 2016-2018 Shiyin Kang +// 2019 Yiwen Shao // See ../../COPYING for clarification regarding multiple authors // @@ -201,34 +202,6 @@ void cudaD_add_vec_vec(int Gr, int Bl, double alpha, double* v, const double* x, const double* y, double beta, int dim); void cudaF_add_vec_vec(int Gr, int Bl, float alpha, float* v, const float* x, const float* y, float beta, int dim); -void cudaD_apply_ceiling(dim3 Gr, dim3 Bl, double* mat, double ceiling_val, - MatrixDim d); -void cudaF_apply_ceiling(dim3 Gr, dim3 Bl, float* mat, float ceiling_val, - MatrixDim d); -void cudaD_apply_exp(dim3 Gr, dim3 Bl, double* mat, MatrixDim d); -void cudaF_apply_exp(dim3 Gr, dim3 Bl, float* mat, MatrixDim d); -void cudaD_apply_exp_limited(dim3 Gr, dim3 Bl, double* mat, MatrixDim d, - double lower_limit, double upper_limit); -void cudaF_apply_exp_limited(dim3 Gr, dim3 Bl, float* mat, MatrixDim d, - float lower_limit, float upper_limit); -void cudaD_apply_exp_special(dim3 Gr, dim3 Bl, double* out, MatrixDim out_dim, - const double* in, int in_stride); -void cudaF_apply_exp_special(dim3 Gr, dim3 Bl, float* out, MatrixDim out_dim, - const float* in, int in_stride); -void cudaD_apply_floor(dim3 Gr, dim3 Bl, double* mat, double floor_val, - MatrixDim d); -void cudaF_apply_floor(dim3 Gr, dim3 Bl, float* mat, float floor_val, - MatrixDim d); -void cudaD_apply_heaviside(dim3 Gr, dim3 Bl, double* mat, MatrixDim d); -void cudaF_apply_heaviside(dim3 Gr, dim3 Bl, float* mat, MatrixDim d); -void cudaD_apply_log(dim3 Gr, dim3 Bl, double *mat, MatrixDim d); -void cudaF_apply_log(dim3 Gr, dim3 Bl, float *mat, MatrixDim d); -void cudaD_apply_pow_abs(dim3 Gr, dim3 Bl, double* mat, double power, - bool include_sign, MatrixDim d); -void cudaF_apply_pow_abs(dim3 Gr, dim3 Bl, float* mat, float power, - bool include_sign, MatrixDim d); -void cudaD_apply_pow(dim3 Gr, dim3 Bl, double* mat, double power, MatrixDim d); -void cudaF_apply_pow(dim3 Gr, dim3 Bl, float* mat, float power, MatrixDim d); void cudaD_block_add_mat_mat(dim3 Gr, dim3 Bl, CuBlockMatrixData *B_cu_data, int num_blocks, const double *C_data, int C_num_cols, int C_row_stride, int C_col_stride, @@ -506,6 +479,36 @@ void cudaD_heaviside(dim3 Gr, dim3 Bl, double *y, const double *x, MatrixDim d, int src_stride); void cudaF_heaviside(dim3 Gr, dim3 Bl, float *y, const float *x, MatrixDim d, int src_stride); +void cudaD_exp(dim3 Gr, dim3 Bl, double *y, const double *x, MatrixDim d, + int src_stride); +void cudaF_exp(dim3 Gr, dim3 Bl, float *y, const float *x, MatrixDim d, + int src_stride); +void cudaD_pow(dim3 Gr, dim3 Bl, double *y, const double *x, double power, MatrixDim d, + int src_stride); +void cudaF_pow(dim3 Gr, dim3 Bl, float *y, const float *x, float power, MatrixDim d, + int src_stride); +void cudaD_ceiling(dim3 Gr, dim3 Bl, double* y, const double* x, double ceiling_val, + MatrixDim dim, int src_stride); +void cudaF_ceiling(dim3 Gr, dim3 Bl, float* y, const float* x, float ceiling_val, + MatrixDim dim, int src_stride); +void cudaD_floor(dim3 Gr, dim3 Bl, double* y, const double* x, double floor_val, + MatrixDim dim, int src_stride); +void cudaF_floor(dim3 Gr, dim3 Bl, float* y, const float* x, float floor_val, + MatrixDim dim, int src_stride); +void cudaD_exp_limited(dim3 Gr, dim3 Bl, double* y, const double* x, + double lower_limit, double upper_limit, MatrixDim d, int src_stride); +void cudaF_exp_limited(dim3 Gr, dim3 Bl, float* y, const float* x, + float lower_limit, float upper_limit, MatrixDim d, int src_stride); +void cudaD_exp_special(dim3 Gr, dim3 Bl, double* y, const double* x, + MatrixDim d, int src_stride); +void cudaF_exp_special(dim3 Gr, dim3 Bl, float* y, const float* x, + MatrixDim d, int src_stride); +void cudaD_log(dim3 Gr, dim3 Bl, double* y, const double* x, MatrixDim d, int src_stride); +void cudaF_log(dim3 Gr, dim3 Bl, float* y, const float* x, MatrixDim d, int src_stride); +void cudaD_pow_abs(dim3 Gr, dim3 Bl, double* y, const double* x, double power, + bool include_sign, MatrixDim dim, int src_stride); +void cudaF_pow_abs(dim3 Gr, dim3 Bl, float* y, const float* x, float power, + bool include_sign, MatrixDim dim, int src_stride); void cuda_int32_add(dim3 Gr, dim3 Bl, int32_cuda *mat, int32_cuda value, MatrixDim d); void cuda_int32_set_const(dim3 Gr, dim3 Bl, int32_cuda *mat, int32_cuda value, diff --git a/src/cudamatrix/cu-kernels.cu b/src/cudamatrix/cu-kernels.cu index b89fc54b6ce..9cba04c3c99 100644 --- a/src/cudamatrix/cu-kernels.cu +++ b/src/cudamatrix/cu-kernels.cu @@ -8,6 +8,7 @@ // 2013-2015 Guoguo Chen // 2016-2018 Shiyin Kang // 2017 Hossein Hadian, Daniel Galvez +// 2019 Yiwen Shao // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -293,25 +294,6 @@ static void _add_smat_trans(Real* mat, MatrixDim mat_dim, Real alpha, } } -/// For each element x of the matrix, set it to -/// (x < 0 ? exp(x) : x + 1). -/// Use block/grid sizes for simple matrix ops -template -__global__ -static void _apply_exp_special(T* out, MatrixDim out_dim, const T* in, - int in_stride) { - const int i = blockIdx.x * blockDim.x + threadIdx.x; - const int j = blockIdx.y * blockDim.y + threadIdx.y; - if (i < out_dim.rows && j < out_dim.cols) { - T x = in[i * in_stride + j]; - if (x < T(0)) { - out[i * out_dim.stride + j] = exp(x); - } else { - out[i * out_dim.stride + j] = x + T(1); - } - } -} - /// Fill the array 'data' with the sequence [base ... base + length) /// Use 1D block and 1D grid template @@ -389,37 +371,6 @@ static void _trace_mat_smat(const Real* mat, MatrixDim mat_dim, } } -template -__global__ -static void _apply_exp(Real* mat, MatrixDim d) { - int32_cuda i = blockIdx.x * blockDim.x + threadIdx.x; - int32_cuda j = blockIdx.y * blockDim.y + threadIdx.y; - int32_cuda index = i + j * d.stride; - if (i < d.cols && j < d.rows) { - mat[index] = exp(mat[index]); - } -} - -template -__global__ -static void _apply_exp_limited(Real* mat, MatrixDim d, - Real lower_limit, Real upper_limit) { - int32_cuda i = blockIdx.x * blockDim.x + threadIdx.x; - int32_cuda j = blockIdx.y * blockDim.y + threadIdx.y; - int32_cuda index = i + j * d.stride; - if (i < d.cols && j < d.rows) { - Real x = mat[index]; - // I'm writing !(x >= lower_limit) instead of (x < lower_limit) so that - // nan's will be set to the lower-limit. - if (!(x >= lower_limit)) - x = lower_limit; - else if (x > upper_limit) - x = upper_limit; - mat[index] = exp(x); - } -} - - template __global__ static void _scale_diag_packed(Real* mat, Real value, int dim) { @@ -500,16 +451,6 @@ static void _scale(Real* mat, Real value, MatrixDim d) { mat[index] = mat[index] * value; } -template -__global__ -static void _apply_log(Real* mat, MatrixDim d) { - int32_cuda i = blockIdx.x * blockDim.x + threadIdx.x; - int32_cuda j = blockIdx.y * blockDim.y + threadIdx.y; - int32_cuda index = i + j * d.stride; - if (i < d.cols && j < d.rows) - mat[index] = log(mat[index]); -} - template __global__ static void _mul_elements(Real* mat, const Real* A, MatrixDim dst_d, @@ -1879,83 +1820,6 @@ static void _vec_apply_ceiling(Real *v, Real ceiling_val, float *count, } } -template -__global__ -static void _apply_pow(Real* mat, Real power, MatrixDim d) { - int i = blockIdx.x * blockDim.x + threadIdx.x; // col index - int j = blockIdx.y * blockDim.y + threadIdx.y; // row index - int index = i + j * d.stride; - if (i < d.cols && j < d.rows) { - if (power == 1.0) - return; - if (power == 2.0) { - mat[index] = mat[index] * mat[index]; - } else if (power == 0.5) { - if (!(mat[index] >= 0.0)) - return; - mat[index] = sqrt(mat[index]); - } else { - mat[index] = pow(mat[index], power); - } - } -} - -template -__global__ -static void _apply_pow_abs(Real* mat, Real power, bool include_sign, - MatrixDim d) { - int i = blockIdx.x * blockDim.x + threadIdx.x; // col index - int j = blockIdx.y * blockDim.y + threadIdx.y; // row index - int index = i + j * d.stride; - if (i < d.cols && j < d.rows) { - if (include_sign == true && mat[index] < 0) { - if (power == 1.0) - mat[index] = -std::abs(mat[index]); - if (power == 2.0) { - mat[index] = -mat[index] * mat[index]; - } else if (power == 0.5) { - mat[index] = -sqrt(std::abs(mat[index])); - } else { - mat[index] = -pow(std::abs(mat[index]), power); - } - } else { - if (power == 1.0) - mat[index] = std::abs(mat[index]); - if (power == 2.0) { - mat[index] = mat[index] * mat[index]; - } else if (power == 0.5) { - mat[index] = sqrt(std::abs(mat[index])); - } else if (power < 0.0 && mat[index] == 0.0) { - mat[index] = 0.0; - } else { - mat[index] = pow(std::abs(mat[index]), power); - } - } - } -} - -template -__global__ -static void _apply_heaviside(Real* mat, MatrixDim d) { - int i = blockIdx.x * blockDim.x + threadIdx.x; // col index - int j = blockIdx.y * blockDim.y + threadIdx.y; // row index - int index = i + j * d.stride; - if (i < d.cols && j < d.rows) - mat[index] = (mat[index] > 0.0 ? 1.0 : 0.0); -} - -template -__global__ -static void _apply_floor(Real* mat, Real floor_val, MatrixDim d) { - int i = blockIdx.x * blockDim.x + threadIdx.x; // col index - int j = blockIdx.y * blockDim.y + threadIdx.y; // row index - int index = i + j * d.stride; - - if (i < d.cols && j < d.rows) { - mat[index] = max(mat[index], floor_val); - } -} - template __global__ static void _copy_cols(Real* dst, const Real *src, @@ -2117,18 +1981,6 @@ static void _add_to_rows(Real alpha, Real* const * dst, const Real *src, } } -template -__global__ -static void _apply_ceiling(Real* mat, Real ceiling_val, MatrixDim d) { - int i = blockIdx.x * blockDim.x + threadIdx.x; - int j = blockIdx.y * blockDim.y + threadIdx.y; - int index = i + j * d.stride; - - if (i < d.cols && j < d.rows) { - mat[index] = min(mat[index], ceiling_val); - } -} - template __global__ static void _invert_elements(Real* data, MatrixDim d) { @@ -2516,7 +2368,7 @@ static void _diff_parametric_relu(Real* eout, const Real* e, const Real* y, template __global__ -static void _heaviside(Real*y, const Real*x, MatrixDim d, int src_stride) { +static void _heaviside(Real* y, const Real* x, MatrixDim d, int src_stride) { int i = blockIdx.x * blockDim.x + threadIdx.x; int j = blockIdx.y * blockDim.y + threadIdx.y; int dst_index = i + j * d.stride, src_index = i + j * src_stride; @@ -2526,6 +2378,120 @@ static void _heaviside(Real*y, const Real*x, MatrixDim d, int src_stride) { } } +template +__global__ +static void _exp(Real* y, const Real* x, MatrixDim d, int src_stride) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + int j = blockIdx.y * blockDim.y + threadIdx.y; + int dst_index = i + j * d.stride, src_index = i + j * src_stride; + if (i < d.cols && j < d.rows) { + Real res = exp(x[src_index]); + y[dst_index] = res; + } +} + +template +__global__ +static void _pow(Real* y, const Real* x, Real power, MatrixDim d, int src_stride) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + int j = blockIdx.y * blockDim.y + threadIdx.y; + int dst_index = i + j * d.stride, src_index = i + j * src_stride; + if (i < d.cols && j < d.rows) { + y[dst_index] = pow(x[src_index], power); + } +} + +template +__global__ +static void _ceiling(Real* y, const Real* x, Real ceiling_val, MatrixDim d, int src_stride) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + int j = blockIdx.y * blockDim.y + threadIdx.y; + int dst_index = i + j * d.stride, src_index = i + j * src_stride; + + if (i < d.cols && j < d.rows) { + y[dst_index] = min(x[src_index], ceiling_val); + } +} + +template +__global__ +static void _floor(Real* y, const Real* x, Real floor_val, MatrixDim d, int src_stride) { + int i = blockIdx.x * blockDim.x + threadIdx.x; // col index + int j = blockIdx.y * blockDim.y + threadIdx.y; // row index + int dst_index = i + j * d.stride, src_index = i + j * src_stride; + + if (i < d.cols && j < d.rows) { + y[dst_index] = max(x[src_index], floor_val); + } +} + +template +__global__ +static void _exp_limited(Real* y, const Real* x, Real lower_limit, Real upper_limit, + MatrixDim d, int src_stride) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + int j = blockIdx.y * blockDim.y + threadIdx.y; + int dst_index = i + j * d.stride, src_index = i + j * src_stride; + if (i < d.cols && j < d.rows) { + const Real x_i = x[src_index]; + // I'm writing !(x >= lower_limit) instead of (x < lower_limit) so that + // nan's will be set to the lower-limit. + if (!(x_i >= lower_limit)) + y[dst_index] = exp(lower_limit); + else if (x_i > upper_limit) + y[dst_index] = exp(upper_limit); + else + y[dst_index] = exp(x_i); + } +} + +/// For each element x of the matrix, set it to +/// (x < 0 ? exp(x) : x + 1). +/// Use block/grid sizes for simple matrix ops +template +__global__ +static void _exp_special(Real* y, const Real* x, MatrixDim d, + int src_stride) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + int j = blockIdx.y * blockDim.y + threadIdx.y; + int dst_index = i + j * d.stride, src_index = i + j * src_stride; + if (i < d.cols && j < d.rows) { + const Real in = x[src_index]; + if (in < Real(0)) { + y[dst_index] = exp(in); + } else { + y[dst_index] = in + Real(1); + } + } +} + +template +__global__ +static void _log(Real* y, const Real* x, MatrixDim d, int src_stride) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + int j = blockIdx.y * blockDim.y + threadIdx.y; + int dst_index = i + j * d.stride, src_index = i + j * src_stride; + if (i < d.cols && j < d.rows) + y[dst_index] = log(x[src_index]); +} + +template +__global__ +static void _pow_abs(Real* y, const Real* x, Real power, bool include_sign, + MatrixDim d, int src_stride) { + int i = blockIdx.x * blockDim.x + threadIdx.x; // col index + int j = blockIdx.y * blockDim.y + threadIdx.y; // row index + int dst_index = i + j * d.stride, src_index = i + j * src_stride; + if (i < d.cols && j < d.rows) { + if (include_sign == true && x[src_index] < 0) { + y[dst_index] = -pow(std::abs(x[src_index]), power); + } + else { + y[dst_index] = pow(std::abs(x[src_index]), power); + } + } +} + template __global__ static void _softmax_reduce(Real*y, const Real*x, MatrixDim d, int src_stride) { @@ -3709,28 +3675,6 @@ void cudaFD_copy_from_tp(dim3 Gr, dim3 Bl, float* A, const double* B, _copy_from_tp<<>>(A,B,dmat); } -void cudaF_apply_exp(dim3 Gr, dim3 Bl, float* mat, MatrixDim d) { - _apply_exp<<>>(mat,d); -} - -void cudaF_apply_exp_limited(dim3 Gr, dim3 Bl, float* mat, MatrixDim d, - float lower_limit, float upper_limit) { - _apply_exp_limited<<>>(mat, d, lower_limit, upper_limit); -} - -void cudaF_apply_pow(dim3 Gr, dim3 Bl, float* mat, float power, MatrixDim d) { - _apply_pow<<>>(mat, power, d); -} - -void cudaF_apply_pow_abs(dim3 Gr, dim3 Bl, float* mat, float power, - bool include_sign, MatrixDim d) { - _apply_pow_abs<<>>(mat, power, include_sign, d); -} - -void cudaF_apply_heaviside(dim3 Gr, dim3 Bl, float* mat, MatrixDim d) { - _apply_heaviside<<>>(mat, d); -} - void cudaF_copy_cols(dim3 Gr, dim3 Bl, float* dst, const float* src, const MatrixIndexT_cuda* reorder, MatrixDim dst_dim, int src_stride) { @@ -3787,16 +3731,6 @@ void cudaF_add_to_rows_direct(dim3 Gr, dim3 Bl, float alpha, float* const * dst, _add_to_rows<<>>(alpha, dst, src, src_dim); } -void cudaF_apply_floor(dim3 Gr, dim3 Bl, float* mat, float floor_val, - MatrixDim d) { - _apply_floor<<>>(mat, floor_val, d); -} - -void cudaF_apply_ceiling(dim3 Gr, dim3 Bl, float* mat, float ceiling_val, - MatrixDim d) { - _apply_ceiling<<>>(mat, ceiling_val, d); -} - void cudaF_set_diag(int Gr, int Bl, float* mat, float value, MatrixDim d) { _set_diag<<>>(mat,value,d); } @@ -3829,10 +3763,6 @@ void cudaF_scale(dim3 Gr, dim3 Bl, float* mat, float value, MatrixDim d) { _scale<<>>(mat,value,d); } -void cudaF_apply_log(dim3 Gr, dim3 Bl, float* mat, MatrixDim d) { - _apply_log<<>>(mat,d); -} - void cudaF_mul_elements(dim3 Gr, dim3 Bl, float* mat, const float* A, MatrixDim dst_d, int src_stride) { _mul_elements<<>>(mat,A,dst_d,src_stride); @@ -4257,6 +4187,45 @@ void cudaF_heaviside(dim3 Gr, dim3 Bl, float* y, const float* x, MatrixDim d, _heaviside<<>>(y, x, d, src_stride); } +void cudaF_exp(dim3 Gr, dim3 Bl, float* y, const float* x, MatrixDim d, + int src_stride) { + _exp<<>>(y, x, d, src_stride); +} + +void cudaF_pow(dim3 Gr, dim3 Bl, float* y, const float* x, float power, MatrixDim d, + int src_stride) { + _pow<<>>(y, x, power, d, src_stride); +} + +void cudaF_ceiling(dim3 Gr, dim3 Bl, float* y, const float* x, float ceiling_val, + MatrixDim d, int src_stride) { + _ceiling<<>>(y, x, ceiling_val, d, src_stride); +} + +void cudaF_floor(dim3 Gr, dim3 Bl, float* y, const float* x, float floor_val, + MatrixDim d, int src_stride) { + _floor<<>>(y, x, floor_val, d, src_stride); +} + +void cudaF_exp_limited(dim3 Gr, dim3 Bl, float* y, const float* x, + float lower_limit, float upper_limit, MatrixDim d, int src_stride) { + _exp_limited<<>>(y, x, lower_limit, upper_limit, d, src_stride); +} + +void cudaF_exp_special(dim3 Gr, dim3 Bl, float* y, const float* x, MatrixDim d, + int src_stride) { + _exp_special<<>>(y, x, d, src_stride); +} + +void cudaF_log(dim3 Gr, dim3 Bl, float* y, const float* x, MatrixDim d, int src_stride) { + _log<<>>(y, x, d, src_stride); +} + +void cudaF_pow_abs(dim3 Gr, dim3 Bl, float* y, const float* x, float power, + bool include_sign, MatrixDim d, int src_stride) { + _pow_abs<<>>(y, x, power, include_sign, d, src_stride); +} + void cudaF_softmax_reduce(size_t Gr, size_t Bl, float* y, const float* x, MatrixDim d, int src_stride) { _softmax_reduce<<>>(y, x, d, src_stride); @@ -4420,30 +4389,6 @@ void cudaDF_copy_from_tp(dim3 Gr, dim3 Bl, double* A, const float* B, _copy_from_tp<<>>(A,B,dmat); } -void cudaD_apply_exp(dim3 Gr, dim3 Bl, double* mat, MatrixDim d) { - _apply_exp<<>>(mat,d); -} - -void cudaD_apply_exp_limited(dim3 Gr, dim3 Bl, double* mat, MatrixDim d, - double lower_limit, double upper_limit) { - _apply_exp_limited<<>>(mat, d, lower_limit, upper_limit); -} - - - -void cudaD_apply_pow(dim3 Gr, dim3 Bl, double* mat, double power, MatrixDim d) { - _apply_pow<<>>(mat, power, d); -} - -void cudaD_apply_pow_abs(dim3 Gr, dim3 Bl, double* mat, double power, - bool include_sign, MatrixDim d) { - _apply_pow_abs<<>>(mat, power, include_sign, d); -} - -void cudaD_apply_heaviside(dim3 Gr, dim3 Bl, double* mat, MatrixDim d) { - _apply_heaviside<<>>(mat, d); -} - void cudaD_copy_cols(dim3 Gr, dim3 Bl, double* dst, const double* src, const MatrixIndexT_cuda* reorder, MatrixDim dst_dim, int src_stride) { @@ -4501,16 +4446,6 @@ void cudaD_add_to_rows_direct(dim3 Gr, dim3 Bl, double alpha, _add_to_rows<<>>(alpha, dst, src, src_dim); } -void cudaD_apply_floor(dim3 Gr, dim3 Bl, double* mat, double floor_val, - MatrixDim d) { - _apply_floor<<>>(mat, floor_val, d); -} - -void cudaD_apply_ceiling(dim3 Gr, dim3 Bl, double* mat, double ceiling_val, - MatrixDim d) { - _apply_ceiling<<>>(mat, ceiling_val, d); -} - void cudaD_set_diag(int Gr, int Bl, double* mat, double value, MatrixDim d) { _set_diag<<>>(mat,value,d); } @@ -4544,10 +4479,6 @@ void cudaD_scale(dim3 Gr, dim3 Bl, double* mat, double value, MatrixDim d) { _scale<<>>(mat,value,d); } -void cudaD_apply_log(dim3 Gr, dim3 Bl, double* mat, MatrixDim d) { - _apply_log<<>>(mat,d); -} - void cudaD_mul_elements(dim3 Gr, dim3 Bl, double* mat, const double* A, MatrixDim dst_d, int src_stride) { _mul_elements<<>>(mat,A,dst_d,src_stride); @@ -4960,6 +4891,45 @@ void cudaD_heaviside(dim3 Gr, dim3 Bl, double* y, const double* x, MatrixDim d, _heaviside<<>>(y, x, d, src_stride); } +void cudaD_exp(dim3 Gr, dim3 Bl, double* y, const double* x, MatrixDim d, + int src_stride) { + _exp<<>>(y, x, d, src_stride); +} + +void cudaD_pow(dim3 Gr, dim3 Bl, double* y, const double* x, double power, MatrixDim d, + int src_stride) { + _pow<<>>(y, x, power, d, src_stride); +} + +void cudaD_ceiling(dim3 Gr, dim3 Bl, double* y, const double* x, double ceiling_val, + MatrixDim d, int src_stride) { + _ceiling<<>>(y, x, ceiling_val, d, src_stride); +} + +void cudaD_floor(dim3 Gr, dim3 Bl, double* y, const double* x, double floor_val, + MatrixDim d, int src_stride) { + _floor<<>>(y, x, floor_val, d, src_stride); +} + +void cudaD_exp_limited(dim3 Gr, dim3 Bl, double* y, const double* x, + double lower_limit, double upper_limit, MatrixDim d, int src_stride) { + _exp_limited<<>>(y, x, lower_limit, upper_limit, d, src_stride); +} + +void cudaD_exp_special(dim3 Gr, dim3 Bl, double* y, const double* x, MatrixDim d, + int src_stride) { + _exp_special<<>>(y, x, d, src_stride); +} + +void cudaD_log(dim3 Gr, dim3 Bl, double* y, const double* x, MatrixDim d, int src_stride) { + _log<<>>(y, x, d, src_stride); +} + +void cudaD_pow_abs(dim3 Gr, dim3 Bl, double* y, const double* x, double power, + bool include_sign, MatrixDim d, int src_stride) { + _pow_abs<<>>(y, x, power, include_sign, d, src_stride); +} + void cudaD_softmax_reduce(size_t Gr, size_t Bl, double* y, const double* x, MatrixDim d, int src_stride) { _softmax_reduce<<>>(y, x, d, src_stride); @@ -5348,14 +5318,6 @@ void cudaF_add_smat_trans(dim3 Gr, dim3 Bl, float* mat, MatrixDim mat_dim, _add_smat_trans<<>>(mat, mat_dim, alpha, smat_row_ptr, smat_col_idx, smat_val); } -void cudaD_apply_exp_special(dim3 Gr, dim3 Bl, double* out, MatrixDim out_dim, - const double* in, int in_stride) { - _apply_exp_special<<>>(out, out_dim, in, in_stride); -} -void cudaF_apply_exp_special(dim3 Gr, dim3 Bl, float* out, MatrixDim out_dim, - const float* in, int in_stride) { - _apply_exp_special<<>>(out, out_dim, in, in_stride); -} void cuda_compress_uint8_sign(dim3 Gr, dim3 Bl, const BaseFloat *src, MatrixDim dim, unsigned char *dest, int dest_stride) { diff --git a/src/cudamatrix/cu-kernels.h b/src/cudamatrix/cu-kernels.h index f93c1e2b2e0..731cebace66 100644 --- a/src/cudamatrix/cu-kernels.h +++ b/src/cudamatrix/cu-kernels.h @@ -7,6 +7,7 @@ // 2013 Xiaohui Zhang // 2013-2015 Guoguo Chen // 2016-2018 Shiyin Kang +// 2019 Yiwen Shao // See ../../COPYING for clarification regarding multiple authors // @@ -345,74 +346,6 @@ inline void cuda_add_vec_vec(int Gr, int Bl, float alpha, float* v, int dim) { cudaF_add_vec_vec(Gr, Bl, alpha, v, x, y, beta, dim); } -inline void cuda_apply_ceiling(dim3 Gr, dim3 Bl, double* mat, - double ceiling_val, MatrixDim dim) { - cudaD_apply_ceiling(Gr, Bl, mat, ceiling_val, dim); -} -inline void cuda_apply_ceiling(dim3 Gr, dim3 Bl, float* mat, float ceiling_val, - MatrixDim dim) { - cudaF_apply_ceiling(Gr, Bl, mat, ceiling_val, dim); -} -inline void cuda_apply_exp(dim3 Gr, dim3 Bl, double* mat, MatrixDim d) { - cudaD_apply_exp(Gr, Bl, mat, d); -} -inline void cuda_apply_exp(dim3 Gr, dim3 Bl, float* mat, MatrixDim d) { - cudaF_apply_exp(Gr, Bl, mat, d); -} -inline void cuda_apply_exp_limited(dim3 Gr, dim3 Bl, double* mat, MatrixDim d, - double lower_limit, double upper_limit) { - cudaD_apply_exp_limited(Gr, Bl, mat, d, lower_limit, upper_limit); -} -inline void cuda_apply_exp_limited(dim3 Gr, dim3 Bl, float* mat, MatrixDim d, - float lower_limit, float upper_limit) { - cudaF_apply_exp_limited(Gr, Bl, mat, d, lower_limit, upper_limit); -} -inline void cuda_apply_exp_special(dim3 Gr, dim3 Bl, double* out, - MatrixDim out_dim, const double* in, - int in_stride) { - cudaD_apply_exp_special(Gr, Bl, out, out_dim, in, in_stride); -} -inline void cuda_apply_exp_special(dim3 Gr, dim3 Bl, float* out, - MatrixDim out_dim, const float* in, - int in_stride) { - cudaF_apply_exp_special(Gr, Bl, out, out_dim, in, in_stride); -} -inline void cuda_apply_floor(dim3 Gr, dim3 Bl, double* mat, double floor_val, - MatrixDim dim) { - cudaD_apply_floor(Gr, Bl, mat, floor_val, dim); -} -inline void cuda_apply_floor(dim3 Gr, dim3 Bl, float* mat, float floor_val, - MatrixDim dim) { - cudaF_apply_floor(Gr, Bl, mat, floor_val, dim); -} -inline void cuda_apply_heaviside(dim3 Gr, dim3 Bl, double* mat, MatrixDim dim) { - cudaD_apply_heaviside(Gr, Bl, mat, dim); -} -inline void cuda_apply_heaviside(dim3 Gr, dim3 Bl, float* mat, MatrixDim dim) { - cudaF_apply_heaviside(Gr, Bl, mat, dim); -} -inline void cuda_apply_log(dim3 Gr, dim3 Bl, double *mat, MatrixDim d) { - cudaD_apply_log(Gr, Bl, mat, d); -} -inline void cuda_apply_log(dim3 Gr, dim3 Bl, float *mat, MatrixDim d) { - cudaF_apply_log(Gr, Bl, mat, d); -} -inline void cuda_apply_pow_abs(dim3 Gr, dim3 Bl, double* mat, double power, - bool include_sign, MatrixDim dim) { - cudaD_apply_pow_abs(Gr, Bl, mat, power, include_sign, dim); -} -inline void cuda_apply_pow_abs(dim3 Gr, dim3 Bl, float* mat, float power, - bool include_sign, MatrixDim dim) { - cudaF_apply_pow_abs(Gr, Bl, mat, power, include_sign, dim); -} -inline void cuda_apply_pow(dim3 Gr, dim3 Bl, double* mat, double power, - MatrixDim dim) { - cudaD_apply_pow(Gr, Bl, mat, power, dim); -} -inline void cuda_apply_pow(dim3 Gr, dim3 Bl, float* mat, float power, - MatrixDim dim) { - cudaF_apply_pow(Gr, Bl, mat, power, dim); -} inline cublasStatus_t cuda_axpy(cublasHandle_t handle, int n, double alpha, const double *x, int incx, double *y, int incy) { @@ -939,19 +872,81 @@ inline void cuda_group_spec_pnorm(dim3 Gr, dim3 Bl, double *y, const double *x, double power) { cudaD_group_spec_pnorm(Gr, Bl, y, x, d, src_stride, group_size, power); } -inline void cuda_group_spec_pnorm(dim3 Gr, dim3 Bl, float *y, const float *x, +inline void cuda_group_spec_pnorm(dim3 Gr, dim3 Bl, float* y, const float* x, MatrixDim d, int src_stride, int group_size, float power) { cudaF_group_spec_pnorm(Gr, Bl, y, x, d, src_stride, group_size, power); } -inline void cuda_heaviside(dim3 Gr, dim3 Bl, double *y, const double *x, +inline void cuda_heaviside(dim3 Gr, dim3 Bl, double* y, const double* x, MatrixDim d, int src_stride) { cudaD_heaviside(Gr, Bl, y, x, d, src_stride); } -inline void cuda_heaviside(dim3 Gr, dim3 Bl, float *y, const float *x, +inline void cuda_heaviside(dim3 Gr, dim3 Bl, float* y, const float* x, MatrixDim d, int src_stride) { cudaF_heaviside(Gr, Bl, y, x, d, src_stride); } +inline void cuda_exp(dim3 Gr, dim3 Bl, double* y, const double* x, + MatrixDim d, int src_stride) { + cudaD_exp(Gr, Bl, y, x, d, src_stride); +} +inline void cuda_exp(dim3 Gr, dim3 Bl, float* y, const float* x, + MatrixDim d, int src_stride) { + cudaF_exp(Gr, Bl, y, x, d, src_stride); +} +inline void cuda_pow(dim3 Gr, dim3 Bl, double* y, const double* x, double power, + MatrixDim d, int src_stride) { + cudaD_pow(Gr, Bl, y, x, power, d, src_stride); +} +inline void cuda_pow(dim3 Gr, dim3 Bl, float* y, const float* x, float power, + MatrixDim d, int src_stride) { + cudaF_pow(Gr, Bl, y, x, power, d, src_stride); +} +inline void cuda_ceiling(dim3 Gr, dim3 Bl, double* y, const double* x, double ceiling_val, + MatrixDim dim, int src_stride) { + cudaD_ceiling(Gr, Bl, y, x, ceiling_val, dim, src_stride); +} +inline void cuda_ceiling(dim3 Gr, dim3 Bl, float* y, const float* x, float ceiling_val, + MatrixDim dim, int src_stride) { + cudaF_ceiling(Gr, Bl, y, x, ceiling_val, dim, src_stride); +} +inline void cuda_floor(dim3 Gr, dim3 Bl, double* y, const double* x, double floor_val, + MatrixDim dim, int src_stride) { + cudaD_floor(Gr, Bl, y, x, floor_val, dim, src_stride); +} +inline void cuda_floor(dim3 Gr, dim3 Bl, float* y, const float* x, float floor_val, + MatrixDim dim, int src_stride) { + cudaF_floor(Gr, Bl, y, x, floor_val, dim, src_stride); +} +inline void cuda_exp_limited(dim3 Gr, dim3 Bl, double* y, const double* x, + double lower_limit, double upper_limit, MatrixDim d, int src_stride) { + cudaD_exp_limited(Gr, Bl, y, x, lower_limit, upper_limit, d, src_stride); +} +inline void cuda_exp_limited(dim3 Gr, dim3 Bl, float* y, const float* x, + float lower_limit, float upper_limit, MatrixDim d, int src_stride) { + cudaF_exp_limited(Gr, Bl, y, x, lower_limit, upper_limit, d, src_stride); +} +inline void cuda_exp_special(dim3 Gr, dim3 Bl, double* y, const double* x, + MatrixDim d, int src_stride) { + cudaD_exp_special(Gr, Bl, y, x, d, src_stride); +} +inline void cuda_exp_special(dim3 Gr, dim3 Bl, float* y, const float* x, + MatrixDim d, int src_stride) { + cudaF_exp_special(Gr, Bl, y, x, d, src_stride); +} +inline void cuda_log(dim3 Gr, dim3 Bl, double* y, const double* x, MatrixDim d, int src_stride) { + cudaD_log(Gr, Bl, y, x, d, src_stride); +} +inline void cuda_log(dim3 Gr, dim3 Bl, float* y, const float* x, MatrixDim d, int src_stride) { + cudaF_log(Gr, Bl, y, x, d, src_stride); +} +inline void cuda_pow_abs(dim3 Gr, dim3 Bl, double* y, const double* x, double power, + bool include_sign, MatrixDim dim, int src_stride) { + cudaD_pow_abs(Gr, Bl, y, x, power, include_sign, dim, src_stride); +} +inline void cuda_pow_abs(dim3 Gr, dim3 Bl, float* y, const float* x, float power, + bool include_sign, MatrixDim dim, int src_stride) { + cudaF_pow_abs(Gr, Bl, y, x, power, include_sign, dim, src_stride); +} inline void cuda_invert_elements(dim3 Gr, dim3 Bl, double *data, MatrixDim d) { cudaD_invert_elements(Gr, Bl, data, d); } diff --git a/src/cudamatrix/cu-matrix-speed-test.cc b/src/cudamatrix/cu-matrix-speed-test.cc index c67eaf220b8..230112b1bd0 100644 --- a/src/cudamatrix/cu-matrix-speed-test.cc +++ b/src/cudamatrix/cu-matrix-speed-test.cc @@ -505,7 +505,7 @@ template void TestCuMatrixSoftmax(int32 dim) { Timer tim; int32 iter = 0; for (;tim.Elapsed() < time_in_secs; iter++) { - N.ApplySoftMaxPerRow(M); + N.SoftMaxPerRow(M); } BaseFloat fdim = dim; @@ -523,7 +523,7 @@ template void TestCuMatrixLogSoftmax(int32 dim) { Timer tim; int32 iter = 0; for (;tim.Elapsed() < time_in_secs; iter++) { - N.ApplyLogSoftMaxPerRow(M); + N.LogSoftMaxPerRow(M); } BaseFloat fdim = dim; diff --git a/src/cudamatrix/cu-matrix-test.cc b/src/cudamatrix/cu-matrix-test.cc index 83ed24b9847..be8483e48f5 100644 --- a/src/cudamatrix/cu-matrix-test.cc +++ b/src/cudamatrix/cu-matrix-test.cc @@ -174,7 +174,6 @@ static void UnitTestCuMatrixApplyExpSpecial() { H.ApplyExpSpecial(); Matrix H2(D); - KALDI_ASSERT(ApproxEqual(H,H2)); } @@ -201,18 +200,14 @@ static void UnitTestCuMatrixApplyExpLimited() { Matrix H(M, N); H.SetRandn(); - BaseFloat lower_limit = -0.2, upper_limit = 0.2; CuMatrix D(H); - D.ApplyExpLimited(lower_limit, upper_limit); - H.ApplyFloor(lower_limit); H.ApplyCeiling(upper_limit); H.ApplyExp(); - Matrix H2(D); KALDI_ASSERT(ApproxEqual(H,H2)); @@ -2389,11 +2384,11 @@ static void UnitTestCuSoftmax() { //gpu if (i % 2 == 0) { - Do.ApplySoftMaxPerRow(Di); + Do.SoftMaxPerRow(Di); } else { // in-place Do.CopyFromMat(Di); - Do.ApplySoftMaxPerRow(Do); + Do.SoftMaxPerRow(Do); } //cpu Ho.CopyFromMat(Hi); @@ -2426,11 +2421,11 @@ static void UnitTestCuLogSoftmax() { //gpu if (i % 2 == 0) { - Do.ApplyLogSoftMaxPerRow(Di); + Do.LogSoftMaxPerRow(Di); } else { // in-place. Do.CopyFromMat(Di); - Do.ApplyLogSoftMaxPerRow(Do); + Do.LogSoftMaxPerRow(Do); } //cpu Ho.CopyFromMat(Hi); diff --git a/src/cudamatrix/cu-matrix.cc b/src/cudamatrix/cu-matrix.cc index ae091370edd..32224e7dd7e 100644 --- a/src/cudamatrix/cu-matrix.cc +++ b/src/cudamatrix/cu-matrix.cc @@ -8,6 +8,7 @@ // 2013-2015 Guoguo Chen // 2016-2017 Shiyin Kang // 2017 Hossein Hadian +// 2019 Yiwen Shao // See ../../COPYING for clarification regarding multiple authors // @@ -634,27 +635,6 @@ void CuMatrixBase::Scale(Real value) { } } -template -void CuMatrixBase::ApplyLog() { - #if HAVE_CUDA == 1 - if (CuDevice::Instantiate().Enabled()) { - if (num_rows_ == 0) return; - CuTimer tim; - - dim3 dimGrid, dimBlock; - GetBlockSizesForSimpleMatrixOperation(NumRows(), NumCols(), - &dimGrid, &dimBlock); - - cuda_apply_log(dimGrid, dimBlock, data_, Dim()); - CU_SAFE_CALL(cudaGetLastError()); - - CuDevice::Instantiate().AccuProfile(__func__, tim); - } else - #endif - { - Mat().ApplyLog(); - } -} template void CuMatrixBase::MulElements(const CuMatrixBase& A) { @@ -1707,7 +1687,7 @@ void CuMatrix::CompObjfAndDeriv(const std::vector >& s } template // Y->this, X->src -void CuMatrixBase::ApplySoftMaxPerRow(const CuMatrixBase &src) { +void CuMatrixBase::SoftMaxPerRow(const CuMatrixBase &src) { KALDI_ASSERT(SameDim(*this, src)); #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { @@ -1730,7 +1710,7 @@ void CuMatrixBase::ApplySoftMaxPerRow(const CuMatrixBase &src) { } template // Y->this, X->src -void CuMatrixBase::ApplyLogSoftMaxPerRow(const CuMatrixBase &src) { +void CuMatrixBase::LogSoftMaxPerRow(const CuMatrixBase &src) { KALDI_ASSERT(SameDim(*this, src)); #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { @@ -1969,7 +1949,7 @@ void CuMatrixBase::DiffXent(const CuArrayBase &tgt, for(int32 r = 0; r < num_rows; r++) { int32 col_tgt = tgt.Data()[r]; Real &value = Mat()(r, col_tgt); - log_post_tgt->Vec()(r) = Log(value); + log_post_tgt->Vec()(r) = kaldi::Log(value); value -= 1.0; } } @@ -2425,61 +2405,72 @@ void CuMatrixBase::CopyColFromVec(const CuVectorBase &v, } template -void CuMatrixBase::ApplyPow(Real power) { +void CuMatrixBase::Heaviside(const CuMatrixBase &src) { + KALDI_ASSERT(SameDim(*this, src)); #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { CuTimer tim; dim3 dimGrid, dimBlock; GetBlockSizesForSimpleMatrixOperation(NumRows(), NumCols(), &dimGrid, &dimBlock); - cuda_apply_pow(dimGrid, dimBlock, data_, power, Dim()); + cuda_heaviside(dimGrid, dimBlock, this->data_, src.data_, this->Dim(), + src.Stride()); CU_SAFE_CALL(cudaGetLastError()); + CuDevice::Instantiate().AccuProfile(__func__, tim); } else -#endif + #endif { - Mat().ApplyPow(power); + Mat().Heaviside(src.Mat()); } } template -void CuMatrixBase::ApplyPowAbs(Real power, bool include_sign) { +void CuMatrixBase::Exp(const CuMatrixBase &src) { + KALDI_ASSERT(SameDim(*this, src)); #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { CuTimer tim; dim3 dimGrid, dimBlock; GetBlockSizesForSimpleMatrixOperation(NumRows(), NumCols(), &dimGrid, &dimBlock); - cuda_apply_pow_abs(dimGrid, dimBlock, data_, power, include_sign, Dim()); + cuda_exp(dimGrid, dimBlock, this->data_, src.data_, this->Dim(), + src.Stride()); CU_SAFE_CALL(cudaGetLastError()); + CuDevice::Instantiate().AccuProfile(__func__, tim); } else -#endif + #endif { - Mat().ApplyPowAbs(power, include_sign); + Mat().Exp(src.Mat()); } } template -void CuMatrixBase::ApplyHeaviside() { +void CuMatrixBase::Log(const CuMatrixBase &src) { + KALDI_ASSERT(SameDim(*this, src)); #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { + if (num_rows_ == 0) return; CuTimer tim; dim3 dimGrid, dimBlock; GetBlockSizesForSimpleMatrixOperation(NumRows(), NumCols(), &dimGrid, &dimBlock); - cuda_apply_heaviside(dimGrid, dimBlock, data_, Dim()); + + cuda_log(dimGrid, dimBlock, this->data_, src.data_, this->Dim(), + src.Stride()); CU_SAFE_CALL(cudaGetLastError()); + CuDevice::Instantiate().AccuProfile(__func__, tim); } else -#endif + #endif { - Mat().ApplyHeaviside(); + Mat().Log(src.Mat()); } } template -void CuMatrixBase::Heaviside(const CuMatrixBase &src) { +void CuMatrixBase::Pow(const CuMatrixBase &src, Real power) { KALDI_ASSERT(SameDim(*this, src)); #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { @@ -2487,38 +2478,41 @@ void CuMatrixBase::Heaviside(const CuMatrixBase &src) { dim3 dimGrid, dimBlock; GetBlockSizesForSimpleMatrixOperation(NumRows(), NumCols(), &dimGrid, &dimBlock); - cuda_heaviside(dimGrid, dimBlock, this->data_, src.data_, this->Dim(), - src.Stride()); + cuda_pow(dimGrid, dimBlock, this->data_, src.data_, power, this->Dim(), + src.Stride()); CU_SAFE_CALL(cudaGetLastError()); CuDevice::Instantiate().AccuProfile(__func__, tim); } else #endif { - Mat().Heaviside(src.Mat()); + Mat().Pow(src.Mat(), power); } } template -void CuMatrixBase::ApplyExp() { +void CuMatrixBase::PowAbs(const CuMatrixBase &src, Real power, bool include_sign) { + KALDI_ASSERT(SameDim(*this, src)); #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { CuTimer tim; dim3 dimGrid, dimBlock; GetBlockSizesForSimpleMatrixOperation(NumRows(), NumCols(), &dimGrid, &dimBlock); - cuda_apply_exp(dimGrid, dimBlock, data_, Dim()); + cuda_pow_abs(dimGrid, dimBlock, this->data_, src.data_, power, include_sign, + this->Dim(), src.Stride()); CU_SAFE_CALL(cudaGetLastError()); CuDevice::Instantiate().AccuProfile(__func__, tim); } else #endif { - Mat().ApplyExp(); + Mat().PowAbs(src.Mat(), power, include_sign); } } - + template -void CuMatrixBase::ApplyExpLimited(Real lower_limit, Real upper_limit) { +void CuMatrixBase::ExpLimited(const CuMatrixBase &src, Real lower_limit, Real upper_limit) { + KALDI_ASSERT(SameDim(*this, src)); KALDI_ASSERT(upper_limit > lower_limit); #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { @@ -2526,82 +2520,72 @@ void CuMatrixBase::ApplyExpLimited(Real lower_limit, Real upper_limit) { dim3 dimGrid, dimBlock; GetBlockSizesForSimpleMatrixOperation(NumRows(), NumCols(), &dimGrid, &dimBlock); - cuda_apply_exp_limited(dimGrid, dimBlock, data_, Dim(), lower_limit, upper_limit); + cuda_exp_limited(dimGrid, dimBlock, this->data_, src.data_, lower_limit, upper_limit, + this->Dim(), src.Stride()); CU_SAFE_CALL(cudaGetLastError()); CuDevice::Instantiate().AccuProfile(__func__, tim); } else #endif { - int32 num_rows = num_rows_, num_cols = num_cols_; - for (int32 r = 0; r < num_rows; r++) { - Real *row_data = this->RowData(r); - for (int32 c = 0; c < num_cols; c++) { - Real x = row_data[c]; - if (!(x >= lower_limit)) - x = lower_limit; - if (x > upper_limit) - x = upper_limit; - row_data[c] = Exp(x); - } - } + Mat().ExpLimited(src.Mat(), lower_limit, upper_limit); } } template -void CuMatrixBase::ApplyExpSpecial() { +void CuMatrixBase::ExpSpecial(const CuMatrixBase &src) { + KALDI_ASSERT(SameDim(*this, src)); #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { CuTimer tim; - - const int warpSize = 32; - dim3 dimBlock(CU1DBLOCK / warpSize, warpSize); - dim3 dimGrid(n_blocks(NumRows(), dimBlock.x), - n_blocks(NumCols(), dimBlock.y)); - - cuda_apply_exp_special(dimGrid, dimBlock, Data(), Dim(), Data(), Stride()); + dim3 dimGrid, dimBlock; + GetBlockSizesForSimpleMatrixOperation(NumRows(), NumCols(), + &dimGrid, &dimBlock); + cuda_exp_special(dimGrid, dimBlock, this->data_, src.data_, Dim(), src.Stride()); CU_SAFE_CALL(cudaGetLastError()); CuDevice::Instantiate().AccuProfile(__func__, tim); } else #endif { - Mat().ApplyExpSpecial(); + Mat().ExpSpecial(src.Mat()); } } template -void CuMatrixBase::ApplyFloor(Real floor_val) { +void CuMatrixBase::Floor(const CuMatrixBase &src, Real floor_val) { + KALDI_ASSERT(SameDim(*this, src)); #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { CuTimer tim; dim3 dimGrid, dimBlock; GetBlockSizesForSimpleMatrixOperation(NumRows(), NumCols(), &dimGrid, &dimBlock); - cuda_apply_floor(dimGrid, dimBlock, data_, floor_val, Dim()); + cuda_floor(dimGrid, dimBlock, data_, src.data_, floor_val, this->Dim(), src.Stride()); CU_SAFE_CALL(cudaGetLastError()); CuDevice::Instantiate().AccuProfile(__func__, tim); } else #endif { - Mat().ApplyFloor(floor_val); + Mat().Floor(src.Mat(), floor_val); } } template -void CuMatrixBase::ApplyCeiling(Real ceiling_val) { +void CuMatrixBase::Ceiling(const CuMatrixBase &src, Real ceiling_val) { + KALDI_ASSERT(SameDim(*this, src)); #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { CuTimer tim; dim3 dimGrid, dimBlock; GetBlockSizesForSimpleMatrixOperation(NumRows(), NumCols(), &dimGrid, &dimBlock); - cuda_apply_ceiling(dimGrid, dimBlock, data_, ceiling_val, Dim()); + cuda_ceiling(dimGrid, dimBlock, this->data_, src.data_, ceiling_val, this->Dim(), src.Stride()); CU_SAFE_CALL(cudaGetLastError()); CuDevice::Instantiate().AccuProfile(__func__, tim); } else #endif { - Mat().ApplyCeiling(ceiling_val); + Mat().Ceiling(src.Mat(), ceiling_val); } } diff --git a/src/cudamatrix/cu-matrix.h b/src/cudamatrix/cu-matrix.h index 85aa4c049e7..7bc7e1806a1 100644 --- a/src/cudamatrix/cu-matrix.h +++ b/src/cudamatrix/cu-matrix.h @@ -6,6 +6,7 @@ // 2013 Xiaohui Zhang // 2013-2015 Guoguo Chen // 2017 Shiyin Kang +// 2019 Yiwen Shao // See ../../COPYING for clarification regarding multiple authors // @@ -283,6 +284,48 @@ class CuMatrixBase { /// in general, there are different ways to deal with the situation when x==0.] void Heaviside(const CuMatrixBase &src); + void Exp(const CuMatrixBase &src); + + void Log(const CuMatrixBase &src); + + void Pow(const CuMatrixBase &src, Real power); + + /// Apply power to the absolute value of each element. + /// If include_sign is true, the result will be multiplied with + /// the sign of the input value. + /// If the power is negative and the input to the power is zero, + /// The output will be set zero. If include_sign is true, it will + /// multiply the result by the sign of the input. + void PowAbs(const CuMatrixBase &src, Real power, bool include_sign=false); + + void Floor(const CuMatrixBase &src, Real floor_val); + + void Ceiling(const CuMatrixBase &src, Real ceiling_val); + + /// This is equivalent to running: + /// Floor(src, lower_limit); + /// Ceiling(src, upper_limit); + /// Exp(src) + void ExpLimited(const CuMatrixBase &src, Real lower_limit, Real upper_limit); + + /// For each element x of the matrix, set it to + /// (x < 0 ? exp(x) : x + 1). This function is used + /// in our RNNLM training. + void ExpSpecial(const CuMatrixBase &src); + + /// Softmax nonlinearity + /// Y = Softmax(X) : Yij = e^Xij / sum_k(e^Xik), done to each row, + /// with attention to avoiding overflow or underflow. + /// Supports in-place operation (i.e. this == &src). + void SoftMaxPerRow(const CuMatrixBase &src); + + /// LogSoftmax nonlinearity + /// Y = LogSoftmax(X) : Yij = Xij - log(sum_k(e^Xik)), done to each row, + /// with attention to avoiding overflow or underflow. + /// Supports in-place operation (i.e. this == &src). + void LogSoftMaxPerRow(const CuMatrixBase &src); + + /// Apply the function y = log(1 + exp(x)), to each element. /// Note: the derivative of this function is the sigmoid function. /// This is like a soft ReLU. @@ -384,44 +427,51 @@ class CuMatrixBase { /// The output is symmetric. void SymInvertPosDef(); - void ApplyPow(Real power); - /// Apply power to the absolute value of each element. - /// If include_sign is true, the result will be multiplied with - /// the sign of the input value. - /// If the power is negative and the input to the power is zero, - /// The output will be set zero. If include_sign is true, it will - /// multiply the result by the sign of the input. - void ApplyPowAbs(Real power, bool include_sign=false); - /// For each element, sets x = (x > 0 ? 1.0 : 0.0). - /// See also Heaviside(). - void ApplyHeaviside(); - void ApplyFloor(Real floor_val); - void ApplyCeiling(Real ceiling_val); - void ApplyExp(); - - - /// This is equivalent to running: - /// ApplyFloor(lower_limit); - /// ApplyCeiling(upper_limit); - /// ApplyExp() - void ApplyExpLimited(Real lower_limit, Real upper_limit); - - /// For each element x of the matrix, set it to - /// (x < 0 ? exp(x) : x + 1). This function is used - /// in our RNNLM training. - void ApplyExpSpecial(); - - /// Softmax nonlinearity - /// Y = Softmax(X) : Yij = e^Xij / sum_k(e^Xik), done to each row, - /// with attention to avoiding overflow or underflow. - /// Supports in-place operation (i.e. this == &src). - void ApplySoftMaxPerRow(const CuMatrixBase &src); - - /// LogSoftmax nonlinearity - /// Y = LogSoftmax(X) : Yij = Xij - log(sum_k(e^Xik)), done to each row, - /// with attention to avoiding overflow or underflow. - /// Supports in-place operation (i.e. this == &src). - void ApplyLogSoftMaxPerRow(const CuMatrixBase &src); + inline void ApplyPow(Real power) { + this -> Pow(*this, power); + }; + + + inline void ApplyPowAbs(Real power, bool include_sign=false) { + this -> PowAbs(*this, power, include_sign); + }; + + inline void ApplyHeaviside() { + this -> Heaviside(*this); + }; + + inline void ApplyFloor(Real floor_val) { + this -> Floor(*this, floor_val); + }; + + inline void ApplyCeiling(Real ceiling_val) { + this -> Ceiling(*this, ceiling_val); + }; + + inline void ApplyExp() { + this -> Exp(*this); + }; + + + inline void ApplyExpLimited(Real lower_limit, Real upper_limit) { + this -> ExpLimited(*this, lower_limit, upper_limit); + }; + + inline void ApplyExpSpecial() { + this -> ExpSpecial(*this); + }; + + inline void ApplySoftMaxPerRow() { + this -> SoftMaxPerRow(*this); + }; + + inline void ApplyLogSoftMaxPerRow() { + this -> LogSoftMaxPerRow(*this); + }; + + inline void ApplyLog() { + this -> Log(*this); + }; /// Find the id of the maximal element for each row (resizes the 'id' /// array to the appropriate size). @@ -434,7 +484,6 @@ class CuMatrixBase { /// Zeroes all elements for which col > row. void SetZeroAboveDiag(); void Scale(Real value); - void ApplyLog(); /// Multiply two matrices elementwise: C = C .* A void MulElements(const CuMatrixBase &A); diff --git a/src/cudamatrix/cu-vector.cc b/src/cudamatrix/cu-vector.cc index 5f030e7ca03..5ee5d578511 100644 --- a/src/cudamatrix/cu-vector.cc +++ b/src/cudamatrix/cu-vector.cc @@ -4,6 +4,7 @@ // 2012-2014 Johns Hopkins University (author: Daniel Povey) // 2017 Daniel Galvez // 2016-2018 Shiyin Kang +// 2019 Yiwen Shao // See ../../COPYING for clarification regarding multiple authors // @@ -349,7 +350,7 @@ void CuVectorBase::ApplySoftMax() { } template -void CuVectorBase::ApplyFloor(Real floor_val, MatrixIndexT *floored_count) { +void CuVectorBase::Floor(const CuVectorBase &src, Real floor_val, MatrixIndexT *floored_count) { #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { int dimBlock(CU1DBLOCK); @@ -360,8 +361,8 @@ void CuVectorBase::ApplyFloor(Real floor_val, MatrixIndexT *floored_count) // We are calling a function meant for matrices, by viewing the // vector as a matrix with a single row. ::MatrixDim dim = {1, Dim(), 1}; - cuda_apply_floor(dimGrid, dimBlock, data_, floor_val, dim); - CuDevice::Instantiate().AccuProfile("CuVectorBase::ApplyFloorNoCount", tim); + cuda_floor(dimGrid, dimBlock, this->data_, src.Data(), floor_val, dim, 1); + CuDevice::Instantiate().AccuProfile("CuVectorBase::FloorNoCount", tim); } else { if (dim_ == 0) { *floored_count = 0; return; } CuTimer tim; @@ -371,17 +372,18 @@ void CuVectorBase::ApplyFloor(Real floor_val, MatrixIndexT *floored_count) cuda_vec_apply_floor(dimGrid, dimBlock, data_, floor_val, count_vec.Data(), dim_); CU_SAFE_CALL(cudaGetLastError()); *floored_count = count_vec.Sum(); - CuDevice::Instantiate().AccuProfile("CuVectorBase::ApplyFloor", tim); + CuDevice::Instantiate().AccuProfile("CuVectorBase::Floor", tim); } } else #endif { - Vec().ApplyFloor(floor_val, floored_count); + Vec().Floor(src.Vec(), floor_val, floored_count); } } template -void CuVectorBase::ApplyCeiling(Real ceiling_val, MatrixIndexT *ceiled_count) { +void CuVectorBase::Ceiling(const CuVectorBase &src, Real ceiling_val, + MatrixIndexT *ceiled_count) { #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { int dimBlock(CU1DBLOCK); @@ -392,9 +394,9 @@ void CuVectorBase::ApplyCeiling(Real ceiling_val, MatrixIndexT *ceiled_cou // We are calling a function meant for matrices, by viewing the // vector as a matrix with a single row. ::MatrixDim dim = {1, Dim(), 1}; - cuda_apply_ceiling(dimGrid, dimBlock, data_, ceiling_val, dim); + cuda_ceiling(dimGrid, dimBlock, this->data_, src.Data(), ceiling_val, dim, 1); - CuDevice::Instantiate().AccuProfile("CuVectorBase::ApplyCeilingNoCount", tim); + CuDevice::Instantiate().AccuProfile("CuVectorBase::CeilingNoCount", tim); } else { if (dim_ == 0) { *ceiled_count = 0; return; } CuTimer tim; @@ -404,17 +406,17 @@ void CuVectorBase::ApplyCeiling(Real ceiling_val, MatrixIndexT *ceiled_cou cuda_vec_apply_ceiling(dimGrid, dimBlock, data_, ceiling_val, count_vec.Data(), dim_); CU_SAFE_CALL(cudaGetLastError()); *ceiled_count = count_vec.Sum(); - CuDevice::Instantiate().AccuProfile("CuVectorBase::ApplyCeiling", tim); + CuDevice::Instantiate().AccuProfile("CuVectorBase::Ceiling", tim); } } else #endif { - Vec().ApplyCeiling(ceiling_val, ceiled_count); + Vec().Ceiling(src.Vec(), ceiling_val, ceiled_count); } } template -void CuVectorBase::ApplyPow(Real power) { +void CuVectorBase::Pow(const CuVectorBase &src, Real power) { #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { if (dim_ == 0) return; @@ -425,13 +427,13 @@ void CuVectorBase::ApplyPow(Real power) { dim3 dimGrid(n_blocks(Dim(), CU1DBLOCK), 1); ::MatrixDim fake_matrix_dim = { 1, Dim(), 1 }; // num_cols is Dim(), num_rows is 1, stride is 1 (it's a don't-care). - cuda_apply_pow(dimGrid, dimBlock, data_, power, fake_matrix_dim); + cuda_pow(dimGrid, dimBlock, this->data_, src.Data(), power, fake_matrix_dim, 1); CU_SAFE_CALL(cudaGetLastError()); CuDevice::Instantiate().AccuProfile("CuVectorBase::ApplyPow", tim); } else #endif { - Vec().ApplyPow(power); + Vec().Pow(src.Vec(), power); } } diff --git a/src/cudamatrix/cu-vector.h b/src/cudamatrix/cu-vector.h index d769b614f86..9c532b52f39 100644 --- a/src/cudamatrix/cu-vector.h +++ b/src/cudamatrix/cu-vector.h @@ -6,6 +6,7 @@ // 2013 Xiaohui Zhang // 2015 Guoguo Chen // 2017 Daniel Galvez +// 2019 Yiwen Shao // See ../../COPYING for clarification regarding multiple authors // @@ -131,13 +132,26 @@ class CuVectorBase { const MatrixTransposeType trans, const CuArrayBase &elements); + void Floor(const CuVectorBase &src, Real floor_val, MatrixIndexT *floored_count = NULL); + void Ceiling(const CuVectorBase &src, Real ceiling_val, MatrixIndexT *ceiled_count = NULL); + void Pow(const CuVectorBase &src, Real power); + + inline void ApplyFloor(Real floor_val, MatrixIndexT *floored_count = NULL) { + this -> Floor(*this, floor_val, floored_count); + }; + + inline void ApplyCeiling(Real ceiling_val, MatrixIndexT *ceiled_count = NULL) { + this -> Ceiling(*this, ceiling_val, ceiled_count); + }; + + inline void ApplyPow(Real power) { + this -> Pow(*this, power); + }; + void ApplySoftMax(); void ApplyLogSoftMax(); void ApplyExp(); void ApplyLog(); - void ApplyFloor(Real floor_val, MatrixIndexT *floored_count = NULL); - void ApplyCeiling(Real ceiling_val, MatrixIndexT *ceiled_count = NULL); - void ApplyPow(Real power); Real Sum() const; void SetRandn(); diff --git a/src/matrix/kaldi-matrix.cc b/src/matrix/kaldi-matrix.cc index fcfe0616b64..2354e792fa4 100644 --- a/src/matrix/kaldi-matrix.cc +++ b/src/matrix/kaldi-matrix.cc @@ -5,6 +5,7 @@ // Yanmin Qian; Petr Schwarz; Jan Silovsky; // Haihua Xu // 2017 Shiyin Kang +// 2019 Yiwen Shao // See ../../COPYING for clarification regarding multiple authors // @@ -93,7 +94,7 @@ void MatrixBase::Invert(Real *log_det, Real *det_sign, prod *= (*this)(i, i); if (i == num_rows_ - 1 || std::fabs(prod) < 1.0e-10 || std::fabs(prod) > 1.0e+10) { - if (log_det != NULL) *log_det += Log(std::fabs(prod)); + if (log_det != NULL) *log_det += kaldi::Log(std::fabs(prod)); if (det_sign != NULL) *det_sign *= (prod > 0 ? 1.0 : -1.0); prod = 1.0; } @@ -2098,90 +2099,135 @@ void Matrix::Transpose() { } template -void MatrixBase::ApplyFloor(Real floor_val) { +void MatrixBase::Heaviside(const MatrixBase &src) { + KALDI_ASSERT(SameDim(*this, src)); MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; - for (MatrixIndexT i = 0; i < num_rows; i++) { - Real *data = this->RowData(i); - for (MatrixIndexT j = 0; j < num_cols; j++) - data[j] = (data[j] < floor_val ? floor_val : data[j]); + Real *row_data = data_; + const Real *src_row_data = src.Data(); + for (MatrixIndexT row = 0; row < num_rows; + row++,row_data += stride_, src_row_data += src.stride_) { + for (MatrixIndexT col = 0; col < num_cols; col++) + row_data[col] = (src_row_data[col] > 0 ? 1.0 : 0.0); } } template -void MatrixBase::ApplyCeiling(Real ceiling_val) { +void MatrixBase::Exp(const MatrixBase &src) { + KALDI_ASSERT(SameDim(*this, src)); MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; - for (MatrixIndexT i = 0; i < num_rows; i++) { - Real *data = this->RowData(i); - for (MatrixIndexT j = 0; j < num_cols; j++) - data[j] = (data[j] > ceiling_val ? ceiling_val : data[j]); + Real *row_data = data_; + const Real *src_row_data = src.Data(); + for (MatrixIndexT row = 0; row < num_rows; + row++,row_data += stride_, src_row_data += src.stride_) { + for (MatrixIndexT col = 0; col < num_cols; col++) + row_data[col] = kaldi::Exp(src_row_data[col]); } } template -void MatrixBase::ApplyLog() { - for (MatrixIndexT i = 0; i < num_rows_; i++) { - Row(i).ApplyLog(); +void MatrixBase::Pow(const MatrixBase &src, Real power) { + KALDI_ASSERT(SameDim(*this, src)); + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; + Real *row_data = data_; + const Real *src_row_data = src.Data(); + for (MatrixIndexT row = 0; row < num_rows; + row++,row_data += stride_, src_row_data += src.stride_) { + for (MatrixIndexT col = 0; col < num_cols; col++) { + row_data[col] = pow(src_row_data[col], power); + } } } template -void MatrixBase::ApplyExp() { - for (MatrixIndexT i = 0; i < num_rows_; i++) { - Row(i).ApplyExp(); +void MatrixBase::PowAbs(const MatrixBase &src, Real power, bool include_sign) { + KALDI_ASSERT(SameDim(*this, src)); + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; + Real *row_data = data_; + const Real *src_row_data = src.Data(); + for (MatrixIndexT row = 0; row < num_rows; + row++,row_data += stride_, src_row_data += src.stride_) { + for (MatrixIndexT col = 0; col < num_cols; col ++) { + if (include_sign == true && src_row_data[col] < 0) { + row_data[col] = -pow(std::abs(src_row_data[col]), power); + } else { + row_data[col] = pow(std::abs(src_row_data[col]), power); + } + } } } template -void MatrixBase::ApplyExpSpecial() { - int32 num_rows = num_rows_, num_cols = num_cols_, - stride = stride_; - Real *data = data_; - for (MatrixIndexT i = 0; i < num_rows; ++i) { - for (MatrixIndexT j = 0; j < num_cols; ++j) { - Real &x = *(data + j + stride * i); - x = x < Real(0) ? Exp(x) : x + Real(1); - } +void MatrixBase::Floor(const MatrixBase &src, Real floor_val) { + KALDI_ASSERT(SameDim(*this, src)); + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; + Real *row_data = data_; + const Real *src_row_data = src.Data(); + for (MatrixIndexT row = 0; row < num_rows; + row++,row_data += stride_, src_row_data += src.stride_) { + for (MatrixIndexT col = 0; col < num_cols; col++) + row_data[col] = (src_row_data[col] < floor_val ? floor_val : src_row_data[col]); } } template -void MatrixBase::ApplyPow(Real power) { - for (MatrixIndexT i = 0; i < num_rows_; i++) { - Row(i).ApplyPow(power); +void MatrixBase::Ceiling(const MatrixBase &src, Real ceiling_val) { + KALDI_ASSERT(SameDim(*this, src)); + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; + Real *row_data = data_; + const Real *src_row_data = src.Data(); + for (MatrixIndexT row = 0; row < num_rows; + row++,row_data += stride_, src_row_data += src.stride_) { + for (MatrixIndexT col = 0; col < num_cols; col++) + row_data[col] = (src_row_data[col] > ceiling_val ? ceiling_val : src_row_data[col]); } } template -void MatrixBase::ApplyPowAbs(Real power, bool include_sign) { - for (MatrixIndexT i = 0; i < num_rows_; i++) { - Row(i).ApplyPowAbs(power, include_sign); +void MatrixBase::Log(const MatrixBase &src) { + KALDI_ASSERT(SameDim(*this, src)); + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; + Real *row_data = data_; + const Real *src_row_data = src.Data(); + for (MatrixIndexT row = 0; row < num_rows; + row++,row_data += stride_, src_row_data += src.stride_) { + for (MatrixIndexT col = 0; col < num_cols; col++) + row_data[col] = kaldi::Log(src_row_data[col]); } } template -void MatrixBase::ApplyHeaviside() { +void MatrixBase::ExpSpecial(const MatrixBase &src) { + KALDI_ASSERT(SameDim(*this, src)); MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; - for (MatrixIndexT i = 0; i < num_rows; i++) { - Real *data = this->RowData(i); - for (MatrixIndexT j = 0; j < num_cols; j++) - data[j] = (data[j] > 0 ? 1.0 : 0.0); + Real *row_data = data_; + const Real *src_row_data = src.Data(); + for (MatrixIndexT row = 0; row < num_rows; + row++,row_data += stride_, src_row_data += src.stride_) { + for (MatrixIndexT col = 0; col < num_cols; col++) + row_data[col] = (src_row_data[col] < Real(0) ? kaldi::Exp(src_row_data[col]) : (src_row_data[col] + Real(1))); } } template -void MatrixBase::Heaviside(const MatrixBase &src) { +void MatrixBase::ExpLimited(const MatrixBase &src, Real lower_limit, Real upper_limit) { KALDI_ASSERT(SameDim(*this, src)); MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; Real *row_data = data_; const Real *src_row_data = src.Data(); for (MatrixIndexT row = 0; row < num_rows; row++,row_data += stride_, src_row_data += src.stride_) { - for (MatrixIndexT col = 0; col < num_cols; col++) - row_data[col] = (src_row_data[col] > 0 ? 1.0 : 0.0); + for (MatrixIndexT col = 0; col < num_cols; col++) { + const Real x = src_row_data[col]; + if (!(x >= lower_limit)) + row_data[col] = kaldi::Exp(lower_limit); + else if (x > upper_limit) + row_data[col] = kaldi::Exp(upper_limit); + else + row_data[col] = kaldi::Exp(x); + } } } - template bool MatrixBase::Power(Real power) { KALDI_ASSERT(num_rows_ > 0 && num_rows_ == num_cols_); @@ -2695,10 +2741,10 @@ Real MatrixBase::LogSumExp(Real prune) const { for (MatrixIndexT j = 0; j < num_cols_; j++) { BaseFloat f = (*this)(i, j); if (f >= cutoff) - sum_relto_max_elem += Exp(f - max_elem); + sum_relto_max_elem += kaldi::Exp(f - max_elem); } } - return max_elem + Log(sum_relto_max_elem); + return max_elem + kaldi::Log(sum_relto_max_elem); } template @@ -2707,9 +2753,9 @@ Real MatrixBase::ApplySoftMax() { // the 'max' helps to get in good numeric range. for (MatrixIndexT i = 0; i < num_rows_; i++) for (MatrixIndexT j = 0; j < num_cols_; j++) - sum += ((*this)(i, j) = Exp((*this)(i, j) - max)); + sum += ((*this)(i, j) = kaldi::Exp((*this)(i, j) - max)); this->Scale(1.0 / sum); - return max + Log(sum); + return max + kaldi::Log(sum); } template @@ -2739,7 +2785,7 @@ void MatrixBase::SoftHinge(const MatrixBase &src) { Real x = src_row_data[c], y; if (x > 10.0) y = x; // avoid exponentiating large numbers; function // approaches y=x. - else y = Log1p(Exp(x)); // these defined in kaldi-math.h + else y = Log1p(kaldi::Exp(x)); // these defined in kaldi-math.h row_data[c] = y; } } diff --git a/src/matrix/kaldi-matrix.h b/src/matrix/kaldi-matrix.h index d7ee8eb388f..5ee60f63fdf 100644 --- a/src/matrix/kaldi-matrix.h +++ b/src/matrix/kaldi-matrix.h @@ -4,6 +4,7 @@ // Saarland University; Petr Schwarz; Yanmin Qian; // Karel Vesely; Go Vivace Inc.; Haihua Xu // 2017 Shiyin Kang +// 2019 Yiwen Shao // See ../../COPYING for clarification regarding multiple authors // @@ -337,37 +338,42 @@ class MatrixBase { const MatrixIndexT *indexes, MatrixBase *dst) const; - /// Applies floor to all matrix elements - void ApplyFloor(Real floor_val); - - /// Applies floor to all matrix elements - void ApplyCeiling(Real ceiling_val); - - /// Calculates log of all the matrix elemnts - void ApplyLog(); - - /// Exponentiate each of the elements. - void ApplyExp(); - - /// For each element x of the matrix, set it to - /// (x < 0 ? exp(x) : x + 1). This function is used - /// in our RNNLM training. - void ApplyExpSpecial(); - - /// Applies power to all matrix elements - void ApplyPow(Real power); - - /// Apply power to the absolute value of each element. - /// Include the sign of the input element if include_sign == true. - /// If the power is negative and the input to the power is zero, - /// The output will be set zero. - void ApplyPowAbs(Real power, bool include_sign=false); - - /// Applies the Heaviside step function (x > 0 ? 1 : 0) to all matrix elements - /// Note: in general you can make different choices for x = 0, but for now - /// please leave it as it (i.e. returning zero) because it affects the - /// RectifiedLinearComponent in the neural net code. - void ApplyHeaviside(); + inline void ApplyPow(Real power) { + this -> Pow(*this, power); + }; + + + inline void ApplyPowAbs(Real power, bool include_sign=false) { + this -> PowAbs(*this, power, include_sign); + }; + + inline void ApplyHeaviside() { + this -> Heaviside(*this); + }; + + inline void ApplyFloor(Real floor_val) { + this -> Floor(*this, floor_val); + }; + + inline void ApplyCeiling(Real ceiling_val) { + this -> Ceiling(*this, ceiling_val); + }; + + inline void ApplyExp() { + this -> Exp(*this); + }; + + inline void ApplyExpSpecial() { + this -> ExpSpecial(*this); + }; + + inline void ApplyExpLimited(Real lower_limit, Real upper_limit) { + this -> ExpLimited(*this, lower_limit, upper_limit); + }; + + inline void ApplyLog() { + this -> Log(*this); + }; /// Eigenvalue Decomposition of a square NxN matrix into the form (*this) = P D /// P^{-1}. Be careful: the relationship of D to the eigenvalues we output is @@ -483,6 +489,35 @@ class MatrixBase { /// because it affects the RectifiedLinearComponent in the neural net code. void Heaviside(const MatrixBase &src); + void Exp(const MatrixBase &src); + + void Pow(const MatrixBase &src, Real power); + + void Log(const MatrixBase &src); + + /// Apply power to the absolute value of each element. + /// If include_sign is true, the result will be multiplied with + /// the sign of the input value. + /// If the power is negative and the input to the power is zero, + /// The output will be set zero. If include_sign is true, it will + /// multiply the result by the sign of the input. + void PowAbs(const MatrixBase &src, Real power, bool include_sign=false); + + void Floor(const MatrixBase &src, Real floor_val); + + void Ceiling(const MatrixBase &src, Real ceiling_val); + + /// For each element x of the matrix, set it to + /// (x < 0 ? exp(x) : x + 1). This function is used + /// in our RNNLM training. + void ExpSpecial(const MatrixBase &src); + + /// This is equivalent to running: + /// Floor(src, lower_limit); + /// Ceiling(src, upper_limit); + /// Exp(src) + void ExpLimited(const MatrixBase &src, Real lower_limit, Real upper_limit); + /// Set each element to y = log(1 + exp(x)) void SoftHinge(const MatrixBase &src); diff --git a/src/matrix/kaldi-vector.cc b/src/matrix/kaldi-vector.cc index c8ea35112ea..2671bf5224b 100644 --- a/src/matrix/kaldi-vector.cc +++ b/src/matrix/kaldi-vector.cc @@ -6,7 +6,7 @@ // Haihua Xu; Wei Shi // 2015 Guoguo Chen // 2017 Daniel Galvez - +// 2019 Yiwen Shao // See ../../COPYING for clarification regarding multiple authors // @@ -448,32 +448,20 @@ void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT #ifdef HAVE_MKL template<> -void VectorBase::ApplyPow(float power) { vsPowx(dim_, data_, power, data_); } +void VectorBase::Pow(const VectorBase &v, float power) { + vsPowx(dim_, data_, power, v.data_); +} template<> -void VectorBase::ApplyPow(double power) { vdPowx(dim_, data_, power, data_); } +void VectorBase::Pow(const VectorBase &v, double power) { + vdPowx(dim_, data_, power, v.data_); +} #else -// takes elements to a power. Throws exception if could not (but only for power != 1 and power != 2). + +// takes elements to a power. Does not check output. template -void VectorBase::ApplyPow(Real power) { - if (power == 1.0) return; - if (power == 2.0) { - for (MatrixIndexT i = 0; i < dim_; i++) - data_[i] = data_[i] * data_[i]; - } else if (power == 0.5) { - for (MatrixIndexT i = 0; i < dim_; i++) { - if (!(data_[i] >= 0.0)) - KALDI_ERR << "Cannot take square root of negative value " - << data_[i]; - data_[i] = std::sqrt(data_[i]); - } - } else { - for (MatrixIndexT i = 0; i < dim_; i++) { - data_[i] = pow(data_[i], power); - if (data_[i] == HUGE_VAL) { // HUGE_VAL is what errno returns on error. - KALDI_ERR << "Could not raise element " << i << " to power " - << power << ": returned value = " << data_[i]; - } - } +void VectorBase::Pow(const VectorBase &v, Real power) { + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] = pow(v.data_[i], power); } } #endif @@ -814,17 +802,19 @@ void VectorBase::ApplyAbs() { } template -void VectorBase::ApplyFloor(Real floor_val, MatrixIndexT *floored_count) { +void VectorBase::Floor(const VectorBase &v, Real floor_val, MatrixIndexT *floored_count) { if (floored_count == nullptr) { for (MatrixIndexT i = 0; i < dim_; i++) { - data_[i] = std::max(data_[i], floor_val); + data_[i] = std::max(v.data_[i], floor_val); } } else { MatrixIndexT num_floored = 0; for (MatrixIndexT i = 0; i < dim_; i++) { - if (data_[i] < floor_val) { + if (v.data_[i] < floor_val) { data_[i] = floor_val; num_floored++; + } else { + data_[i] = v.data_[i]; } } *floored_count = num_floored; @@ -832,17 +822,19 @@ void VectorBase::ApplyFloor(Real floor_val, MatrixIndexT *floored_count) { } template -void VectorBase::ApplyCeiling(Real ceil_val, MatrixIndexT *ceiled_count) { +void VectorBase::Ceiling(const VectorBase &v, Real ceil_val, MatrixIndexT *ceiled_count) { if (ceiled_count == nullptr) { for (MatrixIndexT i = 0; i < dim_; i++) { - data_[i] = std::min(data_[i], ceil_val); + data_[i] = std::min(v.data_[i], ceil_val); } } else { MatrixIndexT num_changed = 0; for (MatrixIndexT i = 0; i < dim_; i++) { - if (data_[i] > ceil_val) { + if (v.data_[i] > ceil_val) { data_[i] = ceil_val; num_changed++; + } else { + data_[i] = v.data_[i]; } } *ceiled_count = num_changed; diff --git a/src/matrix/kaldi-vector.h b/src/matrix/kaldi-vector.h index 383d8ca2862..6097ff07f20 100644 --- a/src/matrix/kaldi-vector.h +++ b/src/matrix/kaldi-vector.h @@ -7,6 +7,7 @@ // Wei Shi; // 2015 Guoguo Chen // 2017 Daniel Galvez +// 2019 Yiwen Shao // See ../../COPYING for clarification regarding multiple authors // @@ -119,6 +120,15 @@ class VectorBase { template void CopyFromVec(const CuVectorBase &v); + /// Applies floor to all elements. Returns number of elements + /// floored in floored_count if it is non-null. + void Floor(const VectorBase &v, Real floor_val, MatrixIndexT *floored_count = nullptr); + + /// Applies ceiling to all elements. Returns number of elements + /// changed in ceiled_count if it is non-null. + void Ceiling(const VectorBase &v, Real ceil_val, MatrixIndexT *ceiled_count = nullptr); + + void Pow(const VectorBase &v, Real power); /// Apply natural log to all elements. Throw if any element of /// the vector is negative (but doesn't complain about zero; the @@ -136,11 +146,15 @@ class VectorBase { /// Applies floor to all elements. Returns number of elements /// floored in floored_count if it is non-null. - void ApplyFloor(Real floor_val, MatrixIndexT *floored_count = nullptr); + inline void ApplyFloor(Real floor_val, MatrixIndexT *floored_count = nullptr) { + this->Floor(*this, floor_val, floored_count); + }; /// Applies ceiling to all elements. Returns number of elements /// changed in ceiled_count if it is non-null. - void ApplyCeiling(Real ceil_val, MatrixIndexT *ceiled_count = nullptr); + inline void ApplyCeiling(Real ceil_val, MatrixIndexT *ceiled_count = nullptr) { + this->Ceiling(*this, ceil_val, ceiled_count); + }; /// Applies floor to all elements. Returns number of elements floored. MatrixIndexT ApplyFloor(const VectorBase &floor_vec); @@ -162,7 +176,9 @@ class VectorBase { void Sigmoid(const VectorBase &src); /// Take all elements of vector to a power. - void ApplyPow(Real power); + inline void ApplyPow(Real power) { + this->Pow(*this, power); + }; /// Take the absolute value of all elements of a vector to a power. /// Include the sign of the input element if include_sign == true. diff --git a/src/nnet/nnet-activation.h b/src/nnet/nnet-activation.h index 74b0ebad650..ad9acac26bc 100644 --- a/src/nnet/nnet-activation.h +++ b/src/nnet/nnet-activation.h @@ -49,7 +49,7 @@ class Softmax : public Component { void PropagateFnc(const CuMatrixBase &in, CuMatrixBase *out) { // y = e^x_j/sum_j(e^x_j) - out->ApplySoftMaxPerRow(in); + out->SoftMaxPerRow(in); } void BackpropagateFnc(const CuMatrixBase &in, @@ -81,7 +81,7 @@ class HiddenSoftmax : public Component { void PropagateFnc(const CuMatrixBase &in, CuMatrixBase *out) { // y = e^x_j/sum_j(e^x_j) - out->ApplySoftMaxPerRow(in); + out->SoftMaxPerRow(in); } void BackpropagateFnc(const CuMatrixBase &in, @@ -167,7 +167,7 @@ class BlockSoftmax : public Component { CuSubMatrix out_bl = out->ColRange(block_offset[bl], block_dims[bl]); // y = e^x_j/sum_j(e^x_j), - out_bl.ApplySoftMaxPerRow(in_bl); + out_bl.SoftMaxPerRow(in_bl); } } diff --git a/src/nnet2/nnet-component.cc b/src/nnet2/nnet-component.cc index eafeaceb9fe..f0919acfac8 100644 --- a/src/nnet2/nnet-component.cc +++ b/src/nnet2/nnet-component.cc @@ -909,7 +909,7 @@ void SoftmaxComponent::Propagate(const ChunkInfo &in_info, // for that row, we do // x_i = exp(x_i) / sum_j exp(x_j). - out->ApplySoftMaxPerRow(in); + out->SoftMaxPerRow(in); // This floor on the output helps us deal with // almost-zeros in a way that doesn't lead to overflow. @@ -956,7 +956,7 @@ void LogSoftmaxComponent::Propagate(const ChunkInfo &in_info, // Applies log softmax function to each row of the output. For each row, we do // x_i = x_i - log(sum_j exp(x_j)) - out->ApplyLogSoftMaxPerRow(in); + out->LogSoftMaxPerRow(in); // Just to be consistent with SoftmaxComponent::Propagate() out->ApplyFloor(Log(1.0e-20)); diff --git a/src/nnet3/attention.cc b/src/nnet3/attention.cc index bd8cb6bf85c..ddfddbaf74a 100644 --- a/src/nnet3/attention.cc +++ b/src/nnet3/attention.cc @@ -133,7 +133,7 @@ void AttentionForward(BaseFloat key_scale, // compute the soft-max function. Up till this point, 'c' // actually contained what in attention.h we called 'b', which is // the input to the softmax. - c->ApplySoftMaxPerRow(*c); + c->SoftMaxPerRow(*c); // the part of the output that is weighted diff --git a/src/nnet3/nnet-simple-component.cc b/src/nnet3/nnet-simple-component.cc index 32f49745c0c..53c8d46578b 100644 --- a/src/nnet3/nnet-simple-component.cc +++ b/src/nnet3/nnet-simple-component.cc @@ -3548,7 +3548,7 @@ void* SoftmaxComponent::Propagate(const ComponentPrecomputedIndexes *indexes, // Apply softmax function to each row of the output... // for that row, we do // x_i = exp(x_i) / sum_j exp(x_j). - out->ApplySoftMaxPerRow(in); + out->SoftMaxPerRow(in); // This floor on the output helps us deal with // almost-zeros in a way that doesn't lead to overflow. @@ -3601,7 +3601,7 @@ void* LogSoftmaxComponent::Propagate(const ComponentPrecomputedIndexes *indexes, CuMatrixBase *out) const { // Applies log softmax function to each row of the output. For each row, we do // x_i = x_i - log(sum_j exp(x_j)) - out->ApplyLogSoftMaxPerRow(in); + out->LogSoftMaxPerRow(in); return NULL; } diff --git a/src/nnetbin/cuda-gpu-available.cc b/src/nnetbin/cuda-gpu-available.cc index 69637d3601a..2036ea82056 100644 --- a/src/nnetbin/cuda-gpu-available.cc +++ b/src/nnetbin/cuda-gpu-available.cc @@ -35,7 +35,7 @@ using namespace kaldi; void TestGpuComputation() { CuMatrix m(100,100); m.SetRandn(); - m.ApplySoftMaxPerRow(m); + m.SoftMaxPerRow(m); } #endif