-
Notifications
You must be signed in to change notification settings - Fork 355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ModelZoo] Support Co_Action Network #344
Open
aiden-law-tian
wants to merge
8
commits into
DeepRec-AI:main
Choose a base branch
from
aiden-law-tian:aiden
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 2 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
32ce9e7
[ModelZoo] Support Co_Action Network
aiden-law-tian d6d5be5
[ModelZoo] Support Co_Action Network
aiden-law-tian b438c64
[ModelZoo] Support Co_Action Net
aiden-law-tian c5df688
[ModelZoo] Support FNN
aiden-law-tian 4e4b400
[ModelZoo] Support Co_Action Network
aiden-law-tian b613472
[ModelZoo] Support FNN
aiden-law-tian 0cc1389
[ModelZoo] Support FwFM
aiden-law-tian 6421ff4
[ModelZoo] Support PNN
aiden-law-tian File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Co-Action Network | ||
|
||
Implementation of paper "CAN: Revisiting Feature Co-Action for Click Through Rate Prediction". | ||
|
||
paper: [arxiv (to be released)]() | ||
|
||
## Installation | ||
dependences: | ||
|
||
tensorflow:1.4.1 | ||
|
||
python: 2.7 | ||
|
||
higher version of tensorflow and python3 will be supported soon! | ||
|
||
## Getting Started | ||
training: | ||
|
||
CUDA_VISIBLE_DEVICES=0 python train.py train {model} | ||
|
||
model: CAN,Cartesion,PNN, etc. (check the train.py) | ||
|
||
## Citation | ||
## Contact | ||
## License |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Co-Action Network | ||
|
||
Implementation of paper "CAN: Revisiting Feature Co-Action for Click Through Rate Prediction". | ||
|
||
paper: [arxiv (to be released)]() | ||
|
||
## Installation | ||
dependences: | ||
|
||
tensorflow:1.4.1 | ||
|
||
python: 2.7 | ||
|
||
higher version of tensorflow and python3 will be supported soon! | ||
|
||
## Getting Started | ||
training: | ||
|
||
CUDA_VISIBLE_DEVICES=0 python train.py train {model} | ||
|
||
model: CAN,Cartesion,PNN, etc. (check the train.py) | ||
|
||
## Citation | ||
## Contact | ||
## License |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
export PATH="~/anaconda4/bin:$PATH" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 移除本地开发变量 |
||
wget http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Books.json.gz | ||
wget http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/meta_Books.json.gz | ||
gunzip reviews_Books.json.gz | ||
gunzip meta_Books.json.gz | ||
python script/process_data.py meta_Books.json reviews_Books.json | ||
python script/local_aggretor.py | ||
python script/split_by_user.py | ||
python script/generate_voc.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import tensorflow as tf | ||
|
||
def dice(_x, axis=-1, epsilon=0.000000001, name=''): | ||
with tf.variable_scope(name, reuse=tf.AUTO_REUSE): | ||
alphas = tf.get_variable('alpha'+name, _x.get_shape()[-1], | ||
initializer=tf.constant_initializer(0.0), | ||
dtype=tf.float32) | ||
input_shape = list(_x.get_shape()) | ||
|
||
reduction_axes = list(range(len(input_shape))) | ||
del reduction_axes[axis] | ||
broadcast_shape = [1] * len(input_shape) | ||
broadcast_shape[axis] = input_shape[axis] | ||
|
||
# case: train mode (uses stats of the current batch) | ||
mean = tf.reduce_mean(_x, axis=reduction_axes) | ||
brodcast_mean = tf.reshape(mean, broadcast_shape) | ||
std = tf.reduce_mean(tf.square(_x - brodcast_mean) + epsilon, axis=reduction_axes) | ||
std = tf.sqrt(std) | ||
brodcast_std = tf.reshape(std, broadcast_shape) | ||
x_normed = (_x - brodcast_mean) / (brodcast_std + epsilon) | ||
# x_normed = tf.layers.batch_normalization(_x, center=False, scale=False) | ||
x_p = tf.sigmoid(x_normed) | ||
|
||
|
||
return alphas * (1.0 - x_p) * _x + x_p * _x | ||
|
||
def parametric_relu(_x): | ||
alphas = tf.get_variable('alpha', _x.get_shape()[-1], | ||
initializer=tf.constant_initializer(0.0), | ||
dtype=tf.float32) | ||
pos = tf.nn.relu(_x) | ||
neg = alphas * (_x - abs(_x)) * 0.5 | ||
|
||
return pos + neg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
|
||
ckpt = tf.train.get_checkpoint_state("./ckpt_path/").model_checkpoint_path | ||
saver = tf.train.import_meta_graph(ckpt+'.meta') | ||
variables = tf.trainable_variables() | ||
total_parameters = 0 | ||
for variable in variables: | ||
shape = variable.get_shape() | ||
variable_parameters = 1 | ||
for dim in shape: | ||
variable_parameters *= dim.value | ||
total_parameters += variable_parameters | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
import numpy | ||
import json | ||
import _pickle as cPickle | ||
import random | ||
|
||
import gzip | ||
|
||
import data.script.shuffle | ||
|
||
def unicode_to_utf8(d): | ||
return dict((key.encode("UTF-8"), value) for (key,value) in d.items()) | ||
def dict_unicode_to_utf8(d): | ||
return dict(((key[0].encode("UTF-8"), key[1].encode("UTF-8")), value) for (key,value) in d.items()) | ||
|
||
def load_dict(filename): | ||
try: | ||
with open(filename, 'rb') as f: | ||
return unicode_to_utf8(json.load(f)) | ||
except: | ||
try: | ||
with open(filename, 'rb') as f: | ||
return unicode_to_utf8(cPickle.load(f)) | ||
except: | ||
with open(filename, 'rb') as f: | ||
return dict_unicode_to_utf8(cPickle.load(f)) | ||
|
||
|
||
def fopen(filename, mode='r'): | ||
if filename.endswith('.gz'): | ||
return gzip.open(filename, mode) | ||
return open(filename, mode) | ||
|
||
|
||
class DataIterator: | ||
|
||
def __init__(self, source, | ||
uid_voc, | ||
mid_voc, | ||
cat_voc, | ||
batch_size=128, | ||
maxlen=100, | ||
skip_empty=False, | ||
shuffle_each_epoch=False, | ||
sort_by_length=True, | ||
max_batch_size=20, | ||
minlen=None, | ||
label_type=1): | ||
if shuffle_each_epoch: | ||
self.source_orig = source | ||
self.source = shuffle.main(self.source_orig, temporary=True) | ||
else: | ||
self.source = fopen(source, 'r') | ||
self.source_dicts = [] | ||
for source_dict in [uid_voc, mid_voc, cat_voc, '../CAN/data/item_carte_voc.pkl', '../CAN/data/cate_carte_voc.pkl']: | ||
self.source_dicts.append(load_dict(source_dict)) | ||
|
||
f_meta = open("../CAN/data/item-info", "r") | ||
meta_map = {} | ||
for line in f_meta: | ||
arr = line.strip().split("\t") | ||
if arr[0] not in meta_map: | ||
meta_map[arr[0]] = arr[1] | ||
self.meta_id_map ={} | ||
for key in meta_map: | ||
val = meta_map[key] | ||
if key in self.source_dicts[1]: | ||
mid_idx = self.source_dicts[1][key] | ||
else: | ||
mid_idx = 0 | ||
if val in self.source_dicts[2]: | ||
cat_idx = self.source_dicts[2][val] | ||
else: | ||
cat_idx = 0 | ||
self.meta_id_map[mid_idx] = cat_idx | ||
|
||
f_review = open("../CAN/data/reviews-info", "r") | ||
self.mid_list_for_random = [] | ||
for line in f_review: | ||
arr = line.strip().split("\t") | ||
tmp_idx = 0 | ||
if arr[1] in self.source_dicts[1]: | ||
tmp_idx = self.source_dicts[1][arr[1]] | ||
self.mid_list_for_random.append(tmp_idx) | ||
|
||
self.batch_size = batch_size | ||
self.maxlen = maxlen | ||
self.minlen = minlen | ||
self.skip_empty = skip_empty | ||
|
||
self.n_uid = len(self.source_dicts[0]) | ||
self.n_mid = len(self.source_dicts[1]) | ||
self.n_cat = len(self.source_dicts[2]) | ||
self.n_carte = [len(self.source_dicts[3]), len(self.source_dicts[4])] | ||
|
||
self.shuffle = shuffle_each_epoch | ||
self.sort_by_length = sort_by_length | ||
|
||
self.source_buffer = [] | ||
self.k = batch_size * max_batch_size | ||
|
||
self.end_of_data = False | ||
self.label_type = label_type | ||
|
||
def get_n(self): | ||
return self.n_uid, self.n_mid, self.n_cat, self.n_carte | ||
|
||
def __iter__(self): | ||
return self | ||
|
||
def reset(self): | ||
if self.shuffle: | ||
self.source= shuffle.main(self.source_orig, temporary=True) | ||
else: | ||
self.source.seek(0) | ||
|
||
def __next__(self): | ||
if self.end_of_data: | ||
self.end_of_data = False | ||
self.reset() | ||
raise StopIteration | ||
|
||
source = [] | ||
target = [] | ||
|
||
if len(self.source_buffer) == 0: | ||
for k_ in range(self.k): | ||
ss = self.source.readline() | ||
if ss == "": | ||
break | ||
self.source_buffer.append(ss.strip("\n").split("\t")) | ||
|
||
# sort by history behavior length | ||
if self.sort_by_length: | ||
his_length = numpy.array([len(s[4].split("")) for s in self.source_buffer]) | ||
tidx = his_length.argsort() | ||
|
||
_sbuf = [self.source_buffer[i] for i in tidx] | ||
self.source_buffer = _sbuf | ||
else: | ||
self.source_buffer.reverse() | ||
|
||
if len(self.source_buffer) == 0: | ||
self.end_of_data = False | ||
self.reset() | ||
raise StopIteration | ||
|
||
try: | ||
|
||
# actual work here | ||
while True: | ||
|
||
# read from source file and map to word index | ||
try: | ||
ss = self.source_buffer.pop() | ||
except IndexError: | ||
break | ||
|
||
uid = self.source_dicts[0][ss[1]] if ss[1] in self.source_dicts[0] else 0 | ||
mid = self.source_dicts[1][ss[2]] if ss[2] in self.source_dicts[1] else 0 | ||
cat = self.source_dicts[2][ss[3]] if ss[3] in self.source_dicts[2] else 0 | ||
|
||
tmp = [] | ||
item_carte = [] | ||
for fea in ss[4].split(""): | ||
m = self.source_dicts[1][fea] if fea in self.source_dicts[1] else 0 | ||
tmp.append(m) | ||
i_c = self.source_dicts[3][(ss[2], fea)] if (ss[2], fea) in self.source_dicts[3] else 0 | ||
item_carte.append(i_c) | ||
mid_list = tmp | ||
|
||
tmp1 = [] | ||
cate_carte = [] | ||
for fea in ss[5].split(""): | ||
c = self.source_dicts[2][fea] if fea in self.source_dicts[2] else 0 | ||
tmp1.append(c) | ||
c_c = self.source_dicts[4][(ss[3], fea)] if (ss[3], fea) in self.source_dicts[4] else 0 | ||
cate_carte.append(c_c) | ||
cat_list = tmp1 | ||
|
||
# read from source file and map to word index | ||
|
||
if self.minlen != None: | ||
if len(mid_list) <= self.minlen: | ||
continue | ||
if self.skip_empty and (not mid_list): | ||
continue | ||
|
||
noclk_mid_list = [] | ||
noclk_cat_list = [] | ||
for pos_mid in mid_list: | ||
noclk_tmp_mid = [] | ||
noclk_tmp_cat = [] | ||
noclk_index = 0 | ||
while True: | ||
noclk_mid_indx = random.randint(0, len(self.mid_list_for_random)-1) | ||
noclk_mid = self.mid_list_for_random[noclk_mid_indx] | ||
if noclk_mid == pos_mid: | ||
continue | ||
noclk_tmp_mid.append(noclk_mid) | ||
noclk_tmp_cat.append(self.meta_id_map[noclk_mid]) | ||
noclk_index += 1 | ||
if noclk_index >= 5: | ||
break | ||
noclk_mid_list.append(noclk_tmp_mid) | ||
noclk_cat_list.append(noclk_tmp_cat) | ||
carte_list = [item_carte, cate_carte] | ||
source.append([uid, mid, cat, mid_list, cat_list, noclk_mid_list, noclk_cat_list, carte_list]) | ||
if self.label_type == 1: | ||
target.append([float(ss[0])]) | ||
else: | ||
target.append([float(ss[0]), 1-float(ss[0])]) | ||
|
||
if len(source) >= self.batch_size or len(target) >= self.batch_size: | ||
break | ||
except IOError: | ||
self.end_of_data = True | ||
|
||
# all sentence pairs in maxibatch filtered out because of length | ||
if len(source) == 0 or len(target) == 0: | ||
source, target = self.next() | ||
|
||
return source, target | ||
|
||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
以deeprec,python3.6为标准来写