diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index ba304ede..f2a73189 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -19,8 +19,8 @@ Steps to reproduce the behavior: **Operating environment(运行环境):** - python version [e.g. 3.5, 3.6] - - torch version [e.g. 1.1.0, 1.2.0] - - deepctr-torch version [e.g. 0.1.0,] + - torch version [e.g. 1.6.0, 1.7.0] + - deepctr-torch version [e.g. 0.2.4,] **Additional context** Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/question.md b/.github/ISSUE_TEMPLATE/question.md index 0b7d76fb..d51629f1 100644 --- a/.github/ISSUE_TEMPLATE/question.md +++ b/.github/ISSUE_TEMPLATE/question.md @@ -16,5 +16,5 @@ Add any other context about the problem here. **Operating environment(运行环境):** - python version [e.g. 3.6] - - torch version [e.g. 1.2.0,] - - deepctr-torch version [e.g. 0.1.0,] + - torch version [e.g. 1.7.0,] + - deepctr-torch version [e.g. 0.2.4,] diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bd0e66cd..7685cfb2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,7 +18,7 @@ jobs: strategy: matrix: python-version: [3.6,3.7] - torch-version: [1.1.0,1.2.0,1.3.0,1.4.0,1.5.0,1.6.0] + torch-version: [1.1.0,1.2.0,1.3.0,1.4.0,1.5.0,1.6.0,1.7.0] # exclude: # - python-version: 3.5 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 15fc17bd..9fd04410 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,5 +1,7 @@ This project is under development and we need developers to participate in. +# Join us + If you - familiar with and interested in CTR models @@ -7,4 +9,11 @@ If you - have spare time to learn and develop - familiar with git -please send a brief introduction of your background and experience to wcshen1994@163.com, welcome to join us! \ No newline at end of file +please send a brief introduction of your background and experience to wcshen1994@163.com, welcome to join us! + +# Creating a pull request +1. **Become a collaborator**: Send an email with introduction and your github account name to wcshen1994@163.com and waiting for invitation to become a collaborator. +2. **Fork&Dev**: Fork your own branch(`dev_yourname`) in `DeepCTR-Torch` from the `master` branch for development.If the `master` is updated during the development process, remember to merge and update to `dev_yourname` regularly. +3. **Testing**: Test logical correctness and effect when finishing the code development of the `dev_yourname` branch. +4. **Pre-release** : After testing contact wcshen1994@163.com for pre-release integration, usually your branch `dev_yourname` will be merged into `release` branch by squash merge. +5. **Release a new version**: After confirming that the change is no longer needed, `release` branch will be merged into `master` and a new python package will be released on pypi. diff --git a/README.md b/README.md index d367d80c..a108659e 100644 --- a/README.md +++ b/README.md @@ -92,25 +92,25 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St Zhang Wutong

Core Dev
Beijing University
of Posts and
Telecommunications

​ + + ​ pic
+ ​ Zan Shuxun +

Core Dev
Beijing University
of Posts and
Telecommunications

​ + ​ pic
Zhang Yuefeng

Core Dev
Peking University

​ + + ​ pic
Huo Junyi

Core Dev
University of Southampton

​ - - - - ​ pic
- ​ Zan Shuxun -

Dev
Beijing University
of Posts and
Telecommunications

