Skip to content

Commit

Permalink
[MXNET-400] support string type for kvstore key in cpp-package (apach…
Browse files Browse the repository at this point in the history
…e#10792)

* Kvstore strkey (#2)

* support string type for kvstore key in cpp-package

* make lines short

* fix build

* add kvstore testcase

* no rand() use

* fix cpplint sanity check

* support string type for kvstore key in cpp-package

* make lines short

* fix build

* print error log

* Update test_kvstore.cpp

* update

* add gpu unittest

* check gpu count

* fix sanity check
  • Loading branch information
nihui authored and wkcn committed Apr 9, 2019
1 parent 4530ad8 commit 74e71e9
Show file tree
Hide file tree
Showing 4 changed files with 292 additions and 7 deletions.
201 changes: 201 additions & 0 deletions cpp-package/example/test_kvstore.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "mxnet/c_api.h" // MXGetGPUCount()
#include "mxnet-cpp/MxNetCpp.h"

using namespace mxnet::cpp;

static bool test_single_key(const Context &context, const std::string &context_str) {
std::string key = "singlekeytest-" + context_str;

NDArray result(Shape(4), context);
NDArray result_cpu;

// initialize data
NDArray data_cpu({0.f, 233.f, -0.12f, 9.f}, Shape(4), Context::cpu());
NDArray data = data_cpu.Copy(context);
NDArray::WaitAll();

KVStore::Init(key, data);
NDArray::WaitAll();

// retrieve result
KVStore::Pull(key, &result);
NDArray::WaitAll();

result_cpu = result.Copy(Context::cpu());
NDArray::WaitAll();

// compare
for (size_t j=0; j < result_cpu.Size(); j++) {
if (result_cpu.GetData()[j] != data_cpu.GetData()[j]) {
LG << "Error: wrong initialized data in singlekeytest-" << context_str
<< ", expect " << data_cpu.GetData()[j]
<< " got " << result_cpu.GetData()[j];
return false;
}
}

// push gradient
NDArray grad_cpu({0.1f, -2.f, -4.4f, 0.f}, Shape(4), Context::cpu());
NDArray grad = grad_cpu.Copy(context);
NDArray::WaitAll();

KVStore::Push(key, grad);
NDArray::WaitAll();

// retrieve result
KVStore::Pull(key, &result);
NDArray::WaitAll();

result_cpu = result.Copy(Context::cpu());
NDArray::WaitAll();

// compare
for (size_t j=0; j < result_cpu.Size(); j++) {
if (result_cpu.GetData()[j] != grad_cpu.GetData()[j]) {
LG << "Error: wrong gradient data in singlekeytest-" << context_str
<< ", expect " << grad_cpu.GetData()[j]
<< " got " << result_cpu.GetData()[j];
return false;
}
}

return true;
}

static bool test_multiple_key(const Context &context, const std::string &context_str) {
std::vector<std::string> keys(2);
keys[0] = "multikeytest-0-" + context_str;
keys[1] = "multikeytest-1-" + context_str;

std::vector<NDArray> results(2);
results[0] = NDArray(Shape(4), context);
results[1] = NDArray(Shape(4), context);
std::vector<NDArray> results_cpu(2);

// initialize data
std::vector<NDArray> data_cpu(2);
data_cpu[0] = NDArray({0.f, 2.f, -3.12f, 4.f}, Shape(4), Context::cpu());
data_cpu[1] = NDArray({0.8f, -2.f, 6.6f, 77.f}, Shape(4), Context::cpu());
std::vector<NDArray> data(2);
data[0] = data_cpu[0].Copy(context);
data[1] = data_cpu[1].Copy(context);
NDArray::WaitAll();

KVStore::Init(keys, data);
NDArray::WaitAll();

// retrieve result
KVStore::Pull(keys, &results);
NDArray::WaitAll();

results_cpu[0] = results[0].Copy(Context::cpu());
results_cpu[1] = results[1].Copy(Context::cpu());
NDArray::WaitAll();

// compare
for (size_t i=0; i < results_cpu.size(); i++) {
for (size_t j=0; j < results_cpu[i].Size(); j++) {
if (results_cpu[i].GetData()[j] != data_cpu[i].GetData()[j]) {
LG << "Error: wrong initialized data in multikeytest-" << context_str
<< ", expect " << data_cpu[i].GetData()[j]
<< " got " << results_cpu[i].GetData()[j];
return false;
}
}
}

// push gradient, reduce for the second
std::vector<std::string> push_keys(3);
push_keys[0] = "multikeytest-0-" + context_str;
push_keys[1] = "multikeytest-1-" + context_str;
push_keys[2] = "multikeytest-1-" + context_str;

std::vector<NDArray> grads_cpu(3);
grads_cpu[0] = NDArray({0.2f, -0.3f, -1.1f, 0.0f}, Shape(4), Context::cpu());
grads_cpu[1] = NDArray({2.f, 4.f, -4.f, -5.f}, Shape(4), Context::cpu());
grads_cpu[2] = NDArray({-3.f, -0.2f, 12.f, -9.f}, Shape(4), Context::cpu());
std::vector<NDArray> grads(3);
grads[0] = grads_cpu[0].Copy(context);
grads[1] = grads_cpu[1].Copy(context);
grads[2] = grads_cpu[2].Copy(context);
NDArray::WaitAll();

KVStore::Push(push_keys, grads);
NDArray::WaitAll();

// retrieve result
KVStore::Pull(keys, &results);
NDArray::WaitAll();

results_cpu[0] = results[0].Copy(Context::cpu());
results_cpu[1] = results[1].Copy(Context::cpu());
NDArray::WaitAll();

// compare the first
for (size_t j=0; j < results_cpu[0].Size(); j++) {
if (results_cpu[0].GetData()[j] != grads_cpu[0].GetData()[j]) {
LG << "Error: wrong gradient data in multikeytest-" << context_str
<< ", expect " << grads_cpu[0].GetData()[j]
<< " got " << results_cpu[0].GetData()[j];
return false;
}
}

// compare the second
for (size_t j=0; j < results_cpu[1].Size(); j++) {
if (results_cpu[1].GetData()[j] != (grads_cpu[1].GetData()[j] + grads_cpu[2].GetData()[j])) {
LG << "Error: wrong reduced gradient data in multikeytest-" << context_str
<< ", expect " << (grads_cpu[1].GetData()[j] + grads_cpu[2].GetData()[j])
<< " got " << results_cpu[1].GetData()[j];
return false;
}
}

return true;
}

int main(int argc, char** argv) {
KVStore::SetType("local");

bool success1 = test_single_key(Context::cpu(), "cpu");
bool success2 = test_multiple_key(Context::cpu(), "cpu");

bool success3 = true;
bool success4 = true;

int gpu_count = 0;
if (MXGetGPUCount(&gpu_count) != 0) {
LG << "Error: MXGetGPUCount";

MXNotifyShutdown();
return 1;
}

if (gpu_count > 0) {
success3 = test_single_key(Context::gpu(), "gpu");
success4 = test_multiple_key(Context::gpu(), "gpu");
}

int ret = (success1 && success2 && success3 && success4) ? 0 : 1;

MXNotifyShutdown();
return ret;
}
13 changes: 11 additions & 2 deletions cpp-package/include/mxnet-cpp/kvstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,21 @@ class KVStore {
static void SetType(const std::string& type);
static void RunServer();
static void Init(int key, const NDArray& val);
static void Init(const std::string& key, const NDArray& val);
static void Init(const std::vector<int>& keys, const std::vector<NDArray>& vals);
static void Init(const std::vector<std::string>& keys, const std::vector<NDArray>& vals);
static void Push(int key, const NDArray& val, int priority = 0);
static void Push(const std::string& key, const NDArray& val, int priority = 0);
static void Push(const std::vector<int>& keys,
const std::vector<NDArray>& vals, int priority = 0);
const std::vector<NDArray>& vals, int priority = 0);
static void Push(const std::vector<std::string>& keys,
const std::vector<NDArray>& vals, int priority = 0);
static void Pull(int key, NDArray* out, int priority = 0);
static void Pull(const std::vector<int>& keys, std::vector<NDArray>* outs, int priority = 0);
static void Pull(const std::string& key, NDArray* out, int priority = 0);
static void Pull(const std::vector<int>& keys,
std::vector<NDArray>* outs, int priority = 0);
static void Pull(const std::vector<std::string>& keys,
std::vector<NDArray>* outs, int priority = 0);
// TODO(lx): put lr in optimizer or not?
static void SetOptimizer(std::unique_ptr<Optimizer> optimizer, bool local = false);
static std::string GetType();
Expand Down
78 changes: 75 additions & 3 deletions cpp-package/include/mxnet-cpp/kvstore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ inline void KVStore::Init(int key, const NDArray& val) {
CHECK_EQ(MXKVStoreInit(get_kvstore()->get_handle(), 1, &key, &val_handle), 0);
}

inline void KVStore::Init(const std::string& key, const NDArray& val) {
const char* key_handle = key.c_str();
NDArrayHandle val_handle = val.GetHandle();
CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), 1, &key_handle, &val_handle), 0);
}

