-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathrelated.py
39 lines (30 loc) · 1.12 KB
/
related.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from typing import List, Tuple
from load import load_words
from vectors import Vector, cosine_similarity_normalized
from word import Word, find_word
WORDS_FILE_NAME = 'data/words.vec'
def related_words(base_vector: Vector, words: List[Word]) -> List[Tuple[float, Word]]:
words_with_distance = [(cosine_similarity_normalized(
base_vector, w.vector), w) for w in words]
# We want cosine similarity to be as large as possible (close to 1)
sorted_by_distance = sorted(
words_with_distance, key=lambda t: t[0], reverse=True)
return sorted_by_distance
def print_related_words(words: List[Word], text: str) -> None:
base_word = find_word(text, words)
if not base_word:
print(f"Unknown word: {text}")
return
related_words_list = related_words(base_word.vector, words)
res = [
word.text for (_, word) in
related_words_list
if word.text.lower() != base_word.text.lower()
]
print(', '.join(res[:10]))
words = load_words(WORDS_FILE_NAME)
print("")
while True:
text = input("Words related to: ")
print_related_words(words, text)
print("")