-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
the gammagl code of FedHGNN #199
Open
ZeroerWiser
wants to merge
2
commits into
BUPT-GAMMA:main
Choose a base branch
from
ZeroerWiser:submit_fedhgnn
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# FedHGNN | ||
The source code of WWW 2024 paper "Federated Heterogeneous Graph Neural Network for Privacy-preserving Recommendation" | ||
|
||
|
||
# Requirements | ||
``` | ||
dgl==1.1.0+cu113 | ||
numpy==1.21.6 | ||
ogb==1.3.6 | ||
python==3.7.13 | ||
scikit-learn==1.0.2 | ||
scipy==1.7.3 | ||
torch==1.12.1+cu113 | ||
torchaudio==0.12.1+cu113 | ||
torchvision==0.13.1+cu113 | ||
``` | ||
|
||
|
||
# Easy Run | ||
``` | ||
cd ./codes/FedHGNN | ||
python main.py --dataset acm --shared_num 20 --p1 1 --p2 2 --lr 0.01 --device cuda:0 | ||
``` | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
import copy | ||
import tensorlayerx as tlx | ||
import os | ||
os.environ['TL_BACKEND'] = 'torch' | ||
import random | ||
|
||
from local_differential_privacy_library import * | ||
from util import * | ||
from random import sample | ||
from sklearn.metrics.pairwise import cosine_similarity | ||
|
||
from warnings import simplefilter | ||
simplefilter(action='ignore', category=FutureWarning) | ||
|
||
|
||
|
||
class Client(tlx.nn.Module): | ||
def __init__(self, user_id, item_id, args): | ||
super().__init__() | ||
self.device = args.device | ||
self.user_id = user_id | ||
self.item_id = item_id #list | ||
#self.semantic_neighbors = semantic_neighbors | ||
|
||
|
||
def negative_sample(self, total_item_num): | ||
'''生成item负样本集合''' | ||
#从item列表里随机选取item作为user的负样本 | ||
item_neg_ind = [] | ||
#item_neg_ind和item_id数量一样 | ||
for _ in self.item_id: | ||
neg_item = np.random.randint(1, total_item_num) | ||
while neg_item in self.item_id: | ||
neg_item = np.random.randint(1, total_item_num) | ||
item_neg_ind.append(neg_item) | ||
'''生成item负样本集合end''' | ||
return item_neg_ind | ||
|
||
def negative_sample_with_augment(self, total_item_num, sampled_items): | ||
item_set = self.item_id+sampled_items | ||
'''生成item负样本集合''' | ||
#从item列表里随机选取item作为user的负样本 | ||
item_neg_ind = [] | ||
#item_neg_ind和item_id数量一样 | ||
for _ in item_set: | ||
neg_item = np.random.randint(1, total_item_num) | ||
while neg_item in item_set: | ||
neg_item = np.random.randint(1, total_item_num) | ||
item_neg_ind.append(neg_item) | ||
'''生成item负样本集合end''' | ||
return item_neg_ind | ||
|
||
def sample_item_augment(self, item_num): | ||
ls = [i for i in range(item_num) if i not in self.item_id] | ||
sampled_items = sample(ls, 5) | ||
|
||
return sampled_items | ||
|
||
|
||
def perturb_adj(self, value, label_author, author_label, label_count, shared_knowledge_rep, eps1, eps2): | ||
#print(value.shape) #1,17431 | ||
#此用户的item共可分成多少个groups | ||
groups = {} | ||
for item in self.item_id: | ||
group = author_label[item] | ||
if(group not in groups.keys()): | ||
groups[group] = [item] | ||
else: | ||
groups[group].append(item) | ||
|
||
'''step1:EM''' | ||
num_groups = len(groups) | ||
quality = np.array([0.0]*len(label_author)) | ||
G_s_u = groups.keys() | ||
if(len(G_s_u)==0):#此用户没有交互的item,则各个位置quality平均 | ||
for group in label_author.keys(): | ||
quality[group] = 1 | ||
num_groups = 1 | ||
else: | ||
for group in label_author.keys(): | ||
qua = max([(cosine_similarity(shared_knowledge_rep[g], shared_knowledge_rep[group])+1)/2.0 for g in G_s_u]) | ||
quality[group] = qua | ||
|
||
EM_eps = eps1/num_groups | ||
EM_p = EM_eps*quality/2 #隐私预算1 eps | ||
EM_p = softmax(EM_p) | ||
|
||
#按照概率选择group | ||
select_group_keys = np.random.choice(range(len(label_author)), size = len(groups), replace = False, p = EM_p) | ||
select_group_keys_temp = list(select_group_keys) | ||
degree_list = [len(v) for _, v in groups.items()] | ||
new_groups = {} | ||
|
||
for key in select_group_keys:#先把存在于当前用户的shared knowledge拿出来 | ||
key_temp = key | ||
if(key_temp in groups.keys()): | ||
new_groups[key_temp] = groups[key_temp] | ||
degree_list.remove(len(groups[key_temp])) | ||
select_group_keys_temp.remove(key_temp) | ||
|
||
for key in select_group_keys_temp:#不存在的随机采样交互的item,并保持度一致 | ||
key_temp = key | ||
cur_degree = degree_list[0] | ||
if(len(label_author[key_temp]) >= cur_degree): | ||
new_groups[key_temp] = random.sample(label_author[key_temp], cur_degree) | ||
else:#需要的度比当前group的size大,则将度设为当前group的size | ||
new_groups[key_temp] = label_author[key_temp] | ||
degree_list.remove(cur_degree) | ||
|
||
groups = new_groups | ||
value = np.zeros_like(value)#一定要更新value | ||
for group_id, items in groups.items(): | ||
value[:,items] = 1 | ||
'''pure em''' | ||
#value_rr = value | ||
|
||
|
||
|
||
'''step2:rr''' | ||
all_items = set(range(len(author_label))) | ||
select_items = [] | ||
for group_id, items in groups.items(): | ||
select_items.extend(label_author[group_id]) | ||
mask_rr = list(all_items - set(select_items)) | ||
|
||
'''rr''' | ||
value_rr = perturbation_test(value, 1-value, eps2) | ||
#print(np.sum(value_rr)) 4648 | ||
value_rr[:, mask_rr] = 0 | ||
# #print(np.sum(value_rr)) 469 | ||
# | ||
'''dprr''' | ||
for group_id, items in groups.items(): | ||
degree = len(items) | ||
n = len(label_author[group_id]) | ||
p = eps2p(eps2) | ||
q = degree/(degree*(2*p-1) + (n)*(1-p)) | ||
rnd = np.random.random(value_rr.shape) | ||
#原来是0的一定还是0,原来是1的以概率q保持1,以达到degree减少 | ||
dprr_results = np.where(rnd<q, value_rr, np.zeros((value_rr.shape))) | ||
value_rr[:, label_author[group_id]] = dprr_results[:, label_author[group_id]] | ||
|
||
|
||
#print('....') | ||
#print(self.item_id) | ||
#print(value_rr.nonzero()[1]) | ||
return value_rr | ||
|
||
|
||
|
||
|
||
|
||
def update(self, model_user, model_item): | ||
self.model_user = copy.deepcopy(model_user) | ||
self.model_item = copy.deepcopy(model_item) | ||
# self.item_emb.weight.data = Parameter(aggr_param['item'].weight.data.clone()) | ||
|
||
|
||
def train_(self, hg, user_emb, item_emb): | ||
total_item_num = item_emb.embeddings.data.shape[0]#item_emb.weight.shape[0] | ||
#user_emb = torch.clone(user_emb.weight).detach() | ||
user_emb = tlx.identity(user_emb.embeddings.data).detach() | ||
item_emb = tlx.identity(item_emb.embeddings.data).detach() | ||
user_emb.requires_grad = True | ||
item_emb.requires_grad = True | ||
user_emb.grad = tlx.zeros_like(user_emb) | ||
item_emb.grad = tlx.zeros_like(item_emb) | ||
hg_user = hg[0] | ||
hg_item = hg[1] | ||
|
||
self.model_user.train() | ||
self.model_item.train() | ||
|
||
#sample_item_augment | ||
sampled_item = self.sample_item_augment(total_item_num) | ||
item_neg_id = self.negative_sample_with_augment(total_item_num, sampled_item) | ||
#item_neg_id = self.negative_sample(total_item_num) | ||
|
||
logits_user = self.model_user(hg_user, user_emb)#+user_emb | ||
logits_item = self.model_item(hg_item, item_emb)#+item_emb | ||
|
||
cur_user = logits_user[self.user_id] | ||
#cur_item_pos = logits_item[self.item_id] | ||
cur_item_pos = logits_item[self.item_id+sampled_item] | ||
cur_item_neg = logits_item[item_neg_id] | ||
|
||
#pos_scores = torch.sum(cur_user * cur_item_pos, dim=-1) | ||
#neg_scores = torch.sum(cur_user * cur_item_neg, dim=-1) | ||
pos_scores = tlx.reduce_sum(cur_user * cur_item_pos, axis=-1) | ||
neg_scores = tlx.reduce_sum(cur_user * cur_item_neg, axis=-1) | ||
loss = -(pos_scores - neg_scores).sigmoid().log().sum() | ||
|
||
|
||
self.model_user.zero_grad() | ||
self.model_item.zero_grad() | ||
|
||
loss.backward() | ||
#self.optimizer.step() | ||
|
||
#grad | ||
model_grad_user = [] | ||
model_grad_item = [] | ||
for param in list(self.model_user.parameters()): | ||
grad = param.grad# | ||
model_grad_user.append(grad) | ||
for param in list(self.model_item.parameters()): | ||
grad = param.grad# | ||
model_grad_item.append(grad) | ||
|
||
mask_item = item_emb.grad.sum(-1)!=0#直接通过grad!=0 | ||
updated_items = np.array(range(item_emb.shape[0]))[mask_item.cpu()]#list(set(self.item_id + item_neg_id)) | ||
#print(updated_items) | ||
item_grad = item_emb.grad[updated_items, :]# | ||
|
||
|
||
mask_user = user_emb.grad.sum(-1)!=0 | ||
updated_users = np.array(range(user_emb.shape[0]))[mask_user.cpu()]#list(set([self.user_id] + self.semantic_neighbors)) | ||
#print(len(updated_users)) | ||
user_grad = user_emb.grad[updated_users, :]# | ||
#print(user_grad) | ||
# torch.cuda.empty_cache() | ||
|
||
|
||
return {'user': (user_grad, updated_users), 'item' : (item_grad, updated_items), 'model': (model_grad_user, model_grad_item)}, \ | ||
loss.detach() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import os | ||
from tensorlayerx.backend.ops.torch_backend import topk | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不能这样导入topk, |
||
os.environ['TL_BACKEND'] = 'torch' | ||
import numpy as np | ||
import math | ||
|
||
def getP(ranklist, gtItems): | ||
p = 0 | ||
for item in ranklist: | ||
if item in gtItems: | ||
p += 1 | ||
return p * 1.0 / len(ranklist) | ||
|
||
def getR(ranklist, gtItems): | ||
r = 0 | ||
for item in ranklist: | ||
if item in gtItems: | ||
r += 1 | ||
return r * 1.0 / len(gtItems) | ||
|
||
|
||
def getHitRatio(ranklist, gtItem): | ||
for item in ranklist: | ||
if item == gtItem: | ||
return 1 | ||
return 0 | ||
|
||
def getDCG(ranklist, gtItems): | ||
dcg = 0.0 | ||
for i in range(len(ranklist)): | ||
item = ranklist[i] | ||
if item in gtItems: | ||
dcg += 1.0 / math.log(i + 2) | ||
return dcg | ||
|
||
def getIDCG(ranklist, gtItems): | ||
idcg = 0.0 | ||
i = 0 | ||
for item in ranklist: | ||
if item in gtItems: | ||
idcg += 1.0 / math.log(i + 2) | ||
i += 1 | ||
return idcg | ||
|
||
def getNDCG(ranklist, gtItems): | ||
dcg = getDCG(ranklist, gtItems) | ||
idcg = getIDCG(ranklist, gtItems) | ||
if idcg == 0: | ||
return 0 | ||
return dcg / idcg | ||
|
||
|
||
'''下面是两个大指标(recall或ndcg),4个小指标的计算代码''' | ||
#指标1 top_k=5 或 10 得到recall@5 或recall@10 指标 | ||
def evaluate_recall(rating, ground_truth, top_k): | ||
_, rating_k = topk(rating, top_k) | ||
rating_k = rating_k.cpu().tolist() | ||
|
||
hit = 0 # | ||
for i, v in enumerate(rating_k): | ||
if v in ground_truth: | ||
hit += 1 | ||
|
||
recall = hit / len(ground_truth) | ||
return recall | ||
|
||
#指标2 top_k = 5 或 10 得到ndcg@5 或ndcg@10 指标 | ||
def evaluate_ndcg(rating, ground_truth, top_k):#参照NDCG的定义 | ||
_, rating_k = topk(rating, top_k)#values, indices | ||
rating_k = rating_k.cpu().tolist() #indices | ||
dcg, idcg = 0., 0. | ||
|
||
for i, v in enumerate(rating_k): | ||
if i < len(ground_truth):#前len()个是真实交互的 | ||
idcg += (1 / np.log2(2 + i))#这里相关性为0或1(真实交互为1,未交互为0) | ||
if v in ground_truth: | ||
dcg += (1 / np.log2(2 + i)) | ||
|
||
ndcg = dcg / idcg | ||
return ndcg |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要提供模型的运行结果以及运行命令,可以参考
gcn
模型的readme.md
文档的内容