Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
PeixuanZuo committed Feb 10, 2023
1 parent 3354bbd commit 10a7350
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 70 deletions.
14 changes: 7 additions & 7 deletions onnxruntime/core/providers/rocm/math/softmax_ck.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#endif
#endif // USE_COMPOSABLE_KERNEL

#include "core/providers/rocm/math/softmax_common.h"

Expand Down Expand Up @@ -41,21 +41,21 @@ using Nop = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 4;
constexpr int NumReduceDim = 1;

template <typename input_t, typename output_t, typename acc_t>
template <typename InputT, typename OutputT, typename AccT>
auto GetCKSoftmaxTypeStringAndOps() {
using InDataType = typename DataTypeAdaptor<input_t>::type;
using OutDataType = typename DataTypeAdaptor<output_t>::type;
using AccDataType = typename DataTypeAdaptor<acc_t>::type;
using InDataType = typename DataTypeAdaptor<InputT>::type;
using OutDataType = typename DataTypeAdaptor<OutputT>::type;
using AccDataType = typename DataTypeAdaptor<AccT>::type;
using DeviceSoftmax = ck::tensor_operation::device::
DeviceSoftmax<InDataType, AccDataType, OutDataType, Nop, Nop, Rank>;
using InstanceFactory = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<DeviceSoftmax>;

std::vector<std::pair<std::string, tunable::Op<SoftmaxParams<input_t, output_t>>>> ret;
std::vector<std::pair<std::string, tunable::Op<SoftmaxParams<InputT, OutputT>>>> ret;
for (auto&& impl : InstanceFactory::GetInstances()) {
auto type_string = onnxruntime::MakeString(impl->GetTypeString());
auto invoker = impl->MakeInvokerPointer();

auto ck_softmax_op = [impl = std::move(impl), invoker = std::move(invoker)](const SoftmaxParams<input_t, output_t>* params) -> Status {
auto ck_softmax_op = [impl = std::move(impl), invoker = std::move(invoker)](const SoftmaxParams<InputT, OutputT>* params) -> Status {
AccDataType alpha{1.0f};
AccDataType beta{0.0f};

Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/rocm/math/softmax_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
namespace onnxruntime {
namespace rocm {

template <typename input_t, typename output_t>
template <typename InputT, typename OutputT>
struct SoftmaxParams : onnxruntime::rocm::tunable::OpParams {
SoftmaxParams(hipStream_t stream, output_t* output, const input_t* input, int softmax_elements,
SoftmaxParams(hipStream_t stream, OutputT* output, const InputT* input, int softmax_elements,
int input_stride, int output_stride, int batch_count, bool is_log_softmax)
: OpParams(stream), output(output), input(input), softmax_elements(softmax_elements), input_stride(input_stride), output_stride(output_stride), batch_count(batch_count), is_log_softmax(is_log_softmax) {}

Expand All @@ -21,8 +21,8 @@ struct SoftmaxParams : onnxruntime::rocm::tunable::OpParams {
return sig;
}

output_t* output;
const input_t* input;
OutputT* output;
const InputT* input;
int softmax_elements;
int input_stride;
int output_stride;
Expand Down
52 changes: 26 additions & 26 deletions onnxruntime/core/providers/rocm/math/softmax_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
namespace onnxruntime {
namespace rocm {

template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
Status dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) {
template <typename InputT, typename OutputT, typename AccT, bool IsLogSoftmax>
Status dispatch_warpwise_softmax_forward(hipStream_t stream, OutputT* dst, const InputT* src, int softmax_elements, int softmax_elements_stride, int batch_count) {
if (softmax_elements == 0) {
return Status::OK();
} else {
Expand All @@ -53,9 +53,9 @@ Status dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, cons
switch (log2_elements) {
#define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) \
case L2E: \
softmax_warp_forward<input_t, output_t, acc_t, L2E, is_log_softmax> \
softmax_warp_forward<InputT, OutputT, AccT, L2E, IsLogSoftmax> \
<<<dim3(blocks), dim3(threads), 0, stream>>>(dst, src, batch_count, \
softmax_elements_stride, softmax_elements); \
softmax_elements_stride, softmax_elements); \
break;
LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1
LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2
Expand All @@ -75,43 +75,43 @@ Status dispatch_warpwise_softmax_forward(hipStream_t stream, output_t* dst, cons
return HIP_CALL(hipGetLastError());
}

#define SPECIALIZED_SOFTMAX_IMPL(input_t, output_t, acc_t) \
template Status dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, false>( \
hipStream_t stream, output_t * dst, const input_t* src, int softmax_elements, \
int softmax_elements_stride, int batch_count); \
template Status dispatch_warpwise_softmax_forward<input_t, output_t, acc_t, true>( \
hipStream_t stream, output_t * dst, const input_t* src, int softmax_elements, \
#define SPECIALIZED_SOFTMAX_IMPL(InputT, OutputT, AccT) \
template Status dispatch_warpwise_softmax_forward<InputT, OutputT, AccT, false>( \
hipStream_t stream, OutputT * dst, const InputT* src, int softmax_elements, \
int softmax_elements_stride, int batch_count); \
template Status dispatch_warpwise_softmax_forward<InputT, OutputT, AccT, true>( \
hipStream_t stream, OutputT * dst, const InputT* src, int softmax_elements, \
int softmax_elements_stride, int batch_count);

SPECIALIZED_SOFTMAX_IMPL(float, float, float)
SPECIALIZED_SOFTMAX_IMPL(half, half, float)
SPECIALIZED_SOFTMAX_IMPL(double, double, double)
SPECIALIZED_SOFTMAX_IMPL(BFloat16, BFloat16, float)

template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax>
Status dispatch_blockwise_softmax_forward(hipStream_t stream, output_t* output, const input_t* input, int softmax_elements,
template <typename InputT, typename OutputT, typename AccT, bool IsLogSoftmax>
Status dispatch_blockwise_softmax_forward(hipStream_t stream, OutputT* output, const InputT* input, int softmax_elements,
int input_stride, int output_stride, int batch_count) {
dim3 grid(batch_count);
constexpr int ILP = sizeof(float4) / sizeof(input_t);
constexpr int ILP = sizeof(float4) / sizeof(InputT);
dim3 block = SoftMax_getBlockSize(ILP, softmax_elements);
if (is_log_softmax) {
softmax_block_forward<ILP, input_t, acc_t, output_t, LogSoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(acc_t), stream>>>(output, const_cast<input_t*>(input),
softmax_elements, input_stride, output_stride);
if (IsLogSoftmax) {
softmax_block_forward<ILP, InputT, AccT, OutputT, LogSoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(AccT), stream>>>(output, const_cast<InputT*>(input),
softmax_elements, input_stride, output_stride);
} else {
softmax_block_forward<ILP, input_t, acc_t, output_t, SoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(acc_t), stream>>>(output, const_cast<input_t*>(input),
softmax_elements, input_stride, output_stride);
softmax_block_forward<ILP, InputT, AccT, OutputT, SoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(AccT), stream>>>(output, const_cast<InputT*>(input),
softmax_elements, input_stride, output_stride);
}
return HIP_CALL(hipGetLastError());
}

#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \
template Status dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, false>( \
hipStream_t stream, output_t * output, const input_t* input, int softmax_elements, \
int input_stride, int output_stride, int batch_count); \
template Status dispatch_blockwise_softmax_forward<input_t, output_t, acc_t, true>( \
hipStream_t stream, output_t * output, const input_t* input, int softmax_elements, \
#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(InputT, OutputT, AccT) \
template Status dispatch_blockwise_softmax_forward<InputT, OutputT, AccT, false>( \
hipStream_t stream, OutputT * output, const InputT* input, int softmax_elements, \
int input_stride, int output_stride, int batch_count); \
template Status dispatch_blockwise_softmax_forward<InputT, OutputT, AccT, true>( \
hipStream_t stream, OutputT * output, const InputT* input, int softmax_elements, \
int input_stride, int output_stride, int batch_count);

SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(float, float, float)
Expand Down
63 changes: 30 additions & 33 deletions onnxruntime/core/providers/rocm/math/softmax_tunable_op.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,63 +15,60 @@
namespace onnxruntime {
namespace rocm {

template <typename input_t, typename output_t, typename acc_t, int VecSize>
Status SoftmaxBlockwiseOp(const SoftmaxParams<input_t, output_t>* params) {
template <typename InputT, typename OutputT, typename AccT, int VecSize>
Status SoftmaxBlockwiseOp(const SoftmaxParams<InputT, OutputT>* params) {
dim3 grid(params->batch_count);
dim3 block = SoftMax_getBlockSize(VecSize, params->softmax_elements);
if (params->is_log_softmax) {
softmax_block_forward<VecSize, input_t, acc_t, output_t, LogSoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(acc_t), params->stream>>>(params->output, const_cast<input_t*>(params->input),
params->softmax_elements, params->input_stride,
params->output_stride);
softmax_block_forward<VecSize, InputT, AccT, OutputT, LogSoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(AccT), params->stream>>>(params->output, const_cast<InputT*>(params->input),
params->softmax_elements, params->input_stride,
params->output_stride);
} else {
softmax_block_forward<VecSize, input_t, acc_t, output_t, SoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(acc_t), params->stream>>>(params->output, const_cast<input_t*>(params->input),
params->softmax_elements, params->input_stride,
params->output_stride);
softmax_block_forward<VecSize, InputT, AccT, OutputT, SoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(AccT), params->stream>>>(params->output, const_cast<InputT*>(params->input),
params->softmax_elements, params->input_stride,
params->output_stride);
}
return HIP_CALL(hipGetLastError());
}

template <typename input_t, typename output_t, typename acc_t>
Status SoftmaxBlockwiseStaticSelection(const SoftmaxParams<input_t, output_t>* params) {
template <typename InputT, typename OutputT, typename AccT>
Status SoftmaxBlockwiseStaticSelection(const SoftmaxParams<InputT, OutputT>* params) {
dim3 grid(params->batch_count);
constexpr int ILP = sizeof(float4) / sizeof(input_t);
constexpr int ILP = sizeof(float4) / sizeof(InputT);
dim3 block = SoftMax_getBlockSize(ILP, params->softmax_elements);
if (params->is_log_softmax) {
softmax_block_forward<ILP, input_t, acc_t, output_t, LogSoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(acc_t), params->stream>>>(params->output, const_cast<input_t*>(params->input),
params->softmax_elements, params->input_stride,
params->output_stride);
softmax_block_forward<ILP, InputT, AccT, OutputT, LogSoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(AccT), params->stream>>>(params->output, const_cast<InputT*>(params->input),
params->softmax_elements, params->input_stride,
params->output_stride);
} else {
softmax_block_forward<ILP, input_t, acc_t, output_t, SoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(acc_t), params->stream>>>(params->output, const_cast<input_t*>(params->input),
params->softmax_elements, params->input_stride,
params->output_stride);
softmax_block_forward<ILP, InputT, AccT, OutputT, SoftMaxForwardEpilogue>
<<<grid, block, block.x * sizeof(AccT), params->stream>>>(params->output, const_cast<InputT*>(params->input),
params->softmax_elements, params->input_stride,
params->output_stride);
}
return HIP_CALL(hipGetLastError());
}

template <typename input_t, typename output_t, typename acc_t>
class SoftmaxTunableOp : public onnxruntime::rocm::tunable::TunableOp<SoftmaxParams<input_t, output_t>> {
template <typename InputT, typename OutputT, typename AccT>
class SoftmaxTunableOp : public onnxruntime::rocm::tunable::TunableOp<SoftmaxParams<InputT, OutputT>> {
public:
SoftmaxTunableOp() {
this->RegisterOp(SoftmaxBlockwiseStaticSelection<input_t, output_t, acc_t>);
this->RegisterOp(SoftmaxBlockwiseOp<input_t, output_t, acc_t, 1>);
this->RegisterOp(SoftmaxBlockwiseOp<input_t, output_t, acc_t, 2>);
this->RegisterOp(SoftmaxBlockwiseOp<input_t, output_t, acc_t, 4>);
this->RegisterOp(SoftmaxBlockwiseOp<input_t, output_t, acc_t, 8>);
this->RegisterOp(SoftmaxBlockwiseOp<input_t, output_t, acc_t, 16>);
this->RegisterOp(SoftmaxBlockwiseStaticSelection<InputT, OutputT, AccT>);
this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 1>);
this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 2>);
this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 4>);
this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 8>);
this->RegisterOp(SoftmaxBlockwiseOp<InputT, OutputT, AccT, 16>);

#ifdef USE_COMPOSABLE_KERNEL
for (auto&& [_, op] : GetCKSoftmaxTypeStringAndOps<input_t, output_t, acc_t>()) {
for (auto&& [_, op] : GetCKSoftmaxTypeStringAndOps<InputT, OutputT, AccT>()) {
ORT_UNUSED_PARAMETER(_);
this->RegisterOp(std::move(op));
}
#endif // USE_COMPOSABLE_KERNEL

// NOTE: the 1st kernel is SoftmaxBlockwise Original implementation.
this->SetDefaultId(0);
}
};

Expand Down

0 comments on commit 10a7350

Please sign in to comment.