diff --git a/src/common/raft/raft_utils.cc b/src/common/raft/raft_utils.cc index 74076cff2..14e906fdc 100644 --- a/src/common/raft/raft_utils.cc +++ b/src/common/raft/raft_utils.cc @@ -16,6 +16,16 @@ gpu_device_manager::choose_with_load(size_t load) { return std::distance(memory_load_.begin(), it); } +void +gpu_device_manager::release_load(int device_id, size_t load) { + if (size_t(device_id) < memory_load_.size()) { + std::lock_guard lock(mtx_); + memory_load_[device_id] -= load; + } else { + LOG_KNOWHERE_WARNING_ << "please check device id " << device_id; + } +} + gpu_device_manager::gpu_device_manager() { int device_counts; try { diff --git a/src/common/raft/raft_utils.h b/src/common/raft/raft_utils.h index 3b157d32b..b7d0ae33e 100644 --- a/src/common/raft/raft_utils.h +++ b/src/common/raft/raft_utils.h @@ -156,6 +156,8 @@ class gpu_device_manager { random_choose() const; int choose_with_load(size_t load); + void + release_load(int device_id, size_t load); private: gpu_device_manager(); @@ -192,3 +194,7 @@ set_mem_pool_size(size_t init_size, size_t max_size) { do { \ x = raft_utils::gpu_device_manager::instance().choose_with_load(load); \ } while (0) +#define RELEASE_DEVICE(x, load) \ + do { \ + raft_utils::gpu_device_manager::instance().release_load(x, load); \ + } while (0) diff --git a/src/index/ivf_raft/ivf_raft.cuh b/src/index/ivf_raft/ivf_raft.cuh index c4bca64b7..02eee67b7 100644 --- a/src/index/ivf_raft/ivf_raft.cuh +++ b/src/index/ivf_raft/ivf_raft.cuh @@ -518,6 +518,7 @@ class RaftIvfIndexNode : public IndexNode { // status is.read((char*)(&this->device_id_), sizeof(this->device_id_)); MIN_LOAD_CHOOSE_DEVICE_WITH_ASSIGN(this->device_id_, binary->size); + load_ = binary->size; raft_utils::device_setter with_this_device{this->device_id_}; raft_utils::init_gpu_resources(); @@ -574,11 +575,17 @@ class RaftIvfIndexNode : public IndexNode { return knowhere::IndexEnum::INDEX_RAFT_IVFPQ; } } + virtual ~RaftIvfIndexNode() { + if (device_id_ >= 0) { + RELEASE_DEVICE(this->device_id_, this->load_); + } + } private: int device_id_ = -1; int64_t dim_ = 0; int64_t counts_ = 0; + size_t load_ = 0; std::optional gpu_index_; template diff --git a/tests/ut/test_get_vector.cc b/tests/ut/test_get_vector.cc index f4cab60c4..182be9702 100644 --- a/tests/ut/test_get_vector.cc +++ b/tests/ut/test_get_vector.cc @@ -18,7 +18,6 @@ #include "knowhere/comp/knowhere_config.h" #include "knowhere/factory.h" #include "utils.h" - TEST_CASE("Test Binary Get Vector By Ids", "[Binary GetVectorByIds]") { using Catch::Approx;