diff --git a/src/solver/batchnorm/backward_ck.cpp b/src/solver/batchnorm/backward_ck.cpp index a83f52edef..19d346811f 100644 --- a/src/solver/batchnorm/backward_ck.cpp +++ b/src/solver/batchnorm/backward_ck.cpp @@ -154,40 +154,6 @@ static bool CheckCKApplicability(const miopen::batchnorm::ProblemDescription& pr CKArgsBNormBwd>(problem); } -#endif - -bool BnCKBwdBackward::IsApplicable( - [[maybe_unused]] const ExecutionContext& context, - [[maybe_unused]] const miopen::batchnorm::ProblemDescription& bn_problem) const -{ -#if MIOPEN_BACKEND_HIP || MIOPEN_USE_COMPOSABLEKERNEL - if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_CK_BN_BACK{})) - return false; - if(!bn_problem.IsLayoutNHWC()) - return false; - if(!ck_utility::is_ck_whitelist(context.GetStream())) - return false; - if(bn_problem.GetXDesc().GetType() != bn_problem.GetScaleBiasDiffDesc().GetType()) - return false; - - switch(bn_problem.GetXDesc().GetType()) - { - case miopenFloat: return CheckCKApplicability<F32, F32, F32, F32, F32, F32, F32>(bn_problem); - case miopenDouble: return CheckCKApplicability<F64, F64, F64, F64, F64, F64, F64>(bn_problem); - case miopenHalf: return CheckCKApplicability<F16, F32, F32, F32, F16, F32, F32>(bn_problem); - case miopenBFloat16: - return CheckCKApplicability<BF16, F32, F32, F32, BF16, F32, F32>(bn_problem); - case miopenInt32: - case miopenInt8: - case miopenInt8x4: - case miopenBFloat8: - case miopenFloat8: - default: MIOPEN_THROW("BnCKBwdBackward operation does not support this data type"); - } - return false; -#endif -} - template <typename XDataType, typename DxDataType, typename DyDataType, @@ -195,7 +161,7 @@ template <typename XDataType, typename ScaleDataType, typename DscaleDbiasDataType, typename MeanVarDataType> -ConvSolution MakeAnyInvokerFactory(const miopen::batchnorm::ProblemDescription& bn_problem) +static ConvSolution MakeAnyInvokerFactory(const miopen::batchnorm::ProblemDescription& bn_problem) { const auto& valid_kernel_ids = FillValidKernelsIDs<DeviceOpBNBwdPtrs<XDataType, DxDataType, @@ -218,6 +184,39 @@ ConvSolution MakeAnyInvokerFactory(const miopen::batchnorm::ProblemDescription& miopen::batchnorm::BwdInvokeParams>(bn_problem, kernel_id); } +#endif + +bool BnCKBwdBackward::IsApplicable( + [[maybe_unused]] const ExecutionContext& context, + [[maybe_unused]] const miopen::batchnorm::ProblemDescription& bn_problem) const +{ +#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL + if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_CK_BN_BACK{})) + return false; + if(!bn_problem.IsLayoutNHWC()) + return false; + if(!ck_utility::is_ck_supported_hardware(context.GetStream())) + return false; + if(bn_problem.GetXDesc().GetType() != bn_problem.GetScaleBiasDiffDesc().GetType()) + return false; + + switch(bn_problem.GetXDesc().GetType()) + { + case miopenFloat: return CheckCKApplicability<F32, F32, F32, F32, F32, F32, F32>(bn_problem); + case miopenDouble: return CheckCKApplicability<F64, F64, F64, F64, F64, F64, F64>(bn_problem); + case miopenHalf: return CheckCKApplicability<F16, F32, F32, F32, F16, F32, F32>(bn_problem); + case miopenBFloat16: + return CheckCKApplicability<BF16, F32, F32, F32, BF16, F32, F32>(bn_problem); + case miopenInt32: + case miopenInt8: + case miopenInt8x4: + case miopenBFloat8: + case miopenFloat8: break; + } +#endif + return false; +} + ConvSolution BnCKBwdBackward::GetSolution( [[maybe_unused]] const ExecutionContext& context, [[maybe_unused]] const miopen::batchnorm::ProblemDescription& bn_problem) const diff --git a/src/solver/batchnorm/forward_inference_ck.cpp b/src/solver/batchnorm/forward_inference_ck.cpp index 0b8f734037..df863b95c5 100644 --- a/src/solver/batchnorm/forward_inference_ck.cpp +++ b/src/solver/batchnorm/forward_inference_ck.cpp @@ -175,14 +175,11 @@ static void RunCKSolution(const Handle& handle, } #endif -bool BnCKFwdInference::IsApplicable(const ExecutionContext& context, - const miopen::batchnorm::ProblemDescription& bn_problem) const +bool BnCKFwdInference::IsApplicable( + [[maybe_unused]] const ExecutionContext& context, + [[maybe_unused]] const miopen::batchnorm::ProblemDescription& bn_problem) const { -#if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL - std::ignore = context; - std::ignore = bn_problem; - return false; -#else +#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_CK_BN_INFER{})) return false; if(!bn_problem.IsLayoutNHWC()) @@ -202,24 +199,17 @@ bool BnCKFwdInference::IsApplicable(const ExecutionContext& context, case miopenInt8: case miopenInt8x4: // Support discontinued. case miopenFloat8: - case miopenBFloat8: - default: MIOPEN_THROW("Unsupported datatype"); + case miopenBFloat8: break; } - return false; #endif + return false; } -ConvSolution -BnCKFwdInference::GetSolution(const ExecutionContext& context, - const miopen::batchnorm::ProblemDescription& bn_problem) const +ConvSolution BnCKFwdInference::GetSolution( + [[maybe_unused]] const ExecutionContext& context, + [[maybe_unused]] const miopen::batchnorm::ProblemDescription& bn_problem) const { -#if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL - std::ignore = context; - std::ignore = bn_problem; - return {}; -#else - std::ignore = context; - +#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL ConvSolution result; result.invoker_factory = [=](const std::vector<Kernel>& kernels) { std::ignore = kernels; @@ -252,6 +242,8 @@ BnCKFwdInference::GetSolution(const ExecutionContext& context, }; }; return result; +#else + return {}; #endif } diff --git a/src/solver/batchnorm/forward_training_ck.cpp b/src/solver/batchnorm/forward_training_ck.cpp index 5feda645b8..cb95b2fb8b 100644 --- a/src/solver/batchnorm/forward_training_ck.cpp +++ b/src/solver/batchnorm/forward_training_ck.cpp @@ -149,36 +149,6 @@ static bool CheckCKApplicability(const miopen::batchnorm::ProblemDescription& pr MeanVarDataType>, CKArgsBNormFwdTraining>(problem); } -#endif - -bool BnCKFwdTraining::IsApplicable( - [[maybe_unused]] const ExecutionContext& context, - [[maybe_unused]] const miopen::batchnorm::ProblemDescription& bn_problem) const -{ -#if MIOPEN_BACKEND_HIP || MIOPEN_USE_COMPOSABLEKERNEL - if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_CK_BN_FWD_TRAINING{})) - return false; - if(!bn_problem.IsLayoutNHWC()) - return false; - if(!ck_utility::is_ck_whitelist(context.GetStream())) - return false; - - switch(bn_problem.GetXDesc().GetType()) - { - case miopenHalf: return CheckCKApplicability<F16, F16, F32, F16, F16, F32>(bn_problem); - case miopenFloat: return CheckCKApplicability<F32, F32, F32, F32, F32, F32>(bn_problem); - case miopenDouble: return CheckCKApplicability<F64, F64, F64, F64, F64, F64>(bn_problem); - case miopenBFloat16: return CheckCKApplicability<BF16, BF16, F32, BF16, BF16, F32>(bn_problem); - case miopenInt32: - case miopenInt8: - case miopenInt8x4: - case miopenBFloat8: - case miopenFloat8: - default: MIOPEN_THROW("BnCKFwdTraining operation does not support this data type"); - } - return false; -#endif -} template <typename XDataType, typename YDataType, @@ -186,7 +156,7 @@ template <typename XDataType, typename ScaleDataType, typename BiasDataType, typename MeanVarDataType> -ConvSolution MakeAnyInvokerFactory(const miopen::batchnorm::ProblemDescription& bn_problem) +static ConvSolution MakeAnyInvokerFactory(const miopen::batchnorm::ProblemDescription& bn_problem) { const auto& valid_kernel_ids = FillValidKernelsIDs<DeviceOpBNFwdTrainingPtrs<XDataType, YDataType, @@ -206,6 +176,35 @@ ConvSolution MakeAnyInvokerFactory(const miopen::batchnorm::ProblemDescription& CKArgsBNormFwdTraining, miopen::batchnorm::InvokeParams>(bn_problem, kernel_id); } +#endif + +bool BnCKFwdTraining::IsApplicable( + [[maybe_unused]] const ExecutionContext& context, + [[maybe_unused]] const miopen::batchnorm::ProblemDescription& bn_problem) const +{ +#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL + if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_CK_BN_FWD_TRAINING{})) + return false; + if(!bn_problem.IsLayoutNHWC()) + return false; + if(!ck_utility::is_ck_supported_hardware(context.GetStream())) + return false; + + switch(bn_problem.GetXDesc().GetType()) + { + case miopenHalf: return CheckCKApplicability<F16, F16, F32, F16, F16, F32>(bn_problem); + case miopenFloat: return CheckCKApplicability<F32, F32, F32, F32, F32, F32>(bn_problem); + case miopenDouble: return CheckCKApplicability<F64, F64, F64, F64, F64, F64>(bn_problem); + case miopenBFloat16: return CheckCKApplicability<BF16, BF16, F32, BF16, BF16, F32>(bn_problem); + case miopenInt32: + case miopenInt8: + case miopenInt8x4: + case miopenBFloat8: + case miopenFloat8: break; + } +#endif + return false; +} ConvSolution BnCKFwdTraining::GetSolution( [[maybe_unused]] const ExecutionContext& context,