diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index a5d6a1dabeff..43081dec25c4 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -474,6 +474,31 @@ class CommDevice : public Comm { } } + const NDArray& ReduceRowSparse(int key, const std::vector& src, + int priority) { + auto& buf = merge_buf_[key]; + std::vector reduce(src.size()); + + const NDArrayStorageType stype = src[0].storage_type(); + NDArray& buf_merged = buf.merged_buf(stype); + if (buf.copy_buf.empty()) { + // initialize buffer for copying during reduce + buf.copy_buf.resize(src.size()); + for (size_t j = 0; j < src.size(); ++j) { + buf.copy_buf[j] = NDArray(stype, src[0].shape(), buf_merged.ctx(), true, src[0].dtype()); + } + } + CHECK(src[0].storage_type() == buf.copy_buf[0].storage_type()) + << "Storage type mismatch detected. " << src[0].storage_type() << "(src) vs. " + << buf.copy_buf[0].storage_type() << "(buf.copy_buf)"; + for (size_t i = 0; i < src.size(); ++i) { + CopyFromTo(src[i], &(buf.copy_buf[i]), priority); + reduce[i] = buf.copy_buf[i]; + } + ElementwiseSum(reduce, &buf_merged, priority); + return buf_merged; + } + const NDArray& Reduce(int key, const std::vector& src, int priority) override { // when this reduce is called from kvstore_dist, gc is not set @@ -490,13 +515,14 @@ class CommDevice : public Comm { InitBuffersAndComm(src); auto& buf = merge_buf_[key]; - std::vector reduce(src.size()); const NDArrayStorageType stype = src[0].storage_type(); NDArray& buf_merged = buf.merged_buf(stype); // normal dense reduce if (stype == kDefaultStorage) { CopyFromTo(src[0], &buf_merged, priority); + + std::vector reduce(src.size()); reduce[0] = buf_merged; if (buf.copy_buf.empty()) { @@ -514,24 +540,11 @@ class CommDevice : public Comm { CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority); reduce[i+1] = buf.copy_buf[i]; } + ElementwiseSum(reduce, &buf_merged, priority); } else { // sparse reduce - if (buf.copy_buf.empty()) { - // initialize buffer for copying during reduce - buf.copy_buf.resize(src.size()); - for (size_t j = 0; j < src.size(); ++j) { - buf.copy_buf[j] = NDArray(stype, src[0].shape(), buf_merged.ctx(), true, src[0].dtype()); - } - } - CHECK(src[0].storage_type() == buf.copy_buf[0].storage_type()) - << "Storage type mismatch detected. " << src[0].storage_type() << "(src) vs. " - << buf.copy_buf[0].storage_type() << "(buf.copy_buf)"; - for (size_t i = 0; i < src.size(); ++i) { - CopyFromTo(src[i], &(buf.copy_buf[i]), priority); - reduce[i] = buf.copy_buf[i]; - } + buf_merged = ReduceRowSparse(key, src, priority); } - ElementwiseSum(reduce, &buf_merged, priority); return buf_merged; } @@ -658,6 +671,42 @@ class CommDevice : public Comm { } } + using KeyAttrs = std::tuple; + // try to allocate buff on device evenly + void InitMergeBuffer(const std::vector& devs) { + std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(), []( + const KeyAttrs& a, const KeyAttrs& b) { + return std::get<1>(a).Size() > std::get<1>(b).Size(); + }); + + std::unordered_map> ctx_info; + for (auto d : devs) { + ctx_info[d.dev_id] = std::make_pair(d, 0); + } + + for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) { + const int key = std::get<0>(sorted_key_attrs_[i]); + const TShape& shape = std::get<1>(sorted_key_attrs_[i]); + const int type = std::get<2>(sorted_key_attrs_[i]); + auto& buf = merge_buf_[key]; + Context ctx; + size_t min_size = std::numeric_limits::max(); + for (auto it = ctx_info.begin(); it != ctx_info.end(); ++it) { + size_t size = it->second.second; + if (size <= min_size) { + ctx = it->second.first; + min_size = size; + } + } + // Delayed allocation - as the dense merged buffer might not be used at all if push() + // only sees sparse arrays + bool delay_alloc = true; + buf.merged = NDArray(shape, ctx, delay_alloc, type); + ctx_info[ctx.dev_id].second += shape.Size(); + } + inited_ = true; + } + private: void EnableP2P(const std::vector& devs) { #if MXNET_USE_CUDA @@ -701,43 +750,6 @@ class CommDevice : public Comm { #endif } - using KeyAttrs = std::tuple; - // try to allocate buff on device evenly - void InitMergeBuffer(const std::vector& devs) { - std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(), []( - const KeyAttrs& a, const KeyAttrs& b) { - return std::get<1>(a).Size() > std::get<1>(b).Size(); - }); - - std::unordered_map> ctx_info; - for (auto d : devs) { - ctx_info[d.dev_id] = std::make_pair(d, 0); - } - - for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) { - const int key = std::get<0>(sorted_key_attrs_[i]); - const TShape& shape = std::get<1>(sorted_key_attrs_[i]); - const int type = std::get<2>(sorted_key_attrs_[i]); - auto& buf = merge_buf_[key]; - Context ctx; - size_t min_size = std::numeric_limits::max(); - for (auto it = ctx_info.begin(); it != ctx_info.end(); ++it) { - size_t size = it->second.second; - if (size <= min_size) { - ctx = it->second.first; - min_size = size; - } - } - // Delayed allocation - as the dense merged buffer might not be used at all if push() - // only sees sparse arrays - bool delay_alloc = true; - buf.merged = NDArray(shape, ctx, delay_alloc, type); - ctx_info[ctx.dev_id].second += shape.Size(); - } - inited_ = true; - } - - std::vector sorted_key_attrs_; /// \brief temporal space for pushing and pulling struct BufferEntry { /// \brief the dense merged value for reduce and broadcast operations @@ -772,7 +784,10 @@ class CommDevice : public Comm { NDArray sparse_merged; }; std::unordered_map merge_buf_; + + public: bool inited_; + std::vector sorted_key_attrs_; }; } // namespace kvstore diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h new file mode 100644 index 000000000000..f2cf4861ca2b --- /dev/null +++ b/src/kvstore/comm_tree.h @@ -0,0 +1,500 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Copyright (c) 2018 by Contributors + */ +#ifndef MXNET_KVSTORE_COMM_TREE_H_ +#define MXNET_KVSTORE_COMM_TREE_H_ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "mxnet/ndarray.h" +#include "gradient_compression.h" +#include "../ndarray/ndarray_function.h" +#include "../operator/tensor/sparse_retain-inl.h" +#include "./kvstore_utils.h" +#include "./gpu_topology.h" +namespace mxnet { +namespace kvstore { +/** + * \brief an implementation of Comm that performs reduction on device + * directly using tree. + * + * It is faster if the total device-to-device bandwidths is larger than + * device-to-cpu, which is often true for 4 or 8 GPUs. But it uses more device + * memory. + */ +class CommDeviceTree : public CommDevice { + public: + CommDeviceTree() { + inited_ = false; + gpuarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_GPUARRAY_BOUND", 10000000); + backtrack_ = dmlc::GetEnv("MXNET_KVSTORE_BACKTRACK", 0); + link_usage_penalty_ = dmlc::GetEnv("MXNET_KVSTORE_LINK_USAGE_PENALTY", 0.7); + } + + virtual ~CommDeviceTree() { } + + void Init(int key, const NDArrayStorageType stype, const TShape& shape, + int dtype = mshadow::kFloat32) override { + tree_sorted_key_attrs_.emplace_back(key, shape, dtype); + sorted_key_attrs_.emplace_back(key, shape, dtype); + } + + void InitBuffersAndComm(const std::vector& src) { + if (!inited_) { + for (const auto& a : src) { + devs_.push_back(a.ctx()); + } + QueryTopology(); + // Note: delayed allocation set to true, because we do not want to allocate + // both in TreeBufferEntry and BufferEntry, so we use a size_t to keep + // track of each key's shape within BufferEntry + // -this information is required for inherited Reduce- and + // BroadcastRowSparse + InitMergeBuffer(devs_); + InitMergeBufferTree(); + if (dmlc::GetEnv("MXNET_ENABLE_GPU_P2P", 1)) { + EnableP2P(); + } + } + } + + // src is sliced shape + // copy_buf not sliced + // merged not sliced + const NDArray& ReduceInner(int key, const std::vector& src, int root, + int merged_row, int priority) { + std::vector> reduce(devs_.size()); + + TreeBufferEntry& random_buf = tree_merge_buf_[0][key]; + const NDArrayStorageType stype = random_buf.merged[0].storage_type(); + std::vector& topology = topology_[root]; + NDArray buf_slice; + + if (stype == kDefaultStorage) { + // Copy everything into buf.merged for each gpu + for (size_t i = 0; i < src.size(); ++i) { + int start = scan_[root][depth_ ]; + int end = scan_[root][depth_+1]; + + for (int j = start; j < end; ++j) { + int topo_id = topology[j]; + TreeBufferEntry& buf = tree_merge_buf_[topo_id][key]; + + if (devs_[topo_id] == src[i].ctx()) { + CopyFromTo(src[i], &(buf.merged[merged_row]), priority); + } + } + } + + for (int level = depth_; level > 0; --level) { + int start = scan_[root][level ]; + int end = scan_[root][level+1]; + + unsigned is_dest = 0; + int dest_id = 0; + for (int j = start; j < end; ++j) { + int topo_id = topology[j]; + dest_id = (is_dest == 0) ? topo_id : dest_id; + + TreeBufferEntry& buf_dest = tree_merge_buf_[dest_id][key]; + TreeBufferEntry& buf_from = tree_merge_buf_[topo_id][key]; + + if (!is_dest) { + reduce[dest_id].push_back(buf_dest.merged[merged_row]); + } else { + if (dest_id != topo_id) { + CopyFromTo(buf_from.merged[merged_row], + &(buf_dest.copy_buf[merged_row][is_dest-1]), + priority); + reduce[dest_id].push_back( + buf_dest.copy_buf[merged_row][is_dest-1]); + } + } + + is_dest = (is_dest == static_cast(kBranch)-1) ? + 0 : is_dest+1; + } + + start = scan_[root][level-1]; + end = scan_[root][level ]; + for (int i = start; i < end; ++i) { + int gpu_id = topology[i]; + + // conditional to detect whether operation must be done + if (reduce[gpu_id].size() > 1) { + TreeBufferEntry& buf = tree_merge_buf_[gpu_id][key]; + ElementwiseSum(reduce[gpu_id], &(buf.merged[merged_row]), priority); + } + } + + // reset + for (unsigned i = 0; i < devs_.size(); ++i) { + reduce[i].clear(); + } + } + } else { + LOG(WARNING) << "Only dense input supported for now"; + } + + int topo_id = topology[0]; + TreeBufferEntry& buf = tree_merge_buf_[topo_id][key]; + return buf.merged[merged_row]; + } + + const NDArray& Reduce(int key, const std::vector& src, + int priority) override { + // when this reduce is called from kvstore_dist, gc is not set + // we don't do compression twice in dist_sync_device + if ((gc_ != nullptr) && (gc_->get_type() != CompressionType::kNone)) { + return ReduceCompressed(key, src, priority); + } + + // avoid extra copy for single device, but it may bring problems for + // abnormal usage of kvstore + if (src.size() == 1) { + return src[0]; + } + + InitBuffersAndComm(src); + std::vector> slice(devs_.size()); + std::vector> broadcast_slice(devs_.size()); + std::vector slice_scan(devs_.size()+1); + + int total_size = src[0].shape().Size(); + unsigned first_size = src[0].shape()[0]; + + const NDArrayStorageType stype = src[0].storage_type(); + // normal dense reduce + if (stype == kDefaultStorage) { + if (total_size > gpuarray_bound_ && first_size >= devs_.size()) { + // Find slice bounds + slice_scan[0] = 0; + int slice_size = (first_size + devs_.size()-1)/devs_.size(); + for (unsigned i = 1; i < devs_.size(); ++i) { + slice_scan[i] = slice_scan[i-1] + slice_size; + } + slice_scan[devs_.size()] = src[0].shape()[0]; + + // row: which slice + // col: which gpu + for (unsigned row = 0; row < devs_.size(); ++row) { + for (unsigned col = 0; col < devs_.size(); ++col) { + TreeBufferEntry& buf = tree_merge_buf_[col][key]; + NDArray curr_slice = src[col].Slice(slice_scan[row], + slice_scan[row+1]); + slice[row].push_back(curr_slice); + broadcast_slice[row].push_back(&(buf.merged[row])); + } + } + + // Do reduce-scatter (multiroot reduce) + // input: slice (src) + // output: buf.merge_buf + for (unsigned i = 0; i < devs_.size(); ++i) { + ReduceInner(key, slice[i], i, i, priority); + } + + for (unsigned i = 0; i < devs_.size(); ++i) { + BroadcastInner(key, *(broadcast_slice[i][i]), broadcast_slice[i], i, i, priority); + } + } else { + int root = 0; + ReduceInner(key, src, root, 0, priority); + + TreeBufferEntry& buf = tree_merge_buf_[root][key]; + return buf.merged[0]; + } + + // Copy from list of small NDArrays to one big NDArray, which is returned + int gpu_id = 0; + return src[gpu_id]; + } else { + // sparse reduce + return ReduceRowSparse(key, src, priority); + } + } + + void BroadcastInner(int key, const NDArray& src, + const std::vector& dst, int root, + int merged_row, int priority) { + // copy to root of tree + std::vector& topology = topology_[root]; + std::vector temp(devs_.size()); + int gpu_id = topology[0]; + if (merged_row == -1) + CopyFromTo(src, dst[gpu_id], priority); + temp[gpu_id] = *dst[gpu_id]; + + for (int level = 1; level <= depth_; ++level) { + int start = scan_[root][level]; + int end = scan_[root][level+1]; + + unsigned is_src = 0; + int src_id = 0; + for (int j = start; j < end; ++j) { + int topo_id = topology[j]; + src_id = (is_src == 0) ? topo_id : src_id; + + if (is_src && src_id != topo_id) { + CopyFromTo(temp[src_id], dst[topo_id], priority); + temp[topo_id] = *dst[topo_id]; + } + + is_src = (is_src == static_cast(kBranch)-1) ? 0 : is_src+1; + } + } + } + + void Broadcast(int key, const NDArray& src, + const std::vector dst, int priority) override { + if (!inited_) { + // copy to a random device first + int dev_id = key % dst.size(); + CopyFromTo(src, dst[dev_id], priority); + for (size_t i = 0; i < dst.size(); ++i) { + if (i != static_cast(dev_id)) { + CopyFromTo(*dst[dev_id], dst[i], priority); + } + } + } else { + int total_size = src.shape().Size(); + unsigned first_size = src.shape()[0]; + if (total_size > gpuarray_bound_ && first_size >= devs_.size()) { + std::vector slice_scan(devs_.size()+1); + slice_scan[0] = 0; + int slice_size = (dst[0]->shape()[0]+devs_.size()-1)/devs_.size(); + for (unsigned i = 1; i < devs_.size(); ++i) { + slice_scan[i] = slice_scan[i-1] + slice_size; + } + slice_scan[devs_.size()] = dst[0]->shape()[0]; + + for (unsigned gpu_id = 0; gpu_id < dst.size(); ++gpu_id) { + TreeBufferEntry& buf = tree_merge_buf_[gpu_id][key]; + for (unsigned i = 0; i < devs_.size(); ++i) { + if (devs_[gpu_id] == dst[gpu_id]->ctx()) { + NDArray curr_slice = dst[gpu_id]->Slice(slice_scan[i], slice_scan[i+1]); + CopyFromTo(buf.merged[i], &curr_slice, priority); + } + } + } + } else { + int root = 0; + BroadcastInner(key, src, dst, root, -1, priority); + } + } + } + + private: + void EnableP2P() { +#if MXNET_USE_CUDA + std::vector gpus; + for (const auto& d : devs_) { + if (d.dev_mask() == gpu::kDevMask) { + gpus.push_back(d.dev_id); + } + } + int n = static_cast(gpus.size()); + int enabled = 0; + std::vector p2p(n*n); + for (int i = 0; i < n; ++i) { + cudaSetDevice(gpus[i]); + for (int j = 0; j < n; j++) { + int access; + cudaDeviceCanAccessPeer(&access, gpus[i], gpus[j]); + if (access) { + cudaError_t e = cudaDeviceEnablePeerAccess(gpus[j], 0); + if (e == cudaSuccess || e == cudaErrorPeerAccessAlreadyEnabled) { + ++enabled; + p2p[i*n+j] = 1; + } + } + } + } + if (enabled != n*(n-1)) { + // print warning info if not fully enabled + LOG(WARNING) << "only " << enabled << " out of " + << n*(n-1) << " GPU pairs are enabled direct access. " + << "It may affect the performance. " + << "You can set MXNET_ENABLE_GPU_P2P=0 to turn it off"; + std::string access(n, '.'); + for (int i = 0; i < n; ++i) { + for (int j = 0; j < n; ++j) { + access[j] = p2p[i*n+j] ? 'v' : '.'; + } + LOG(WARNING) << access; + } + } +#endif + } + + void QueryTopology() { +#if MXNET_USE_CUDA + std::vector link_matrix(devs_.size()*devs_.size()); + GetP2PWeight(devs_, &link_matrix); + if (backtrack_) + LOG(WARNING) << "Using Backtracking to generate trees"; + else + LOG(WARNING) << "Using Kernighan-Lin to generate trees"; + ComputeTrees(link_matrix, devs_.size(), link_usage_penalty_, backtrack_, + &topology_, &scan_); + + depth_ = ComputeDepth(devs_.size()); +#endif + } + + using KeyAttrs = std::tuple; + // try to allocate buff on device evenly + void InitMergeBufferTree() { + LOG(WARNING) << "Using Tree"; + + // same as all-reduce, except: + // 1) Allocate copy_buf here instead of in Reduce() + // 2) Force copy_buf to be of kRecvBufferSize + // 3) Do not use greedy assignment; all keys are assigned to each GPU + for (unsigned i = 0; i < devs_.size(); ++i) + tree_merge_buf_.push_back(std::unordered_map()); + + bool delay_alloc = true; + std::map key_dist; + + for (size_t i = 0; i < tree_sorted_key_attrs_.size(); ++i) { + const int key = std::get<0>(tree_sorted_key_attrs_[i]); + const TShape& shape = std::get<1>(tree_sorted_key_attrs_[i]); + const int type = std::get<2>(tree_sorted_key_attrs_[i]); + + if (key_dist.find(shape.Size()) == key_dist.end()) + key_dist[shape.Size()] = 1; + else + key_dist[shape.Size()]++; + + int start = scan_[0][depth_ ]; + int end = scan_[0][depth_+1]; + + // In order to generalize to any number of GPUs, we use strategy of having + // found the mapping from 0, 1, ..., n_gpus to dev_id i.e. + // idx: 0 1 2 3 4 5 6 + // dev_id: 4 2 3 1 7 5 0 + // and generated an n_gpus x n_gpus link topology matrix: + // + // 1) The reduction trees are saved as indices on 0, 1, ..., n_gpus + // 2) We use the mapping to retrieve dev_id and device context + for (int j = start; j < end; ++j) { + int topo_id = topology_[0][j]; + auto& buf = tree_merge_buf_[topo_id][key]; + Context ctx = devs_[topo_id]; + + // buf.merged enforces that we only visit each GPU once + if (buf.merged.empty()) { + TShape shape_copy = shape; + int total_size = shape.Size(); + unsigned first_size = shape[0]; + if (total_size > gpuarray_bound_ && first_size >= devs_.size()) { + // Find slice bounds + int slice_size = (first_size+devs_.size()-1)/devs_.size(); + int last_slice = first_size-(devs_.size()-1)*slice_size; + shape_copy[0] = slice_size; + buf.merged.resize(devs_.size()); + for (unsigned row = 0; row < devs_.size(); ++row) { + if (row == devs_.size()-1) + shape_copy[0] = last_slice; + buf.merged[row] = NDArray(shape_copy, ctx, delay_alloc, type); + buf.copy_buf.push_back(std::vector()); + if (buf.copy_buf[row].empty()) { + buf.copy_buf[row].resize(kBranch-1); + for (size_t col = 0; col < buf.copy_buf[0].size(); ++col) { + buf.copy_buf[row][col] = NDArray(buf.merged[row].shape(), + buf.merged[row].ctx(), + delay_alloc, + buf.merged[row].dtype()); + } + } + } + } else { + buf.merged.push_back(NDArray(shape, ctx, false, type)); + if (buf.copy_buf.empty()) { + buf.copy_buf.push_back(std::vector()); + buf.copy_buf[0].resize(kBranch-1); + for (size_t col = 0; col < buf.copy_buf[0].size(); ++col) { + buf.copy_buf[0][col] = NDArray(buf.merged[0].shape(), + buf.merged[0].ctx(), delay_alloc, + buf.merged[0].dtype()); + } + } + } + } + } + } + + for (auto it = key_dist.begin(); it != key_dist.end(); ++it) { + LOG(WARNING) << "Size " << it->first << " occurs " << it->second << " times"; + } + inited_ = true; + } + + std::vector tree_sorted_key_attrs_; + /// \brief temporal space for pushing and pulling + struct TreeBufferEntry { + /// \brief the dense merged value for reduce and broadcast operations + std::vector merged; + /// \brief the gpu buffer for copy during reduce operation + std::vector> copy_buf; + /// \brief the residual buffer for gradient compression + std::vector residual; + /// \brief the small buffer for compressed data in sender + std::vector compressed_send_buf; + /// \brief the small buffer for compressed data in receiver + std::vector compressed_recv_buf; + + private: + /// \brief the sparse merged value for reduce and rowsparse broadcast operations + NDArray sparse_merged; + }; + /// \brief intent of tree_merge_buf_ in old comm.h: store key->gpu mapping + /// new intent: for every gpu: store key->memory mapping + std::vector> tree_merge_buf_; + + /// \brief NVLink-connected topology in full binary tree format + std::vector> topology_; + std::vector> scan_; + std::vector devs_; + + /// \brief Highest numbered device + int max_dev_; + int depth_; + int gpuarray_bound_; + bool backtrack_; + float link_usage_penalty_; + + /// \brief constant for maximum size of recv buffer per GPU + /// 2: only receive from 1 other GPU + const int kBranch = 2; +}; + +} // namespace kvstore +} // namespace mxnet +#endif // MXNET_KVSTORE_COMM_TREE_H_ diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h new file mode 100644 index 000000000000..bb466bc4f58b --- /dev/null +++ b/src/kvstore/gpu_topology.h @@ -0,0 +1,1067 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Copyright (c) 2015 by Contributors + */ +#ifndef MXNET_KVSTORE_GPU_TOPOLOGY_H_ +#define MXNET_KVSTORE_GPU_TOPOLOGY_H_ +#if MXNET_USE_CUDA + #include + #include +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define MXNET_KVSTORE_MAXDEPTH 16 + +namespace mxnet { +namespace kvstore { + +template +inline void PrintVector(const std::string& str, const std::vector& vec) { + std::cout << str << ":\n"; + for (unsigned i = 0; i < vec.size(); ++i) + std::cout << vec[i] << " "; + std::cout << std::endl; +} + +template +inline void PrintMatrix(const std::string& str, const std::vector& matrix, + int num_rows, int num_cols) { + std::cout << str << ":\n"; + int count = 0; + for (int row = 0; row < num_rows; ++row) { + for (int col = 0; col < num_cols; ++col) { + std::cout << matrix[count++] << " "; + } + std::cout << std::endl; + } +} + +inline void PrintTopo(const std::string& str, const std::vector& topo_row, + std::vector scan_row) { + PrintVector("Topo vector", topo_row); + PrintVector("Scan vector", scan_row); + std::cout << str << ":\n"; + int depth = scan_row.size()-1; + for (int row = 0; row < depth; ++row) { + int start = scan_row[row]; + int end = scan_row[row+1]; + for (; start < end; start++) { + for (int i = 0; i < (2 << (depth-row-2))+1; ++i) { + std::cout << " "; + } + std::cout << topo_row[start]; + } + std::cout << std::endl; + } +} + +// Uses BFS to find whether undirected graph is connected or not given its +// adjacency matrix +// Note: only consider matrix values > 1, because we care about whether it is +// connected using only NVLink connections +template +inline bool IsConnected(const std::vector& matrix, + int num_gpus) { + int source = 0; + std::vector visited(num_gpus, false); + std::queue work_list; + + work_list.push(source); + visited[source] = true; + while (!work_list.empty()) { + int curr = work_list.front(); + work_list.pop(); + + for (int i = 0; i < num_gpus; ++i) { + int neighbour = matrix[curr*num_gpus + i]; + if (i != curr && neighbour > 1 && visited[i] == false) { + visited[i] = true; + work_list.push(i); + } + } + } + + for (int i = 0; i < num_gpus; ++i) { + if (visited[i] == false) + return false; + } + return true; +} + +// Generate adjacency matrix with row/col numbering from 0, 1, ..., n_gpu +// @input: devs is a vector of GPU contexts +// @output: matrix is adjacency matrix of link topology graph +// where edge weight represents relative performance of NVIDIA GPUs +// 0: Self-connection +// 1: PCI-E +// 2: 1 NVLink connection +// 3: 2 NVLink connections +template +inline void GetP2PWeight(const std::vector& devs, + std::vector* matrix) { + int num_gpus = devs.size(); + int count = 0; + std::vector zero_dev_id(num_gpus, -1); + for (auto d : devs) { + zero_dev_id[count] = d.dev_id; + count++; + } + +#if MXNET_USE_CUDA + cudaDeviceP2PAttr attr; + attr = cudaDevP2PAttrPerformanceRank; + std::vector max(num_gpus, 0); + + for (int row = 0; row < num_gpus; ++row) { + for (int col = 0; col < num_gpus; ++col) { + if (row == col) { + (*matrix)[row*num_gpus+col] = 0; + } else { + int value; + int row_gpu = zero_dev_id[row]; + int col_gpu = zero_dev_id[col]; + cudaDeviceGetP2PAttribute(&value, attr, row_gpu, col_gpu); + if (value > max[row]) + max[row] = value; + (*matrix)[row*num_gpus+col] = static_cast(value)+1; + } + } + } + + // Check that all GPUs have at least 1 NVLink connection + int max_value = 0; + for (unsigned int i = 0; i < max.size(); ++i) { + if (max[i] > max_value) + max_value = max[i]; + } + + // If all GPUs are connected by NVLink, then we can use NVLink only + // to communicate instead of going over PCI-E + bool connected = IsConnected(*matrix, num_gpus); + + if (connected) { + for (auto& matrix_value : *matrix) { + matrix_value = (matrix_value == 1) ? 0 : matrix_value; + } + } + PrintMatrix("Weight W", *matrix, num_gpus, num_gpus); +#else + LOG(WARNING) << "GPU required for link topology"; +#endif +} + +// Dense matrix-vector multiplication +// Assume: matrix is square +// y = A*x (no accumulate) +template +inline void gemv(const std::vector& A, + const std::vector& x, + std::vector* y) { + int nrows = x.size(); + int count = 0; + for (int row=0; row < nrows; ++row) { + (*y)[row] = 0; + for (int col=0; col < nrows; ++col) { + (*y)[row] += A[count]*static_cast(x[col]); + count++; + } + } +} + +// Element-wise multiplication between 2 dense vectors +// w = w * alpha*u +template +inline void ewisemult(const std::vector& u, + T alpha, + std::vector* w) { + int nelem = u.size(); + for (int i=0; i < nelem; ++i) { + (*w)[i] *= alpha*static_cast(u[i]); + } +} + +// Computes best 2 nodes a,b to swap given objective function: +// g = max_{a \in A, b \in B} D(a) + D(b) - 2*W(a,b) +// +// Optimization: Only need to look at upper triangular since weight matrix is +// symmetric +template +inline void FindBestMove(const std::vector& W, + const std::vector& P_temp, + const std::vector& D, + const std::unordered_set& used, + int* a, + int* b, + T* g) { + int nrows = P_temp.size(); + *g = 0; + *a = -1; + *b = -1; + for (int row=0; row < nrows; ++row) { + if (P_temp[row] == 0 || used.find(row) != used.end()) continue; + for (int col=row+1; col < nrows; ++col) { + if (P_temp[col] == 0 || P_temp[row] == P_temp[col]) continue; + + T cost = D[row]+D[col]-2*W[row*nrows+col]; + if (cost > *g) { + *g = cost; + *a = row; + *b = col; + } + } + } +} + +// Performs partition on each existing partition in graph W if partition has +// more than 4 elements in it +// @output: stop returns true if no partitions with >=4 elements found +// returns false otherwise +// cluster_pairs stores the mapping that tells us which 2 clusters are +// the output of partitioning one large cluster +template +inline bool KernighanLin(const std::vector& W, + std::vector* P, + int* num_partitions, + std::vector>* cluster_pairs, + std::mt19937* gen) { + std::vector histogram(*num_partitions, 0); + std::vector P_temp(P->size(), 0); + std::vector P_temp2(P->size(), 0); + std::vector D(P->size(), 0); + std::vector D_temp(P->size(), 0); + + // 0) For every partition, determine if it can be partitioned further. + // To do this, we must do a histogram of each partition: + for (unsigned i=0; i < P->size(); ++i) { + histogram[(*P)[i]]++; + } + + bool stop = true; + for (unsigned color=0; color < histogram.size(); ++color) { + int partition_size = histogram[color]; + // Save cluster in preparation for push to topo in GenerateBinaryTree() + if (partition_size <= 2) { + cluster_pairs->push_back( + std::pair(static_cast(color), -partition_size)); + + // Do Kernighan-Lin if clustering is necessary + } else { + stop = false; + + // 1) If it has more than 4 elements, we can partition further. + // Assign random balanced partition of it + // -balanced is more important than random, so allocate first half to A + // and rest to B + int first_partition = 0; + int target_partition = partition_size/2; + std::vector cluster_list; + + for (unsigned i = 0; i < P->size(); ++i) { + // Required to shift from [0,1] to {-1,1} + // 1 means vertex i is in Cluster A + // -1 means vertex i is in Cluster B + if ((*P)[i] == static_cast(color)) { + cluster_list.push_back(i); + } else { + P_temp[i] = 0; + } + } + + // 1b) Shuffle using random generator + std::shuffle(cluster_list.begin(), cluster_list.end(), *gen); + for (unsigned i = 0; i < cluster_list.size(); ++i) { + if (first_partition < target_partition) { + int dest = cluster_list[i]; + P_temp[dest] = 1; + first_partition++; + } else { + int dest = cluster_list[i]; + P_temp[dest] = -1; + } + } + + // 2) Do iterations of Kernighan-Lin until convergence + T g_max = 0; + int g_k = -1; + unsigned count = 0; + do { + count++; + P_temp2 = P_temp; + + // a) Compute difference between external and internal costs of all + // elements in vector D + gemv(W, P_temp, &D); + ewisemult(P_temp, -1.f, &D); + + // av and bv are used to hold candidates for moving + // gv stores the score associated with move + std::vector av; + std::vector bv; + std::vector gv; + + std::unordered_set used; + + for (int iter=0; iter < partition_size/2; ++iter) { + // b) Find best move by looking through upper triangular of W matrix + int a, b; + T g; + FindBestMove(W, P_temp, D, used, &a, &b, &g); + if (g > 0) { + } else { + g_max = 0; + break; + } + + // c) Store best move to av, bv, gv + av.push_back(a); + bv.push_back(b); + gv.push_back(g); + + // d) Eliminate best move from consideration in vector P_temp + P_temp[a] *= -1; + P_temp[b] *= -1; + used.insert(a); + used.insert(b); + + // e) Update D using P_temp + gemv(W, P_temp, &D); + ewisemult(P_temp, -1.f, &D); + D[a] = 0; + D[b] = 0; + } + + // 3) Find when to stop by doing linear scan through gv + // Recompute score g_max + for (unsigned k = 0; k < gv.size(); ++k) { + if (k > 0) + gv[k] += gv[k-1]; + if (gv[k] > g_max) { + g_max = gv[k]; + g_k = k + 1; + } + } + + // 4) If move is "good", commit moves by updating P_temp and P_temp2 + // Otherwise, rollback changes to P_temp2 + if (g_max > 0) { + for (int i = 0; i < g_k; i++) { + int a = av[i]; + int b = bv[i]; + int temp = P_temp2[a]; + P_temp2[a] = P_temp2[b]; + P_temp2[b] = temp; + + P_temp = P_temp2; + } + } else { + P_temp = P_temp2; + } + } while (g_max > 0 && count <= P->size()); + + // 5) Update P using P_temp + int moves = 0; + for (unsigned i=0; i < P->size(); ++i) { + if (P_temp[i] == -1) { + (*P)[i] = *num_partitions; + moves++; + } + } + cluster_pairs->push_back(std::pair(static_cast(color), + static_cast(*num_partitions))); + + (*num_partitions)++; + } + } + + return stop; +} + +// Returns root of a given color if found in roots +// Returns -1 if it is not found +inline int GetRoot(const std::vector& P, + int color, + const std::unordered_set& roots) { + for (auto root : roots) { + if (P[root] == color) + return root; + } + return -1; +} + +// Returns root of a given color if found in roots +// Returns -1 if it is not found +inline int GetChild(const std::vector& P, + int color, + int parent) { + for (unsigned i = 0; i < P.size(); ++i) { + if (P[i] == color && static_cast(i) != parent) + return i; + } + return -1; +} + +// Computes highest weighted edge a-b +// +// Contraints: +// -vertex a must be parent +// -vertex b must be in dest_cluster +// +// @output: b is vector of candidates if a tie happens +// g is weight of edge +// Optimization: Only need to look at row a in matrix +template +inline void FindBestEdge(const std::vector& W, + const std::vector& P, + int parent, + int dest_cluster, + std::vector* b, + T* g) { + int nrows = P.size(); + int row = parent; + *g = 0; + b->push_back(-1); + for (int col=0; col < nrows; ++col) { + if (col == row || P[col] != dest_cluster) continue; + + T cost = W[row*nrows+col]; + if (cost > *g) { + b->clear(); + } + if (cost >= *g) { + b->push_back(col); + *g = cost; + } + } +} + +// Given a vector of color pairs, appends to binary tree matrix topo +// @input: W gives the link topology +// P gives the result of KL partitioning +// cluster_pairs gives pairing between clusters, an edge is found +// between each pairing +// roots gives source vertices +// gen gives random number generation to break ties +// @output: cluster_pairs +// topo_row says where new edges are appended to +// scan_row says where we should start looking for topo_row +template +inline int KLGenerateBinaryTree(const std::vector& W, + const std::vector& P, + std::vector>* cluster_pairs, + std::unordered_set* roots, + std::vector* topo_row, + std::vector* scan_row, + std::mt19937* gen) { + std::unordered_set new_roots; + std::unordered_map new_topo; + int reset = 0; + + for (unsigned i = 0; i < cluster_pairs->size(); ++i) { + if (i == 0) + scan_row->push_back(topo_row->size()); + int parent, child = -1; + if ((*cluster_pairs)[i].second == -2) { + // Root must be color of pair.first + int color = (*cluster_pairs)[i].first; + parent = GetRoot(P, color, *roots); + if (parent == -1) return 1; + child = GetChild(P, color, parent); + } else if ((*cluster_pairs)[i].second == -1) { + int color = (*cluster_pairs)[i].first; + parent = GetRoot(P, color, *roots); + if (parent == -1) return 1; + child = parent; + } else { + // Root must exist in either first or second element of pair + int color = (*cluster_pairs)[i].first; + parent = GetRoot(P, color, *roots); + color = (parent == -1) ? (*cluster_pairs)[i].second : color; + parent = (parent == -1) ? GetRoot(P, color, *roots) : parent; + + int from_cluster = color; + int dest_cluster = (from_cluster == (*cluster_pairs)[i].first) ? + (*cluster_pairs)[i].second : (*cluster_pairs)[i].first; + + std::vector candidates; + T weight; + FindBestEdge(W, P, parent, dest_cluster, &candidates, &weight); + + // If no candidates + if (candidates[0] != -1) { + std::shuffle(candidates.begin(), candidates.end(), *gen); + child = candidates[0]; + } + + if (child == -1) { + new_roots.insert(parent); + return 1; + } else { + new_roots.insert(parent); + new_roots.insert(child); + } + } + + new_topo[parent] = child; + } + + int depth = scan_row->size(); + int start = (*scan_row)[depth-2]; + int end = (*scan_row)[depth-1]; + + for (int i = start; i < end; ++i) { + int parent = (*topo_row)[i]; + int child; + + // If not first, check previous level whether or not we are encountering + // this root for the first time in this level of the tree + if (i != start && parent == static_cast((*topo_row)[i-1])) + child = parent; + else + child = new_topo[parent]; + topo_row->push_back(parent); + topo_row->push_back(child); + } + + cluster_pairs->clear(); + roots->clear(); + *roots = std::move(new_roots); + + return reset; +} + +// @input: n is the number of nodes in a balanced binary tree +// @output: returns how many levels of binary tree there are +inline int ComputeDepth(int n) { + for (int depth = 0; depth < MXNET_KVSTORE_MAXDEPTH; ++depth) { + int num = 2 << depth; + if (n <= num) + return depth+1; + } + return 0; +} + +// Checks whether a given state forms a spanning tree that satisfies: +// -balanced +// -binary +// -each edge in tree corresponds to link in network topology +// -each edge in tree does not form self-loop +template +inline bool IsValid(const std::vector& W, + const std::vector& state, + int num_elements, + int row, + int depth) { + // At each level of tree, check whether edge: + // -corresponds to link in network topology + // -corresponds to self-loop + for (int i = 0; i < depth; ++i) { + int stride = 1 << i; + for (int j = 0; j+stride < row; j += 2*stride) { + int from = state[j]; + int dest = state[j+stride]; + if (W[from*num_elements + dest] == static_cast(0) && from != dest) { + return false; + } + } + } + + // If we encounter GPU for first time, increment found_vec. + // Otherwise, do nothing + std::unordered_set found; + std::vector found_vec(num_elements, 0); + for (auto val : state) { + if (val == -1) + continue; + if (val < num_elements) { + if (found.find(val) == found.end()) { + found.insert(val); + found_vec[val] = 1; + } + } else { + return false; + } + } + + // modifier is maximum number of repeats a single GPU can take + // e.g. 5 GPUs in 3-level binary tree => one GPU can repeat 3x + // GPU0 GPU0 GPU0 GPU0 GPU1 GPU2 GPU3 GPU4 + int modifier = (1 << depth) - num_elements; + int num_found = found.size(); + + // So we know we have an invalid state if we find: + // -only 4 unique GPUs + // -9 unique GPUs + if (row < num_elements) { + if (num_found > row || num_found < row - modifier) { + return false; + } + + // If we are at last recursive level, we can apply a more stringent check: + // -if some GPU is not found, then we are in invalid state + } else if (row == static_cast(state.size())) { + for (int i = 0; i < num_elements; ++i) { + if (found_vec[i] == 0) { + return false; + } + } + } + + return true; +} + +// This function takes a spanning tree encoded as state (result), which may have +// repeated GPUs representing NO-SENDs and converts it into a unique format. +// This has the effect of recognizing redundant sends, grouping them together, +// so that the Reduce call knows not to perform a CopyFromTo. +// +// Initial result: [3 0 0 4 1 2 5 6] +// Final result: [3 3 0 4 1 2 5 6] +// +// Initial: +// 3 +// 3 1 +// 3 0 1 5 +// 3 0 0 4 1 2 5 6 // GPU3 will make redundant send to GPU0 +// +// Final: +// 3 +// 3 1 +// 3 0 1 5 +// 3 3 0 4 1 2 5 6 // GPU3 knows not to make redundant send to itself +inline void Postprocess(std::vector* result, int num_elements, int depth) { + for (int level = depth - 1; level >= 0; --level) { + int stride = 1 << level; + std::vector histogram_above(num_elements, 0); + for (unsigned i = 0; i < result->size(); i += 2*stride) { + int val = (*result)[i]; + histogram_above[val]++; + } + std::vector histogram(num_elements, 0); + for (unsigned i = 0; i < result->size(); i += stride) { + int val = (*result)[i]; + histogram[val]++; + } + + for (int i = result->size()-stride; i-stride >= 0; i -= 2*stride) { + int from = (*result)[i]; + int dest = (*result)[i-stride]; + if ((histogram[from] > 1 || histogram_above[from] >= 1) && from != dest) { + (*result)[i] = dest; + histogram[from]--; + } + } + } +} + +// Given a spanning tree encoded as a state (result) and weight of each edge +// in the link topology graph, compute its weight. +// @input: penalty controls whether or not penalties are applied to tree +// -usually turned on when backtracking to get better solutions +// -usually turned off when outside the penalty to get weight of tree +template +inline T ComputeTreeWeight(const std::vector& W, + const std::vector& result, + int num_elements, + int depth, + bool penalty) { + T weight = 0.f; + std::unordered_set links_used; + + for (int i = 0; i < depth; ++i) { + int stride = 1 << i; + std::vector nodes_used(num_elements, false); + for (unsigned j = 0; j+stride < result.size(); j += 2*stride) { + int from = result[j]; + int dest = result[j+stride]; + if (from != dest) { + weight += W[from*num_elements+dest]; + + // Penalize: (1) use of redundant edges in a single tree + // (2) repeated use of a GPU in a single tree at the same + // level above the leaf level + if (links_used.find(from*num_elements+dest) != links_used.end() + && penalty) { + weight -= 100; + } + links_used.insert(from*num_elements+dest); + links_used.insert(dest*num_elements+from); + } + + nodes_used[from] = true; + if (i > 0 && nodes_used[dest] && penalty) { + weight -= 10; + } + nodes_used[dest] = true; + } + } + + return weight; +} + +// Given a spanning tree encoded as result, which was convenient for performing +// backtracking, convert it topology_ and scan_ in the classic "binary tree +// stored in an array" format. For binary trees scan_ is redundant, but this +// additional data structure leaves future generalization to k-radix trees. +// +// Initial result: [3 3 0 4 1 2 5 6] +// topology_: [3 3 1 3 0 1 5 3 3 0 4 1 2 5 6] +// scan_: [0 1 3 7 15] +// +// topology_ is stored in the classic "binary tree stored in an array" format +// e.g. 3 +// 3 1 +// 3 0 1 5 +// 3 3 0 4 1 2 5 6 +inline void FormTopology(const std::vector& result, + std::vector* topo_row, + std::vector* scan_row, + int depth) { + scan_row->push_back(topo_row->size()); + for (int i = depth; i > 0; --i) { + int stride = 1 << i; + for (unsigned j = 0; j < result.size(); j += stride) { + int from = result[j]; + topo_row->push_back(from); + } + scan_row->push_back(topo_row->size()); + } + + // Insert at the end, result vector + topo_row->insert(topo_row->end(), result.begin(), result.end()); + scan_row->push_back(topo_row->size()); +} + +// Recursive function that finds a spanning tree, which fulfills the following +// conditions: +// -balanced +// -binary +// -maximum weight +template +inline bool RecursiveBacktrack(const std::vector& W, + std::vector* state, + std::vector* best_result, + T* best_result_weight, + int row, + int num_elements, + int depth, + bool optimal) { + if (row == static_cast(state->size())) { + std::vector result = *state; + Postprocess(&result, num_elements, depth); + T weight = ComputeTreeWeight(W, result, num_elements, depth, true); + + // Save this spanning tree if it is highest weight tree found sofar + if (weight > *best_result_weight) { + std::swap(*best_result_weight, weight); + *best_result = result; + } + return !optimal; + } + + // If not last recursive level, try to find valid tree for next level + bool stop = false; + for (int j = 0; j < num_elements; ++j) { + (*state)[row] = j; + if (IsValid(W, state, num_elements, row+1, depth)) + stop = RecursiveBacktrack(W, state, best_result, best_result_weight, + row+1, num_elements, depth, optimal); + (*state)[row] = -1; + if (stop) + return stop; + } + return stop; +} + +template +inline void IterativeBacktrack(const std::vector& W, + std::vector* state, + std::vector* best_result, + T* best_result_weight, + int row, + int num_elements, + int depth, + bool optimal) { + std::stack state_stack; + row = 1; + int pos = 0; + state_stack.push(pos); + + while (true) { + // If there is no valid position, 2 cases: + // a) if stack is empty, break and stop search + // b) if stack is not empty, pop stack and set current position to next + // position backtrack to previous row + while (!state_stack.empty() && pos >= num_elements) { + pos = state_stack.top(); + pos++; + state_stack.pop(); + (*state)[state_stack.size()+1] = -1; + row--; + } + if (state_stack.empty()) break; + + (*state)[row] = pos; + // If there is a valid position push the position to stack, set current + // position to 0 and move to next row + if (IsValid(W, *state, num_elements, row+1, depth)) { + state_stack.push(pos); + pos = 0; + row++; + } else { + pos++; + (*state)[row] = -1; + } + + // If stack has size N, a solution is found + // Pop stack, set current position to next position + // Backtrack to find next solution + if (row == static_cast(state->size())) { + std::vector result = *state; + Postprocess(&result, num_elements, depth); + T weight = ComputeTreeWeight(W, result, num_elements, depth, true); + + // Save this spanning tree if it is highest weight tree found so far + if (weight > *best_result_weight) { + std::swap(*best_result_weight, weight); + *best_result = result; + } + if (!optimal) break; + + pos = state_stack.top(); + pos++; + state_stack.pop(); + (*state)[state_stack.size()+1] = -1; + row--; + } + } +} + +// Apply penalty factor alpha to each link in link topology graph that is used +// by the spanning tree +template +inline void UpdateWeight(std::vector* W, + const std::vector& topo_row, + int num_elements, + float alpha) { + for (unsigned i = 1; i < topo_row.size() - 1; i += 2) { + unsigned parent = topo_row[i]; + unsigned child = topo_row[i+1]; + if (!(parent >= num_elements*num_elements || + child >= num_elements*num_elements) && (parent != child)) { + (*W)[parent*num_elements+child] *= alpha; + (*W)[child*num_elements+parent] *= alpha; + } + } +} + +// Do brute-force backtracking approach if Kernighan-Lin fails to find a binary +// tree of height Log P. +// +// Constraints: +// 1) minimize depth (balance) +// 2) maximize edge weight +// 3) tree is binary +template +inline void BacktrackGenerateBinaryTree(std::vector* W, + int num_elements, + int root, + std::vector* topo_row, + std::vector* scan_row) { + // Clear before starting + topo_row->clear(); + scan_row->clear(); + + // Compute depth + // num_elements: depth + // 5: 3 + // 6: 3 + // 7: 3 + // 8: 3 + // 9: 4 + int depth = ComputeDepth(num_elements); + int depth_leaves = 1 << depth; + + // State vector + // -1 means unplaced + std::vector state(depth_leaves, -1); + std::vector result(depth_leaves, -1); + T result_weight = std::numeric_limits::lowest(); + + // Place root and try all combinations + state[0] = root; + + // Seek optimal solution until depth <= 3 i.e. 8 GPUs + // For larger numbers of GPUs, settle for first tree found (non-optimal), but + // this saves a lot of runtime, because Backtrack is exponential time + if (depth <= 3) + IterativeBacktrack(*W, &state, &result, &result_weight, 1, num_elements, + depth, true); + else + IterativeBacktrack(*W, &state, &result, &result_weight, 1, num_elements, + depth, false); + FormTopology(result, topo_row, scan_row, depth); +} + +// ComputeTreesFromRoot does the same thing as ComputeTrees, with the only +// exception being it will do it from a fixed GPU as root +template +inline void ComputeTreesFromRoot(std::vector* W, + int num_elements, + int root, + float alpha, + bool backtrack, + std::vector* topo, + std::vector* scan) { + int num_partitions = 1; + + // Initialize partition array to indicate which partition each element belongs + // to beginning with 0 + std::vector P(num_elements, 0); + + // Initialize vector of pairs that will tell us edges between what 2 clusters + // we should be looking to build the tree from + std::vector> cluster_pairs; + + // Initialize vector of roots that will tell us edges between + std::unordered_set roots; + roots.insert(root); + + // Will be used to obtain a seed for the random number engine + // RNG: Standard mersenne_twister_engine seeded with rd() + // -use 0 for testing (TODO: remove this) + // std::random_device rd; + // std::mt19937 gen(rd()); + std::mt19937 gen(1); + + // Temporary variables for rewinding + std::vector P_temp; + int num_partitions_temp; + std::unordered_set roots_temp; + std::vector topo_temp; + std::vector scan_temp; + + // Determine number of partition levels + // If first partition, determine root of maximal spanning tree + bool stop = false; + int reset = 1; + int level = 0; + + while (!backtrack && (!stop || reset)) { + if (reset == 1) { + cluster_pairs.clear(); + P_temp = P; + num_partitions_temp = num_partitions; + roots_temp = roots; + topo_temp = *topo; + scan_temp = *scan; + } + + // Run Kernighan-Lin to generate partition + stop = KernighanLin(*W, &P_temp, &num_partitions_temp, &cluster_pairs, + &gen); + + // Use partitions found and a given root to find best inter-cluster edge for + // each pair of clusters, and returns them as roots of next cluster + // If reset is true, then rewind back to previous clustering + reset = KLGenerateBinaryTree(*W, P_temp, &cluster_pairs, &roots_temp, + &topo_temp, &scan_temp, &gen); + + if (reset) + level++; + if (level > 10) break; + } + + if (reset == 1) { + // if (!backtrack) + // LOG(WARNING) << "No valid binary tree found from root " << root << ", try backtracking"; + BacktrackGenerateBinaryTree(W, num_elements, root, topo, scan); + } else { + *topo = topo_temp; + *scan = scan_temp; + scan->push_back(topo->size()); + } + UpdateWeight(W, *topo, num_elements, alpha); +} + +// ComputeTrees computes balanced binary spanning trees of maximum edge weight +// given a link topology graph stored in adjacency matrix format +// @input: W is the link topology matrix +// num_elements is the number of GPUs +// alpha is the link usage penalty +// backtrack is whether or not we use backtracking to generate trees +// @output: topo stores the trees generated +// scan stores the start of each level of each tree +template +inline void ComputeTrees(const std::vector& W, + int num_elements, + float alpha, + bool backtrack, + std::vector>* topo, + std::vector>* scan) { + std::vector W_copy = W; + + topo->clear(); + scan->clear(); + for (int i = 0; i < num_elements; ++i) { + topo->push_back(std::vector()); + scan->push_back(std::vector()); + (*topo)[i].push_back(i); + (*scan)[i].push_back(0); + ComputeTreesFromRoot(&W_copy, num_elements, i, alpha, backtrack, + &((*topo)[i]), &((*scan)[i])); + } + + // Note: must sum up adj matrix to show link usage before we readjust topo + // from 0, 1, ..., n_gpus format to dev_id format, which will cause segfault + std::vector adj(W.size(), 0); + for (int row = 0; row < num_elements; ++row) { + for (unsigned col = 1; col < (*topo)[0].size(); col += 2) { + int from = std::min((*topo)[row][col], (*topo)[row][col+1]); + int dest = std::max((*topo)[row][col], (*topo)[row][col+1]); + if (from != dest) { + adj[from*num_elements+dest] += 1; + adj[dest*num_elements+from] += 1; + } + } + } + + std::vector> topo_temp(num_elements, + std::vector()); + + /*for (int i = 0; i < num_elements; ++i) + PrintTopo("Topo", topo[i], scan[i]); + + PrintMatrix("W", W, num_elements, num_elements); + PrintMatrix("Links", adj, num_elements, num_elements);*/ +} +} // namespace kvstore +} // namespace mxnet +#endif // MXNET_KVSTORE_GPU_TOPOLOGY_H_ diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index 38ecf121dfeb..791ad3362010 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -34,6 +34,7 @@ #include #include #include "./comm.h" +#include "./comm_tree.h" #include "./kvstore_utils.h" #include "../ndarray/ndarray_function.h" @@ -56,7 +57,12 @@ class KVStoreLocal : public KVStore { */ explicit KVStoreLocal(bool use_device_comm) : KVStore() { if (use_device_comm) { - comm_ = new CommDevice(); + bool tree = dmlc::GetEnv("MXNET_KVSTORE_USETREE", 0) & MXNET_USE_CUDA; + if (tree) { + comm_ = new CommDeviceTree(); + } else { + comm_ = new CommDevice(); + } } else { comm_ = new CommCPU(); } diff --git a/tests/cpp/kvstore/gpu_topology_test.cc b/tests/cpp/kvstore/gpu_topology_test.cc new file mode 100644 index 000000000000..1fba05ef64cf --- /dev/null +++ b/tests/cpp/kvstore/gpu_topology_test.cc @@ -0,0 +1,671 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file gpu_topology_test.cc + * \brief gpu topology tests +*/ + +#include +#include +#include +#include "../src/kvstore/gpu_topology.h" + +void GenerateMatrix(std::vector* W, int num_gpus, float k, + std::mt19937* gen) { + std::uniform_real_distribution<> dis(0., 1.); + for (int row = 0; row < num_gpus; ++row) { + for (int col = row+1; col < num_gpus; ++col) { + float sample = dis(*gen); + if (sample < k) + continue; + sample = dis(*gen); + if (sample < 0.33f) { + (*W)[row*num_gpus+col] = 1.f; + (*W)[col*num_gpus+row] = 1.f; + } else if (sample < 0.66f) { + (*W)[row*num_gpus+col] = 2.f; + (*W)[col*num_gpus+row] = 2.f; + } else { + (*W)[row*num_gpus+col] = 3.f; + (*W)[col*num_gpus+row] = 3.f; + } + } + } +} + +bool IsSatisfactory(const std::vector& W, int num_gpus, int depth) { + for (int row = 0; row < num_gpus; ++row) { + int out_edges = 0; + for (int col = 0; col < num_gpus; ++col) { + if (W[row*num_gpus+col] > 0.f) + out_edges++; + } + if (out_edges < depth) + return false; + } + return true; +} + +// Generates random link topology matrix using random number generator +void TestComputeTreesRandomized(int num_gpus, float alpha, int backtrack, + std::mt19937* gen) { + std::uniform_real_distribution<> dis(0.f, 1.f); + bool satisfied = false; + std::vector W(num_gpus*num_gpus, 0.f); + int depth = mxnet::kvstore::ComputeDepth(num_gpus); + while (!satisfied) { + float k = dis(*gen); + std::fill(W.begin(), W.end(), 0.f); + GenerateMatrix(&W, num_gpus, k, gen); + satisfied = IsSatisfactory(W, num_gpus, depth); + } + + std::vector> topo; + std::vector> scan; + mxnet::kvstore::ComputeTrees(W, num_gpus, alpha, backtrack, &topo, &scan); + + unsigned correct_topo_size = (1 << (depth + 1)) - 1; + unsigned correct_scan_size = depth+2; + for (int i = 0; i < num_gpus; ++i) { + ASSERT_EQ(correct_topo_size, topo[i].size()); + ASSERT_EQ(correct_scan_size, scan[i].size()); + } +} + +// Permutes matrix W using permutation vector P and stores output in matrix A +// Assumption: W is square and symmetric +void PermuteMatrix(const std::vector& W, + const std::vector& P, + std::vector* A) { + int nrows = P.size(); + std::vector temp(nrows*nrows, 0); + + int count = 0; + for (int row=0; row < nrows; ++row) { + for (int col=0; col < nrows; ++col) { + int row_start = P[row]; + temp[count] = W[row_start*nrows+col]; + count++; + } + } + + count = 0; + for (int row=0; row < nrows; ++row) { + for (int col=0; col < nrows; ++col) { + int col_index = P[col]; + (*A)[count] = temp[row*nrows+col_index]; + count++; + } + } +} + +TEST(GpuTopology, TestFormTopology) { + std::vector state0 = {3, 2, 1, 5, 0, 0, 4, 6}; + std::vector topo0; + std::vector scan0; + std::vector correct0 = {3, 3, 0, 3, 1, 0, 4, 3, 2, 1, 5, 0, 0, 4, 6}; + std::vector correct_scan0 = {0, 1, 3, 7, 15}; + mxnet::kvstore::FormTopology(state0, &topo0, &scan0, 3); + ASSERT_EQ(topo0.size(), correct0.size()); + for (unsigned i = 0; i < correct0.size(); ++i) + ASSERT_EQ(static_cast(topo0[i]), correct0[i]); + ASSERT_EQ(scan0.size(), correct_scan0.size()); + for (unsigned i = 0; i < correct_scan0.size(); ++i) + ASSERT_EQ(static_cast(scan0[i]), correct_scan0[i]); + + std::vector state1 = {3, 2, 0, 4, 1, 1, 5, 6}; + std::vector topo1; + std::vector scan1; + std::vector correct1 = {3, 3, 1, 3, 0, 1, 5, 3, 2, 0, 4, 1, 1, 5, 6}; + std::vector correct_scan1 = {0, 1, 3, 7, 15}; + mxnet::kvstore::FormTopology(state1, &topo1, &scan1, 3); + ASSERT_EQ(topo1.size(), correct1.size()); + for (unsigned i = 0; i < correct1.size(); ++i) + ASSERT_EQ(static_cast(topo1[i]), correct1[i]); + ASSERT_EQ(scan1.size(), correct_scan1.size()); + for (unsigned i = 0; i < correct_scan1.size(); ++i) + ASSERT_EQ(static_cast(scan1[i]), correct_scan1[i]); +} + +TEST(GpuTopology, TestComputeTreeWeight) { + std::vector W = {0, 2, 2, 3, 3, 0, 0, + 2, 0, 3, 2, 0, 3, 0, + 2, 3, 0, 3, 0, 0, 2, + 3, 2, 3, 0, 0, 0, 0, + 3, 0, 0, 0, 0, 2, 2, + 0, 3, 0, 0, 2, 0, 3, + 0, 0, 2, 0, 2, 3, 0}; + + std::vector state0 = {3, 2, 1, 5, 0, 0, 4, 6}; + ASSERT_EQ(mxnet::kvstore::ComputeTreeWeight(W, state0, 7, 3, false), 16); + + std::vector state1 = {3, 2, 0, 4, 1, 1, 5, 6}; + ASSERT_EQ(mxnet::kvstore::ComputeTreeWeight(W, state1, 7, 3, false), 17); +} + +TEST(GpuTopology, TestPostprocess) { + std::vector result0 = {3, 0, 0, 4, 1, 2, 5, 6}; + std::vector correct0 = {3, 3, 0, 4, 1, 2, 5, 6}; + mxnet::kvstore::Postprocess(&result0, 7, 3); + for (unsigned i = 0; i < correct0.size(); ++i) + ASSERT_EQ(result0[i], correct0[i]); + + std::vector result1 = {2, 0, 0, 4, 1, 3, 5, 1}; + std::vector correct1 = {2, 2, 0, 4, 1, 3, 5, 5}; + mxnet::kvstore::Postprocess(&result1, 6, 3); + for (unsigned i = 0; i < correct1.size(); ++i) + ASSERT_EQ(result1[i], correct1[i]); + + std::vector result2 = {5, 4, 1, 3, 1, 0, 2, 0}; + std::vector correct2 = {5, 4, 5, 3, 1, 0, 2, 2}; + mxnet::kvstore::Postprocess(&result2, 6, 3); + for (unsigned i = 0; i < correct2.size(); ++i) + ASSERT_EQ(result2[i], correct2[i]); + + std::vector result3 = {10, 10, 0, 0, 0, 0, 0, 1, 2, 3, 6, 4, 7, 5, 8, 9}; + std::vector correct3 = {10, 10, 10, 10, 0, 0, 0, 1, 2, 3, 6, 4, 7, 5, 8, 9}; + mxnet::kvstore::Postprocess(&result3, 11, 4); + for (unsigned i = 0; i < correct3.size(); ++i) + ASSERT_EQ(result3[i], correct3[i]); +} + +TEST(GpuTopology, TestDepth) { + ASSERT_EQ(mxnet::kvstore::ComputeDepth(8), 3); + ASSERT_EQ(mxnet::kvstore::ComputeDepth(7), 3); + ASSERT_EQ(mxnet::kvstore::ComputeDepth(5), 3); + ASSERT_EQ(mxnet::kvstore::ComputeDepth(4), 2); + ASSERT_EQ(mxnet::kvstore::ComputeDepth(16), 4); +} + +TEST(GpuTopology, TestIsValid) { + std::vector W = {0, 2, 2, 3, 3, 0, 0, + 2, 0, 3, 2, 0, 3, 0, + 2, 3, 0, 3, 0, 0, 2, + 3, 2, 3, 0, 0, 0, 0, + 3, 0, 0, 0, 0, 2, 2, + 0, 3, 0, 0, 2, 0, 3, + 0, 0, 2, 0, 2, 3, 0}; + + std::vector state0 = {3, 2, 1, 5, 0, 0, 4, 6}; + ASSERT_EQ(mxnet::kvstore::IsValid(W, state0, 7, 7, 3), true); + + // 3 connects to 1 first + std::vector state1 = {3, 2, 0, 4, 1, 1, 5, 6}; + ASSERT_EQ(mxnet::kvstore::IsValid(W, state1, 7, 7, 3), true); + + // 3 does not connect to 5 + std::vector state2 = {3, 2, 5, 1, 0, 4, 2, 5}; + ASSERT_EQ(mxnet::kvstore::IsValid(W, state2, 7, 7, 3), false); + + // 7 exceeds number of GPUs + std::vector state3 = {3, 7, 2, 6, 0, 1, 4, 5}; + ASSERT_EQ(mxnet::kvstore::IsValid(W, state3, 7, 7, 3), false); + + // Test -1 + std::vector state4 = {3, -1, 2, 6, 0, 1, 4, 5}; + ASSERT_EQ(mxnet::kvstore::IsValid(W, state4, 7, 7, 3), true); + + // Test -1 + std::vector state5 = {3, -1, 2, 6, 0, 1, 4, -1}; + ASSERT_EQ(mxnet::kvstore::IsValid(W, state5, 7, 8, 3), false); + + // Test 1 row + std::vector state6 = {3, -1, -1, -1, -1, -1, -1, -1}; + ASSERT_EQ(mxnet::kvstore::IsValid(W, state6, 7, 1, 3), true); +} + +// gemvTest +TEST(GpuTopology, TestGemv) { + std::vector A = {0, 2, 2, 3, 3, 1, 1, 1, // 13 + 2, 0, 3, 2, 1, 3, 1, 1, // 13 + 2, 3, 0, 3, 1, 1, 2, 1, // 13 + 3, 2, 3, 0, 1, 1, 1, 2, // 13 + 3, 1, 1, 1, 0, 2, 2, 3, // 13 + 1, 3, 1, 1, 2, 0, 3, 2, // 13 + 1, 1, 2, 1, 2, 3, 0, 3, // 13 + 1, 1, 1, 2, 3, 2, 3, 0}; // 13 + std::vector x(8, 1); + std::vector y(8, 0); + std::iota(y.begin(), y.end(), 0); + std::vector correct_y(8, 13); + mxnet::kvstore::gemv(A, x, &y); + + ASSERT_EQ(y.size(), correct_y.size()); + for (unsigned i = 0; i < y.size(); ++i) + ASSERT_EQ(y[i], correct_y[i]); +} + +// ewisemultTest +TEST(GpuTopology, TestEwisemult) { + std::vector x(8, 1); + std::vector y(8, 0); + std::iota(y.begin(), y.end(), 0); + int alpha = 5; + std::vector correct_y = {0, 5, 10, 15, 20, 25, 30, 35}; + mxnet::kvstore::ewisemult(x, alpha, &y); + + ASSERT_EQ(y.size(), correct_y.size()); + for (unsigned i = 0; i < y.size(); ++i) + ASSERT_EQ(y[i], correct_y[i]); +} + +// FindBestMoveTest +TEST(GpuTopology, TestFindBestMove) { + std::vector W = {0, 2, 2, 3, 3, 1, 1, 1, + 2, 0, 3, 2, 1, 3, 1, 1, + 2, 3, 0, 3, 1, 1, 2, 1, + 3, 2, 3, 0, 1, 1, 1, 2, + 3, 1, 1, 1, 0, 2, 2, 3, + 1, 3, 1, 1, 2, 0, 3, 2, + 1, 1, 2, 1, 2, 3, 0, 3, + 1, 1, 1, 2, 3, 2, 3, 0}; + std::vector P(8, 0); + std::iota(P.begin(), P.end(), 1); + std::unordered_set used; + + std::vector D1 = {20, 0, 0, 0, 0, 0, 0, 20}; + int a1, b1, g1; + int correct_a1 = 0; + int correct_b1 = 7; + int correct_g1 = 38; + mxnet::kvstore::FindBestMove(W, P, D1, used, &a1, &b1, &g1); + ASSERT_EQ(a1, correct_a1); + ASSERT_EQ(b1, correct_b1); + ASSERT_EQ(g1, correct_g1); + + // -1, -1, 0 indicates no best edge found + std::vector D2 = {0, 0, 0, 0, 0, 0, 0, 0}; + int a2, b2, g2; + int correct_a2 = -1; + int correct_b2 = -1; + int correct_g2 = 0; + mxnet::kvstore::FindBestMove(W, P, D2, used, &a2, &b2, &g2); + ASSERT_EQ(a2, correct_a2); + ASSERT_EQ(b2, correct_b2); + ASSERT_EQ(g2, correct_g2); +} + +// GetRootTest +TEST(GpuTopology, TestGetRoot) { + std::vector P = {0, 0, 1, 1, 2, 2, 3, 3}; + + // Test when roots are non-empty, and matches color + std::unordered_set roots1 = {0, 2, 4, 6}; + std::vector color1 = {0, 1, 2, 3}; + for (unsigned i = 0; i < color1.size(); ++i) { + int root1 = mxnet::kvstore::GetRoot(P, color1[i], roots1); + int correct_root1 = 2*i; + ASSERT_EQ(root1, correct_root1); + } + + // Test when roots is empty + std::unordered_set roots2; + int color2 = 0; + int correct_root2 = -1; + int root2 = mxnet::kvstore::GetRoot(P, color2, roots2); + ASSERT_EQ(root2, correct_root2); + + // Test when roots is non-empty, but no root matches color + std::unordered_set roots3 = {0}; + int color3 = 1; + int correct_root3 = -1; + int root3 = mxnet::kvstore::GetRoot(P, color3, roots3); + ASSERT_EQ(root3, correct_root3); + + std::vector P2 = {0, 1, 1, 0, 2, 3, 3, 2}; + std::unordered_set roots4 = roots1; + int color4 = 0; + int correct_root4 = 0; + int root4 = mxnet::kvstore::GetRoot(P, color4, roots4); + ASSERT_EQ(root4, correct_root4); +} + +// GetChildTest +TEST(GpuTopology, TestGetChild) { + std::vector P = {0, 0, 1, 2, 2, 2, 3, 3}; + + // Test when color is not found + int color1 = 4; + int parent1 = 4; + int correct_child1 = -1; + int child1 = mxnet::kvstore::GetChild(P, color1, parent1); + ASSERT_EQ(child1, correct_child1); + + // Test when color is found, but is equal to parent + int color2 = 1; + int parent2 = 2; + int correct_child2 = -1; + int child2 = mxnet::kvstore::GetChild(P, color2, parent2); + ASSERT_EQ(child2, correct_child2); + + // Test when color is found and not equal to parent + int color3 = 3; + int parent3 = 6; + int correct_child3 = 7; + int child3 = mxnet::kvstore::GetChild(P, color3, parent3); + ASSERT_EQ(child3, correct_child3); +} + +// FindBestEdgeTest +TEST(GpuTopology, TestFindBestEdge) { + std::vector W = {0, 2, 2, 3, 3, 1, 1, 1, + 2, 0, 3, 2, 1, 3, 1, 1, + 2, 3, 0, 3, 1, 1, 2, 1, + 3, 2, 3, 0, 1, 1, 1, 2, + 3, 1, 1, 1, 0, 2, 2, 3, + 1, 3, 1, 1, 2, 0, 3, 2, + 1, 1, 2, 1, 2, 3, 0, 3, + 1, 1, 1, 2, 3, 2, 3, 0}; + std::vector P(8, 0); + std::unordered_set used; + + int parent1 = 3; + int dest1 = 0; + std::vector b1; + int g1; + std::vector correct_b1 = {0, 2}; + int correct_g1 = 3; + mxnet::kvstore::FindBestEdge(W, P, parent1, dest1, &b1, &g1); + ASSERT_EQ(b1.size(), correct_b1.size()); + for (unsigned i = 0; i < b1.size(); ++i) + ASSERT_EQ(b1[i], correct_b1[i]); + ASSERT_EQ(g1, correct_g1); + + // {-1}, 0 indicates no best edge found + int parent2 = 4; + int dest2 = 1; + std::vector b2; + int g2; + std::vector correct_b2 = {-1}; + int correct_g2 = 0; + mxnet::kvstore::FindBestEdge(W, P, parent2, dest2, &b2, &g2); + ASSERT_EQ(b2.size(), correct_b2.size()); + for (unsigned i = 0; i < b2.size(); ++i) + ASSERT_EQ(b2[i], correct_b2[i]); + ASSERT_EQ(g2, correct_g2); +} + +// KLGenerateBinaryTreeTest +TEST(GpuTopology, TestKLGenerateBinaryTree1) { + std::vector W = {0, 2, 3, 3, 3, 1, 1, 1, + 2, 0, 3, 2, 1, 3, 1, 1, + 2, 3, 0, 3, 1, 1, 2, 1, + 3, 2, 3, 0, 1, 1, 1, 2, + 3, 1, 1, 1, 0, 2, 3, 3, + 1, 3, 1, 1, 2, 0, 3, 2, + 1, 1, 2, 1, 2, 3, 0, 3, + 1, 1, 1, 2, 3, 2, 3, 0}; + std::vector P = {0, 1, 1, 0, 2, 3, 3, 2}; + std::vector> cluster_pairs; + cluster_pairs.push_back(std::pair(0, -2)); + cluster_pairs.push_back(std::pair(1, -2)); + cluster_pairs.push_back(std::pair(2, -2)); + cluster_pairs.push_back(std::pair(3, -2)); + std::unordered_set roots = {0, 2, 4, 6}; + std::vector topo = {0, 2, 4, 6}; + std::vector scan(2, 0); + std::mt19937 gen(1); + mxnet::kvstore::KLGenerateBinaryTree(W, P, &cluster_pairs, &roots, &topo, + &scan, &gen); + std::vector correct_topo = {0, 2, 4, 6, 0, 3, 2, 1, 4, 7, 6, 5}; + std::vector correct_scan = {0, 0, 4}; + ASSERT_EQ(topo.size(), correct_topo.size()); + for (unsigned i = 0; i < topo.size(); ++i) + ASSERT_EQ(topo[i], correct_topo[i]); + ASSERT_EQ(scan.size(), correct_scan.size()); + for (unsigned i = 0; i < scan.size(); ++i) + ASSERT_EQ(scan[i], correct_scan[i]); +} + +TEST(GpuTopology, TestKLGenerateBinaryTree2) { + std::vector W = {0, 2, 3, 3, 3, 1, 1, 1, + 2, 0, 3, 2, 1, 3, 1, 1, + 2, 3, 0, 3, 1, 1, 2, 1, + 3, 2, 3, 0, 1, 1, 1, 2, + 3, 1, 1, 1, 0, 2, 3, 3, + 1, 3, 1, 1, 2, 0, 3, 2, + 1, 1, 2, 1, 2, 3, 0, 3, + 1, 1, 1, 2, 3, 2, 3, 0}; + std::vector P = {0, 1, 1, 0, 2, 3, 3, 2}; + std::vector> cluster_pairs; + cluster_pairs.push_back(std::pair(0, -2)); + cluster_pairs.push_back(std::pair(1, -2)); + cluster_pairs.push_back(std::pair(2, -2)); + cluster_pairs.push_back(std::pair(3, -2)); + std::unordered_set roots = {0, 2, 4, 6}; + std::vector topo = {0, 6, 4, 2}; + std::vector scan(2, 0); + std::mt19937 gen(1); + mxnet::kvstore::KLGenerateBinaryTree(W, P, &cluster_pairs, &roots, &topo, + &scan, &gen); + std::vector correct_topo = {0, 6, 4, 2, 0, 3, 6, 5, 4, 7, 2, 1}; + std::vector correct_scan = {0, 0, 4}; + ASSERT_EQ(topo.size(), correct_topo.size()); + for (unsigned i = 0; i < topo.size(); ++i) + ASSERT_EQ(topo[i], correct_topo[i]); + ASSERT_EQ(scan.size(), correct_scan.size()); + for (unsigned i = 0; i < scan.size(); ++i) + ASSERT_EQ(scan[i], correct_scan[i]); +} + +// UpdateWeightTest +TEST(GpuTopology, TestUpdateWeight) { + std::vector W = {0.f, 1.f, + 1.f, 0.f}; + std::vector topo = {1, 1, 0}; + int num_gpus = 2; + float alpha = 0.7; + std::vector correct_W = {0.f, 0.7f, + 0.7f, 0.f}; + mxnet::kvstore::UpdateWeight(&W, topo, num_gpus, alpha); + ASSERT_EQ(W.size(), correct_W.size()); + for (unsigned i = 0; i < W.size(); ++i) { + ASSERT_EQ(W[i], correct_W[i]); + } +} + +// ComputeTreesFromRoot +TEST(GpuTopology, TestComputeTreesFromRoot1) { + std::vector W = {0, 2, 2, 3, 3, 1, 1, 1, + 2, 0, 3, 2, 1, 3, 1, 1, + 2, 3, 0, 3, 1, 1, 2, 1, + 3, 2, 3, 0, 1, 1, 1, 2, + 3, 1, 1, 1, 0, 2, 2, 3, + 1, 3, 1, 1, 2, 0, 3, 2, + 1, 1, 2, 1, 2, 3, 0, 3, + 1, 1, 1, 2, 3, 2, 3, 0}; + int num_gpus = 8; + int root = 0; + float alpha = 0.7; + bool backtrack = true; + unsigned correct_topo_size = 15; + unsigned correct_scan_size = 5; + std::vector topo; + std::vector scan; + + mxnet::kvstore::ComputeTreesFromRoot(&W, num_gpus, root, alpha, backtrack, + &topo, &scan); + + ASSERT_EQ(topo.size(), correct_topo_size); + ASSERT_EQ(scan.size(), correct_scan_size); +} + +// IsConnected +// Test on graph that is "disconnected" by NVLink +TEST(GpuTopology, TestIsConnected1) { + std::vector W = {0, 0, 2, 0, + 0, 0, 0, 2, + 2, 0, 0, 0, + 0, 2, 0, 0}; + int num_gpus = 4; + + bool connected = mxnet::kvstore::IsConnected(W, num_gpus); + + bool correct_connected = false; + ASSERT_EQ(connected, correct_connected); +} + +// IsConnected +// Test on graph that is "disconnected" by NVLink +TEST(GpuTopology, TestIsConnected2) { + std::vector W = {1, 1, 2, 1, + 1, 1, 1, 2, + 2, 1, 1, 1, + 1, 2, 1, 1}; + int num_gpus = 4; + + bool connected = mxnet::kvstore::IsConnected(W, num_gpus); + + bool correct_connected = false; + ASSERT_EQ(connected, correct_connected); +} + +// IsConnected +// Test on graph that is "disconnected" by NVLink +TEST(GpuTopology, TestIsConnected3) { + std::vector W = {1, 1, 2, 2, + 1, 1, 1, 2, + 2, 1, 1, 1, + 2, 2, 1, 1}; + int num_gpus = 4; + + bool connected = mxnet::kvstore::IsConnected(W, num_gpus); + + bool correct_connected = true; + ASSERT_EQ(connected, correct_connected); +} + +// ComputeTreesTest with backtracking +TEST(GpuTopology, TestComputeTrees1) { + std::mt19937 gen(1); + float alpha = 0.7; + bool backtrack = true; + // Do 5 randomized tests per GPU count from 2 to 16 + for (int num_gpus = 2; num_gpus <= 16; ++num_gpus) { + for (int i = 0; i < 5; ++i) { + TestComputeTreesRandomized(num_gpus, alpha, backtrack, &gen); + } + } +} + +// ComputeTreesTest with Kernighan-Lin +TEST(GpuTopology, TestComputeTrees2) { + std::mt19937 gen(1); + float alpha = 0.7; + bool backtrack = false; + // Do 5 randomized tests per GPU count from 2 to 16 + for (int num_gpus = 2; num_gpus <= 16; ++num_gpus) { + for (int i = 0; i < 5; ++i) { + TestComputeTreesRandomized(num_gpus, alpha, backtrack, &gen); + } + } +} + +TEST(GpuTopology, TestPermuteMatrix) { + std::vector W = {0, 2, 2, 3, 3, 1, 1, 1, + 2, 0, 3, 2, 1, 3, 1, 1, + 2, 3, 0, 3, 1, 1, 2, 1, + 3, 2, 3, 0, 1, 1, 1, 2, + 3, 1, 1, 1, 0, 2, 2, 3, + 1, 3, 1, 1, 2, 0, 3, 2, + 1, 1, 2, 1, 2, 3, 0, 3, + 1, 1, 1, 2, 3, 2, 3, 0}; + + std::vector P1 = {0, 1, 2, 3, 4, 5, 6, 7}; + std::vector A(8*8, 0); + PermuteMatrix(W, P1, &A); + for (unsigned i=0; i < W.size(); ++i) + ASSERT_EQ(A[i], W[i]); +} + +TEST(GpuTopology, TestKernighanLin1) { + std::vector W = {0, 1, 2, 3, 2, 4, + 1, 0, 1, 4, 2, 1, + 2, 1, 0, 3, 2, 1, + 3, 4, 3, 0, 4, 3, + 2, 2, 2, 4, 0, 2, + 4, 1, 1, 3, 2, 0}; + std::vector P(6, 0); + std::vector> cluster_pairs; + int num_partitions = 1; + std::mt19937 gen(1); + bool stop = mxnet::kvstore::KernighanLin(W, &P, &num_partitions, + &cluster_pairs, &gen); + + std::vector> correct_pairs; + correct_pairs.push_back(std::pair(0, 1)); + std::vector correct_P = {0, 1, 0, 1, 1, 0}; + ASSERT_EQ(stop, false); + ASSERT_EQ(num_partitions, 2); + ASSERT_EQ(cluster_pairs.size(), correct_pairs.size()); + for (unsigned i = 0; i < cluster_pairs.size(); ++i) { + ASSERT_EQ(cluster_pairs[i].first, correct_pairs[i].first); + ASSERT_EQ(cluster_pairs[i].second, correct_pairs[i].second); + } + ASSERT_EQ(P.size(), correct_P.size()); + unsigned error = 0; + for (unsigned i = 0; i < P.size(); ++i) { + if (P[i] != correct_P[i]) + error++; + } + EXPECT_TRUE(error == 0 || error == P.size()) + << "Where real value: " << error + << " not equal neither: " << 0 + << " nor: " << P.size() << "."; +} + +TEST(GpuTopology, TestKernighanLin2) { + std::vector W = {0, 1, 0, 0, 1, 1, 0, 0, + 1, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 1, 0, 1, 1, 1, + 0, 0, 1, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 1, 0, 0, + 1, 1, 1, 0, 1, 0, 0, 0, + 0, 0, 1, 1, 0, 0, 0, 1, + 0, 0, 1, 1, 0, 0, 1, 0}; + std::vector P(8, 0); + std::vector> cluster_pairs; + int num_partitions = 1; + std::mt19937 gen(1); + bool stop = mxnet::kvstore::KernighanLin(W, &P, &num_partitions, + &cluster_pairs, &gen); + + std::vector> correct_pairs; + correct_pairs.push_back(std::pair(0, 1)); + std::vector correct_P = {0, 0, 1, 1, 0, 0, 1, 1}; + ASSERT_EQ(stop, false); + ASSERT_EQ(num_partitions, 2); + ASSERT_EQ(cluster_pairs.size(), correct_pairs.size()); + for (unsigned i = 0; i < cluster_pairs.size(); ++i) { + ASSERT_EQ(cluster_pairs[i].first, correct_pairs[i].first); + ASSERT_EQ(cluster_pairs[i].second, correct_pairs[i].second); + } + ASSERT_EQ(P.size(), correct_P.size()); + unsigned error = 0; + for (unsigned i = 0; i < P.size(); ++i) { + if (P[i] != correct_P[i]) + error++; + } + EXPECT_TRUE(error == 0 || error == P.size()) + << "Where real value: " << error + << " not equal neither: " << 0 + << " nor: " << P.size() << "."; +} diff --git a/tests/python/gpu/test_kvstore_gpu.py b/tests/python/gpu/test_kvstore_gpu.py index b5b4d90ca1dd..70a776e2d76c 100644 --- a/tests/python/gpu/test_kvstore_gpu.py +++ b/tests/python/gpu/test_kvstore_gpu.py @@ -21,6 +21,7 @@ import mxnet as mx import numpy as np import unittest +import logging from mxnet.test_utils import assert_almost_equal, default_context curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.insert(0, os.path.join(curr_path, '../unittest')) @@ -30,6 +31,21 @@ keys = [5, 7, 11] str_keys = ['b', 'c', 'd'] +class EnvManager: + def __init__(self, key, val): + self._key = key + self._next_val = val + self._prev_val = None + + def __enter__(self): + try: + self._prev_val = os.environ[self._key] + except KeyError: + self._prev_val = "" + os.environ[self._key] = self._next_val + + def __exit__(self, ptype, value, trace): + os.environ[self._key] = self._prev_val def init_kv_with_str(stype='default', kv_type='local'): """init kv """ @@ -86,34 +102,48 @@ def check_rsp_pull(kv, count, ctxs, is_same_rowid=False, use_slice=False): check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)], use_slice=True) check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)], use_slice=True) - check_rsp_push_pull('local') - check_rsp_push_pull('device') - check_rsp_push_pull('device', is_push_cpu=False) + # test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/9384 + # check_rsp_push_pull('local') + envs = ["","1"] + key = "MXNET_KVSTORE_USETREE" + for val in envs: + with EnvManager(key, val): + check_rsp_push_pull('local') + check_rsp_push_pull('device') + check_rsp_push_pull('device', is_push_cpu=False) def test_row_sparse_pull_single_device(): - kvstore = mx.kv.create('device') - copy = mx.nd.random_normal(shape=(4,4), ctx=mx.gpu(0)) - grad = copy.tostype("row_sparse") + envs = ["","1"] + key = "MXNET_KVSTORE_USETREE" + for val in envs: + with EnvManager(key, val): + kvstore = mx.kv.create('device') + copy = mx.nd.random_normal(shape=(4,4), ctx=mx.gpu(0)) + grad = copy.tostype("row_sparse") - key = 0 - kvstore.init(key, grad) - idx = grad.indices - kvstore.push(key, grad) - kvstore.row_sparse_pull(key, out=grad, row_ids=idx) + k = 0 + kvstore.init(k, grad) + idx = grad.indices + kvstore.push(k, grad) + kvstore.row_sparse_pull(k, out=grad, row_ids=idx) - assert_almost_equal(grad.asnumpy(), copy.asnumpy()) + assert_almost_equal(grad.asnumpy(), copy.asnumpy()) def test_rsp_push_pull_large_rowid(): - num_rows = 793470 - val = mx.nd.ones((num_rows, 1)).tostype('row_sparse').copyto(mx.gpu()) - kv = mx.kv.create('device') - kv.init('a', val) - out = mx.nd.zeros((num_rows,1), stype='row_sparse').copyto(mx.gpu()) - kv.push('a', val) - kv.row_sparse_pull('a', out=out, row_ids=mx.nd.arange(0, num_rows, dtype='int64')) - assert(out.indices.shape[0] == num_rows) + envs = ["","1"] + key = "MXNET_KVSTORE_USETREE" + for val in envs: + with EnvManager(key, val): + num_rows = 793470 + val = mx.nd.ones((num_rows, 1)).tostype('row_sparse').copyto(mx.gpu()) + kv = mx.kv.create('device') + kv.init('a', val) + out = mx.nd.zeros((num_rows,1), stype='row_sparse').copyto(mx.gpu()) + kv.push('a', val) + kv.row_sparse_pull('a', out=out, row_ids=mx.nd.arange(0, num_rows, dtype='int64')) + assert(out.indices.shape[0] == num_rows) if __name__ == '__main__': import nose