inline void KVStore::Init(const std::vector<int>& keys, const std::vector<NDArray>& vals) {
CHECK_EQ(keys.size(), vals.size());
std::vector<NDArrayHandle> val_handles(vals.size());
Expand All @@ -99,14 +105,36 @@ inline void KVStore::Init(const std::vector<int>& keys, const std::vector<NDArra
val_handles.data()), 0);
}

inline void KVStore::Init(const std::vector<std::string>& keys, const std::vector<NDArray>& vals) {
CHECK_EQ(keys.size(), vals.size());
std::vector<const char*> key_handles(keys.size());
std::transform(keys.cbegin(), keys.cend(), key_handles.begin(),
[](const std::string& key) {
return key.c_str();
});
std::vector<NDArrayHandle> val_handles(vals.size());
std::transform(vals.cbegin(), vals.cend(), val_handles.begin(),
[](const NDArray& val) {
return val.GetHandle();
});

CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(),
val_handles.data()), 0);
}

inline void KVStore::Push(int key, const NDArray& val, int priority) {
NDArrayHandle val_handle = val.GetHandle();
CHECK_EQ(MXKVStorePush(get_kvstore()->get_handle(), 1, &key, &val_handle, priority), 0);
}

inline void KVStore::Push(const std::string& key, const NDArray& val, int priority) {
const char* key_handle = key.c_str();
NDArrayHandle val_handle = val.GetHandle();
CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), 1, &key_handle, &val_handle, priority), 0);
}

