-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
. #48
base: master
Are you sure you want to change the base?
. #48
Conversation
ge/models/line.py
Outdated
@@ -205,9 +211,13 @@ def get_embeddings(self,): | |||
|
|||
return self._embeddings | |||
|
|||
def train(self, batch_size=1024, epochs=1, initial_epoch=0, verbose=1, times=1): | |||
def train(self, batch_size=1024, epochs=1, initial_epoch=0, verbose=1, times=1,workers=tf.data.experimental.AUTOTUNE,use_multiprocessing=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的修改是为什么呀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tf.data.experimental.AUTOTUNE可以让程序自动的选择最优的线程并行个数
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
当然用户也可以自己选择workers的数量,这里就是做为默认的设定
ge/models/node2vec.py
Outdated
|
||
|
||
|
||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一大块为啥删除了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改的时候直接复制进来,给替换掉了。。。
ge/models/node2vec.py
Outdated
|
||
def train(self, embed_size=128, window_size=5, workers=3, iter=5, **kwargs): | ||
def __init__(self, graph, walk_length, num_walks, p=1.0, q=1.0,threads=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
新的函数参数比旧的少了。。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def init(self, graph, walk_length, num_walks, p=1.0, q=1.0, workers=1, use_rejection_sampling=0):部分的参数移动到train的部分了,use_rejection_sampling 这个木有实现
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_rejection_sampling 如果需要增加这个的numba实现我可以写一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那相当于把原来有的功能给删除了,这里还是要保持一致的
ge/models/node2vec.py
Outdated
|
||
def get_embeddings(self,): | ||
if self.w2v_model is None: | ||
print("model not train") | ||
return {} | ||
|
||
self._embeddings = {} | ||
for word in self.graph.nodes(): | ||
self._embeddings[word] = self.w2v_model.wv[word] | ||
for word in self.node_dict.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么用self.node_dict替换self.graph?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
csrgraph是以scipy形式存储图的,所以节点的名字变成了0,1,2,3.。。。这样的形式,node_dict是networkx和csrgraph之间的节点名字的对应关系,比如原来节点叫“XXX”可能对应的是新的节点名是1这样
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ge/models/deepwalk.py 这个文件被你删除了。。
另外看下其他文件的一些修改我有些疑问,麻烦看下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为node2vec的接口已经实现了deepwalk了,所以就把原来deep walk去掉了,当p和q都为1的时候,csrgraph内部会自动选择deepwalk对应的优化游走策略
ge/models/node2vec.py
Outdated
|
||
|
||
|
||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改的时候直接复制进来,给替换掉了。。。
ge/models/node2vec.py
Outdated
|
||
def train(self, embed_size=128, window_size=5, workers=3, iter=5, **kwargs): | ||
def __init__(self, graph, walk_length, num_walks, p=1.0, q=1.0,threads=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def init(self, graph, walk_length, num_walks, p=1.0, q=1.0, workers=1, use_rejection_sampling=0):部分的参数移动到train的部分了,use_rejection_sampling 这个木有实现
ge/models/node2vec.py
Outdated
|
||
def train(self, embed_size=128, window_size=5, workers=3, iter=5, **kwargs): | ||
def __init__(self, graph, walk_length, num_walks, p=1.0, q=1.0,threads=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_rejection_sampling 如果需要增加这个的numba实现我可以写一下
ge/models/node2vec.py
Outdated
|
||
def get_embeddings(self,): | ||
if self.w2v_model is None: | ||
print("model not train") | ||
return {} | ||
|
||
self._embeddings = {} | ||
for word in self.graph.nodes(): | ||
self._embeddings[word] = self.w2v_model.wv[word] | ||
for word in self.node_dict.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
csrgraph是以scipy形式存储图的,所以节点的名字变成了0,1,2,3.。。。这样的形式,node_dict是networkx和csrgraph之间的节点名字的对应关系,比如原来节点叫“XXX”可能对应的是新的节点名是1这样
ge/models/line.py
Outdated
@@ -205,9 +211,13 @@ def get_embeddings(self,): | |||
|
|||
return self._embeddings | |||
|
|||
def train(self, batch_size=1024, epochs=1, initial_epoch=0, verbose=1, times=1): | |||
def train(self, batch_size=1024, epochs=1, initial_epoch=0, verbose=1, times=1,workers=tf.data.experimental.AUTOTUNE,use_multiprocessing=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
当然用户也可以自己选择workers的数量,这里就是做为默认的设定
deepwalk去掉的话会让用户有困惑的。建议保留deepwalk的接口,底层可以调用node2vec |
.