forked from teticio/Deej-AI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_track2vec.py
67 lines (57 loc) · 2.05 KB
/
test_track2vec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import argparse
import csv
import gensim
import pandas as pd
from gensim.models.callbacks import CallbackAny2Vec
pd.set_option("max_colwidth", 0)
pd.set_option("display.max_rows", 1000)
if __name__ == "__main__":
"""
Entry point for the test_track2vec script.
Test the Track2Vec model.
Args:
--dedup_tracks_file (str): Path to the deduplicated tracks CSV file. Default is "data/tracks_dedup.csv".
--model_file (str): Path to the Track2Vec model file (without extension). Default is "models/track2vec".
Returns:
None
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--dedup_tracks_file",
type=str,
default="data/tracks_dedup.csv",
help="Deduplicated tracks CSV file",
)
parser.add_argument(
"--model_file",
type=str,
default="models/track2vec",
help="Track2Vec model file (without extension)",
)
args = parser.parse_args()
tracks_df = pd.read_csv(
args.dedup_tracks_file,
header=None,
index_col=0,
names=["artist", "title", "url", "count"],
).fillna("")
tracks_df["name"] = tracks_df["artist"] + " - " + tracks_df["title"]
model = gensim.models.Word2Vec.load(args.model_file)
while True:
search = input("Search for a track: ")
track_ids = tracks_df[tracks_df["name"].str.contains(search, case=False)][
["name"]
]
if len(track_ids) > 0:
break
print(track_ids)
track_id = input("Enter track ID: ")
print()
print(
f"\u001b]8;;{tracks_df.loc[track_id]['url']}\u001b\\{tracks_df.loc[track_id]['artist']} - {tracks_df.loc[track_id]['title']}\u001b]8;;\u001b\\"
) # type: ignore
most_similar = model.wv.most_similar(positive=[track_id], topn=8)
for i, similar in enumerate(most_similar):
print(
f"{i + 1}. \u001b]8;;{tracks_df.loc[similar[0]]['url']}\u001b\\{tracks_df.loc[similar[0]]['artist']} - {tracks_df.loc[similar[0]]['title']}\u001b]8;;\u001b\\ ({similar[1]:.2f})"
)