-
Notifications
You must be signed in to change notification settings - Fork 3
/
relative_album_caculate.py
121 lines (107 loc) · 4.06 KB
/
relative_album_caculate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from .din_model import Model
import pickle
from .tdm import get_data
import numpy as np
from ximalaya_brain_jobs.model.util import upload_result_to_hdfs
def load_data():
with open('/home/dev/data/andrew.zhu/tdm/data_flow/final_tree.pkl', 'wb') as f:
tree = pickle.load(f)
return tree
def embedding_index(embeddings, space='ip'):
"""
通过 hnswlib 建立 item向量索引,从而快速进行最近邻查找
:param sess:
:param space:
:return:
"""
# embeddings = sess.run("vec/clip_V:0")
print("embeddings type is %s" % type(embeddings))
dim = embeddings.shape[1]
print('embeddings shape is %s' % str(embeddings.shape))
# # 建立索引
import hnswlib
nmsl_index = hnswlib.Index(space=space, dim=dim)
nmsl_index.init_index(max_elements=100000, ef_construction=200)
nmsl_index.set_ef(50)
nmsl_index.add_items(embeddings)
return nmsl_index
def get_embedding():
data_train, data_validate, cache = get_data()
print('data_train len %d'% len(data_train))
print('data_validate len %d' % len(data_validate))
# uid,ts,item_list,behavior_list + mask
_, _, tree = cache
item_ids, item_size ,node_size = tree.items, len(tree.items),tree.node_size
print('item_size %d' % item_size)
print('node_size %d' % node_size)
model = Model(item_size, node_size,10)
with tf.Session() as sess:
saver = tf.train.Saver()
saver.restore(sess, "/home/dev/data/andrew.zhu/tdm/model/tdm.ckpt")
item_embeddings = sess.run(model.item_emb_w)
# print(item_embeddings.tolist())
return np.array(item_embeddings)
# print(item_embeddings.tolist())
# return item_embeddings
def get_dict():
with open('/home/dev/data/andrew.zhu/tdm/data_flow/sample.pkl', 'rb') as f:
data_train = pickle.load(f)
data_validate = pickle.load(f)
cache = pickle.load(f)
return cache
def get_item_similar_item(index_item_dict, nmsl, embeddings, save_path, file_name, topK=30):
"""
获取item相似item
:param index_album_dict:
:param nmsl:
:param embeddings:
:param topK:
:return:
"""
print("top k is %s" % topK)
labels, distance = nmsl.knn_query(embeddings, k=topK)
print(labels.shape)
print(distance.shape)
item_length = len(index_item_dict)
print("sim album num is %d" % item_length)
result = {}
for i in range(item_length):
# print(i)
item_id = index_item_dict[i]
label = labels[i]
items = []
for j in label.tolist():
try:
items.append(index_item_dict[int(j)])
except:
print('-- %s -- %s' % (j, type(j)))
# albums = [index_album_dict[j] for j in label.tolist()]
similar_item = []
for k in range(topK):
similar_item.append(str(items[k]))
sim_item = '|'.join(similar_item)
result[item_id] = sim_item
# print("album_id is %s" % album_id)
# print("sim album is %s" % sim_album)
from pandas.core.frame import DataFrame
re = DataFrame.from_dict(result, orient='index', columns=['re_items'])
re = re.reset_index().rename(columns={'index': 'item_id'})
# re.rename(columns={0: 'album_id', 1: 're_albums'}, inplace=True)
print(re.head(5))
re.to_csv(save_path + file_name, index=True)
upload_result_to_hdfs("/user/dev/andrew.zhu/test",
save_path + file_name)
return result
import tensorflow as tf
def main():
# tree = load_data()
# save_path = '/home/dev/data/andrew.zhu/tdm/model/tdm.ckpt'
# item_ids, item_size, node_size = tree.items, len(tree.items), tree.node_size
item_embedding = get_embedding()
item_embedding_index = embedding_index(item_embedding)
#
(user_dict, item_dict, random_tree) = get_dict()
item_sim_item_save_path='/home/dev/data/andrew.zhu/tdm/data_flow/'
file_name='album_sim'
item_dict = dict(zip(item_dict.values(), item_dict.keys()))
get_item_similar_item(item_dict, item_embedding_index, item_embedding, item_sim_item_save_path, file_name, 4)