Skip to content

Commit

Permalink
Bound ANN hamming scores between 0.0 and 1.0, closes #838
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Dec 10, 2024
1 parent 6cc886a commit d39b85e
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 8 deletions.
27 changes: 22 additions & 5 deletions src/python/txtai/ann/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,8 @@ def search(self, queries, limit):
# Map results to [(id, score)]
results = []
for x, score in enumerate(scores):
# Transform scores
score = [1.0 - (x / (self.config["dimensions"] * 8)) for x in score.tolist()] if self.qbits else score.tolist()

# Add results
results.append(list(zip(ids[x].tolist(), score)))
# Transform scores and add results
results.append(list(zip(ids[x].tolist(), self.scores(score))))

return results

Expand Down Expand Up @@ -214,3 +211,23 @@ def nprobe(self):

default = 6 if count <= 5000 else round(self.cells(count) / 16)
return self.setting("nprobe", default)

def scores(self, scores):
"""
Calculates the index score from the input score. This method returns the hamming score
(1.0 - (hamming distance / total number of bits)) for binary indexes and the input
scores otherwise.
Args:
scores: input scores
Returns:
index scores
"""

# Calculate hamming score, bound between 0.0 - 1.0
if self.qbits:
return np.clip(1.0 - (scores / (self.config["dimensions"] * 8)), 0.0, 1.0).tolist()

# Standard scoring
return scores.tolist()
5 changes: 3 additions & 2 deletions src/python/txtai/ann/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, config):

# Array function definitions
self.all, self.cat, self.dot, self.zeros = np.all, np.concatenate, np.dot, np.zeros
self.argsort, self.xor = np.argsort, np.bitwise_xor
self.argsort, self.xor, self.clip = np.argsort, np.bitwise_xor, np.clip

# Scalar quantization
quantize = self.config.get("quantize")
Expand Down Expand Up @@ -160,4 +160,5 @@ def hammingscore(self, queries):
delta = self.totype(delta, np.int64)

# Calculate score as 1.0 - percentage of different bits
return 1.0 - (table[delta].sum(axis=2) / (self.config["dimensions"] * 8))
# Bound score from 0 to 1
return self.clip(1.0 - (table[delta].sum(axis=2) / (self.config["dimensions"] * 8)), 0.0, 1.0)
2 changes: 1 addition & 1 deletion src/python/txtai/ann/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, config):

# Define array functions
self.all, self.cat, self.dot, self.zeros = torch.all, torch.cat, torch.mm, torch.zeros
self.argsort, self.xor = torch.argsort, torch.bitwise_xor
self.argsort, self.xor, self.clip = torch.argsort, torch.bitwise_xor, torch.clip

def tensor(self, array):
# Convert array to Tensor
Expand Down

0 comments on commit d39b85e

Please sign in to comment.