From 08ff7596df9ee32604b130a699addae417754857 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Fri, 1 Oct 2021 13:17:43 -0700 Subject: [PATCH] Make Conv and Deconv cuDNN implementation use v8 API This copies changes I previously implemented in the container. Dick Carter made a number of improvements and fixes (memory use during auto-tuning, proper time calculation and time limit cutoff in auto-tuning sampler, etc). --- CMakeLists.txt | 4 +- src/common/cuda/utils.h | 11 +- src/common/cudnn_cxx.cc | 307 +++++++ src/common/cudnn_cxx.h | 294 +++++++ src/operator/cudnn_ops.cc | 599 +++++++++++++ src/operator/cudnn_ops.h | 191 ++++ src/operator/nn/convolution.cu | 206 ++--- src/operator/nn/cudnn/cudnn_convolution-inl.h | 831 ------------------ src/operator/nn/deconvolution.cu | 167 ++-- 9 files changed, 1559 insertions(+), 1051 deletions(-) create mode 100644 src/common/cudnn_cxx.cc create mode 100644 src/common/cudnn_cxx.h create mode 100644 src/operator/cudnn_ops.cc create mode 100644 src/operator/cudnn_ops.h delete mode 100644 src/operator/nn/cudnn/cudnn_convolution-inl.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 882e8b09d404..44bd744f326a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -112,7 +112,7 @@ endif() message(STATUS "CMake version '${CMAKE_VERSION}' using generator '${CMAKE_GENERATOR}'") if(USE_CUDA) - cmake_minimum_required(VERSION 3.13.2) # CUDA 10 (Turing) detection available starting 3.13.2 + cmake_minimum_required(VERSION 3.17.0) include(CheckLanguage) check_language(CUDA) if (NOT CMAKE_CUDA_COMPILER AND UNIX AND EXISTS "/usr/local/cuda/bin/nvcc") @@ -121,7 +121,7 @@ if(USE_CUDA) "Please fix your cuda installation: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#mandatory-post") endif() enable_language(CUDA) - set(CMAKE_CUDA_STANDARD 14) + set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CUDA_STANDARD_REQUIRED ON) endif() diff --git a/src/common/cuda/utils.h b/src/common/cuda/utils.h index c1fde5f571b1..792f7319def8 100644 --- a/src/common/cuda/utils.h +++ b/src/common/cuda/utils.h @@ -645,11 +645,12 @@ static_assert(CUDNN_PATCHLEVEL < 100 && CUDNN_MINOR < 10, "Compiled-against cuDNN version " CUDNN_VERSION_AS_STRING \ " is too old, please upgrade system to version " QUOTEVALUE(min_version) " or later.") -#define CUDNN_CALL(func) \ - { \ - cudnnStatus_t e = (func); \ - CHECK_EQ(e, CUDNN_STATUS_SUCCESS) << "cuDNN: " << cudnnGetErrorString(e); \ - } +#define CUDNN_CALL_S(f, s) \ + if (cudnnStatus_t unclash_cxx_e = (f); unclash_cxx_e != CUDNN_STATUS_SUCCESS) \ + LOG(s) << "cuDNN: " << cudnnGetErrorString(unclash_cxx_e); + +#define CUDNN_CALL(f) CUDNN_CALL_S(f, FATAL) +#define CUDNN_CALL_NONFATAL(f) CUDNN_CALL_S(f, WARNING) #define CUTENSOR_CALL(func) \ { \ diff --git a/src/common/cudnn_cxx.cc b/src/common/cudnn_cxx.cc new file mode 100644 index 000000000000..ac5b60f56af0 --- /dev/null +++ b/src/common/cudnn_cxx.cc @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2021 by Contributors + * \file cudnn_cxx.cc + */ +#include "cudnn_cxx.h" + +#include +#if MXNET_USE_CUDNN == 1 + +#include +#include +#include +#include + +namespace mxnet { +namespace cudnn_cxx { + +Descriptor Make(cudnnBackendDescriptorType_t type) { + cudnnBackendDescriptor_t desc{}; + CUDNN_CALL(cudnnBackendCreateDescriptor(type, &desc)); + return Descriptor(desc); +} + +std::vector MakeRawDescriptors(size_t n, + cudnnBackendDescriptorType_t type) { + std::vector ret(n); + for (auto& d : ret) CUDNN_CALL(cudnnBackendCreateDescriptor(type, &d)); + return ret; +} + +void SetAttr(const Descriptor& desc, cudnnBackendAttributeName_t name, const Descriptor& val) { + auto raw = val.get(); + CUDNN_CALL(cudnnBackendSetAttribute(desc.get(), name, CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &raw)); +} + +void SetAttr(const Descriptor& desc, cudnnBackendAttributeName_t name, const WeakDescriptor& val) { + auto raw = val.get(); + CUDNN_CALL(cudnnBackendSetAttribute(desc.get(), name, CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &raw)); +} + +void SetAttr(const Descriptor& desc, cudnnBackendAttributeName_t name, + const std::vector& val) { + std::vector raw(val.size()); + std::transform(val.begin(), val.end(), raw.begin(), [](const Descriptor& d) { return d.get(); }); + CUDNN_CALL(cudnnBackendSetAttribute(desc.get(), name, CUDNN_TYPE_BACKEND_DESCRIPTOR, raw.size(), + &raw[0])); +} + +Descriptor GetAttr(const Descriptor& desc, cudnnBackendAttributeName_t name, + cudnnBackendDescriptorType_t type) { + cudnnBackendDescriptor_t ret{}; + CUDNN_CALL(cudnnBackendCreateDescriptor(type, &ret)); + int64_t count = 0; + CUDNN_CALL( + cudnnBackendGetAttribute(desc.get(), name, CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &count, &ret)); + CHECK_EQ(count, 1); + return Descriptor(ret); +} + +std::vector GetAllAttrs(const Descriptor& desc, cudnnBackendAttributeName_t name, + cudnnBackendDescriptorType_t type) { + int64_t count = 0; + CUDNN_CALL(cudnnBackendGetAttribute(desc.get(), name, CUDNN_TYPE_BACKEND_DESCRIPTOR, 0, &count, + nullptr)); + auto raw = MakeRawDescriptors(count, type); + CUDNN_CALL(cudnnBackendGetAttribute(desc.get(), name, CUDNN_TYPE_BACKEND_DESCRIPTOR, raw.size(), + &count, raw.data())); + + // TODO(vcherepanov): uncomment when cuDNN fix 3313649 + // CHECK_EQ(count, raw.size()); + // std::vector ret(raw.begin(), raw.end()); + CHECK_LE(count, raw.size()); + std::vector ret(raw.begin(), raw.begin() + count); + for (size_t i = count; i < raw.size(); ++i) CUDNN_CALL(cudnnBackendDestroyDescriptor(raw[i])); + return ret; +} + +std::vector GetSomeAttrs(size_t max_n, const Descriptor& desc, + cudnnBackendAttributeName_t name, + cudnnBackendDescriptorType_t type) { + auto raw = MakeRawDescriptors(max_n, type); + int64_t count = 0; + CUDNN_CALL(cudnnBackendGetAttribute(desc.get(), name, CUDNN_TYPE_BACKEND_DESCRIPTOR, raw.size(), + &count, raw.data())); + std::vector ret(count); + size_t i = 0; + for (; i < count; ++i) ret[i] = Descriptor(raw[i]); + for (; i < max_n; ++i) CUDNN_CALL(cudnnBackendDestroyDescriptor(raw[i])); + return ret; +} + +std::vector PackedStrides(const std::vector& order, + const std::vector& dims) { + CHECK_EQ(order.size(), dims.size()); + std::vector ret(dims.size(), 1); + for (size_t i = dims.size() - 1; i--;) ret[order[i]] = dims[order[i + 1]] * ret[order[i + 1]]; + return ret; +} + +Sampler MakeAvgSampler(size_t n, float max_cutoff_msec, size_t warmups) { + size_t warmups_performed = 0; + size_t k = 0; + float s = 0.0f; + if (n < 1) n = 1; + + return [n, max_cutoff_msec, warmups, warmups_performed, k, s](float x) mutable { + if (warmups_performed < warmups && x < max_cutoff_msec) { + warmups_performed++; + } else { + // Add this sample to the average calculation + s += x; + k++; + } + bool keep_going = k < n && x < max_cutoff_msec; + return keep_going ? std::nullopt : std::optional(s / k); + }; +} + +std::vector GetPlans(cudnnBackendHeurMode_t h_mode, cudnnHandle_t handle, + const Descriptor& op_graph, size_t workspace_limit, + size_t* max_workspace, + const std::unordered_set& excl_engines, + const std::vector& req_numeric, + const std::vector& excl_numeric, + const std::vector& req_behavior, + const std::vector& excl_behavior, + bool verbose_filter) { + auto heur = + MakeFinalized(CUDNN_BACKEND_ENGINEHEUR_DESCRIPTOR, CUDNN_ATTR_ENGINEHEUR_OPERATION_GRAPH, + op_graph, CUDNN_ATTR_ENGINEHEUR_MODE, h_mode); + auto cfgs = GetAllAttrs(heur, CUDNN_ATTR_ENGINEHEUR_RESULTS, CUDNN_BACKEND_ENGINECFG_DESCRIPTOR); + std::vector plans; + if (max_workspace) *max_workspace = 0; + for (const auto& cfg : cfgs) { + auto plan = Make(CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR, CUDNN_ATTR_EXECUTION_PLAN_HANDLE, + handle, CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG, cfg); + auto err = cudnnBackendFinalize(plan.get()); + if (err == CUDNN_STATUS_NOT_SUPPORTED || err == CUDNN_STATUS_ARCH_MISMATCH) continue; + if (err != CUDNN_STATUS_SUCCESS) { + LOG(WARNING) << "Unexpected cuDNN status: " << err << ": " << cudnnGetErrorString(err); + continue; + } + auto workspace = GetAttr(plan, CUDNN_ATTR_EXECUTION_PLAN_WORKSPACE_SIZE); + if (workspace_limit < workspace) { + if (verbose_filter) LOG(INFO) << " Plan " << PlanStr(plan) << " exceeds workspace limit"; + continue; + } + auto engine = GetAttr(cfg, CUDNN_ATTR_ENGINECFG_ENGINE, CUDNN_BACKEND_ENGINE_DESCRIPTOR); + if (excl_engines.count(GetAttr(engine, CUDNN_ATTR_ENGINE_GLOBAL_INDEX))) { + if (verbose_filter) LOG(INFO) << " Plan " << PlanStr(plan) << " excluded by engine"; + continue; + } + auto numerical = GetSomeAttrs( + CUDNN_NUMERICAL_NOTE_TYPE_COUNT, engine, CUDNN_ATTR_ENGINE_NUMERICAL_NOTE); + if (!IsCompatible(numerical, req_numeric, excl_numeric)) { + if (verbose_filter) LOG(INFO) << " Plan " << PlanStr(plan) << " has incompatible numerics"; + continue; + } + auto behavior = GetSomeAttrs(CUDNN_BEHAVIOR_NOTE_TYPE_COUNT, engine, + CUDNN_ATTR_ENGINE_BEHAVIOR_NOTE); + if (!IsCompatible(behavior, req_behavior, excl_behavior)) { + if (verbose_filter) LOG(INFO) << " Plan " << PlanStr(plan) << " has incompatible behavior"; + continue; + } + plans.push_back(std::move(plan)); + if (max_workspace) *max_workspace = std::max(*max_workspace, static_cast(workspace)); + } + return plans; +} + +std::vector FindTopPlans(std::vector&& plans, size_t max_results, + cudnnHandle_t handle, const Descriptor& var_pack, + Sampler sampler) { + // We're about to perform kernel timings, so we need to quiet the system by grabbing + // the Storage lock. Concurrent cudaMalloc's can disrupt the accurate timing + // measurements of the algos, and can prevent the cuda driver's proper freeing + // of temporary workspace allocations. Grabbing the lock might also + // impede other threads from launching work on the GPU. + std::lock_guard lock(Storage::Get()->GetMutex(Context::kGPU)); + std::array ev; + for (auto& ee : ev) CUDA_CALL(cudaEventCreate(&ee)); + auto cmp = [](const FindResult& lhs, const FindResult& rhs) { return lhs.time < rhs.time; }; + cudaStream_t stream{}; + CUDNN_CALL(cudnnGetStream(handle, &stream)); + std::vector h; + for (size_t i = 0; i < plans.size(); ++i) { + auto&& plan = plans[i]; + // Make a copy of the unused sampler for each plan's timing. Timed warm-up + // runs are handled by the sampler to enable early loop exit for slow kernels. + auto sampler_copy = sampler; + for (;;) { + CUDA_CALL(cudaEventRecord(ev[0], stream)); + CUDNN_CALL(cudnnBackendExecute(handle, plan.get(), var_pack.get())); + CUDA_CALL(cudaEventRecord(ev[1], stream)); + CUDA_CALL(cudaEventSynchronize(ev[1])); + float t = 0.0f; + CUDA_CALL(cudaEventElapsedTime(&t, ev[0], ev[1])); + if (auto r = sampler_copy(t); r) { + auto time_to_record = r.value(); + if (h.size() == max_results) { + if (time_to_record < h[0].time) { + std::pop_heap(h.begin(), h.end(), cmp); + h.back() = {std::move(plan), i, time_to_record}; + std::push_heap(h.begin(), h.end(), cmp); + } + } else { + h.push_back({std::move(plan), i, time_to_record}); + std::push_heap(h.begin(), h.end(), cmp); + } + break; + } + } + } + for (auto& ee : ev) CUDA_CALL(cudaEventDestroy(ee)); + std::sort_heap(h.begin(), h.end(), cmp); + return h; +} + +std::string NoteStr(cudnnBackendNumericalNote_t note) { + std::unordered_map m{ + {CUDNN_NUMERICAL_NOTE_TENSOR_CORE, "tc"}, + {CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS, "dci"}, + {CUDNN_NUMERICAL_NOTE_REDUCED_PRECISION_REDUCTION, "rp"}, + {CUDNN_NUMERICAL_NOTE_FFT, "fft"}, + {CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC, "nd"}, + {CUDNN_NUMERICAL_NOTE_WINOGRAD, "w"}, + }; + auto it = m.find(note); + return it != m.end() ? it->second : std::to_string(note); +} + +std::string KnobStr(cudnnBackendKnobType_t knob) { + std::unordered_map m{ + {CUDNN_KNOB_TYPE_SPLIT_K, "split_k"}, + {CUDNN_KNOB_TYPE_SWIZZLE, "swizzle"}, + {CUDNN_KNOB_TYPE_TILE_SIZE, "tile_size"}, + {CUDNN_KNOB_TYPE_USE_TEX, "use_tex"}, + {CUDNN_KNOB_TYPE_EDGE, "edge"}, + {CUDNN_KNOB_TYPE_KBLOCK, "kblock"}, + {CUDNN_KNOB_TYPE_LDGA, "ldga"}, + {CUDNN_KNOB_TYPE_LDGB, "ldgb"}, + {CUDNN_KNOB_TYPE_CHUNK_K, "chunk_k"}, + {CUDNN_KNOB_TYPE_SPLIT_H, "split_h"}, + {CUDNN_KNOB_TYPE_WINO_TILE, "wino_tile"}, + {CUDNN_KNOB_TYPE_MULTIPLY, "multiply"}, + {CUDNN_KNOB_TYPE_SPLIT_K_BUF, "split_k_buf"}, + {CUDNN_KNOB_TYPE_TILEK, "tilek"}, + {CUDNN_KNOB_TYPE_STAGES, "stages"}, + {CUDNN_KNOB_TYPE_REDUCTION_MODE, "reduction_mode"}, + {CUDNN_KNOB_TYPE_CTA_SPLIT_K_MODE, "cta_split_k_mode"}, + {CUDNN_KNOB_TYPE_SPLIT_K_SLC, "split_k_slc"}, + {CUDNN_KNOB_TYPE_IDX_MODE, "idx_mode"}, + {CUDNN_KNOB_TYPE_SLICED, "sliced"}, + {CUDNN_KNOB_TYPE_SPLIT_RS, "split_rs"}, + {CUDNN_KNOB_TYPE_SINGLEBUFFER, "singlebuffer"}, + {CUDNN_KNOB_TYPE_LDGC, "ldgc"}, + {CUDNN_KNOB_TYPE_SPECFILT, "specfilt"}, + {CUDNN_KNOB_TYPE_KERNEL_CFG, "kernel_cfg"}, + }; + auto it = m.find(knob); + return it != m.end() ? it->second : std::to_string(knob); +} + +std::string PlanStr(const Descriptor& plan) { + auto wks = GetAttr(plan, CUDNN_ATTR_EXECUTION_PLAN_WORKSPACE_SIZE); + auto cfg = + GetAttr(plan, CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG, CUDNN_BACKEND_ENGINECFG_DESCRIPTOR); + auto engine = GetAttr(cfg, CUDNN_ATTR_ENGINECFG_ENGINE, CUDNN_BACKEND_ENGINE_DESCRIPTOR); + auto engine_idx = GetAttr(engine, CUDNN_ATTR_ENGINE_GLOBAL_INDEX); + std::ostringstream ss; + ss << "eng:" << engine_idx << " wksp:" << wks; + auto notes = GetSomeAttrs(CUDNN_NUMERICAL_NOTE_TYPE_COUNT, engine, + CUDNN_ATTR_ENGINE_NUMERICAL_NOTE); + for (auto note : notes) ss << " " << NoteStr(note); + auto choices = GetSomeAttrs(CUDNN_KNOB_TYPE_COUNTS, cfg, CUDNN_ATTR_ENGINECFG_KNOB_CHOICES, + CUDNN_BACKEND_KNOB_CHOICE_DESCRIPTOR); + for (const auto& choice : choices) { + auto type = GetAttr(choice, CUDNN_ATTR_KNOB_CHOICE_KNOB_TYPE); + auto val = GetAttr(choice, CUDNN_ATTR_KNOB_CHOICE_KNOB_VALUE); + ss << " " << KnobStr(type) << ":" << val; + } + return ss.str(); +} + +} // namespace cudnn_cxx +} // namespace mxnet + +#endif // MXNET_USE_CUDNN == 1 diff --git a/src/common/cudnn_cxx.h b/src/common/cudnn_cxx.h new file mode 100644 index 000000000000..d66aa0bd989f --- /dev/null +++ b/src/common/cudnn_cxx.h @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2021 by Contributors + * \file cudnn_cxx.h + * \brief Convenience utilities to make coding against cuDNN v8 API less verbose + */ +#ifndef MXNET_COMMON_CUDNN_CXX_H_ +#define MXNET_COMMON_CUDNN_CXX_H_ + +#include +#if MXNET_USE_CUDNN == 1 + +#include +#include +#include +#include +#include +#include // NOLINT(build/include_order) +#include +#include +#include +#include + +#include "cuda/utils.h" + +namespace mxnet { +namespace cudnn_cxx { + +struct DescriptorDestroyer { + using pointer = cudnnBackendDescriptor_t; + + void operator()(cudnnBackendDescriptor_t desc) { + CUDNN_CALL_NONFATAL(cudnnBackendDestroyDescriptor(desc)); + } +}; + +using Descriptor = std::unique_ptr; + +struct WeakDescriptor { + cudnnBackendDescriptor_t desc = nullptr; + + explicit WeakDescriptor(const Descriptor& other) : desc(other.get()) {} + cudnnBackendDescriptor_t get() const { return desc; } +}; + +template +struct AttrType; + +template <> +struct AttrType { + static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_INT64; +}; + +template <> +struct AttrType { + static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_VOID_PTR; +}; + +template <> +struct AttrType { + static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_FLOAT; +}; + +template <> +struct AttrType { + static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_DOUBLE; +}; + +template <> +struct AttrType { + static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_HANDLE; +}; + +template <> +struct AttrType { + static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_BOOLEAN; +}; + +template <> +struct AttrType { + static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_DATA_TYPE; +}; + +template <> +struct AttrType { + static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_CONVOLUTION_MODE; +}; + +template <> +struct AttrType { + static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_NAN_PROPOGATION; +}; + +template <> +struct AttrType { + static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_POINTWISE_MODE; +}; + +template <> +struct AttrType { + static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_REDUCTION_OPERATOR_TYPE; +}; + +template <> +struct AttrType { + static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_HEUR_MODE; +}; + +template <> +struct AttrType { + static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_NUMERICAL_NOTE; +}; + +template <> +struct AttrType { + static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_BEHAVIOR_NOTE; +}; + +template <> +struct AttrType { + static constexpr cudnnBackendAttributeType_t type = CUDNN_TYPE_KNOB_TYPE; +}; + +void SetAttr(const Descriptor& desc, cudnnBackendAttributeName_t name, const Descriptor& val); +void SetAttr(const Descriptor& desc, cudnnBackendAttributeName_t name, const WeakDescriptor& val); +void SetAttr(const Descriptor& desc, cudnnBackendAttributeName_t name, + const std::vector& val); + +template +void SetAttr(const Descriptor& desc, cudnnBackendAttributeName_t name, T val) { + CUDNN_CALL(cudnnBackendSetAttribute(desc.get(), name, AttrType::type, 1, &val)); +} + +template +void SetAttr(const Descriptor& desc, cudnnBackendAttributeName_t name, const std::vector& val) { + CUDNN_CALL(cudnnBackendSetAttribute(desc.get(), name, AttrType::type, val.size(), &val[0])); +} + +template +void SetAttr(const Descriptor& desc, cudnnBackendAttributeName_t name, + const std::array& val) { + CUDNN_CALL(cudnnBackendSetAttribute(desc.get(), name, AttrType::type, val.size(), &val[0])); +} + +inline void SetAttrs(const Descriptor& desc) {} + +template +void SetAttrs(const Descriptor& desc, cudnnBackendAttributeName_t name, T&& val, Attrs&&... rest) { + SetAttr(desc, name, std::forward(val)); + SetAttrs(desc, std::forward(rest)...); +} + +std::vector MakeRawDescriptors(size_t n, + cudnnBackendDescriptorType_t type); + +Descriptor Make(cudnnBackendDescriptorType_t type); + +template +Descriptor Make(cudnnBackendDescriptorType_t type, Attrs&&... attrs) { + auto desc = Make(type); + SetAttrs(desc, std::forward(attrs)...); + return desc; +} + +template +Descriptor MakeFinalized(cudnnBackendDescriptorType_t type, Attrs&&... attrs) { + auto desc = Make(type, std::forward(attrs)...); + CUDNN_CALL(cudnnBackendFinalize(desc.get())); + return desc; +} + +template +T GetAttr(const Descriptor& desc, cudnnBackendAttributeName_t name) { + T ret{}; + int64_t ret_count = 0; + CUDNN_CALL(cudnnBackendGetAttribute(desc.get(), name, AttrType::type, 1, &ret_count, &ret)); + CHECK_EQ(ret_count, 1); + return ret; +} + +template +std::vector GetAllAttrs(const Descriptor& desc, cudnnBackendAttributeName_t name) { + int64_t count = 0; + CUDNN_CALL(cudnnBackendGetAttribute(desc.get(), name, AttrType::type, 0, &count, nullptr)); + std::vector ret(count); + CUDNN_CALL(cudnnBackendGetAttribute(desc.get(), name, AttrType::type, ret.size(), &count, + ret.data())); + return ret; +} + +template +std::vector GetSomeAttrs(size_t max_n, const Descriptor& desc, + cudnnBackendAttributeName_t name) { + int64_t count = 0; + std::vector ret(max_n); + CUDNN_CALL(cudnnBackendGetAttribute(desc.get(), name, AttrType::type, ret.size(), &count, + ret.data())); + ret.resize(count); + return ret; +} + +Descriptor GetAttr(const Descriptor& desc, cudnnBackendAttributeName_t name, + cudnnBackendDescriptorType_t type); + +std::vector GetAllAttrs(const Descriptor& desc, cudnnBackendAttributeName_t name, + cudnnBackendDescriptorType_t type); + +std::vector GetSomeAttrs(size_t max_n, const Descriptor& desc, + cudnnBackendAttributeName_t name, + cudnnBackendDescriptorType_t type); + +// Order sets layout, as a permutation of dims, with N,C, being identity. +std::vector PackedStrides(const std::vector& order, + const std::vector& dims); + +// Given an engine config's `notes`, return whether that config is compatible, i.e. does +// the config have all of the required notes and none of the notes that are being excluded. +template +inline bool IsCompatible(const std::vector& notes, const std::vector& require_notes, + const std::vector& exclude_notes) { + for (auto rn : require_notes) { + if (auto it = std::find(notes.begin(), notes.end(), rn); it == notes.end()) { + return false; + } + } + for (auto en : exclude_notes) { + if (auto it = std::find(notes.begin(), notes.end(), en); it != notes.end()) { + return false; + } + } + return true; +} + +// Execution plans are returned in the order of cuDNN heurstics, i.e. from best to worst. +// - max_workspace is an out parameter - the maximum workspace requirement among returned plans, +// may be nullptr if not needed. +std::vector GetPlans(cudnnBackendHeurMode_t h_mode, cudnnHandle_t handle, + const Descriptor& op_graph, size_t workspace_limit, + size_t* max_workspace, + const std::unordered_set& excl_engines, + const std::vector& req_numeric, + const std::vector& excl_numeric, + const std::vector& req_behavior, + const std::vector& excl_behavior, + bool verbose_filter); + +// Defines a sampling algorithm. +// Returns an aggregate value, to be used as a metric for time comparison, or std::nullopt to +// perform another time measurement. +using Sampler = std::function(float)>; + +// Return a sampler that after `n` trials returns the average. +// Before tallying trials, `warmups` trials are first ignored. +// If ever a trial that exceeds `max_cutoff_msec` is encountered (even during warmup), +// that trial is tallied and the sampling ends with the then-current trial average. +Sampler MakeAvgSampler(size_t n, float max_cutoff_msec = 1000.0, size_t warmups = 1); + +struct FindResult { + Descriptor plan; + size_t heur_i; + float time; +}; + +// Executes and times the plans. The results are returned in the order from best to worst. +std::vector FindTopPlans(std::vector&& plans, size_t max_results, + cudnnHandle_t handle, const Descriptor& var_pack, + Sampler sampler); + +std::string PlanStr(const Descriptor& plan); + +} // namespace cudnn_cxx +} // namespace mxnet + +#endif // MXNET_USE_CUDNN == 1 + +#endif // MXNET_COMMON_CUDNN_CXX_H_ diff --git a/src/operator/cudnn_ops.cc b/src/operator/cudnn_ops.cc new file mode 100644 index 000000000000..617161e83964 --- /dev/null +++ b/src/operator/cudnn_ops.cc @@ -0,0 +1,599 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2021 by Contributors + * \file cudnn_ops.cc + * \brief cuDNN v8 ops + */ + +#include "cudnn_ops.h" + +#include +#if MXNET_USE_CUDNN == 1 + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mxnet { +namespace op { + +using cudnn_cxx::Descriptor; +using cudnn_cxx::GetAttr; +using cudnn_cxx::GetSomeAttrs; +using cudnn_cxx::MakeAvgSampler; +using cudnn_cxx::MakeFinalized; +using cudnn_cxx::PackedStrides; +using cudnn_cxx::PlanStr; +using cudnn_cxx::IsCompatible; + +namespace cudnn { + +cudnnDataType_t CudnnType(mshadow::TypeFlag dtype) { + static std::unordered_map type_map{ + {mshadow::kFloat32, CUDNN_DATA_FLOAT}, + {mshadow::kFloat64, CUDNN_DATA_DOUBLE}, + {mshadow::kFloat16, CUDNN_DATA_HALF}, + {mshadow::kUint8, CUDNN_DATA_UINT8}, + {mshadow::kInt8, CUDNN_DATA_INT8}, + {mshadow::kInt32, CUDNN_DATA_INT32}, + {mshadow::kInt64, CUDNN_DATA_INT64}, + }; + auto it = type_map.find(dtype); + CHECK(it != type_map.end()) << "Unsupported type: " << dtype; + return it->second; +} + +std::vector LayoutInfo::Order() const { + std::vector ret(n_space_dims + 2); + std::iota(ret.begin(), ret.end(), 0); + if (channel_last) std::rotate(ret.begin() + 1, ret.begin() + 2, ret.end()); + return ret; +} + +size_t LayoutInfo::ChannelIdx() const { return channel_last ? 1 + n_space_dims : 1; } + +std::vector LayoutInfo::Strides(const std::vector& dims) const { + return PackedStrides(Order(), dims); +} + +LayoutInfo GetLayoutInfo(mshadow::LayoutFlag layout) { + static std::unordered_map layout_map{ + {mshadow::kNCW, {1, false}}, + {mshadow::kNWC, {1, true}}, + {mshadow::kNCHW, {2, false}}, + {mshadow::kNHWC, {2, true}}, + {mshadow::kNCDHW, {3, false}}, + {mshadow::kNDHWC, {3, true}}, + }; + auto it = layout_map.find(layout); + CHECK(it != layout_map.end()) << "Unsupported layout: " << layout; + return it->second; +} + +TShape ExpandChannelDims(mshadow::LayoutFlag layout, int c) { + auto li = GetLayoutInfo(layout); + std::vector dims(li.n_space_dims + 2, 1); + dims[li.ChannelIdx()] = c; + return TShape(dims.begin(), dims.end()); +} + +std::vector ReverseOrder(const std::vector& o) { + std::vector ret(o.size()); + for (size_t i = 0; i < ret.size(); ++i) ret[o[i]] = i; + return ret; +} + +std::vector RequireNumerics() { + std::vector ret; + return ret; +} + +std::vector ExcludeNumerics() { + std::vector ret; + if (!dmlc::GetEnv("MXNET_CUDA_ALLOW_TENSOR_CORE", true)) + ret.push_back(CUDNN_NUMERICAL_NOTE_TENSOR_CORE); + if (!dmlc::GetEnv("MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION", false)) + ret.push_back(CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS); + if (!dmlc::GetEnv("MXNET_CUDNN_ALLOW_REDUCED_PRECISION_REDUCTION", true)) + ret.push_back(CUDNN_NUMERICAL_NOTE_REDUCED_PRECISION_REDUCTION); + if (!dmlc::GetEnv("MXNET_CUDNN_ALLOW_FFT", true)) + ret.push_back(CUDNN_NUMERICAL_NOTE_FFT); + if (dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", false)) + ret.push_back(CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC); + if (!dmlc::GetEnv("MXNET_CUDNN_ALLOW_WINOGRAD", true)) + ret.push_back(CUDNN_NUMERICAL_NOTE_WINOGRAD); + return ret; +} + +Descriptor MakeTensorDesc(int64_t uid, cudnnDataType_t dtype, const std::vector& dims, + const std::vector& strides, bool is_virtual) { + int64_t alignment = 16; // TODO(vcherepanov): ? + return MakeFinalized(CUDNN_BACKEND_TENSOR_DESCRIPTOR, + CUDNN_ATTR_TENSOR_UNIQUE_ID, uid, + CUDNN_ATTR_TENSOR_DATA_TYPE, dtype, + CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT, alignment, + CUDNN_ATTR_TENSOR_DIMENSIONS, dims, + CUDNN_ATTR_TENSOR_STRIDES, strides, + CUDNN_ATTR_TENSOR_IS_VIRTUAL, is_virtual); +} + +Descriptor MakeTensorDesc(int64_t uid, const TBlob& blob, const LayoutInfo& li, bool expand_1d, + bool is_virtual) { + std::vector dims(blob.shape_.ndim()); + CHECK_EQ(dims.size(), li.n_space_dims + 2); + auto rev_order = ReverseOrder(li.Order()); + for (size_t i = 0; i < dims.size(); ++i) dims[i] = blob.shape_[rev_order[i]]; + auto strides = li.Strides(dims); + if (li.n_space_dims == 1 && expand_1d) { + dims.insert(dims.begin() + 2, 1); + std::vector order(dims.size()); + std::iota(order.begin(), order.end(), 0); + if (li.channel_last) std::rotate(order.begin() + 1, order.begin() + 2, order.end()); + strides = PackedStrides(order, dims); + } + return MakeTensorDesc(uid, CudnnType(static_cast(blob.type_flag_)), dims, + strides, is_virtual); +} + +Descriptor MakeCTensorDescExpandDims(int64_t uid, const TBlob& b, const LayoutInfo& li, + bool is_virtual) { + std::vector dims(li.n_space_dims + 2, 1); + dims[1] = b.shape_[0]; + auto dtype = CudnnType(static_cast(b.type_flag_)); + return MakeTensorDesc(uid, dtype, dims, li.Strides(dims), is_virtual); +} + +Descriptor MakeConvDesc(const ConvParam& param, mshadow::TypeFlag dtype) { + int64_t sdims = param.kernel.ndim(); + std::vector stride(param.stride.begin(), param.stride.end()); + std::vector dilate(param.dilate.begin(), param.dilate.end()); + std::vector pad(param.pad.begin(), param.pad.end()); + + auto comp_type = CudnnType(dtype); + if (comp_type == CUDNN_DATA_HALF) comp_type = CUDNN_DATA_FLOAT; + + if (sdims == 1) { + // TODO(vcherepanov): remove this once cuDNN properly supports 1D convolutions. + // For now, making spacial dims 2D: 1 x W. + ++sdims; + stride.insert(stride.begin(), 1); + dilate.insert(dilate.begin(), 1); + pad.insert(pad.begin(), 0); + } + return MakeFinalized(CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR, + CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS, sdims, + CUDNN_ATTR_CONVOLUTION_COMP_TYPE, comp_type, + CUDNN_ATTR_CONVOLUTION_CONV_MODE, CUDNN_CROSS_CORRELATION, + CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES, stride, + CUDNN_ATTR_CONVOLUTION_DILATIONS, dilate, + CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS, pad, + CUDNN_ATTR_CONVOLUTION_POST_PADDINGS, pad); +} + +Descriptor MakeConvFwdOp(const Descriptor& conv, const Descriptor& x, const Descriptor& w, + const Descriptor& y, bool add_to) { + auto ret = Make(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC, conv, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X, x, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W, w, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y, y); + if (GetAttr(x, CUDNN_ATTR_TENSOR_DATA_TYPE) == CUDNN_DATA_DOUBLE) { + SetAttrs(ret, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA, 1.0, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA, add_to ? 1.0 : 0.0); + } else { + SetAttrs(ret, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA, 1.0f, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA, add_to ? 1.0f : 0.0f); + } + CUDNN_CALL(cudnnBackendFinalize(ret.get())); + return ret; +} + +Descriptor MakeConvDgradOp(const Descriptor& conv, const Descriptor& w, const Descriptor& dy, + const Descriptor& dx, bool add_to) { + auto ret = Make(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC, conv, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W, w, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY, dy, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX, dx); + if (GetAttr(w, CUDNN_ATTR_TENSOR_DATA_TYPE) == CUDNN_DATA_DOUBLE) { + SetAttrs(ret, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA, 1.0, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA, add_to ? 1.0 : 0.0); + } else { + SetAttrs(ret, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA, 1.0f, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA, add_to ? 1.0f : 0.0f); + } + CUDNN_CALL(cudnnBackendFinalize(ret.get())); + return ret; +} + +Descriptor MakeConvWgradOp(const Descriptor& conv, const Descriptor& x, const Descriptor& dy, + const Descriptor& dw, bool add_to) { + auto ret = Make(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_CONV_DESC, conv, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X, x, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY, dy, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW, dw); + if (GetAttr(x, CUDNN_ATTR_TENSOR_DATA_TYPE) == CUDNN_DATA_DOUBLE) { + SetAttrs(ret, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA, 1.0, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA, add_to ? 1.0 : 0.0); + } else { + SetAttrs(ret, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA, 1.0f, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA, add_to ? 1.0f : 0.0f); + } + CUDNN_CALL(cudnnBackendFinalize(ret.get())); + return ret; +} + +Descriptor MakeOpGraph(cudnnHandle_t handle, const std::vector& ops) { + return MakeFinalized(CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR, + CUDNN_ATTR_OPERATIONGRAPH_HANDLE, handle, + CUDNN_ATTR_OPERATIONGRAPH_OPS, ops); +} + +ConvParam::ConvParam(const ConvolutionParam& p, bool add_to) + : kernel(p.kernel), + stride(p.stride), + dilate(p.dilate), + pad(p.pad), + num_filter(p.num_filter), + num_group(p.num_group), + workspace(p.workspace), + cudnn_tune(p.cudnn_tune), + layout(p.layout), + add_to(add_to) {} + +ConvParam::ConvParam(const DeconvolutionParam& p, bool add_to) + : kernel(p.kernel), + stride(p.stride), + dilate(p.dilate), + pad(p.pad), + num_filter(p.num_filter), + num_group(p.num_group), + workspace(p.workspace), + cudnn_tune(p.cudnn_tune), + layout(p.layout), + add_to(add_to) {} + +void TuneWarnOnce() { + thread_local bool done = false; + if (!done) { + LOG(INFO) << "Auto-tuning cuDNN op, set MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable"; + done = true; + } +} + +std::vector MakeFallbackPlans(const std::vector& ixs, cudnnHandle_t handle, + const Descriptor& op_graph, size_t workspace_limit, + size_t* max_workspace, + const std::unordered_set& excl_engines, + const std::vector& req_numeric, + const std::vector& excl_numeric, + const std::vector& req_behavior, + const std::vector& excl_behavior) { + std::vector plans; + if (max_workspace) *max_workspace = 0; + for (auto ix : ixs) { + if (excl_engines.count(ix)) continue; + auto engine = Make(CUDNN_BACKEND_ENGINE_DESCRIPTOR, + CUDNN_ATTR_ENGINE_OPERATION_GRAPH, op_graph, + CUDNN_ATTR_ENGINE_GLOBAL_INDEX, ix); + auto err = cudnnBackendFinalize(engine.get()); + if (err == CUDNN_STATUS_NOT_SUPPORTED || err == CUDNN_STATUS_ARCH_MISMATCH) continue; + if (err != CUDNN_STATUS_SUCCESS) { + LOG(WARNING) << "Unexpected cuDNN status: " << err << ": " << cudnnGetErrorString(err); + continue; + } + auto cfg = MakeFinalized(CUDNN_BACKEND_ENGINECFG_DESCRIPTOR, + CUDNN_ATTR_ENGINECFG_ENGINE, engine); + auto plan = Make(CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR, + CUDNN_ATTR_EXECUTION_PLAN_HANDLE, handle, + CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG, cfg); + err = cudnnBackendFinalize(plan.get()); + if (err == CUDNN_STATUS_NOT_SUPPORTED || err == CUDNN_STATUS_ARCH_MISMATCH) continue; + if (err != CUDNN_STATUS_SUCCESS) { + LOG(WARNING) << "Unexpected cuDNN status: " << err << ": " << cudnnGetErrorString(err); + continue; + } + auto workspace = GetAttr(plan, CUDNN_ATTR_EXECUTION_PLAN_WORKSPACE_SIZE); + if (workspace > workspace_limit) continue; + auto numerical = GetSomeAttrs( + CUDNN_NUMERICAL_NOTE_TYPE_COUNT, engine, CUDNN_ATTR_ENGINE_NUMERICAL_NOTE); + if (!IsCompatible(numerical, req_numeric, excl_numeric)) continue; + auto behavior = GetSomeAttrs(CUDNN_BEHAVIOR_NOTE_TYPE_COUNT, engine, + CUDNN_ATTR_ENGINE_BEHAVIOR_NOTE); + if (!IsCompatible(behavior, req_behavior, excl_behavior)) continue; + plans.push_back(std::move(plan)); + if (max_workspace) *max_workspace = std::max(*max_workspace, static_cast(workspace)); + } + return plans; +} + +cudnnBackendHeurMode_t HeurMode() { + auto minor = cudnnGetVersion() / 100 % 10; + int default_mode = minor < 2 ? CUDNN_HEUR_MODE_INSTANT : CUDNN_HEUR_MODE_B; + return static_cast(dmlc::GetEnv("MXNET_CUDNN_HEUR_MODE", default_mode)); +} + +std::string ConvParamStr(const ConvParam& param) { + std::ostringstream ss; + ss << " layout: " << param.layout.value(); + ss << " kernel: " << param.kernel; + ss << " stride: " << param.stride; + ss << " dilate: " << param.dilate; + ss << " pad: " << param.pad; + ss << " num_filter: " << param.num_filter; + ss << " num_group: " << param.num_group; + ss << " workspace: " << param.workspace; + return ss.str(); +} + +size_t GetWorkspace(const Descriptor& plan) { + return GetAttr(plan, CUDNN_ATTR_EXECUTION_PLAN_WORKSPACE_SIZE); +} + +Storage::Handle FailsafeAlloc(size_t workspace_size) { + return Storage::Get()->Alloc(workspace_size, Context::GPU(), true); +} + +Storage::Handle AllocWorkspace(std::vector* plans, size_t* workspace_size) { + Storage::Handle workspace; + size_t alloc_size = *workspace_size; + while ((workspace = FailsafeAlloc(alloc_size)).dptr == nullptr && + alloc_size > 0) { + // Remove any plan whose workspace_size equals the failed allocation size + auto hasMaxWorkspace = [alloc_size](auto const& plan){ + return GetWorkspace(plan) == alloc_size; + }; + plans->erase(std::remove_if(plans->begin(), plans->end(), hasMaxWorkspace), + plans->end()); + // Calculate new maximum workspace_size for remaining plans + alloc_size = 0; + for (auto& plan : *plans) + alloc_size = std::max(alloc_size, GetWorkspace(plan)); + } + *workspace_size = alloc_size; + return workspace; +} + +std::unordered_set ExcludeEngines(const std::string& env_var) { + std::string engines = dmlc::GetEnv(env_var.c_str(), std::string()); + std::replace(engines.begin(), engines.end(), ',', ' '); + std::istringstream ss(engines); + return std::unordered_set(std::istream_iterator(ss), + std::istream_iterator()); +} + +Descriptor SelectPlan(const OpContext& ctx, const ConvParam& param, Descriptor op, + size_t n_fallbacks, const std::function& make_op_str, + const std::vector& ids, const std::vector& tensor_ptrs, + int64_t out_size, const std::string& excl_engines_var) { + auto s = ctx.get_stream(); + std::vector ops; + ops.push_back(std::move(op)); + auto op_graph = MakeOpGraph(s->dnn_handle_, ops); + + int verbose = dmlc::GetEnv("MXNET_CUDNN_ALGO_VERBOSE_LEVEL", 0); + if (verbose > 0) LOG(INFO) << "Selecting plan for " << make_op_str() << ":"; + + auto tune = param.cudnn_tune + ? param.cudnn_tune.value() + : dmlc::GetEnv("MXNET_CUDNN_AUTOTUNE_DEFAULT", static_cast(conv::kLimited)); + size_t workspace_size = 0; + size_t workspace_limit = + tune != conv::kFastest ? param.workspace << 20 : std::numeric_limits::max(); + auto excl_engines = ExcludeEngines(excl_engines_var); + auto plans = GetPlans(HeurMode(), s->dnn_handle_, op_graph, workspace_limit, &workspace_size, + excl_engines, RequireNumerics(), ExcludeNumerics(), {}, {}, verbose > 1); + + + Storage::Handle out_space; + auto ptrs = tensor_ptrs; + if (tune != conv::kOff && param.add_to) { + // Cannot trash output tensor while auto-tuning. + out_space = FailsafeAlloc(out_size); + if (out_space.dptr) ptrs.back() = out_space.dptr; + } + // Todo: + // - should we be able to ask the tempspace for it's current size, then + // alloc the workspace from the tempspace if its current size > workspace_size? + auto workspace = AllocWorkspace(&plans, &workspace_size); + + if (plans.empty()) { + std::vector ixs(n_fallbacks); + std::iota(ixs.begin(), ixs.end(), 0); + plans = MakeFallbackPlans(ixs, s->dnn_handle_, op_graph, workspace_limit, &workspace_size, + excl_engines, RequireNumerics(), ExcludeNumerics(), {}, {}); + workspace = AllocWorkspace(&plans, &workspace_size); + CHECK(!plans.empty()); + LOG(WARNING) << "Using fallback engine(s) for " << make_op_str(); + } + + if (tune == conv::kOff || plans.size() == 1 || (param.add_to && !out_space.dptr)) { + if (verbose > 0) LOG(INFO) << " " << PlanStr(plans[0]); + Storage::Get()->Free(out_space); + Storage::Get()->Free(workspace); + return std::move(plans[0]); + } + + TuneWarnOnce(); + size_t n = verbose > 1 ? plans.size() : 1; + auto var_pack = MakeFinalized(CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR, + CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS, ids, + CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS, ptrs, + CUDNN_ATTR_VARIANT_PACK_WORKSPACE, workspace.dptr); + auto top = FindTopPlans(std::move(plans), n, s->dnn_handle_, var_pack, MakeAvgSampler(3)); + Storage::Get()->Free(out_space); + Storage::Get()->Free(workspace); + auto str_time = [](float t) { + std::ostringstream ss; + ss << std::fixed << std::setprecision(6) << t; + return ss.str(); + }; + for (size_t i = 0; verbose > 0 && i < top.size(); ++i) { + auto prefix = i == 0 ? " * " : " "; + LOG(INFO) << prefix << top[i].heur_i << ") " << str_time(top[i].time) << "ms " + << PlanStr(top[i].plan); + } + return std::move(top[0].plan); +} + +size_t Size(const TBlob& t) { + return t.Size() * mshadow::mshadow_sizeof(t.type_flag_); +} + +// TODO(vcherepanov): remove these, once fallbacks are received as a heuristics mode in 8.3 +enum MaxFallbacks { + kMaxConvFallbacks = 58, + kMaxDgradFallbacks = 63, + kMaxWgradFallbacks = 62 +}; + +std::optional Conv::Make(const OpContext& ctx, const Param& param, const TBlob& x, + const TBlob& w, const TBlob& y) { + auto conv = MakeConvDesc(param, static_cast(x.type_flag_)); + auto li = GetLayoutInfo(static_cast(param.layout.value())); + auto x_desc = MakeTensorDesc(ID_X, x, li, true, false); + auto w_desc = MakeTensorDesc(ID_W, w, li, true, false); + auto y_desc = MakeTensorDesc(ID_Y, y, li, true, false); + auto conv_fwd = MakeConvFwdOp(conv, x_desc, w_desc, y_desc, param.add_to); + + auto make_op_str = [¶m, &x]() { + std::ostringstream ss; + ss << "fprop " << mshadow::dtype_string(x.type_flag_) << " " << ConvParamStr(param); + return ss.str(); + }; + + std::vector ids{ID_X, ID_W, ID_Y}; + std::vector ptrs{x.dptr_, w.dptr_, y.dptr_}; + + Conv ret; + ret.plan_ = SelectPlan(ctx, param, std::move(conv_fwd), kMaxConvFallbacks, make_op_str, ids, ptrs, + Size(y), "MXNET_CUDNN_DISABLED_CONV_FWD_ENGINES"); + return ret; +} + +void Conv::Exec(const OpContext& ctx, const TBlob& x, const TBlob& w, const TBlob& y) const { + auto s = ctx.get_stream(); + auto workspace_size = GetAttr(plan_, CUDNN_ATTR_EXECUTION_PLAN_WORKSPACE_SIZE); + auto workspace = ctx.requested[0].get_space_internal(workspace_size, "Conv"); + + std::vector ids{ID_X, ID_W, ID_Y}; + std::vector ptrs{x.dptr_, w.dptr_, y.dptr_}; + auto var_pack = MakeFinalized(CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR, + CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS, ids, + CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS, ptrs, + CUDNN_ATTR_VARIANT_PACK_WORKSPACE, workspace); + CUDNN_CALL(cudnnBackendExecute(s->dnn_handle_, plan_.get(), var_pack.get())); +} + +std::optional ConvDgrad::Make(const OpContext& ctx, const Param& param, const TBlob& w, + const TBlob& dy, const TBlob& dx) { + auto conv = MakeConvDesc(param, static_cast(w.type_flag_)); + auto li = GetLayoutInfo(static_cast(param.layout.value())); + auto w_desc = MakeTensorDesc(ID_W, w, li, true, false); + auto dy_desc = MakeTensorDesc(ID_DY, dy, li, true, false); + auto dx_desc = MakeTensorDesc(ID_DX, dx, li, true, false); + auto dgrad = MakeConvDgradOp(conv, w_desc, dy_desc, dx_desc, param.add_to); + + auto make_op_str = [¶m, &dx]() { + std::ostringstream ss; + ss << "dgrad " << mshadow::dtype_string(dx.type_flag_) << " " << ConvParamStr(param); + return ss.str(); + }; + + std::vector ids{ID_W, ID_DY, ID_DX}; + std::vector ptrs{w.dptr_, dy.dptr_, dx.dptr_}; + + ConvDgrad ret; + ret.plan_ = SelectPlan(ctx, param, std::move(dgrad), kMaxDgradFallbacks, make_op_str, ids, ptrs, + Size(dx), "MXNET_CUDNN_DISABLED_CONV_DGRAD_ENGINES"); + return ret; +} + +void ConvDgrad::Exec(const OpContext& ctx, const TBlob& w, const TBlob& dy, const TBlob& dx) const { + auto s = ctx.get_stream(); + auto workspace_size = GetAttr(plan_, CUDNN_ATTR_EXECUTION_PLAN_WORKSPACE_SIZE); + auto workspace = ctx.requested[0].get_space_internal(workspace_size, "ConvDgrad"); + + std::vector ids{ID_W, ID_DY, ID_DX}; + std::vector ptrs{w.dptr_, dy.dptr_, dx.dptr_}; + auto var_pack = MakeFinalized(CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR, + CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS, ids, + CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS, ptrs, + CUDNN_ATTR_VARIANT_PACK_WORKSPACE, workspace); + CUDNN_CALL(cudnnBackendExecute(s->dnn_handle_, plan_.get(), var_pack.get())); +} + +std::optional ConvWgrad::Make(const OpContext& ctx, const Param& param, const TBlob& x, + const TBlob& dy, const TBlob& dw) { + auto conv = MakeConvDesc(param, static_cast(x.type_flag_)); + auto li = GetLayoutInfo(static_cast(param.layout.value())); + auto x_desc = MakeTensorDesc(ID_X, x, li, true, false); + auto dy_desc = MakeTensorDesc(ID_DY, dy, li, true, false); + auto dw_desc = MakeTensorDesc(ID_DW, dw, li, true, false); + auto wgrad = MakeConvWgradOp(conv, x_desc, dy_desc, dw_desc, param.add_to); + + auto make_op_str = [¶m, &x]() { + std::ostringstream ss; + ss << "wgrad " << mshadow::dtype_string(x.type_flag_) << " " << ConvParamStr(param); + return ss.str(); + }; + + std::vector ids{ID_X, ID_DY, ID_DW}; + std::vector ptrs{x.dptr_, dy.dptr_, dw.dptr_}; + + ConvWgrad ret; + ret.plan_ = SelectPlan(ctx, param, std::move(wgrad), kMaxWgradFallbacks, make_op_str, ids, ptrs, + Size(dw), "MXNET_CUDNN_DISABLED_CONV_WGRAD_ENGINES"); + return ret; +} + +void ConvWgrad::Exec(const OpContext& ctx, const TBlob& x, const TBlob& dy, const TBlob& dw) const { + auto s = ctx.get_stream(); + auto workspace_size = GetAttr(plan_, CUDNN_ATTR_EXECUTION_PLAN_WORKSPACE_SIZE); + auto workspace = ctx.requested[0].get_space_internal(workspace_size, "ConvWgrad"); + + std::vector ids{ID_X, ID_DY, ID_DW}; + std::vector ptrs{x.dptr_, dy.dptr_, dw.dptr_}; + auto var_pack = MakeFinalized(CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR, + CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS, ids, + CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS, ptrs, + CUDNN_ATTR_VARIANT_PACK_WORKSPACE, workspace); + CUDNN_CALL(cudnnBackendExecute(s->dnn_handle_, plan_.get(), var_pack.get())); +} + +} // namespace cudnn +} // namespace op +} // namespace mxnet +#endif // MXNET_USE_CUDNN == 1 diff --git a/src/operator/cudnn_ops.h b/src/operator/cudnn_ops.h new file mode 100644 index 000000000000..23d66e65e8a7 --- /dev/null +++ b/src/operator/cudnn_ops.h @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2021 by Contributors + * \file cudnn_ops.h + * \brief cuDNN v8 ops + */ +#ifndef MXNET_OPERATOR_CUDNN_OPS_H_ +#define MXNET_OPERATOR_CUDNN_OPS_H_ + +#include +#if MXNET_USE_CUDNN == 1 + +#include + +#include +#include // NOLINT(build/include_order) +#include +#include +#include +#include + +#include "nn/convolution-inl.h" +#include "nn/deconvolution-inl.h" + +#include "../common/cudnn_cxx.h" + +namespace std { + +// Enable tuples as keys. +template +struct hash> { + size_t operator()(const std::tuple& t) const { + size_t ret = 0; + if constexpr (sizeof...(Ts) > 0) { + std::apply( + [&ret](auto head, const auto&... tail) { + ret = dmlc::HashCombine(ret, head); + ret = dmlc::HashCombine(ret, std::make_tuple(tail...)); + }, + t); + } + return ret; + } +}; + +} // namespace std + +namespace mxnet { +namespace op { + +namespace cudnn { + +struct LayoutInfo { + size_t n_space_dims; + bool channel_last; + + std::vector Order() const; + size_t ChannelIdx() const; + std::vector Strides(const std::vector& dims) const; +}; + +LayoutInfo GetLayoutInfo(mshadow::LayoutFlag layout); + +TShape ExpandChannelDims(mshadow::LayoutFlag layout, int c); + +void MaybeLogSelectedPlan(const cudnn_cxx::Descriptor& plan); + +// To support cached lookup and execution an operation Op must define: +// +// Op::Param - a type, collecting all data, required to create cuDNN descriptor(s), but not needed +// for execution. +// Op::MakeKey() - a static function, which maps its arguments to a tuple - a key in the op cache. +// Op::Make() - a static function, creating all necessary cuDNN descriptors. +// Op::Exec() - a member function, calling cudnnBackendExecute() with the prepared descriptor(s) and +// the passed arguments. +template +bool Exec(const OpContext& ctx, const typename Op::Param& param, Args&&... args) { + auto key = std::tuple_cat(std::make_tuple(ctx.run_ctx.ctx.dev_id), + Op::MakeKey(param, std::forward(args)...)); + static std::unordered_map> op_map; + static std::mutex mx; + std::unique_lock lk(mx); + auto it = op_map.find(key); + if (it == op_map.end()) { + auto op = Op::Make(ctx, param, std::forward(args)...); + it = op_map.emplace(key, std::move(op)).first; + } + lk.unlock(); + if (!it->second) return false; + it->second.value().Exec(ctx, std::forward(args)...); + return true; +} + +// The subset of ConvolutionParam / DeconvolutionParam fields, +// which unambiguously identify and consturct cuDNN convolution, plus add_to flag. +struct ConvParam { + mxnet::TShape kernel; + mxnet::TShape stride; + mxnet::TShape dilate; + mxnet::TShape pad; + uint32_t num_filter; + uint32_t num_group; + uint64_t workspace; + dmlc::optional cudnn_tune; + dmlc::optional layout; + + bool add_to; + + ConvParam(const ConvolutionParam& p, bool add_to); + ConvParam(const DeconvolutionParam& p, bool add_to); +}; + +struct Conv { + using Param = ConvParam; + enum UIDs { ID_X = 1, ID_W, ID_Y }; + + cudnn_cxx::Descriptor plan_; + + static auto MakeKey(const Param& p, const TBlob& x, const TBlob& w, const TBlob& y) { + return std::make_tuple(p.kernel, p.stride, p.dilate, p.pad, p.num_filter, p.num_group, + p.workspace, p.layout, p.add_to, x.shape_, x.type_flag_, w.shape_, + w.type_flag_, y.shape_); + } + + static std::optional Make(const OpContext& ctx, const Param& param, const TBlob& x, + const TBlob& w, const TBlob& y); + + void Exec(const OpContext& ctx, const TBlob& x, const TBlob& w, const TBlob& y) const; +}; + +struct ConvDgrad { + using Param = ConvParam; + enum UIDs { ID_W = 1, ID_DY, ID_DX }; + + cudnn_cxx::Descriptor plan_; + + static auto MakeKey(const Param& p, const TBlob& w, const TBlob& dy, const TBlob& dx) { + return std::make_tuple(p.kernel, p.stride, p.dilate, p.pad, p.num_filter, p.num_group, + p.workspace, p.layout, p.add_to, w.shape_, w.type_flag_, dy.shape_, + dy.type_flag_, dx.shape_); + } + + static std::optional Make(const OpContext& ctx, const Param& param, const TBlob& w, + const TBlob& dy, const TBlob& dx); + + void Exec(const OpContext& ctx, const TBlob& w, const TBlob& dy, const TBlob& dx) const; +}; + +struct ConvWgrad { + using Param = ConvParam; + enum UIDs { ID_X = 1, ID_DY, ID_DW }; + + cudnn_cxx::Descriptor plan_; + + static auto MakeKey(const Param& p, const TBlob& x, const TBlob& dy, const TBlob& dw) { + return std::make_tuple(p.kernel, p.stride, p.dilate, p.pad, p.num_filter, p.num_group, + p.workspace, p.layout, p.add_to, x.shape_, x.type_flag_, dy.shape_, + dy.type_flag_, dw.shape_); + } + + static std::optional Make(const OpContext& ctx, const Param& param, const TBlob& x, + const TBlob& dy, const TBlob& dw); + + void Exec(const OpContext& ctx, const TBlob& x, const TBlob& dy, const TBlob& dw) const; +}; + +} // namespace cudnn +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_CUDNN == 1 + +#endif // MXNET_OPERATOR_CUDNN_OPS_H_ diff --git a/src/operator/nn/convolution.cu b/src/operator/nn/convolution.cu index deeac83456db..53d54304630e 100644 --- a/src/operator/nn/convolution.cu +++ b/src/operator/nn/convolution.cu @@ -18,74 +18,25 @@ */ /*! + * Copyright (c) 2017 by Contributors * \file convolution.cu * \brief * \author Bing Xu, Jun Wu, Da Zheng - */ +*/ #include "./convolution-inl.h" #include #include "./depthwise_convolution-inl.h" #if MXNET_USE_CUDNN == 1 -#include "./cudnn/cudnn_convolution-inl.h" +#include "../cudnn_ops.h" +#include "../tensor/broadcast_reduce_op.h" +#include "../tensor/elemwise_binary_broadcast_op.h" +#include "fully_connected-inl.h" #endif // MXNET_USE_CUDNN namespace mxnet { namespace op { -#if MXNET_USE_CUDNN == 1 -template -static CuDNNConvolutionOp& GetCuDNNConvOp(const ConvolutionParam& param, - int forward_compute_type, - int backward_compute_type, - const mxnet::ShapeVector& in_shape, - const mxnet::ShapeVector& out_shape, - const RunContext& rctx, - bool add_to_weight) { -#if DMLC_CXX11_THREAD_LOCAL - static thread_local std:: - unordered_map>, OpHash> - ops; -#else - static MX_THREAD_LOCAL - std::unordered_map>, OpHash> - ops; -#endif - ConvSignature key(param); - size_t ndim = 0; - for (auto& s : in_shape) - ndim += s.ndim(); - for (auto& s : out_shape) - ndim += s.ndim(); - key.Reserve(1 /* for forward_compute_type */ + 1 /* for backward_compute_type */ + - ndim /* for in and out shapes */ + 1 /* for dev_id */ + 1 /* for add_to_weight */); - - key.AddSign(forward_compute_type); - key.AddSign(backward_compute_type); - key.AddSign(in_shape); - key.AddSign(out_shape); - key.AddSign(rctx.ctx.dev_id); - key.AddSign(add_to_weight ? 1 : 0); - - auto it = ops.find(key); - if (it == ops.end()) { - std::shared_ptr> op(new CuDNNConvolutionOp()); - auto ins_ret = - ops.insert(std::pair>>(key, op)); - CHECK(ins_ret.second); - it = ins_ret.first; - it->second->Init(param, - forward_compute_type, - backward_compute_type, - in_shape, - out_shape, - rctx, - add_to_weight); - } - return *it->second; -} -#endif - template <> void ConvolutionCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -93,40 +44,52 @@ void ConvolutionCompute(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const ConvolutionParam& param = nnvm::get(attrs.parsed); - int dtype = inputs[conv::kData].type_flag_; + int dtype = inputs[conv::kData].type_flag_; + CHECK_EQ(req.size(), 1); + CHECK_EQ(req[conv::kOut], kWriteTo); #if MXNET_USE_CUDNN == 1 - STATIC_ASSERT_CUDNN_VERSION_GE(7000); - // On fp16-I/O instances, use fp32 compute (i.e. pseudo-fp16). - int compute_type = (dtype == mshadow::kFloat16) ? mshadow::kFloat32 : dtype; - + STATIC_ASSERT_CUDNN_VERSION_GE(8000); MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - if (param.cudnn_off) { - ConvolutionOp op; - op.Init(param); - op.Forward(ctx, inputs, req, outputs); - } else if (!CuDNNConvolutionOp::Supports( - param, compute_type, compute_type, ctx.run_ctx.ctx.dev_id)) { - LOG(WARNING) << "This convolution is not supported by cudnn, MXNET convolution is applied."; + cudnn::ConvParam conv_param(param, false); + bool ok = + !param.cudnn_off && cudnn::Exec(ctx, conv_param, inputs[conv::kData], + inputs[conv::kWeight], outputs[conv::kOut]); + if (ok && !param.no_bias) { + CHECK_EQ(inputs[conv::kBias].shape_.ndim(), 1); + auto layout = static_cast(param.layout.value()); + int k = inputs[conv::kBias].shape_.Size(); + auto b = inputs[conv::kBias].reshape(cudnn::ExpandChannelDims(layout, k)); + BinaryBroadcastRTCCompute {"add"}(attrs, ctx, {outputs[conv::kOut], b}, {kWriteInplace}, + {outputs[conv::kOut]}); + } + if (!ok) { + if (!param.cudnn_off) + LOG(WARNING) << "This convolution is not supported by cuDNN, MXNet convolution is applied."; ConvolutionOp op; op.Init(param); op.Forward(ctx, inputs, req, outputs); - } else { - mxnet::ShapeVector in_shape(inputs.size()); - mxnet::ShapeVector out_shape(1, outputs[0].shape_); - for (size_t i = 0; i < in_shape.size(); i++) - in_shape[i] = inputs[i].shape_; - // req[conv::kWeight] is only set for backward, so assume the typical 'write' for now. - auto add_to_weight = false; - CuDNNConvolutionOp& op = GetCuDNNConvOp( - param, compute_type, compute_type, in_shape, out_shape, ctx.run_ctx, add_to_weight); - op.Forward(ctx, inputs, req, outputs); } }) #else - if (param.num_filter == param.num_group && param.layout.value() == mshadow::kNCHW && - param.num_filter == inputs[conv::kData].shape_[1] && param.kernel.ndim() == 2 && - param.dilate == mshadow::Shape2(1, 1) && dtype == mshadow::kFloat32) { + if (param.layout.value() != kNCW && + param.layout.value() != kNCHW && + param.layout.value() != kNCDHW) { + // Need CuDNN > 5.0 for layout support. use MXNet implementation + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + ConvolutionOp op; + op.Init(param); + op.Forward(ctx, inputs, req, outputs); + }) + return; + } + + if (param.num_filter == param.num_group && + param.layout.value() == mshadow::kNCHW && + param.num_filter == inputs[conv::kData].shape_[1] && + param.kernel.ndim() == 2 && + param.dilate == mshadow::Shape2(1, 1) && + dtype == mshadow::kFloat32) { mxnet::ShapeVector in_shape(inputs.size()); mxnet::ShapeVector out_shape(1, outputs[0].shape_); for (size_t i = 0; i < in_shape.size(); i++) @@ -153,42 +116,65 @@ void ConvolutionGradCompute(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { const ConvolutionParam& param = nnvm::get(attrs.parsed); std::vector in_data(inputs.begin() + 1, inputs.end()); - const TBlob& out_grad = inputs[0]; - const std::vector& in_grad = outputs; - int dtype = out_grad.type_flag_; + const TBlob &out_grad = inputs[0]; + const std::vector &in_grad = outputs; + int dtype = out_grad.type_flag_; + CHECK_EQ(req.size(), param.no_bias ? 2 : 3); + CHECK_NE(req[conv::kData], kWriteInplace); + CHECK_NE(req[conv::kWeight], kWriteInplace); + if (!param.no_bias) CHECK_NE(req[conv::kBias], kWriteInplace); #if MXNET_USE_CUDNN == 1 - STATIC_ASSERT_CUDNN_VERSION_GE(7000); - // On fp16-I/O instances, use fp32 compute (i.e. pseudo-fp16). - int compute_type = (dtype == mshadow::kFloat16) ? mshadow::kFloat32 : dtype; - + STATIC_ASSERT_CUDNN_VERSION_GE(8000); MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - if (param.cudnn_off) { - ConvolutionOp op; - op.Init(param); - op.Backward(ctx, std::vector{out_grad}, in_data, req, in_grad); - } else if (!CuDNNConvolutionOp::Supports( - param, compute_type, compute_type, ctx.run_ctx.ctx.dev_id)) { - LOG(WARNING) << "This convolution is not supported by cudnn, MXNET convolution is applied."; + cudnn::ConvParam conv_param(param, req[conv::kData] == kAddTo); + bool ok = !param.cudnn_off; + ok = ok && (req[conv::kData] == kNullOp || + cudnn::Exec(ctx, conv_param, inputs[1 + conv::kWeight], inputs[0], + outputs[conv::kData])); + conv_param.add_to = req[conv::kWeight] == kAddTo; + ok = ok && (req[conv::kWeight] == kNullOp || + cudnn::Exec(ctx, conv_param, inputs[1 + conv::kData], inputs[0], + outputs[conv::kWeight])); + if (ok && !param.no_bias && req[conv::kBias] != kNullOp) { + auto li = cudnn::GetLayoutInfo(static_cast(param.layout.value())); + if (li.channel_last) { + // This kernel should be faster. + auto y_grad = FlattenAs2DHead(inputs[0], ctx); + AddBiasGrad(outputs[conv::kBias], y_grad, req[conv::kBias], param.num_filter, ctx); + } else { + TShape axes{static_cast(li.ChannelIdx())}; + TShape small = ReduceAxesShapeImpl(inputs[0].shape_, dmlc::optional(axes), true, true); + ReduceAxesRTCComputeImpl(ctx, {inputs[0]}, {req[conv::kBias]}, {outputs[conv::kBias]}, + small, "red::sum{}"); + } + } + if (!ok) { + if (!param.cudnn_off) + LOG(WARNING) << "This convolution backward is not supported by cuDNN, MXNet op is applied."; ConvolutionOp op; op.Init(param); op.Backward(ctx, std::vector{out_grad}, in_data, req, in_grad); - } else { - // The first element stores out grad. - mxnet::ShapeVector in_shape(in_data.size()); - mxnet::ShapeVector out_shape(1, out_grad.shape_); - for (size_t i = 0; i < in_shape.size(); i++) - in_shape[i] = in_data[i].shape_; - auto add_to_weight = req[conv::kWeight] == kAddTo; - CuDNNConvolutionOp& op = GetCuDNNConvOp( - param, compute_type, compute_type, in_shape, out_shape, ctx.run_ctx, add_to_weight); - op.Backward(ctx, std::vector{out_grad}, in_data, req, in_grad); } }) #else - if (param.num_filter == param.num_group && param.layout.value() == mshadow::kNCHW && - param.num_filter == in_data[conv::kData].shape_[1] && param.kernel.ndim() == 2 && - param.dilate == mshadow::Shape2(1, 1) && dtype == mshadow::kFloat32) { + if (param.layout.value() != kNCW && + param.layout.value() != kNCHW && + param.layout.value() != kNCDHW) { + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + ConvolutionOp op; + op.Init(param); + op.Backward(ctx, std::vector{out_grad}, in_data, req, in_grad); + }) + return; + } + + if (param.num_filter == param.num_group && + param.layout.value() == mshadow::kNCHW && + param.num_filter == in_data[conv::kData].shape_[1] && + param.kernel.ndim() == 2 && + param.dilate == mshadow::Shape2(1, 1) && + dtype == mshadow::kFloat32) { // The first element stores out grad. mxnet::ShapeVector in_shape(in_data.size()); mxnet::ShapeVector out_shape(1, out_grad.shape_); @@ -208,10 +194,12 @@ void ConvolutionGradCompute(const nnvm::NodeAttrs& attrs, #endif // MXNET_USE_CUDNN } -NNVM_REGISTER_OP(Convolution).set_attr("FCompute", ConvolutionCompute); +NNVM_REGISTER_OP(Convolution) +.set_attr("FCompute", ConvolutionCompute); NNVM_REGISTER_OP(_backward_Convolution) - .set_attr("FCompute", ConvolutionGradCompute); +.set_attr("FCompute", ConvolutionGradCompute); } // namespace op } // namespace mxnet + diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h deleted file mode 100644 index e94b172bc398..000000000000 --- a/src/operator/nn/cudnn/cudnn_convolution-inl.h +++ /dev/null @@ -1,831 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file cudnn_convolution-inl.h - * \brief - * \author Bing Xu - */ -#ifndef MXNET_OPERATOR_NN_CUDNN_CUDNN_CONVOLUTION_INL_H_ -#define MXNET_OPERATOR_NN_CUDNN_CUDNN_CONVOLUTION_INL_H_ - -#include -#include -#include -#include -#include -#include "../convolution-inl.h" -#include "./cudnn_algoreg-inl.h" -#include "../../../common/cuda/utils.h" - -namespace mxnet { -namespace op { -#if MXNET_USE_CUDNN == 1 - -/*! - * \brief The Operator used to perform convolution using cuDNN kernels. - */ -template -class CuDNNConvolutionOp { - STATIC_ASSERT_CUDNN_VERSION_GE(7000); - - public: - CuDNNConvolutionOp() { - CUDNN_CALL(cudnnCreateTensorDescriptor(&in_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&out_desc_)); - CUDNN_CALL(cudnnCreateTensorDescriptor(&bias_desc_)); - CUDNN_CALL(cudnnCreateFilterDescriptor(&filter_desc_)); - CUDNN_CALL(cudnnCreateConvolutionDescriptor(&forward_conv_desc_)); - CUDNN_CALL(cudnnCreateConvolutionDescriptor(&back_conv_desc_)); - CUDNN_CALL(cudnnCreateConvolutionDescriptor(&back_conv_desc_w_)); - parallelize_backward_kernels_ = Context::GetGPUStreamsPerWorker() >= 2; - } - - void Init(const ConvolutionParam& param, - int forward_compute_type, - int backward_compute_type, - const mxnet::ShapeVector& in_shape, - const mxnet::ShapeVector& out_shape, - const RunContext& rctx, - bool add_to_weight) { - using namespace mshadow; - this->param_ = param; - this->add_to_weight_ = add_to_weight; - InitBufferForParam(); - auto cudnn_forward_compute_type = convertToCuDNNDataType(forward_compute_type); - auto cudnn_backward_compute_type = convertToCuDNNDataType(backward_compute_type); - // convert MB to words - param_.workspace = (param_.workspace << 20) / sizeof(DType); - dtype_ = DataType::kCudnnFlag; - // TensorCore algos only allowed on fp16-I/O convolutions if permitted by the global policy. - cudnn_tensor_core_ = DataType::kFlag == kFloat16 && GetEnvAllowTensorCore(); - - auto effective_layout = param_.layout.value(); - switch (effective_layout) { - // 1D convolutions will be executed as 2D convolutions with a height of 1. - case mshadow::kNCW: - effective_layout = mshadow::kNCHW; - break; - case mshadow::kNWC: - effective_layout = mshadow::kNHWC; - break; - case mshadow::kCWN: - effective_layout = mshadow::kCHWN; - break; - default: - break; - } - - MSHADOW_LAYOUT_SWITCH(effective_layout, Layout, { format_ = LayoutType::kCudnnFlag; }); - // Double check to make sure this class supports the operation - if (!Supports(param, forward_compute_type, backward_compute_type, rctx.ctx.dev_id)) - LOG(FATAL) << "Convolution parameters not supported by cuDNN implementation."; - - InitDescriptors(in_shape, out_shape, cudnn_forward_compute_type, cudnn_backward_compute_type); - - if (!param_.cudnn_tune) { - param_.cudnn_tune = dmlc::GetEnv("MXNET_CUDNN_AUTOTUNE_DEFAULT", 1); - } - // In cuDNN_v6, dilated convolution descriptors are compatible with only a - // single convolution algorithm. Despite this, we go through the algorithm - // selection process, which will return the only algorithm supported. This - // approach keeps the treatment of convolution cases uniform and will - // naturally respond to more algorithms supporting dilated convolutions in - // future cuDNN releases. - SelectAlgo(rctx, in_shape, out_shape, cudnn_forward_compute_type, cudnn_backward_compute_type); - GetTempSize(rctx); - } - - ~CuDNNConvolutionOp() { - CUDNN_CALL(cudnnDestroyTensorDescriptor(in_desc_)); - CUDNN_CALL(cudnnDestroyTensorDescriptor(out_desc_)); - CUDNN_CALL(cudnnDestroyTensorDescriptor(bias_desc_)); - CUDNN_CALL(cudnnDestroyFilterDescriptor(filter_desc_)); - CUDNN_CALL(cudnnDestroyConvolutionDescriptor(forward_conv_desc_)); - CUDNN_CALL(cudnnDestroyConvolutionDescriptor(back_conv_desc_)); - CUDNN_CALL(cudnnDestroyConvolutionDescriptor(back_conv_desc_w_)); - } - - void Forward(const OpContext& ctx, - const std::vector& in_data, - const std::vector& req, - const std::vector& out_data) { - using namespace mshadow; - size_t expected = param_.no_bias ? 2 : 3; - CHECK_EQ(in_data.size(), expected); - CHECK_EQ(out_data.size(), 1U); - Stream* s = ctx.get_stream(); - Tensor workspace = AllocateTempWorkspace(ctx, forward_workspace_byte_); - size_t workspace_size = TensorSizeBytes(workspace); - - // I/O's should have 2 more dims than the kernel dim - DType* data_ptr = GetNdPtr(in_data[conv::kData], param_.kernel.ndim() + 2, s); - DType* wmat_ptr = GetNdPtr(in_data[conv::kWeight], param_.kernel.ndim() + 2, s); - DType* out_ptr = GetNdPtr(out_data[conv::kOut], param_.kernel.ndim() + 2, s); - - typename DataType::ScaleType alpha = 1.0f; - typename DataType::ScaleType beta = 0.0f; - typename DataType::ScaleType beta_add = 1.0f; - CUDNN_CALL(cudnnConvolutionForward(s->dnn_handle_, - &alpha, - in_desc_, - data_ptr, - filter_desc_, - wmat_ptr, - forward_conv_desc_, - forward_algo_.AlgoNumber(), - workspace.dptr_, - workspace_size, - req[conv::kOut] == kAddTo ? &beta_add : &beta, - out_desc_, - out_ptr)); - - if (!param_.no_bias) { - Tensor bias = in_data[conv::kBias].get(s); - CUDNN_CALL(cudnnAddTensor( - s->dnn_handle_, &alpha, bias_desc_, bias.dptr_, &beta_add, out_desc_, out_ptr)); - } - } - - void Backward(const OpContext& ctx, - const std::vector& out_grad, - const std::vector& in_data, - const std::vector& req, - const std::vector& in_grad) { - using namespace mshadow; - using namespace mshadow::expr; - size_t expected = param_.no_bias == 0 ? 3 : 2; - CHECK_EQ(out_grad.size(), 1U); - CHECK_EQ(in_data.size(), expected); - CHECK_EQ(in_grad.size(), expected); - Stream* s = ctx.get_stream(); - // RAII object to handle syncing of the underlying auxiliary stream with the primary stream - SyncedGPUAuxStream s_dgrad = ctx.get_gpu_aux_stream(); - - // I/O's should have 2 more dims than the kernel dim - DType* grad_ptr = GetNdPtr(out_grad[conv::kOut], param_.kernel.ndim() + 2, s); - DType* wmat_ptr = GetNdPtr(in_data[conv::kWeight], param_.kernel.ndim() + 2, s); - DType* gwmat_ptr = GetNdPtr(in_grad[conv::kWeight], param_.kernel.ndim() + 2, s); - DType* data_ptr = GetNdPtr(in_data[conv::kData], param_.kernel.ndim() + 2, s); - DType* gdata_ptr = GetNdPtr(in_grad[conv::kData], param_.kernel.ndim() + 2, s); - - size_t backward_workspace_byte = - parallelize_backward_kernels_ - ? back_workspace_byte_dgrad_ + back_workspace_byte_wgrad_ - : std::max(back_workspace_byte_dgrad_, back_workspace_byte_wgrad_); - Tensor workspace = AllocateTempWorkspace(ctx, backward_workspace_byte); - size_t workspace_size = TensorSizeBytes(workspace); - DType* workspace_dptr_wgrad = workspace.dptr_; - DType* workspace_dptr_dgrad = workspace.dptr_; - if (parallelize_backward_kernels_) { - CHECK_LE(back_workspace_byte_dgrad_ + back_workspace_byte_wgrad_, workspace_size); - // Large allocations at some point will be given their own page. Pass this alignment on to - // the larger of the two separate dgrad/wgrad workspaces. This probably doesn't matter, but - // corresponds more closely to the workspace alignments used during cudnnFind. - if (back_workspace_byte_dgrad_ > back_workspace_byte_wgrad_) - workspace_dptr_wgrad = workspace.dptr_ + back_workspace_byte_dgrad_ / sizeof(DType); - else - workspace_dptr_dgrad = workspace.dptr_ + back_workspace_byte_wgrad_ / sizeof(DType); - } else { - CHECK_LE(back_workspace_byte_dgrad_, workspace_size); - CHECK_LE(back_workspace_byte_wgrad_, workspace_size); - } - typename DataType::ScaleType alpha = 1.0f; - typename DataType::ScaleType beta = 0.0f; - typename DataType::ScaleType beta_add = 1.0f; - if (req[conv::kWeight] != kNullOp) { - CHECK_EQ(add_to_weight_, req[conv::kWeight] == kAddTo); - CUDNN_CALL(cudnnConvolutionBackwardFilter(s->dnn_handle_, - &alpha, - in_desc_, - data_ptr, - out_desc_, - grad_ptr, - back_conv_desc_w_, - back_algo_w_.AlgoNumber(), - workspace_dptr_wgrad, - back_workspace_byte_wgrad_, - req[conv::kWeight] == kAddTo ? &beta_add : &beta, - filter_desc_, - gwmat_ptr)); - } - if (!param_.no_bias && (req[conv::kBias] != kNullOp)) { - Tensor gbias = in_grad[conv::kBias].get(s); - CUDNN_CALL(cudnnConvolutionBackwardBias(s->dnn_handle_, - &alpha, - out_desc_, - grad_ptr, - req[conv::kBias] == kAddTo ? &beta_add : &beta, - bias_desc_, - gbias.dptr_)); - } - if (req[conv::kData] != kNullOp) { - CUDNN_CALL(cudnnConvolutionBackwardData(s_dgrad.GetStream()->dnn_handle_, - &alpha, - filter_desc_, - wmat_ptr, - out_desc_, - grad_ptr, - back_conv_desc_, - back_algo_.AlgoNumber(), - workspace_dptr_dgrad, - back_workspace_byte_dgrad_, - req[conv::kData] == kAddTo ? &beta_add : &beta, - in_desc_, - gdata_ptr)); - } - } - - /*! - * \brief Returns whether the cuDNN library version supports the convolution - * operation described by `param`: cuDNN v5 and earlier does not support - * dilated convolutions. Dilation only enabled after v6.0.20. - */ - static bool Supports(ConvolutionParam param, - int forward_compute_type, - int backward_compute_type, - int dev_id) { - using namespace mshadow; - - // NDHWC not supported, NHWC not supported in true fp16 - auto layout_val = param.layout.value(); - auto true_fp16 = DataType::kFlag == kFloat16 && - (forward_compute_type == kFloat16 || backward_compute_type == kFloat16); - if (layout_val == kNDHWC || layout_val == kNWC || layout_val == kNHWC && true_fp16) - return false; - - // Permits graceful fallback to pseudo-fp16 on heterogenous systems - if (!SupportsFloat16Compute(dev_id) && - (forward_compute_type == kFloat16 || backward_compute_type == kFloat16)) { - return false; - } - - return true; - } - - private: - /*! - * \brief Translate an mxnet datatype to the corresponding cudnnDataType_t. - */ - cudnnDataType_t convertToCuDNNDataType(int dtype) { - cudnnDataType_t converted = CUDNN_DATA_FLOAT; - // The following will always assign to `converted` or throw an exception. - MSHADOW_REAL_TYPE_SWITCH( - dtype, mxDType, { converted = mshadow::DataType::kCudnnFlag; }) - return converted; - } - - void InitDescriptors(const mxnet::ShapeVector& in_shape, - const mxnet::ShapeVector& out_shape, - cudnnDataType_t cudnn_forward_compute_type, - cudnnDataType_t cudnn_backward_compute_type) { - using namespace mshadow; - size_t expected = param_.no_bias ? 2 : 3; - CHECK_EQ(in_shape.size(), expected); - CHECK_EQ(out_shape.size(), 1U); - - mxnet::TShape dshape = in_shape[conv::kData]; - mxnet::TShape wshape = in_shape[conv::kWeight]; - mxnet::TShape oshape = out_shape[conv::kOut]; - mxnet::TShape dstride, ostride; - - if (param_.kernel.ndim() == 1 || param_.kernel.ndim() == 2) { - // 1d or 2d conv - auto pad = param_.kernel.ndim() == 2 ? param_.pad : mxnet::TShape({0, param_.pad[0]}); - auto stride = - param_.kernel.ndim() == 2 ? param_.stride : mxnet::TShape({1, param_.stride[0]}); - auto dilate = - param_.kernel.ndim() == 2 ? param_.dilate : mxnet::TShape({1, param_.dilate[0]}); - CUDNN_CALL(cudnnSetConvolution2dDescriptor(forward_conv_desc_, - pad[0], - pad[1], - stride[0], - stride[1], - dilate[0], - dilate[1], - CUDNN_CROSS_CORRELATION, - cudnn_forward_compute_type)); - CUDNN_CALL(cudnnSetConvolution2dDescriptor(back_conv_desc_, - pad[0], - pad[1], - stride[0], - stride[1], - dilate[0], - dilate[1], - CUDNN_CROSS_CORRELATION, - cudnn_backward_compute_type)); - CUDNN_CALL(cudnnSetConvolution2dDescriptor(back_conv_desc_w_, - pad[0], - pad[1], - stride[0], - stride[1], - dilate[0], - dilate[1], - CUDNN_CROSS_CORRELATION, - cudnn_backward_compute_type)); - if (param_.kernel.ndim() == 2) { - wshape = ConvertLayout(wshape.get<4>(), param_.layout.value(), kNCHW); - dstride = ConvertLayout(Strides<4>(dshape), param_.layout.value(), kNCHW); - dshape = ConvertLayout(dshape.get<4>(), param_.layout.value(), kNCHW); - ostride = ConvertLayout(Strides<4>(oshape), param_.layout.value(), kNCHW); - oshape = ConvertLayout(oshape.get<4>(), param_.layout.value(), kNCHW); - } else { - wshape = ConvertLayout(wshape.get<3>(), param_.layout.value(), kNCW); - wshape = mxnet::TShape({wshape[0], wshape[1], 1, wshape[2]}); - dstride = ConvertLayout(Strides<3>(dshape), param_.layout.value(), kNCW); - dstride = mxnet::TShape({dstride[0], dstride[1], dstride[1], dstride[2]}); - dshape = ConvertLayout(dshape.get<3>(), param_.layout.value(), kNCW); - dshape = mxnet::TShape({dshape[0], dshape[1], 1, dshape[2]}); - ostride = ConvertLayout(Strides<3>(oshape), param_.layout.value(), kNCW); - ostride = mxnet::TShape({ostride[0], ostride[1], ostride[1], ostride[2]}); - oshape = ConvertLayout(oshape.get<3>(), param_.layout.value(), kNCW); - oshape = mxnet::TShape({oshape[0], oshape[1], 1, oshape[2]}); - } - CUDNN_CALL(cudnnSetFilter4dDescriptor( - filter_desc_, dtype_, format_, wshape[0], wshape[1], wshape[2], wshape[3])); -#if CUDNN_VERSION >= 7301 && CUDNN_VERSION < 7500 - auto kernel_h = wshape[2]; - auto kernel_w = wshape[3]; - auto stride_h = stride[0]; - auto stride_w = stride[1]; - auto pad_h = pad[0]; - auto pad_w = pad[1]; - if (param_.layout.value() == kNCHW && - (((stride_h == 2) && (kernel_h % 2 == 0) && (pad_h % 2 == 0)) || - ((stride_w == 2) && (kernel_w % 2 == 0) && (pad_w % 2 == 0)))) { - exclude_dgrad_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING; - } -#endif - } else if (param_.kernel.ndim() == 3) { - // 3d conv - CHECK_EQ(param_.layout.value(), kNCDHW) << "CuDNN only support 3D conv with NCDHW layout"; - std::vector wshape_buffer(wshape.ndim()); - CUDNN_CALL(cudnnSetFilterNdDescriptor(filter_desc_, - dtype_, - CUDNN_TENSOR_NCHW, - static_cast(wshape.ndim()), - CastTShapeToIntPtr(wshape, &wshape_buffer))); - CUDNN_CALL(cudnnSetConvolutionNdDescriptor(forward_conv_desc_, - 3, - param_pad_.data(), - param_stride_.data(), - param_dilate_.data(), - CUDNN_CROSS_CORRELATION, - cudnn_forward_compute_type)); - - CUDNN_CALL(cudnnSetConvolutionNdDescriptor(back_conv_desc_, - 3, - param_pad_.data(), - param_stride_.data(), - param_dilate_.data(), - CUDNN_CROSS_CORRELATION, - cudnn_backward_compute_type)); - - CUDNN_CALL(cudnnSetConvolutionNdDescriptor(back_conv_desc_w_, - 3, - param_pad_.data(), - param_stride_.data(), - param_dilate_.data(), - CUDNN_CROSS_CORRELATION, - cudnn_backward_compute_type)); - - dstride = ConvertLayout(Strides<5>(dshape), param_.layout.value(), kNCDHW); - dshape = ConvertLayout(dshape.get<5>(), param_.layout.value(), kNCDHW); - ostride = ConvertLayout(Strides<5>(oshape), param_.layout.value(), kNCDHW); - oshape = ConvertLayout(oshape.get<5>(), param_.layout.value(), kNCDHW); - } - // Set "allow tensor core" flag in convolution descriptors, if available. - cudnnMathType_t math_type = cudnn_tensor_core_ ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH; -#if CUDNN_VERSION >= 7200 - if (GetEnvAllowTensorCore() && GetEnvAllowTensorCoreConversion() && - (DataType::kFlag != kFloat16)) - math_type = CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION; -#endif - CUDNN_CALL(cudnnSetConvolutionMathType(forward_conv_desc_, math_type)); - CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_, math_type)); - CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_w_, math_type)); - CUDNN_CALL(cudnnSetConvolutionGroupCount(forward_conv_desc_, param_.num_group)); - CUDNN_CALL(cudnnSetConvolutionGroupCount(back_conv_desc_, param_.num_group)); - CUDNN_CALL(cudnnSetConvolutionGroupCount(back_conv_desc_w_, param_.num_group)); - - std::vector dshape_buffer(dshape.ndim()); - nnvm::ShapeTypeCast(dshape.begin(), dshape.end(), dshape_buffer.data()); - std::vector dstride_buffer(dstride.ndim()); - nnvm::ShapeTypeCast(dstride.begin(), dstride.end(), dstride_buffer.data()); - - CUDNN_CALL(cudnnSetTensorNdDescriptor(in_desc_, - dtype_, - static_cast(dshape.ndim()), - dshape_buffer.data(), - dstride_buffer.data())); - - std::vector oshape_buffer(oshape.ndim()); - nnvm::ShapeTypeCast(oshape.begin(), oshape.end(), oshape_buffer.data()); - std::vector ostride_buffer(ostride.ndim()); - nnvm::ShapeTypeCast(ostride.begin(), ostride.end(), ostride_buffer.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(out_desc_, - dtype_, - static_cast(oshape.ndim()), - oshape_buffer.data(), - ostride_buffer.data())); - - if (!param_.no_bias) { - mxnet::TShape bias = in_shape[conv::kBias]; - int bias_dim = static_cast(bias[0]); - std::vector bias_shape = {1, bias_dim, 1, 1}; - std::vector bias_stride = {bias_dim, 1, bias_dim, bias_dim}; - if (param_.kernel.ndim() == 3) { - bias_shape.push_back(1); - bias_stride.push_back(bias_dim); - } - CUDNN_CALL(cudnnSetTensorNdDescriptor(bias_desc_, - dtype_, - static_cast(bias_shape.size()), - &bias_shape[0], - &bias_stride[0])); - } - } - - void CuDNNAlgoSetter(const RunContext& rctx, - const mxnet::ShapeVector& in_shape, - const mxnet::ShapeVector& out_shape, - cudnnDataType_t cudnn_forward_compute_type, - cudnnDataType_t cudnn_backward_compute_type, - CuDNNAlgo* fwd, - CuDNNAlgo* bwd, - CuDNNAlgo* flt) { - // Not in algo registry, must determine via *Get*() or *Find*() - mshadow::Stream* s = rctx.get_stream(); - CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); - size_t workspace_byte = static_cast(param_.workspace * sizeof(DType)); - - // Since the function signature of *Get*_v7() matches that of *Find*(), - // we can unify the find-vs-get logic by using function pointers. - - // Forward Algorithm Find/Get() v7 - std::vector fwd_results(MaxForwardAlgos(s->dnn_handle_)); - int actual_fwd_algos = 0; - auto fwd_algo_discoverer = param_.cudnn_tune.value() == conv::kOff - ? cudnnGetConvolutionForwardAlgorithm_v7 - : cudnnFindConvolutionForwardAlgorithm; - CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_, - in_desc_, - filter_desc_, - forward_conv_desc_, - out_desc_, - fwd_results.size(), - &actual_fwd_algos, - fwd_results.data())); - fwd_results.resize(actual_fwd_algos); - AlgoFinalSelect( - fwd_results, "forward", workspace_byte, fwd); - - // Backprop-to-Filter Algorithm Find/Get() v7 - auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_); - std::vector bwd_filt_results(max_bwd_filt_algos); - int actual_bwd_filter_algos = 0; - // In cudnn v7.1.4, find() returned wgrad algos that could fail for large c if we - // were summing into the output (i.e. beta != 0). Get() returned OK algos though. - auto bwd_filter_algo_discoverer = param_.cudnn_tune.value() == conv::kOff - ? cudnnGetConvolutionBackwardFilterAlgorithm_v7 - : cudnnFindConvolutionBackwardFilterAlgorithm; - CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_, - in_desc_, - out_desc_, - back_conv_desc_w_, - filter_desc_, - bwd_filt_results.size(), - &actual_bwd_filter_algos, - bwd_filt_results.data())); - bwd_filt_results.resize(actual_bwd_filter_algos); - AlgoFinalSelect( - bwd_filt_results, "backprop-to-filter", workspace_byte, flt); - - // Backprop-to-Data Algorithm Find/Get() v7 - auto max_bwd_data_algos = MaxBackwardDataAlgos(s->dnn_handle_); - std::vector bwd_data_results(max_bwd_data_algos); - int actual_bwd_data_algos = 0; - auto bwd_data_algo_discoverer = param_.cudnn_tune.value() == conv::kOff - ? cudnnGetConvolutionBackwardDataAlgorithm_v7 - : cudnnFindConvolutionBackwardDataAlgorithm; - CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_, - filter_desc_, - out_desc_, - back_conv_desc_, - in_desc_, - bwd_data_results.size(), - &actual_bwd_data_algos, - bwd_data_results.data())); - bwd_data_results.resize(actual_bwd_data_algos); - AlgoFinalSelect( - bwd_data_results, "backprop-to-data", workspace_byte, bwd, exclude_dgrad_algo_); - - // Fix for issue #11241 - int cudnn_find_issue_max_features = 64 * 1024; - if (add_to_weight_ && Features(in_shape[conv::kData]) >= cudnn_find_issue_max_features) { - flt->Set(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true); - } - } - - void SelectAlgo(const RunContext& rctx, - const mxnet::ShapeVector& in_shape, - const mxnet::ShapeVector& out_shape, - cudnnDataType_t cudnn_forward_compute_type, - cudnnDataType_t cudnn_backward_compute_type) { - auto algo_setter = [&](CuDNNAlgo* fwd, - CuDNNAlgo* bwd, - CuDNNAlgo* flt) { - if (param_.cudnn_tune.value() == conv::kOff) { - // The routine will only be calling cudnnGet, so no need to grab the Storage lock. - this->CuDNNAlgoSetter(rctx, - in_shape, - out_shape, - cudnn_forward_compute_type, - cudnn_backward_compute_type, - fwd, - bwd, - flt); - } else { - // One potential problem is that cudnnFind() uses cudaMalloc() to directly allocate - // I/O and workspace areas, and these allocations may result in an out-of-memory - // error even though the StorageMangager free pool is not empty. Ideally, cudnnFind - // would use MXNet's storage allocator for its I/O and workspace areas, instead of using - // the area carved out by MXNET_GPU_MEM_POOL_RESERVE. - // To get somewhat the same effect as this, we can pre-allocate the areas needed for the - // I/Os (possibly triggering a desirable StorageManager::ReleaseAll()), followed by a - // DirectFree(), which makes these areas available for cudnn's subsequent cudaMalloc(). - - // Allocate for x (or dx), w (or dw) and y (or dy). - ReserveElements({in_shape[conv::kData].Size(), - in_shape[conv::kWeight].Size(), - out_shape[conv::kOut].Size()}); - - // We're about to call cudnnFind so we need to quiet the system by grabbing - // the Storage lock. Concurrent cudaMalloc's can disrupt the accurate timing - // measurements of the algos, and can prevent the cuda driver's proper freeing - // of cudnnFind's internal temporary allocations. Grabbing the lock might also - // impede other threads from launching work on the GPU. - std::lock_guard lock(Storage::Get()->GetMutex(Context::kGPU)); - this->CuDNNAlgoSetter(rctx, - in_shape, - out_shape, - cudnn_forward_compute_type, - cudnn_backward_compute_type, - fwd, - bwd, - flt); - } - }; - - CuDNNConvAlgoReg::Get()->FindOrElseRegister(param_, - in_shape, - out_shape, - dtype_, - cudnn_forward_compute_type, - cudnn_backward_compute_type, - SMArch(rctx.ctx.dev_id), - add_to_weight_, - &forward_algo_, - &back_algo_, - &back_algo_w_, - algo_setter); - - // If we're allowing Tensor Core variants of the algos to be considered in - // *Find*() or *Get*(), but a non-Tensor-Core algo variant is the fastest, - // we must change the descriptor to preclude Tensor Core. Simplest is to - // once again set the mathType in all cases. - CUDNN_CALL(cudnnSetConvolutionMathType(forward_conv_desc_, forward_algo_.MathType())); - CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_, back_algo_.MathType())); - CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_w_, back_algo_w_.MathType())); - } - - // Look over the results from *Find*() or *Get*() and pick the fastest algo given possible - // workspace constraints. - template - void AlgoFinalSelect(const std::vector& perf_results, - std::string kernel_name, - size_t workspace_byte, - CuDNNAlgo* algo, - int32_t algo_exclude = -1) { - // Determine the fastest acceptable algo that matches the algo_preference (-1 = any), - // regardless of mathType. - bool enforce_determinism = dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", false); - for (decltype(perf_results.size()) i = 0; i != perf_results.size(); ++i) { - const auto& result = perf_results[i]; - bool algo_exclusion = static_cast(result.algo) == algo_exclude; - bool algo_is_tensor_core = false; - algo_is_tensor_core = result.mathType == CUDNN_TENSOR_OP_MATH; - if (result.status == CUDNN_STATUS_SUCCESS && - (!enforce_determinism || result.determinism == cudnnDeterminism_t::CUDNN_DETERMINISTIC) && - (param_.cudnn_tune.value() == conv::kLimited || result.memory <= workspace_byte) && - !algo_exclusion) { - algo->Set(result.algo, algo_is_tensor_core); - return; - } - } - auto mode = param_.cudnn_tune.value() == conv::kOff ? " get " : " find "; - LOG(FATAL) << "Failed to" << mode << "any " << kernel_name << " convolution algorithm. " - << " with workspace size of " << workspace_byte << " bytes," - << " please consider reducing batch/model size or increasing the workspace size"; - } - - void GetTempSize(const RunContext& rctx) { - mshadow::Stream* s = rctx.get_stream(); - CUDNN_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize(s->dnn_handle_, - filter_desc_, - out_desc_, - back_conv_desc_, - in_desc_, - back_algo_.AlgoNumber(), - &back_workspace_byte_dgrad_)); - CUDNN_CALL(cudnnGetConvolutionBackwardFilterWorkspaceSize(s->dnn_handle_, - in_desc_, - out_desc_, - back_conv_desc_w_, - filter_desc_, - back_algo_w_.AlgoNumber(), - &back_workspace_byte_wgrad_)); - // cudaMalloc returns addresses that are aligned for large accesses (e.g. to 512 bytes). - // Since we only make one allocation and divide it into two parts when we parallelize - // the dgrad and wgrad kernels, we round the sizes up to this alignment size so the - // dptrs respect this alignment, even if the separate areas are stacked. - const size_t dptr_alignment = 512; - back_workspace_byte_dgrad_ = RoundToMultiple(back_workspace_byte_dgrad_, dptr_alignment); - back_workspace_byte_wgrad_ = RoundToMultiple(back_workspace_byte_wgrad_, dptr_alignment); - - CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(s->dnn_handle_, - in_desc_, - filter_desc_, - forward_conv_desc_, - out_desc_, - forward_algo_.AlgoNumber(), - &forward_workspace_byte_)); - } - - int* CastTShapeToIntPtr(const mxnet::TShape& s, std::vector* buffer) { - buffer->resize(s.ndim()); - nnvm::ShapeTypeCast(s.begin(), s.end(), buffer->data()); - return buffer->data(); - } - - // Converts a TBlob to a dptr, checking for the expected dim and that it's contiguous. - DType* GetNdPtr(const TBlob& tb, int dim, Stream* s) { - DType* data_ptr = nullptr; - if (dim == 3) { - Tensor data = tb.get(s); - CHECK_EQ(data.CheckContiguous(), true); - data_ptr = data.dptr_; - } else if (dim == 4) { - Tensor data = tb.get(s); - CHECK_EQ(data.CheckContiguous(), true); - data_ptr = data.dptr_; - } else if (dim == 5) { - Tensor data = tb.get(s); - CHECK_EQ(data.CheckContiguous(), true); - data_ptr = data.dptr_; - } else { - LOG(FATAL) << "Unexpected Tensor size " << dim << ", supporting only 3, 4 or 5."; - } - return data_ptr; - } - - // Converts a mxnet::TShape to a Shape<> of strides. - // e.g. {shape[0], shape[1], shape[2]} -> {shape[1]*shape[2], shape[2], 1} - template - inline Shape Strides(const mxnet::TShape& s) { - int ndim = s.ndim(); - mxnet::TShape strides(ndim, -1); - for (int i = 0; i != ndim; ++i) - strides[i] = s.ProdShape(i + 1, ndim); - return strides.get(); - } - - void InitBufferForParam() { - CastTShapeToIntPtr(param_.stride, ¶m_stride_); - CastTShapeToIntPtr(param_.dilate, ¶m_dilate_); - CastTShapeToIntPtr(param_.pad, ¶m_pad_); - } - - // Round a value 'x' up to the next multiple of 'multiple' - size_t RoundToMultiple(size_t x, size_t multiple) { - size_t retVal = ((x + multiple - 1) / multiple) * multiple; - return retVal; - } - - // Allocates a 1D Tensor of words with size in bytes >= `size_bytes`. - // Always allocates at least one word. - mshadow::Tensor AllocateTempWorkspace(const OpContext& ctx, size_t size_bytes) { - mshadow::Stream* s = ctx.get_stream(); - size_t size_words = - std::max(1, RoundToMultiple(size_bytes, sizeof(DType)) / sizeof(DType)); - return ctx.requested[conv::kTempSpace].get_space_typed( - mshadow::Shape1(size_words), s); - } - - // Returns the size in bytes of the 1D Tensor of words. - size_t TensorSizeBytes(const mshadow::Tensor& tensor) { - return tensor.MSize() * sizeof(DType); - } - - // Given a tensor shape of this operation, return the number of features 'c' - int64_t Features(const mxnet::TShape& dshape) { - int c = 0; - switch (dshape.ndim()) { - case 3: - c = ConvertLayout(dshape.get<3>(), param_.layout.value(), kNCW)[1]; - break; - case 4: - c = ConvertLayout(dshape.get<4>(), param_.layout.value(), kNCHW)[1]; - break; - case 5: - c = ConvertLayout(dshape.get<5>(), param_.layout.value(), kNCDHW)[1]; - break; - default: - LOG(FATAL) << "Unexpected convolution data dimension " << dshape.ndim(); - } - return c; - } - - // Make a number of allocations and directly free them, ensuring room for an equivalent set of - // cudaMalloc() calls by (say) cudnnFind(). `elements` spec the alloc size in DTypes, not bytes. - void ReserveElements(const std::vector& elements) { - std::vector handles; - for (size_t alloc_element : elements) { - handles.push_back(Storage::Get()->Alloc(alloc_element * sizeof(DType), Context::GPU())); - handles.back().profiler_scope = ":"; - handles.back().name = "reserve_elements"; - } - for (auto& handle : handles) - Storage::Get()->DirectFree(handle); - } - - // Log that no suitable algo was found that met the workspace constraints, then exit. - void LogNoSuitableAlgoAndExit(int num_algos_tried, - size_t min_memory_needs, - size_t workspace_byte, - std::string algo_kind) { - LOG(FATAL) << num_algos_tried << " " << algo_kind << " with minimum memory requirement " - << min_memory_needs << " bytes have been tried. Workspace size is set to " - << workspace_byte << " bytes, please consider reducing the batch/model size, " - << "or increasing workspace size."; - } - - std::vector param_stride_; - std::vector param_dilate_; - std::vector param_pad_; - - // Temp workspace size in bytes needed for Forward() operation. - size_t forward_workspace_byte_; - // Temp workspace size in bytes needed for Backward() dgrad (data gradient) operation. - size_t back_workspace_byte_dgrad_; - // Temp workspace size in bytes needed for Backward() wgrad (weight gradient) operation. - size_t back_workspace_byte_wgrad_; - cudnnDataType_t dtype_; - cudnnTensorDescriptor_t in_desc_; - cudnnTensorDescriptor_t out_desc_; - cudnnTensorDescriptor_t bias_desc_; - cudnnFilterDescriptor_t filter_desc_; - // Convolution descriptor for forward inference operation - cudnnConvolutionDescriptor_t forward_conv_desc_; - // Convolution descriptor for back-prop operations to the data - cudnnConvolutionDescriptor_t back_conv_desc_; - // Convolution descriptor for back-prop operations to the weights - cudnnConvolutionDescriptor_t back_conv_desc_w_; - // Should dgrad and wgrad be launched into separate streams - bool parallelize_backward_kernels_; - // Algorithm for the forward inference operation - CuDNNAlgo forward_algo_; - // Algorithm for the back-prop operation to the data - CuDNNAlgo back_algo_; - // Algorithm for the back-prop operation to the weights - CuDNNAlgo back_algo_w_; - cudnnTensorFormat_t format_; - // Allow TensorCore algo policy - bool cudnn_tensor_core_; - // Is req[kWeight] == conv::kAddTo ? - bool add_to_weight_; - // Is there a dgrad algo that should be avoided (-1 == none)? - int32_t exclude_dgrad_algo_ = -1; - ConvolutionParam param_; -}; -#endif // __CUDACC__ && CUDNN -} // namespace op -} // namespace mxnet - -#endif // MXNET_OPERATOR_NN_CUDNN_CUDNN_CONVOLUTION_INL_H_ diff --git a/src/operator/nn/deconvolution.cu b/src/operator/nn/deconvolution.cu index 63b8b71ed452..c1919f3e514e 100644 --- a/src/operator/nn/deconvolution.cu +++ b/src/operator/nn/deconvolution.cu @@ -18,72 +18,23 @@ */ /*! + * Copyright (c) 2015 by Contributors * \file deconvolution.cu * \brief * \author Wei Wu, Da Zheng - */ +*/ #include "./deconvolution-inl.h" #if MXNET_USE_CUDNN == 1 -#include "./cudnn/cudnn_deconvolution-inl.h" +#include "../cudnn_ops.h" +#include "../tensor/broadcast_reduce_op.h" +#include "../tensor/elemwise_binary_broadcast_op.h" +#include "fully_connected-inl.h" #endif // MXNET_USE_CUDNN namespace mxnet { namespace op { -#if MXNET_USE_CUDNN == 1 -template -static CuDNNDeconvolutionOp& GetCuDNNDeconvOp(const DeconvolutionParam& param, - int forward_compute_type, - int backward_compute_type, - const mxnet::ShapeVector& in_shape, - const mxnet::ShapeVector& out_shape, - const RunContext& rctx, - bool add_to_weight) { -#if DMLC_CXX11_THREAD_LOCAL - static thread_local std:: - unordered_map>, OpHash> - ops; -#else - static MX_THREAD_LOCAL - std::unordered_map>, OpHash> - ops; -#endif - DeconvSignature key(param); - size_t ndim = 0; - for (auto& s : in_shape) - ndim += s.ndim(); - for (auto& s : out_shape) - ndim += s.ndim(); - key.Reserve(1 /* for forward_compute_type */ + 1 /* for backward_compute_type */ + - ndim /* for in and out shapes */ + 1 /* for dev_id */ + 1 /* for add_to_weight */); - - key.AddSign(forward_compute_type); - key.AddSign(backward_compute_type); - key.AddSign(in_shape); - key.AddSign(out_shape); - key.AddSign(rctx.ctx.dev_id); - key.AddSign(add_to_weight ? 1 : 0); - - auto it = ops.find(key); - if (it == ops.end()) { - std::shared_ptr> op(new CuDNNDeconvolutionOp()); - auto ins_ret = ops.insert( - std::pair>>(key, op)); - CHECK(ins_ret.second); - it = ins_ret.first; - it->second->Init(param, - forward_compute_type, - backward_compute_type, - in_shape, - out_shape, - rctx, - add_to_weight); - } - return *it->second; -} -#endif - template <> void DeconvolutionCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -91,35 +42,32 @@ void DeconvolutionCompute(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const DeconvolutionParam& param = nnvm::get(attrs.parsed); - int dtype = inputs[0].type_flag_; + int dtype = inputs[0].type_flag_; + CHECK_EQ(req.size(), 1); + CHECK_EQ(req[deconv::kOut], kWriteTo); #if MXNET_USE_CUDNN == 1 - // On fp16-I/O instances, use fp32 compute (i.e. pseudo-fp16). - int compute_type = (dtype == mshadow::kFloat16) ? mshadow::kFloat32 : dtype; - + STATIC_ASSERT_CUDNN_VERSION_GE(8000); MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - if (param.cudnn_off) { - DeconvolutionOp op; - op.Init(param); - op.Forward(ctx, inputs, req, outputs); - } else if (!CuDNNDeconvolutionOp::Supports( - param, compute_type, compute_type, ctx.run_ctx.ctx.dev_id)) { - LOG(WARNING) - << "This deconvolution is not supported by cudnn, MXNET deconvolution is applied."; + cudnn::ConvParam conv_param(param, false); + bool ok = !param.cudnn_off && + cudnn::Exec(ctx, conv_param, inputs[deconv::kWeight], + inputs[deconv::kData], outputs[deconv::kOut]); + if (ok && !param.no_bias) { + CHECK_EQ(inputs[deconv::kBias].shape_.ndim(), 1); + auto layout = static_cast(param.layout.value()); + int k = inputs[deconv::kBias].shape_.Size(); + auto b = inputs[deconv::kBias].reshape(cudnn::ExpandChannelDims(layout, k)); + BinaryBroadcastRTCCompute {"add"}(attrs, ctx, {outputs[deconv::kOut], b}, {kWriteInplace}, + {outputs[deconv::kOut]}); + } + if (!ok) { + if (!param.cudnn_off) + LOG(WARNING) + << "This deconvolution is not supported by cuDNN, MXNet deconvolution is applied."; DeconvolutionOp op; op.Init(param); op.Forward(ctx, inputs, req, outputs); - } else { - mxnet::ShapeVector in_shape(inputs.size()); - mxnet::ShapeVector out_shape(1, outputs[0].shape_); - for (size_t i = 0; i < in_shape.size(); i++) { - in_shape[i] = inputs[i].shape_; - } - // req[deconv::kWeight] is only set for backward, so assume the typical 'write' for now. - auto add_to_weight = false; - GetCuDNNDeconvOp( - param, compute_type, compute_type, in_shape, out_shape, ctx.run_ctx, add_to_weight) - .Forward(ctx, inputs, req, outputs); } }) #else @@ -139,36 +87,46 @@ void DeconvolutionGradCompute(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { const DeconvolutionParam& param = nnvm::get(attrs.parsed); std::vector in_data(inputs.begin() + 1, inputs.end()); - const TBlob& out_grad = inputs[0]; - const std::vector& in_grad = outputs; - int dtype = out_grad.type_flag_; + const TBlob &out_grad = inputs[0]; + const std::vector &in_grad = outputs; + int dtype = out_grad.type_flag_; + CHECK_EQ(req.size(), param.no_bias ? 2 : 3); + CHECK_NE(req[deconv::kData], kWriteInplace); + CHECK_NE(req[deconv::kWeight], kWriteInplace); + if (!param.no_bias) CHECK_NE(req[deconv::kBias], kWriteInplace); #if MXNET_USE_CUDNN == 1 - // On fp16-I/O instances, use fp32 compute (i.e. pseudo-fp16). - int compute_type = (dtype == mshadow::kFloat16) ? mshadow::kFloat32 : dtype; - + STATIC_ASSERT_CUDNN_VERSION_GE(8000); MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - if (param.cudnn_off) { - DeconvolutionOp op; - op.Init(param); - op.Backward(ctx, std::vector{out_grad}, in_data, req, in_grad); - } else if (!CuDNNDeconvolutionOp::Supports( - param, compute_type, compute_type, ctx.run_ctx.ctx.dev_id)) { - LOG(WARNING) - << "This deconvolution is not supported by cudnn, MXNET deconvolution is applied."; + cudnn::ConvParam conv_param(param, req[deconv::kData] == kAddTo); + bool ok = !param.cudnn_off; + ok = ok && (req[deconv::kData] == kNullOp || + cudnn::Exec(ctx, conv_param, inputs[0], inputs[1 + deconv::kWeight], + outputs[deconv::kData])); + conv_param.add_to = req[deconv::kWeight] == kAddTo; + ok = ok && (req[deconv::kWeight] == kNullOp || + cudnn::Exec(ctx, conv_param, inputs[0], inputs[1 + deconv::kData], + outputs[deconv::kWeight])); + if (ok && !param.no_bias && req[deconv::kBias] != kNullOp) { + auto li = cudnn::GetLayoutInfo(static_cast(param.layout.value())); + if (li.channel_last) { + // This kernel should be faster. + auto y_grad = FlattenAs2DHead(inputs[0], ctx); + AddBiasGrad(outputs[deconv::kBias], y_grad, req[deconv::kBias], param.num_filter, ctx); + } else { + TShape axes{static_cast(li.ChannelIdx())}; + TShape small = ReduceAxesShapeImpl(inputs[0].shape_, dmlc::optional(axes), true, true); + ReduceAxesRTCComputeImpl(ctx, {inputs[0]}, {req[deconv::kBias]}, {outputs[deconv::kBias]}, + small, "red::sum{}"); + } + } + if (!ok) { + if (!param.cudnn_off) + LOG(WARNING) + << "This deconvolution backward is not supported by cuDNN, MXNet op is applied."; DeconvolutionOp op; op.Init(param); op.Backward(ctx, std::vector{out_grad}, in_data, req, in_grad); - } else { - mxnet::ShapeVector in_shape(in_data.size()); - mxnet::ShapeVector out_shape(1, out_grad.shape_); - for (size_t i = 0; i < in_shape.size(); i++) { - in_shape[i] = in_data[i].shape_; - } - auto add_to_weight = req[deconv::kWeight] == kAddTo; - GetCuDNNDeconvOp( - param, compute_type, compute_type, in_shape, out_shape, ctx.run_ctx, add_to_weight) - .Backward(ctx, std::vector{out_grad}, in_data, req, in_grad); } }) #else @@ -180,10 +138,11 @@ void DeconvolutionGradCompute(const nnvm::NodeAttrs& attrs, #endif // MXNET_USE_CUDNN } -NNVM_REGISTER_OP(Deconvolution).set_attr("FCompute", DeconvolutionCompute); +NNVM_REGISTER_OP(Deconvolution) +.set_attr("FCompute", DeconvolutionCompute); NNVM_REGISTER_OP(_backward_Deconvolution) - .set_attr("FCompute", DeconvolutionGradCompute); +.set_attr("FCompute", DeconvolutionGradCompute); } // namespace op } // namespace mxnet