diff --git a/src/index/gpu_raft/gpu_raft_cagra.cc b/src/index/gpu_raft/gpu_raft_cagra.cc index 49da15781..240052eef 100644 --- a/src/index/gpu_raft/gpu_raft_cagra.cc +++ b/src/index/gpu_raft/gpu_raft_cagra.cc @@ -38,7 +38,7 @@ class GpuRaftCagraHybridIndexNode : public GpuRaftCagraIndexNode { Status Train(const DataSetPtr dataset, const Config& cfg) override { const GpuRaftCagraConfig& cagra_cfg = static_cast(cfg); - if (cagra_cfg.adapt_for_cpu) + if (cagra_cfg.adapt_for_cpu.value()) adapt_for_cpu = true; return GpuRaftCagraIndexNode::Train(dataset, cfg); } @@ -111,6 +111,16 @@ class GpuRaftCagraHybridIndexNode : public GpuRaftCagraIndexNode { return result; } + int64_t + Count() const override { + if (!adapt_for_cpu) + return GpuRaftCagraIndexNode::Count(); + if (!hnsw_index_) { + return 0; + } + return hnsw_index_->cur_element_count; + } + Status Deserialize(const BinarySet& binset, const Config& config) override { if (binset.Contains(std::string(this->Type()) + "_cpu")) {