Skip to content

Commit

Permalink
common: add multiple key hash functions (#140)
Browse files Browse the repository at this point in the history
* common: add multiple hash functions

* common: improve format

* common: improve built-in hash

* quick fix
  • Loading branch information
ymjiang authored Nov 4, 2019
1 parent 63cb3bb commit b7615c5
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 8 deletions.
61 changes: 55 additions & 6 deletions byteps/common/global.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ uint32_t BytePSGlobal::_partition_bytes = 4096000;
std::shared_ptr<BytePSComm> BytePSGlobal::_basic_comm;
std::shared_ptr<BytePSSharedMemory> BytePSGlobal::_shm_obj;
std::unordered_map<uint64_t, PSKV> BytePSGlobal::ps_kv_;
std::hash<ps::Key> BytePSGlobal::_hash_fn;
std::vector<unsigned long> BytePSGlobal::_server_accumulated_len;
bool BytePSGlobal::_use_hash;
std::string BytePSGlobal::_hash_knob;

volatile BytePSScheduledQueue* BytePSGlobal::_queues[QueueNum] = {NULL};
std::mutex BytePSGlobal::_queues_mutex[QueueNum];
Expand All @@ -65,6 +64,9 @@ cudaStream_t* BytePSGlobal::_copy_host2device_stream;
std::shared_ptr<NcclManager> BytePSGlobal::_nccl_manager;
std::shared_ptr<CpuReducer> BytePSGlobal::_cpu_reducer;

std::hash<std::string> BytePSGlobal::_built_in_hash_fn;
unsigned int BytePSGlobal::_built_in_hash_coefficient;

uint64_t BytePSGlobal::_sample_key = std::numeric_limits<uint64_t>::max();
std::atomic_int BytePSGlobal::joined_thread_cnt;

Expand Down Expand Up @@ -115,7 +117,16 @@ void BytePSGlobal::Init() {
if (_is_distributed_job) {
BPS_CHECK(getenv("DMLC_NUM_SERVER"))
<< "error: launch distributed job, but env DMLC_NUM_SERVER not set";
_use_hash = getenv("BYTEPS_USE_HASH_KEY") ? atoi(getenv("BYTEPS_USE_HASH_KEY")) : false;

// set hash function
_hash_knob = std::string(getenv("BYTEPS_KEY_HASH_FN") ? getenv("BYTEPS_KEY_HASH_FN") : "djb2");
BPS_LOG(DEBUG) << "Using key hash function type: " << _hash_knob;
if (!_hash_knob.compare(std::string("built_in"))) {
_built_in_hash_coefficient = getenv("BYTEPS_BUILT_IN_HASH_COEF") ? atoi(getenv("BYTEPS_BUILT_IN_HASH_COEF")) : 1;
BPS_LOG(DEBUG) << "The built in hash coefficient is set to " << _built_in_hash_coefficient;
}

// set server load counter
int num_server = atoi(getenv("DMLC_NUM_SERVER"));
for (int i = 0; i < num_server; ++i) _server_accumulated_len.push_back(0);
}
Expand Down Expand Up @@ -315,6 +326,36 @@ bool BytePSGlobal::IsTensorDeclared(const std::string& name) {
return true;
}

uint64_t BytePSGlobal::Hash_Naive(uint64_t key) {
return ((key >> 16) + (key % 65536)) * 9973;
}
uint64_t BytePSGlobal::Hash_BuiltIn(uint64_t key) {
auto str = std::to_string(key).c_str();
return _built_in_hash_fn(str) * _built_in_hash_coefficient;
}

uint64_t BytePSGlobal::Hash_DJB2(uint64_t key) {
auto str = std::to_string(key).c_str();
uint64_t hash = 5381;
int c;
while (c = *str) { // hash(i) = hash(i-1) * 33 ^ str[i]
hash = ((hash << 5) + hash) + c;
str++;
}
return hash;
}

uint64_t BytePSGlobal::Hash_SDBM(uint64_t key) {
auto str = std::to_string(key).c_str();
uint64_t hash = 0;
int c;
while (c = *str) { // hash(i) = hash(i-1) * 65599 + str[i]
hash = c + (hash << 6) + (hash << 16) - hash;
str++;
}
return hash;
}

PSKV& BytePSGlobal::EncodeDefaultKey(uint64_t key, size_t len) {
std::lock_guard<std::mutex> lock(_encode_mutex);
PSKV& pskv = ps_kv_[key];
Expand All @@ -327,11 +368,19 @@ PSKV& BytePSGlobal::EncodeDefaultKey(uint64_t key, size_t len) {
BPS_CHECK_GT(num_servers, 0);
// send it to a single random picked server
int server;
if (_use_hash) {
server = _hash_fn(key) % num_servers;
if (!_hash_knob.compare(std::string("naive"))) {
server = Hash_Naive(key) % num_servers;
} else if (!_hash_knob.compare(std::string("built_in"))) {
server = Hash_BuiltIn(key) % num_servers;
} else if (!_hash_knob.compare(std::string("djb2"))) {
server = Hash_DJB2(key) % num_servers;
} else if (!_hash_knob.compare(std::string("sdbm"))) {
server = Hash_SDBM(key) % num_servers;
} else {
server = (((key >> 16) + (key % 65536)) * 9973) % num_servers;
BPS_CHECK(0) << "Unsupported BYTEPS_KEY_HASH_FN, "
<< "must be one of [naive, built_in, djb2, sdbm]";
}

_server_accumulated_len[server] += len;
BPS_LOG(DEBUG) << "key " << key << " assigned to server " << server
<< ", accumulated workload for this server is "
Expand Down
11 changes: 9 additions & 2 deletions byteps/common/global.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ class BytePSGlobal {
static BPSContext& GetContextFromName(const std::string& name);
static uint32_t GetTensorCount();

static bool _use_hash;
static std::hash<ps::Key> _hash_fn;
static std::vector<unsigned long> _server_accumulated_len;
static std::unordered_map<uint64_t, PSKV> ps_kv_;
static PSKV& EncodeDefaultKey(uint64_t key, size_t len);
Expand Down Expand Up @@ -164,6 +162,15 @@ class BytePSGlobal {
static int AlignTo(int input, int alignment) {
return input / alignment * alignment;
}

// hash functions
static std::string _hash_knob;
static std::hash<std::string> _built_in_hash_fn;
static unsigned int _built_in_hash_coefficient;
static uint64_t Hash_Naive(uint64_t key);
static uint64_t Hash_BuiltIn(uint64_t key);
static uint64_t Hash_DJB2(uint64_t key);
static uint64_t Hash_SDBM(uint64_t key);
};

} // namespace common
Expand Down

0 comments on commit b7615c5

Please sign in to comment.