forked from teticio/Deej-AI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinstall_model.py
108 lines (98 loc) · 3.47 KB
/
install_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import argparse
import os
import pickle
import shutil
from utils import read_tracks
if __name__ == "__main__":
"""
Entry point for the install_model script.
Installs model to deej-ai.online app.
Args:
--deejai_model_dir (str): Path to the deej-ai.online model directory. Default is "../deej-ai.online-dev/model".
--mp3tovec_model_file (str): Path to the MP3ToVec model file. Default is "models/mp3tovec.ckpt".
--mp3tovec_file (str): Path to the MP3ToVec file. Default is "models/mp3tovec.p".
--old_deejai_model_dir (str): Optionally merge old track metdata for backwards compatibility. Default is None.
--track2vec_file (str): Path to the Track2Vec file. Default is "models/track2vec.p".
--tracks_file (str): Path to the tracks CSV file. Default is "data/tracks.csv".
Returns:
None
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--deejai_model_dir",
type=str,
default="../deej-ai.online-dev/model",
help="deej-ai.online model directory",
)
parser.add_argument(
"--mp3tovec_model_file",
type=str,
default="models/speccy_model",
help="MP3ToVec model file",
)
parser.add_argument(
"--mp3tovec_file",
type=str,
default="models/mp3tovec.p",
help="MP3ToVec file",
)
parser.add_argument(
"--old_deejai_model_dir",
type=str,
default=None,
help="Merge old track metadata (optional)",
)
parser.add_argument(
"--track2vec_file",
type=str,
default="models/track2vec.p",
help="Track2Vec file",
)
parser.add_argument(
"--tracks_file",
type=str,
default="data/tracks_dedup.csv",
help="Tracks CSV file",
)
args = parser.parse_args()
track2vec = pickle.load(open(f"{args.track2vec_file}", "rb"))
spotify2vec = pickle.load(open(f"{args.mp3tovec_file}", "rb"))
tracks = read_tracks(args.tracks_file)
common_tracks = set(track2vec.keys()).intersection(set(spotify2vec.keys()))
print(f"{len(common_tracks)} tracks")
to_delete = set(track2vec.keys()).difference(common_tracks)
for track_id in to_delete:
del track2vec[track_id]
to_delete = set(spotify2vec.keys()).difference(common_tracks)
for track_id in to_delete:
del spotify2vec[track_id]
spotify_tracks = {}
spotify_urls = {}
if args.old_deejai_model_dir is not None:
spotify_tracks = pickle.load(
open(os.path.join(args.old_deejai_model_dir, "spotify_tracks.p"), "rb")
)
spotify_urls = pickle.load(
open(os.path.join(args.old_deejai_model_dir, "spotify_urls.p"), "rb")
)
for track_id in common_tracks:
spotify_urls[track_id] = tracks[track_id]["url"]
spotify_tracks[
track_id
] = f"{tracks[track_id]['artist']} - {tracks[track_id]['title']}"
pickle.dump(
track2vec, open(os.path.join(args.deejai_model_dir, "tracktovec.p"), "wb")
)
pickle.dump(
spotify2vec, open(os.path.join(args.deejai_model_dir, "spotifytovec.p"), "wb")
)
pickle.dump(
spotify_urls, open(os.path.join(args.deejai_model_dir, "spotify_urls.p"), "wb")
)
pickle.dump(
spotify_tracks,
open(os.path.join(args.deejai_model_dir, "spotify_tracks.p"), "wb"),
)
shutil.copyfile(
args.mp3tovec_model_file, os.path.join(args.deejai_model_dir, "speccy_model")
)