diff --git a/tmap/recipe/meta.yaml b/tmap/recipe/meta.yaml index cc8defa..82e3600 100644 --- a/tmap/recipe/meta.yaml +++ b/tmap/recipe/meta.yaml @@ -1,6 +1,6 @@ package: name: tmap - version: 1.0.3 + version: 1.0.4 build: number: 0 diff --git a/tmap/tmap/bindings.cc b/tmap/tmap/bindings.cc index 4d14bd7..eb61208 100644 --- a/tmap/tmap/bindings.cc +++ b/tmap/tmap/bindings.cc @@ -657,13 +657,20 @@ PYBIND11_MODULE(tmap, m) vecs (:obj:`List` of :obj:`VectorUint`): A list of MinHash vectors that is to be added to the LSH forest labels (:obj:`VectorUint`) A vector containing labels. )pbdoc") - .def("predict", &LSHForest::Predict, R"pbdoc( + .def("predict", + &LSHForest::Predict, + py::arg("vecs"), + py::arg("k") = 10, + py::arg("kc") = 10, + py::arg("weighted") = false, + R"pbdoc( Predict labels of Minhashes using the kNN algorithm (parallelized). Arguments: vecs (:obj:`List` of :obj:`VectorUint`): A list of MinHash vectors that is to be added to the LSH forest k (:obj:`int`) The degree of the kNN algorithm kc (:obj:`int`) The scalar by which k is multiplied before querying the LSH + weighted (:obj:`bool` Whether distances are used as weights by the knn algorithm) Returns: :obj:`VectorUint` The predicted labels )pbdoc") diff --git a/tmap/tmap/lshforest.cc b/tmap/tmap/lshforest.cc index 6f5112a..974d7d3 100644 --- a/tmap/tmap/lshforest.cc +++ b/tmap/tmap/lshforest.cc @@ -124,36 +124,67 @@ tmap::LSHForest::Fit(std::vector>& vecs, std::vector tmap::LSHForest::Predict(std::vector>& vecs, unsigned int k, - unsigned int kc) + unsigned int kc, + bool weighted) { std::vector pred_labels(vecs.size()); + if (!weighted) { #pragma omp parallel for - for (size_t i = 0; i < vecs.size(); i++) { - auto nn = QueryLinearScan(vecs[i], k, kc); - - std::sort(nn.begin(), nn.end(), [this](auto &left, auto &right) { - return labels_[left.second] < labels_[right.second]; - }); - - uint32_t max_element = labels_[nn[0].second]; - uint32_t max_count = 1; - uint32_t count = 1; - - - for (size_t j = 1; j < nn.size(); j++) { - if (labels_[nn[j].second] == labels_[nn[j-1].second]) { - count++; - if (count > max_count) { - max_count = count; - max_element = labels_[nn[j].second]; + for (size_t i = 0; i < vecs.size(); i++) { + auto nn = QueryLinearScan(vecs[i], k, kc); + + std::sort(nn.begin(), nn.end(), [this](auto &left, auto &right) { + return labels_[left.second] < labels_[right.second]; + }); + + uint32_t max_element = labels_[nn[0].second]; + uint32_t max_count = 1; + uint32_t count = 1; + + + for (size_t j = 1; j < nn.size(); j++) { + if (labels_[nn[j].second] == labels_[nn[j-1].second]) { + count++; + if (count > max_count) { + max_count = count; + max_element = labels_[nn[j].second]; + } + } else { + count = 1; } - } else { - count = 1; } + + pred_labels[i] = max_element; } + } else { +#pragma omp parallel for + for (size_t i = 0; i < vecs.size(); i++) { + auto nn = QueryLinearScan(vecs[i], k, kc); + + std::sort(nn.begin(), nn.end(), [this](auto &left, auto &right) { + return labels_[left.second] < labels_[right.second]; + }); + + uint32_t max_element = labels_[nn[0].second]; + double max_count = 1.0 / (nn[0].first * nn[0].first); + double count = max_count; + + + for (size_t j = 1; j < nn.size(); j++) { + if (labels_[nn[j].second] == labels_[nn[j-1].second]) { + count += 1.0 / (nn[j].first * nn[j].first); + if (count > max_count) { + max_count = count; + max_element = labels_[nn[j].second]; + } + } else { + count = 1.0 / (nn[j].first * nn[j].first); + } + } - pred_labels[i] = max_element; + pred_labels[i] = max_element; + } } return pred_labels; diff --git a/tmap/tmap/lshforest.hh b/tmap/tmap/lshforest.hh index 057aeaa..b0ecf53 100644 --- a/tmap/tmap/lshforest.hh +++ b/tmap/tmap/lshforest.hh @@ -162,12 +162,14 @@ public: * @param vecs A vector containing MinHash vectors. * @param k The degree of the kNN algorithm. * @param kc The scalar by which k is multiplied before querying the LSH + * @param weighted Whether distances are used as weights by the knn algorithm * * @return std::vector The predicted labels. */ std::vector Predict(std::vector>& vecs, unsigned int k = 10, - unsigned int kc = 10); + unsigned int kc = 10, + bool weighted = false); /** * @brief Create the index (trees).