diff --git a/src/ocl/convolutionocl.cpp b/src/ocl/convolutionocl.cpp index 29650085d1..2827ce62a8 100644 --- a/src/ocl/convolutionocl.cpp +++ b/src/ocl/convolutionocl.cpp @@ -131,6 +131,18 @@ static inline void ValidateGroupCount(const TensorDescriptor& x, MIOPEN_THROW(miopenStatusBadParm, "Invalid group number"); } +static inline void ValidateWorkspace(Data_t workSpace, const size_t workSpaceSize) +{ + + [[maybe_unused]] bool x = (workSpace != nullptr); + [[maybe_unused]] bool y = (workSpaceSize != 0); + + assert(((x && y) || (!x && !y)) && "workspace pointer and size don't match. Either both should " + "be zero or both should be non-zero"); + + /// \todo could add a check here that workSpace points to GPU memory +} + static Invoker PrepareInvoker(ExecutionContext ctx, const conv::ProblemDescription& problem, const NetworkConfig& config, @@ -258,6 +270,7 @@ void ConvolutionDescriptor::FindConvFwdAlgorithm(Handle& handle, bool exhaustiveSearch) const { MIOPEN_LOG_I("requestAlgoCount = " << requestAlgoCount << ", workspace = " << workSpaceSize); + ValidateWorkspace(workSpace, workSpaceSize); if(x == nullptr || w == nullptr || y == nullptr) MIOPEN_THROW(miopenStatusBadParm, "Buffers cannot be NULL"); if(returnedAlgoCount == nullptr) @@ -492,6 +505,7 @@ void ConvolutionDescriptor::ConvolutionForward(Handle& handle, size_t workSpaceSize) const { MIOPEN_LOG_I("algo = " << algo << ", workspace = " << workSpaceSize); + ValidateWorkspace(workSpace, workSpaceSize); const auto tensors = ConvFwdTensors{xDesc, x, wDesc, w, yDesc, y}; ValidateTensors(tensors); @@ -807,6 +821,7 @@ void ConvolutionDescriptor::ConvolutionForwardImmediate(Handle& handle, const solver::Id solver_id) const { MIOPEN_LOG_I("solver_id = " << solver_id.ToString() << ", workspace = " << workSpaceSize); + ValidateWorkspace(workSpace, workSpaceSize); const auto tensors = ConvFwdTensors{xDesc, x, wDesc, w, yDesc, y}; ValidateTensors(tensors); @@ -841,6 +856,7 @@ void ConvolutionDescriptor::FindConvBwdDataAlgorithm(Handle& handle, bool exhaustiveSearch) const { MIOPEN_LOG_I("requestAlgoCount = " << requestAlgoCount << ", workspace = " << workSpaceSize); + ValidateWorkspace(workSpace, workSpaceSize); if(dx == nullptr || w == nullptr || dy == nullptr) MIOPEN_THROW(miopenStatusBadParm, "Buffers cannot be NULL"); if(returnedAlgoCount == nullptr) @@ -938,6 +954,7 @@ void ConvolutionDescriptor::ConvolutionBackwardData(Handle& handle, size_t workSpaceSize) const { MIOPEN_LOG_I("algo = " << algo << ", workspace = " << workSpaceSize); + ValidateWorkspace(workSpace, workSpaceSize); auto tensors = ConvBwdTensors{dyDesc, dy, wDesc, w, dxDesc, dx}; @@ -1005,6 +1022,7 @@ void ConvolutionDescriptor::ConvolutionBackwardImmediate(Handle& handle, solver::Id solver_id) const { MIOPEN_LOG_I("solver_id = " << solver_id.ToString() << ", workspace = " << workSpaceSize); + ValidateWorkspace(workSpace, workSpaceSize); auto tensors = ConvBwdTensors{dyDesc, dy, wDesc, w, dxDesc, dx}; ValidateTensors(tensors); @@ -1045,6 +1063,7 @@ void ConvolutionDescriptor::FindConvBwdWeightsAlgorithm(Handle& handle, bool exhaustiveSearch) const { MIOPEN_LOG_I("requestAlgoCount = " << requestAlgoCount << ", workspace = " << workSpaceSize); + ValidateWorkspace(workSpace, workSpaceSize); if(x == nullptr || dw == nullptr || dy == nullptr) MIOPEN_THROW(miopenStatusBadParm, "Buffers cannot be NULL"); if(returnedAlgoCount == nullptr) @@ -1140,6 +1159,7 @@ void ConvolutionDescriptor::ConvolutionBackwardWeights(const Handle& handle, size_t workSpaceSize) const { MIOPEN_LOG_I("algo = " << algo << ", workspace = " << workSpaceSize); + ValidateWorkspace(workSpace, workSpaceSize); decltype(auto) tensors = ConvWrwTensors{dyDesc, dy, xDesc, x, dwDesc, dw}; ValidateTensors(tensors); ValidateAlphaBeta(alpha, beta); @@ -1203,6 +1223,7 @@ void ConvolutionDescriptor::ConvolutionWrwImmediate(Handle& handle, solver::Id solver_id) const { MIOPEN_LOG_I("solver_id = " << solver_id.ToString() << ", workspace = " << workSpaceSize); + ValidateWorkspace(workSpace, workSpaceSize); auto tensors = ConvWrwTensors{dyDesc, dy, xDesc, x, dwDesc, dw}; ValidateTensors(tensors);