Skip to content

Commit

Permalink
Added inverse square distance weighting to predict
Browse files Browse the repository at this point in the history
  • Loading branch information
daenuprobst committed Dec 23, 2019
1 parent f88abab commit 9aa043d
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 25 deletions.
2 changes: 1 addition & 1 deletion tmap/recipe/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package:
name: tmap
version: 1.0.3
version: 1.0.4

build:
number: 0
Expand Down
9 changes: 8 additions & 1 deletion tmap/tmap/bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -657,13 +657,20 @@ PYBIND11_MODULE(tmap, m)
vecs (:obj:`List` of :obj:`VectorUint`): A list of MinHash vectors that is to be added to the LSH forest
labels (:obj:`VectorUint`) A vector containing labels.
)pbdoc")
.def("predict", &LSHForest::Predict, R"pbdoc(
.def("predict",
&LSHForest::Predict,
py::arg("vecs"),
py::arg("k") = 10,
py::arg("kc") = 10,
py::arg("weighted") = false,
R"pbdoc(
Predict labels of Minhashes using the kNN algorithm (parallelized).
Arguments:
vecs (:obj:`List` of :obj:`VectorUint`): A list of MinHash vectors that is to be added to the LSH forest
k (:obj:`int`) The degree of the kNN algorithm
kc (:obj:`int`) The scalar by which k is multiplied before querying the LSH
weighted (:obj:`bool` Whether distances are used as weights by the knn algorithm)
Returns:
:obj:`VectorUint` The predicted labels
)pbdoc")
Expand Down
75 changes: 53 additions & 22 deletions tmap/tmap/lshforest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,36 +124,67 @@ tmap::LSHForest::Fit(std::vector<std::vector<uint32_t>>& vecs,
std::vector<uint32_t>
tmap::LSHForest::Predict(std::vector<std::vector<uint32_t>>& vecs,
unsigned int k,
unsigned int kc)
unsigned int kc,
bool weighted)
{
std::vector<uint32_t> pred_labels(vecs.size());

if (!weighted) {
#pragma omp parallel for
for (size_t i = 0; i < vecs.size(); i++) {
auto nn = QueryLinearScan(vecs[i], k, kc);

std::sort(nn.begin(), nn.end(), [this](auto &left, auto &right) {
return labels_[left.second] < labels_[right.second];
});

uint32_t max_element = labels_[nn[0].second];
uint32_t max_count = 1;
uint32_t count = 1;


for (size_t j = 1; j < nn.size(); j++) {
if (labels_[nn[j].second] == labels_[nn[j-1].second]) {
count++;
if (count > max_count) {
max_count = count;
max_element = labels_[nn[j].second];
for (size_t i = 0; i < vecs.size(); i++) {
auto nn = QueryLinearScan(vecs[i], k, kc);

std::sort(nn.begin(), nn.end(), [this](auto &left, auto &right) {
return labels_[left.second] < labels_[right.second];
});

uint32_t max_element = labels_[nn[0].second];
uint32_t max_count = 1;
uint32_t count = 1;


for (size_t j = 1; j < nn.size(); j++) {
if (labels_[nn[j].second] == labels_[nn[j-1].second]) {
count++;
if (count > max_count) {
max_count = count;
max_element = labels_[nn[j].second];
}
} else {
count = 1;
}
} else {
count = 1;
}

pred_labels[i] = max_element;
}
} else {
#pragma omp parallel for
for (size_t i = 0; i < vecs.size(); i++) {
auto nn = QueryLinearScan(vecs[i], k, kc);

std::sort(nn.begin(), nn.end(), [this](auto &left, auto &right) {
return labels_[left.second] < labels_[right.second];
});

uint32_t max_element = labels_[nn[0].second];
double max_count = 1.0 / (nn[0].first * nn[0].first);
double count = max_count;


for (size_t j = 1; j < nn.size(); j++) {
if (labels_[nn[j].second] == labels_[nn[j-1].second]) {
count += 1.0 / (nn[j].first * nn[j].first);
if (count > max_count) {
max_count = count;
max_element = labels_[nn[j].second];
}
} else {
count = 1.0 / (nn[j].first * nn[j].first);
}
}

pred_labels[i] = max_element;
pred_labels[i] = max_element;
}
}

return pred_labels;
Expand Down
4 changes: 3 additions & 1 deletion tmap/tmap/lshforest.hh
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,14 @@ public:
* @param vecs A vector containing MinHash vectors.
* @param k The degree of the kNN algorithm.
* @param kc The scalar by which k is multiplied before querying the LSH
* @param weighted Whether distances are used as weights by the knn algorithm
*
* @return std::vector<uint32_t> The predicted labels.
*/
std::vector<uint32_t> Predict(std::vector<std::vector<uint32_t>>& vecs,
unsigned int k = 10,
unsigned int kc = 10);
unsigned int kc = 10,
bool weighted = false);

/**
* @brief Create the index (trees).
Expand Down

0 comments on commit 9aa043d

Please sign in to comment.