inline void KVStore::Push(const std::vector<int>& keys,
const std::vector<NDArray>& vals,
int priority) {
const std::vector<NDArray>& vals, int priority) {
CHECK_EQ(keys.size(), vals.size());
std::vector<NDArrayHandle> val_handles(vals.size());
std::transform(vals.cbegin(), vals.cend(), val_handles.begin(),
Expand All @@ -118,12 +146,37 @@ inline void KVStore::Push(const std::vector<int>& keys,
val_handles.data(), priority), 0);
}

inline void KVStore::Push(const std::vector<std::string>& keys,
const std::vector<NDArray>& vals, int priority) {
CHECK_EQ(keys.size(), vals.size());
std::vector<const char*> key_handles(keys.size());
std::transform(keys.cbegin(), keys.cend(), key_handles.begin(),
[](const std::string& key) {
return key.c_str();
});
std::vector<NDArrayHandle> val_handles(vals.size());
std::transform(vals.cbegin(), vals.cend(), val_handles.begin(),
[](const NDArray& val) {
return val.GetHandle();
});

CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(),
val_handles.data(), priority), 0);
}

inline void KVStore::Pull(int key, NDArray* out, int priority) {
NDArrayHandle out_handle = out->GetHandle();
CHECK_EQ(MXKVStorePull(get_kvstore()->get_handle(), 1, &key, &out_handle, priority), 0);
}

inline void KVStore::Pull(const std::vector<int>& keys, std::vector<NDArray>* outs, int priority) {
inline void KVStore::Pull(const std::string& key, NDArray* out, int priority) {
const char* key_handle = key.c_str();
NDArrayHandle out_handle = out->GetHandle();
CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), 1, &key_handle, &out_handle, priority), 0);
}

inline void KVStore::Pull(const std::vector<int>& keys,
std::vector<NDArray>* outs, int priority) {
CHECK_EQ(keys.size(), outs->size());

std::vector<NDArrayHandle> out_handles(keys.size());
Expand All @@ -136,6 +189,25 @@ inline void KVStore::Pull(const std::vector<int>& keys, std::vector<NDArray>* ou
out_handles.data(), priority), 0);
}

inline void KVStore::Pull(const std::vector<std::string>& keys,
std::vector<NDArray>* outs, int priority) {
CHECK_EQ(keys.size(), outs->size());

std::vector<const char*> key_handles(keys.size());
std::transform(keys.cbegin(), keys.cend(), key_handles.begin(),
[](const std::string& key) {
return key.c_str();
});
std::vector<NDArrayHandle> out_handles(keys.size());
std::transform(outs->cbegin(), outs->cend(), out_handles.begin(),
[](const NDArray& val) {
return val.GetHandle();
});

CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(),
out_handles.data(), priority), 0);
}

inline void KVStore::Updater(int key, NDArrayHandle recv, NDArrayHandle local,
void* handle_) {
Optimizer *opt = static_cast<Optimizer*>(handle_);
Expand Down
7 changes: 5 additions & 2 deletions cpp-package/tests/ci_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,11 @@ cp ../../build/cpp-package/example/mlp_cpu .
cp ../../build/cpp-package/example/mlp_gpu .
./mlp_gpu

cp ../../build/cpp-package/example/test_optimizer .
./test_optimizer
cp ../../build/cpp-package/example/test_optimizer .
./test_optimizer

cp ../../build/cpp-package/example/test_kvstore .
./test_kvstore

cp ../../build/cpp-package/example/test_score .
./test_score 0.93
Expand Down

0 comments on commit 74e71e9

Please sign in to comment.