​ - ​ pic
Zeng Kai ​ diff --git a/deepctr_torch/__init__.py b/deepctr_torch/__init__.py index 760ef662..6161b74f 100644 --- a/deepctr_torch/__init__.py +++ b/deepctr_torch/__init__.py @@ -2,5 +2,5 @@ from . import models from .utils import check_version -__version__ = '0.2.3' +__version__ = '0.2.4' check_version(__version__) \ No newline at end of file diff --git a/deepctr_torch/callbacks.py b/deepctr_torch/callbacks.py index d3fd7aed..d1a69fe5 100644 --- a/deepctr_torch/callbacks.py +++ b/deepctr_torch/callbacks.py @@ -1,9 +1,10 @@ import torch from tensorflow.python.keras.callbacks import EarlyStopping from tensorflow.python.keras.callbacks import ModelCheckpoint +from tensorflow.python.keras.callbacks import History EarlyStopping = EarlyStopping - +History = History class ModelCheckpoint(ModelCheckpoint): """Save the model after every epoch. diff --git a/deepctr_torch/inputs.py b/deepctr_torch/inputs.py index ea272fb1..4371d404 100644 --- a/deepctr_torch/inputs.py +++ b/deepctr_torch/inputs.py @@ -266,6 +266,6 @@ def get_dense_input(X, features, feature_columns): def maxlen_lookup(X, sparse_input_dict, maxlen_column): if maxlen_column is None or len(maxlen_column)==0: - raise ValueError('please add max length column for VarLenSparseFeat of DIEN input') + raise ValueError('please add max length column for VarLenSparseFeat of DIN/DIEN input') lookup_idx = np.array(sparse_input_dict[maxlen_column[0]]) return X[:, lookup_idx[0]:lookup_idx[1]].long() diff --git a/deepctr_torch/layers/activation.py b/deepctr_torch/layers/activation.py index 6e2a8c35..4ba8758e 100644 --- a/deepctr_torch/layers/activation.py +++ b/deepctr_torch/layers/activation.py @@ -25,10 +25,11 @@ def __init__(self, emb_size, dim=2, epsilon=1e-8, device='cpu'): self.sigmoid = nn.Sigmoid() self.dim = dim + # wrap alpha in nn.Parameter to make it trainable if self.dim == 2: - self.alpha = torch.zeros((emb_size,)).to(device) + self.alpha = nn.Parameter(torch.zeros((emb_size,)).to(device)) else: - self.alpha = torch.zeros((emb_size, 1)).to(device) + self.alpha = nn.Parameter(torch.zeros((emb_size, 1)).to(device)) def forward(self, x): assert x.dim() == self.dim diff --git a/deepctr_torch/layers/interaction.py b/deepctr_torch/layers/interaction.py index f4ccd1b8..11213647 100644 --- a/deepctr_torch/layers/interaction.py +++ b/deepctr_torch/layers/interaction.py @@ -512,7 +512,7 @@ def forward(self, inputs): # (2) E(x_l) # project the input x_l to $\mathbb{R}^{r}$ - v_x = torch.matmul(self.V_list[i][expert_id].T, x_l) # (bs, low_rank, 1) + v_x = torch.matmul(self.V_list[i][expert_id].t(), x_l) # (bs, low_rank, 1) # nonlinear activation in low rank space v_x = torch.tanh(v_x) diff --git a/deepctr_torch/models/afm.py b/deepctr_torch/models/afm.py index 19eeaf9e..caf508ea 100644 --- a/deepctr_torch/models/afm.py +++ b/deepctr_torch/models/afm.py @@ -43,7 +43,7 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, use_attention=Tr if use_attention: self.fm = AFMLayer(self.embedding_size, attention_factor, l2_reg_att, afm_dropout, seed, device) - self.add_regularization_weight(self.fm.attention_W, l2_reg_att) + self.add_regularization_weight(self.fm.attention_W, l2=l2_reg_att) else: self.fm = FM() diff --git a/deepctr_torch/models/autoint.py b/deepctr_torch/models/autoint.py index 027e229b..c3b10eaa 100644 --- a/deepctr_torch/models/autoint.py +++ b/deepctr_torch/models/autoint.py @@ -69,7 +69,7 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, att_layer_num=3, activation=dnn_activation, l2_reg=l2_reg_dnn, dropout_rate=dnn_dropout, use_bn=dnn_use_bn, init_std=init_std, device=device) self.add_regularization_weight( - filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2_reg_dnn) + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn) self.int_layers = nn.ModuleList( [InteractingLayer(self.embedding_size if i == 0 else att_embedding_size * att_head_num, att_embedding_size, att_head_num, att_res, device=device) for i in range(att_layer_num)]) diff --git a/deepctr_torch/models/basemodel.py b/deepctr_torch/models/basemodel.py index f0942b75..6d7998f5 100644 --- a/deepctr_torch/models/basemodel.py +++ b/deepctr_torch/models/basemodel.py @@ -26,8 +26,8 @@ from ..inputs import build_input_features, SparseFeat, DenseFeat, VarLenSparseFeat, get_varlen_pooling_list, \ create_embedding_matrix from ..layers import PredictionLayer - from ..layers.utils import slice_arrays +from ..callbacks import History class Linear(nn.Module): @@ -55,8 +55,8 @@ def __init__(self, feature_columns, feature_index, init_std=0.0001, device='cpu' nn.init.normal_(tensor.weight, mean=0, std=init_std) if len(self.dense_feature_columns) > 0: - self.weight = nn.Parameter(torch.Tensor(sum(fc.dimension for fc in self.dense_feature_columns), 1)).to( - device) + self.weight = nn.Parameter(torch.Tensor(sum(fc.dimension for fc in self.dense_feature_columns), 1).to( + device)) torch.nn.init.normal_(self.weight, mean=0, std=init_std) def forward(self, X): @@ -117,14 +117,16 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, l2_reg_linear=1e self.regularization_weight = [] - self.add_regularization_weight( - self.embedding_dict.parameters(), l2_reg_embedding) - self.add_regularization_weight( - self.linear_model.parameters(), l2_reg_linear) + self.add_regularization_weight(self.embedding_dict.parameters(), l2=l2_reg_embedding) + self.add_regularization_weight(self.linear_model.parameters(), l2=l2_reg_linear) self.out = PredictionLayer(task, ) self.to(device) - self._is_graph_network = True # used for callbacks + + # parameters of callbacks + self._is_graph_network = True # used for ModelCheckpoint + self.stop_training = False # used for EarlyStopping + self.history = History() def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoch=0, validation_split=0., validation_data=None, shuffle=True, callbacks=None): @@ -142,6 +144,7 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoc :param shuffle: Boolean. Whether to shuffle the order of the batches at the beginning of each epoch. :param callbacks: List of `deepctr_torch.callbacks.Callback` instances. List of callbacks to apply during training and validation (if ). See [callbacks](https://tensorflow.google.cn/api_docs/python/tf/keras/callbacks). Now available: `EarlyStopping` , `ModelCheckpoint` + :return: A `History` object. Its `History.history` attribute is a record of training loss values and metrics values at successive epochs, as well as validation loss values and validation metrics values (if applicable). """ if isinstance(x, dict): x = [x[feature] for feature in self.feature_index] @@ -200,10 +203,14 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoc sample_num = len(train_tensor_data) steps_per_epoch = (sample_num - 1) // batch_size + 1 + # configure callbacks + callbacks = (callbacks or []) + [self.history] # add history callback callbacks = CallbackList(callbacks) - callbacks.set_model(self) callbacks.on_train_begin() - self.stop_training = False # used for early stopping + callbacks.set_model(self) + if not hasattr(callbacks, 'model'): + callbacks.__setattr__('model', self) + callbacks.model.stop_training = False # Train print("Train on {0} samples, validate on {1} samples, {2} steps per epoch".format( @@ -231,7 +238,7 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoc loss_epoch += loss.item() total_loss_epoch += total_loss.item() - total_loss.backward(retain_graph=True) + total_loss.backward() optim.step() if verbose > 0: @@ -279,6 +286,8 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoc callbacks.on_train_end() + return self.history + def evaluate(self, x, y, batch_size=256): """ @@ -368,21 +377,32 @@ def compute_input_dim(self, feature_columns, include_sparse=True, include_dense= input_dim += dense_input_dim return input_dim - def add_regularization_weight(self, weight_list, weight_decay, p=2): - self.regularization_weight.append((list(weight_list), weight_decay, p)) + def add_regularization_weight(self, weight_list, l1=0.0, l2=0.0): + # For a Parameter, put it in a list to keep Compatible with get_regularization_loss() + if isinstance(weight_list, torch.nn.parameter.Parameter): + weight_list = [weight_list] + # For generators, filters and ParameterLists, convert them to a list of tensors to avoid bugs. + # e.g., we can't pickle generator objects when we save the model. + else: + weight_list = list(weight_list) + self.regularization_weight.append((weight_list, l1, l2)) def get_regularization_loss(self, ): total_reg_loss = torch.zeros((1,), device=self.device) - for weight_list, weight_decay, p in self.regularization_weight: - weight_reg_loss = torch.zeros((1,), device=self.device) + for weight_list, l1, l2 in self.regularization_weight: for w in weight_list: if isinstance(w, tuple): - l2_reg = torch.norm(w[1], p=p, ) + parameter = w[1] # named_parameters else: - l2_reg = torch.norm(w, p=p, ) - weight_reg_loss = weight_reg_loss + l2_reg - reg_loss = weight_decay * weight_reg_loss - total_reg_loss += reg_loss + parameter = w + if l1 > 0: + total_reg_loss += torch.sum(l1 * torch.abs(parameter)) + if l2 > 0: + try: + total_reg_loss += torch.sum(l2 * torch.square(parameter)) + except AttributeError: + total_reg_loss += torch.sum(l2 * parameter * parameter) + return total_reg_loss def add_auxiliary_loss(self, aux_loss, alpha): diff --git a/deepctr_torch/models/ccpm.py b/deepctr_torch/models/ccpm.py index 62767271..73272b66 100644 --- a/deepctr_torch/models/ccpm.py +++ b/deepctr_torch/models/ccpm.py @@ -60,8 +60,8 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, conv_kernel_widt init_std=init_std, device=device) self.dnn_linear = nn.Linear(dnn_hidden_units[-1], 1, bias=False).to(device) self.add_regularization_weight( - filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2_reg_dnn) - self.add_regularization_weight(self.dnn_linear.weight, l2_reg_dnn) + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn) + self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_dnn) self.to(device) diff --git a/deepctr_torch/models/dcn.py b/deepctr_torch/models/dcn.py index df8b71e2..4528b9a7 100644 --- a/deepctr_torch/models/dcn.py +++ b/deepctr_torch/models/dcn.py @@ -65,9 +65,9 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, cross_num=2, cro self.crossnet = CrossNet(in_features=self.compute_input_dim(dnn_feature_columns), layer_num=cross_num, parameterization=cross_parameterization, device=device) self.add_regularization_weight( - filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2_reg_dnn) - self.add_regularization_weight(self.dnn_linear.weight, l2_reg_linear) - self.add_regularization_weight(self.crossnet.kernels, l2_reg_cross) + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn) + self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_linear) + self.add_regularization_weight(self.crossnet.kernels, l2=l2_reg_cross) self.to(device) def forward(self, X): diff --git a/deepctr_torch/models/dcnmix.py b/deepctr_torch/models/dcnmix.py index 8fb95a73..c01fd44c 100644 --- a/deepctr_torch/models/dcnmix.py +++ b/deepctr_torch/models/dcnmix.py @@ -68,11 +68,11 @@ def __init__(self, linear_feature_columns, low_rank=low_rank, num_experts=num_experts, layer_num=cross_num, device=device) self.add_regularization_weight( - filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2_reg_dnn) - self.add_regularization_weight(self.dnn_linear.weight, l2_reg_linear) - self.add_regularization_weight(self.crossnet.U_list, l2_reg_cross) - self.add_regularization_weight(self.crossnet.V_list, l2_reg_cross) - self.add_regularization_weight(self.crossnet.C_list, l2_reg_cross) + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn) + self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_linear) + self.add_regularization_weight(self.crossnet.U_list, l2=l2_reg_cross) + self.add_regularization_weight(self.crossnet.V_list, l2=l2_reg_cross) + self.add_regularization_weight(self.crossnet.C_list, l2=l2_reg_cross) self.to(device) def forward(self, X): diff --git a/deepctr_torch/models/deepfm.py b/deepctr_torch/models/deepfm.py index 5c351dcc..187c2592 100644 --- a/deepctr_torch/models/deepfm.py +++ b/deepctr_torch/models/deepfm.py @@ -59,8 +59,8 @@ def __init__(self, dnn_hidden_units[-1], 1, bias=False).to(device) self.add_regularization_weight( - filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2_reg_dnn) - self.add_regularization_weight(self.dnn_linear.weight, l2_reg_dnn) + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn) + self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_dnn) self.to(device) def forward(self, X): diff --git a/deepctr_torch/models/din.py b/deepctr_torch/models/din.py index 19cc8dfa..0a8e46af 100644 --- a/deepctr_torch/models/din.py +++ b/deepctr_torch/models/din.py @@ -63,7 +63,7 @@ def __init__(self, dnn_feature_columns, history_feature_list, dnn_use_bn=False, self.attention = AttentionSequencePoolingLayer(att_hidden_units=att_hidden_size, embedding_dim=att_emb_dim, - activation=att_activation, + att_activation=att_activation, return_score=False, supports_masking=False, weight_normalization=att_weight_normalization) @@ -79,16 +79,15 @@ def __init__(self, dnn_feature_columns, history_feature_list, dnn_use_bn=False, def forward(self, X): - sparse_embedding_list, dense_value_list = self.input_from_feature_columns(X, self.dnn_feature_columns, - self.embedding_dict) + _, dense_value_list = self.input_from_feature_columns(X, self.dnn_feature_columns, self.embedding_dict) # sequence pooling part query_emb_list = embedding_lookup(X, self.embedding_dict, self.feature_index, self.sparse_feature_columns, - self.history_feature_list, self.history_feature_list, to_list=True) + return_feat_list=self.history_feature_list, to_list=True) keys_emb_list = embedding_lookup(X, self.embedding_dict, self.feature_index, self.history_feature_columns, - self.history_fc_names, self.history_fc_names, to_list=True) - dnn_input_emb_list = embedding_lookup(X, self.embedding_dict, self.feature_index, self.sparse_feature_columns, mask_feat_list=self.history_feature_list, to_list=True) - + return_feat_list=self.history_fc_names, to_list=True) + dnn_input_emb_list = embedding_lookup(X, self.embedding_dict, self.feature_index, self.sparse_feature_columns, + to_list=True) sequence_embed_dict = varlen_embedding_lookup(X, self.embedding_dict, self.feature_index, self.sparse_varlen_feature_columns) @@ -97,12 +96,15 @@ def forward(self, X): self.sparse_varlen_feature_columns, self.device) dnn_input_emb_list += sequence_embed_list + deep_input_emb = torch.cat(dnn_input_emb_list, dim=-1) # concatenate query_emb = torch.cat(query_emb_list, dim=-1) # [B, 1, E] keys_emb = torch.cat(keys_emb_list, dim=-1) # [B, T, E] - keys_length = torch.ones((query_emb.size(0), 1)).to(self.device) # [B, 1] - deep_input_emb = torch.cat(dnn_input_emb_list, dim=-1) + + keys_length_feature_name = [feat.length_name for feat in self.varlen_sparse_feature_columns if + feat.length_name is not None] + keys_length = torch.squeeze(maxlen_lookup(X, self.feature_index, keys_length_feature_name), 1) # [B, 1] hist = self.attention(query_emb, keys_emb, keys_length) # [B, 1, E] diff --git a/deepctr_torch/models/nfm.py b/deepctr_torch/models/nfm.py index d5ed9d9a..73a15923 100644 --- a/deepctr_torch/models/nfm.py +++ b/deepctr_torch/models/nfm.py @@ -48,8 +48,8 @@ def __init__(self, self.dnn_linear = nn.Linear( dnn_hidden_units[-1], 1, bias=False).to(device) self.add_regularization_weight( - filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2_reg_dnn) - self.add_regularization_weight(self.dnn_linear.weight, l2_reg_dnn) + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn) + self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_dnn) self.bi_pooling = BiInteractionPooling() self.bi_dropout = bi_dropout if self.bi_dropout > 0: diff --git a/deepctr_torch/models/onn.py b/deepctr_torch/models/onn.py index bdab09eb..b4d4d085 100644 --- a/deepctr_torch/models/onn.py +++ b/deepctr_torch/models/onn.py @@ -69,8 +69,7 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, dnn_feature_columns, embedding_size=embedding_size, sparse=False).to(device) # add regularization for second_order_embedding - self.add_regularization_weight( - self.second_order_embedding_dict.parameters(), l2_reg_embedding) + self.add_regularization_weight(self.second_order_embedding_dict.parameters(), l2=l2_reg_embedding) dim = self.__compute_nffm_dnn_dim( feature_columns=dnn_feature_columns, embedding_size=embedding_size) @@ -82,8 +81,8 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, self.dnn_linear = nn.Linear( dnn_hidden_units[-1], 1, bias=False).to(device) self.add_regularization_weight( - filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2_reg_dnn) - self.add_regularization_weight(self.dnn_linear.weight, l2_reg_dnn) + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn) + self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_dnn) self.to(device) def __compute_nffm_dnn_dim(self, feature_columns, embedding_size): diff --git a/deepctr_torch/models/pnn.py b/deepctr_torch/models/pnn.py index 00fb3416..7a79d827 100644 --- a/deepctr_torch/models/pnn.py +++ b/deepctr_torch/models/pnn.py @@ -69,8 +69,8 @@ def __init__(self, dnn_feature_columns, dnn_hidden_units=(128, 128), l2_reg_embe self.dnn_linear = nn.Linear( dnn_hidden_units[-1], 1, bias=False).to(device) self.add_regularization_weight( - filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2_reg_dnn) - self.add_regularization_weight(self.dnn_linear.weight, l2_reg_dnn) + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn) + self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_dnn) self.to(device) diff --git a/deepctr_torch/models/wdl.py b/deepctr_torch/models/wdl.py index f65b585a..0fdd374c 100644 --- a/deepctr_torch/models/wdl.py +++ b/deepctr_torch/models/wdl.py @@ -51,8 +51,8 @@ def __init__(self, init_std=init_std, device=device) self.dnn_linear = nn.Linear(dnn_hidden_units[-1], 1, bias=False).to(device) self.add_regularization_weight( - filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2_reg_dnn) - self.add_regularization_weight(self.dnn_linear.weight, l2_reg_dnn) + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn) + self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_dnn) self.to(device) diff --git a/deepctr_torch/models/xdeepfm.py b/deepctr_torch/models/xdeepfm.py index dc00ec4c..87cac472 100644 --- a/deepctr_torch/models/xdeepfm.py +++ b/deepctr_torch/models/xdeepfm.py @@ -54,9 +54,9 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, dnn_hidden_units init_std=init_std, device=device) self.dnn_linear = nn.Linear(dnn_hidden_units[-1], 1, bias=False).to(device) self.add_regularization_weight( - filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2_reg_dnn) + filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn) - self.add_regularization_weight(self.dnn_linear.weight, l2_reg_dnn) + self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_dnn) self.cin_layer_size = cin_layer_size self.use_cin = len(self.cin_layer_size) > 0 and len(dnn_feature_columns) > 0 @@ -70,8 +70,8 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, dnn_hidden_units self.cin = CIN(field_num, cin_layer_size, cin_activation, cin_split_half, l2_reg_cin, seed, device=device) self.cin_linear = nn.Linear(self.featuremap_num, 1, bias=False).to(device) - self.add_regularization_weight( - filter(lambda x: 'weight' in x[0], self.cin.named_parameters()), l2_reg_cin) + self.add_regularization_weight(filter(lambda x: 'weight' in x[0], self.cin.named_parameters()), + l2=l2_reg_cin) self.to(device) diff --git a/docs/source/FAQ.md b/docs/source/FAQ.md index 2909aacb..3399bb06 100644 --- a/docs/source/FAQ.md +++ b/docs/source/FAQ.md @@ -33,8 +33,9 @@ model = DeepFM(linear_feature_columns,dnn_feature_columns) model.compile(Adagrad(model.parameters(),0.1024),'binary_crossentropy',metrics=['binary_crossentropy']) es = EarlyStopping(monitor='val_binary_crossentropy', min_delta=0, verbose=1, patience=0, mode='min') -mdckpt = ModelCheckpoint(filepath='model.ckpt') +mdckpt = ModelCheckpoint(filepath = 'model.ckpt', save_best_only= True) history = model.fit(model_input,data[target].values,batch_size=256,epochs=10,verbose=2,validation_split=0.2,callbacks=[es,mdckpt]) +print(history) ``` ## 3. How to add a long dense feature vector as a input to the model? diff --git a/docs/source/History.md b/docs/source/History.md index 713480b8..c0469087 100644 --- a/docs/source/History.md +++ b/docs/source/History.md @@ -1,4 +1,5 @@ # History +- 12/05/2020 : [v0.2.4](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.4) released.Imporve compatibility & fix issues.Add History callback.([example](https://deepctr-torch.readthedocs.io/en/latest/FAQ.html#set-learning-rate-and-use-earlystopping)). - 10/18/2020 : [v0.2.3](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.3) released.Add [DCN-M](./Features.html#dcn-deep-cross-network)&[DCN-Mix](./Features.html#dcn-mix-improved-deep-cross-network-with-mix-of-experts-and-matrix-kernel).Add EarlyStopping and ModelCheckpoint callbacks([example](https://deepctr-torch.readthedocs.io/en/latest/FAQ.html#set-learning-rate-and-use-earlystopping)). - 10/09/2020 : [v0.2.2](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.2) released.Improve the reproducibility & fix some bugs. - 03/27/2020 : [v0.2.1](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.1) released.Add [DIN](./Features.html#din-deep-interest-network) and [DIEN](./Features.html#dien-deep-interest-evolution-network) . diff --git a/docs/source/conf.py b/docs/source/conf.py index 425de06e..40f2e387 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -26,7 +26,7 @@ # The short X.Y version version = '' # The full version, including alpha/beta/rc tags -release = '0.2.3' +release = '0.2.4' # -- General configuration --------------------------------------------------- diff --git a/docs/source/index.rst b/docs/source/index.rst index d1d154ff..fc1cec28 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -34,12 +34,12 @@ You can read the latest code at https://github.com/shenweichen/DeepCTR-Torch and News ----- +12/05/2020 : Imporve compatibility & fix issues.Add History callback(`example `_). `Changelog `_ + 10/18/2020 : Add `DCN-M <./Features.html#dcn-deep-cross-network>`_ and `DCN-Mix <./Features.html#dcn-mix-improved-deep-cross-network-with-mix-of-experts-and-matrix-kernel>`_ . Add EarlyStopping and ModelCheckpoint callbacks(`example `_). `Changelog `_ 10/09/2020 : Improve the reproducibility & fix some bugs. `Changelog `_ -03/27/2020 : Add `DIN <./Features.html#din-deep-interest-network>`_ and `DIEN <./Features.html#dien-deep-interest-evolution-network>`_ . `Changelog `_ - DisscussionGroup ----------------------- diff --git a/examples/run_classification_criteo.py b/examples/run_classification_criteo.py index 1201d2b0..881fdfbb 100644 --- a/examples/run_classification_criteo.py +++ b/examples/run_classification_criteo.py @@ -58,8 +58,8 @@ model.compile("adagrad", "binary_crossentropy", metrics=["binary_crossentropy", "auc"], ) - model.fit(train_model_input, train[target].values, batch_size=32, epochs=10, verbose=2, validation_split=0.2) - + history = model.fit(train_model_input, train[target].values, batch_size=32, epochs=10, verbose=2, + validation_split=0.2) pred_ans = model.predict(test_model_input, 256) print("") print("test LogLoss", round(log_loss(test[target].values, pred_ans), 4)) diff --git a/examples/run_din.py b/examples/run_din.py index fccc02c3..de716e16 100644 --- a/examples/run_din.py +++ b/examples/run_din.py @@ -14,9 +14,8 @@ def get_xy_fd(): SparseFeat('item', 3 + 1, embedding_dim=8), SparseFeat('item_gender', 2 + 1, embedding_dim=8), DenseFeat('score', 1)] - feature_columns += [VarLenSparseFeat(SparseFeat('hist_item', 3 + 1, embedding_dim=8), 4), - VarLenSparseFeat(SparseFeat('hist_item_gender', 2 + 1, embedding_dim=8), 4)] - + feature_columns += [VarLenSparseFeat(SparseFeat('hist_item', 3 + 1, embedding_dim=8), 4, length_name="seq_length"), + VarLenSparseFeat(SparseFeat('hist_item_gender', 2 + 1, embedding_dim=8), 4, length_name="seq_length")] behavior_feature_list = ["item", "item_gender"] uid = np.array([0, 1, 2]) ugender = np.array([0, 1, 0]) @@ -26,9 +25,11 @@ def get_xy_fd(): hist_iid = np.array([[1, 2, 3, 0], [1, 2, 3, 0], [1, 2, 0, 0]]) hist_igender = np.array([[1, 1, 2, 0], [2, 1, 1, 0], [2, 1, 0, 0]]) + behavior_length = np.array([3, 3, 2]) feature_dict = {'user': uid, 'gender': ugender, 'item': iid, 'item_gender': igender, - 'hist_item': hist_iid, 'hist_item_gender': hist_igender, 'score': score} + 'hist_item': hist_iid, 'hist_item_gender': hist_igender, 'score': score, + "seq_length": behavior_length} x = {name: feature_dict[name] for name in get_feature_names(feature_columns)} y = np.array([1, 0, 1]) @@ -43,7 +44,7 @@ def get_xy_fd(): print('cuda ready...') device = 'cuda:0' - model = DIN(feature_columns, behavior_feature_list, device=device) + model = DIN(feature_columns, behavior_feature_list, device=device, att_weight_normalization=True) model.compile('adagrad', 'binary_crossentropy', metrics=['binary_crossentropy']) history = model.fit(x, y, batch_size=3, epochs=10, verbose=2, validation_split=0.0) diff --git a/setup.py b/setup.py index dd2bc9f3..5a44ccd1 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ setuptools.setup( name="deepctr-torch", - version="0.2.3", + version="0.2.4", author="Weichen Shen", author_email="wcshen1994@163.com", description="Easy-to-use,Modular and Extendible package of deep learning based CTR(Click Through Rate) prediction models with PyTorch", diff --git a/tests/models/DCNMix_test.py b/tests/models/DCNMix_test.py index ce1119fa..1dc90ef6 100644 --- a/tests/models/DCNMix_test.py +++ b/tests/models/DCNMix_test.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize( 'embedding_size,cross_num,hidden_size,sparse_feature_num', - [(8, 0, (32,), 2), + [(8, 0, (32,), 2), (8, 1, (32,), 2) ] # ('auto', 1, (32,), 3) , ('auto', 1, (), 1), ('auto', 1, (32,), 3) ) def test_DCNMix(embedding_size, cross_num, hidden_size, sparse_feature_num): @@ -18,7 +18,7 @@ def test_DCNMix(embedding_size, cross_num, hidden_size, sparse_feature_num): sample_size, sparse_feature_num=sparse_feature_num, dense_feature_num=sparse_feature_num) model = DCNMix(linear_feature_columns=feature_columns, dnn_feature_columns=feature_columns, - cross_num=cross_num, dnn_hidden_units=hidden_size, dnn_dropout=0.5, device=get_device()) + cross_num=cross_num, dnn_hidden_units=hidden_size, dnn_dropout=0.5, device=get_device()) check_model(model, model_name, x, y) diff --git a/tests/models/DCN_test.py b/tests/models/DCN_test.py index 8be0ef00..bffc2ac1 100644 --- a/tests/models/DCN_test.py +++ b/tests/models/DCN_test.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize( 'embedding_size,cross_num,hidden_size,sparse_feature_num,cross_parameterization', - [(8, 0, (32,), 2, 'vector'), (8, 0, (32,), 2, 'matrix'), + [(8, 2, (32,), 2, 'vector'), (8, 1, (32,), 2, 'matrix'), ] # ('auto', 1, (32,), 3) , ('auto', 1, (), 1), ('auto', 1, (32,), 3) ) def test_DCN(embedding_size, cross_num, hidden_size, sparse_feature_num, cross_parameterization): diff --git a/tests/models/PNN_test.py b/tests/models/PNN_test.py index 0e419a6c..8afc9efe 100644 --- a/tests/models/PNN_test.py +++ b/tests/models/PNN_test.py @@ -5,17 +5,18 @@ @pytest.mark.parametrize( - 'use_inner, use_outter,sparse_feature_num', - [(True, True, 2), (True, False, 2), (False, True, 3), (False, False, 1) + 'use_inner, use_outter, kernel_type, sparse_feature_num', + [(True, True, 'mat', 2), (True, False, 'mat', 2), (False, True, 'vec', 3), (False, True, 'num', 3), + (False, False, 'mat', 1) ] ) -def test_PNN(use_inner, use_outter, sparse_feature_num): +def test_PNN(use_inner, use_outter, kernel_type, sparse_feature_num): model_name = "PNN" sample_size = SAMPLE_SIZE x, y, feature_columns = get_test_data(sample_size, sparse_feature_num=sparse_feature_num, dense_feature_num=sparse_feature_num) model = PNN(feature_columns, dnn_hidden_units=[32, 32], dnn_dropout=0.5, use_inner=use_inner, - use_outter=use_outter, device=get_device()) + use_outter=use_outter, kernel_type=kernel_type, device=get_device()) check_model(model, model_name, x, y) diff --git a/tests/models/WDL_test.py b/tests/models/WDL_test.py index bea91a8d..c06be14e 100644 --- a/tests/models/WDL_test.py +++ b/tests/models/WDL_test.py @@ -16,7 +16,7 @@ def test_WDL(sparse_feature_num, dense_feature_num): x, y, feature_columns = get_test_data( sample_size, sparse_feature_num=sparse_feature_num, dense_feature_num=dense_feature_num) - model = WDL(feature_columns, feature_columns, + model = WDL(feature_columns, feature_columns, dnn_activation='prelu', dnn_hidden_units=[32, 32], dnn_dropout=0.5, device=get_device()) check_model(model, model_name, x, y)