Skip to content

Commit

Permalink
Dev swc (#143)
Browse files Browse the repository at this point in the history
refactor add_regularization_weight method of BaseModel
  • Loading branch information
浅梦 authored Dec 5, 2020
1 parent 2e42ff6 commit 500c0a5
Show file tree
Hide file tree
Showing 29 changed files with 87 additions and 73 deletions.
4 changes: 2 additions & 2 deletions .github/ISSUE_TEMPLATE/bug_report.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ Steps to reproduce the behavior:

**Operating environment(运行环境):**
- python version [e.g. 3.5, 3.6]
- torch version [e.g. 1.1.0, 1.2.0]
- deepctr-torch version [e.g. 0.1.0,]
- torch version [e.g. 1.6.0, 1.7.0]
- deepctr-torch version [e.g. 0.2.4,]

**Additional context**
Add any other context about the problem here.
4 changes: 2 additions & 2 deletions .github/ISSUE_TEMPLATE/question.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ Add any other context about the problem here.

**Operating environment(运行环境):**
- python version [e.g. 3.6]
- torch version [e.g. 1.2.0,]
- deepctr-torch version [e.g. 0.1.0,]
- torch version [e.g. 1.7.0,]
- deepctr-torch version [e.g. 0.2.4,]
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
strategy:
matrix:
python-version: [3.6,3.7]
torch-version: [1.1.0,1.2.0,1.3.0,1.4.0,1.5.0,1.6.0]
torch-version: [1.1.0,1.2.0,1.3.0,1.4.0,1.5.0,1.6.0,1.7.0]

# exclude:
# - python-version: 3.5
Expand Down
11 changes: 10 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
This project is under development and we need developers to participate in.

# Join us

If you

- familiar with and interested in CTR models
- familiar with pytorch(both pytorch and tensorflow better)
- have spare time to learn and develop
- familiar with git

please send a brief introduction of your background and experience to [email protected], welcome to join us!
please send a brief introduction of your background and experience to [email protected], welcome to join us!

# Creating a pull request
1. **Become a collaborator**: Send an email with introduction and your github account name to [email protected] and waiting for invitation to become a collaborator.
2. **Fork&Dev**: Fork your own branch(`dev_yourname`) in `DeepCTR-Torch` from the `master` branch for development.If the `master` is updated during the development process, remember to merge and update to `dev_yourname` regularly.
3. **Testing**: Test logical correctness and effect when finishing the code development of the `dev_yourname` branch.
4. **Pre-release** : After testing contact [email protected] for pre-release integration, usually your branch `dev_yourname` will be merged into `release` branch by squash merge.
5. **Release a new version**: After confirming that the change is no longer needed, `release` branch will be merged into `master` and a new python package will be released on pypi.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,25 +92,25 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St
<a href="https://github.com/wutongzhang">Zhang Wutong</a>
<p>Core Dev<br> Beijing University <br> of Posts and <br> Telecommunications</p>​
</td>
<td>
​ <a href="https://github.com/zanshuxun"><img width="70" height="70" src="https://github.com/zanshuxun.png?s=40" alt="pic"></a><br>
​ <a href="https://github.com/zanshuxun">Zan Shuxun</a>
<p>Core Dev<br> Beijing University <br> of Posts and <br> Telecommunications</p>​
</td>
<td>
​ <a href="https://github.com/ZhangYuef"><img width="70" height="70" src="https://github.com/ZhangYuef.png?s=40" alt="pic"></a><br>
​ <a href="https://github.com/ZhangYuef">Zhang Yuefeng</a>
<p>Core Dev<br>
Peking University <br> <br> </p>​
</td>
</tr>
<tr align="center">
<td>
​ <a href="https://github.com/JyiHUO"><img width="70" height="70" src="https://github.com/JyiHUO.png?s=40" alt="pic"></a><br>
​ <a href="https://github.com/JyiHUO">Huo Junyi</a>
<p>Core Dev<br>
University of Southampton <br> <br> </p>​
</td>
</tr>
<tr align="center">
<td>
​ <a href="https://github.com/zanshuxun"><img width="70" height="70" src="https://github.com/zanshuxun.png?s=40" alt="pic"></a><br>
​ <a href="https://github.com/zanshuxun">Zan Shuxun</a>
<p>Dev<br> Beijing University <br> of Posts and <br> Telecommunications</p>​
</td>
<td>
​ <a href="https://github.com/Zengai"><img width="70" height="70" src="https://github.com/Zengai.png?s=40" alt="pic"></a><br>
​ <a href="https://github.com/Zengai">Zeng Kai</a> ​
Expand Down
2 changes: 1 addition & 1 deletion deepctr_torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from . import models
from .utils import check_version

__version__ = '0.2.3'
__version__ = '0.2.4'
check_version(__version__)
2 changes: 1 addition & 1 deletion deepctr_torch/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tensorflow.python.keras.callbacks import History

EarlyStopping = EarlyStopping

History = History

class ModelCheckpoint(ModelCheckpoint):
"""Save the model after every epoch.
Expand Down
2 changes: 1 addition & 1 deletion deepctr_torch/models/afm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, use_attention=Tr
if use_attention:
self.fm = AFMLayer(self.embedding_size, attention_factor, l2_reg_att, afm_dropout,
seed, device)
self.add_regularization_weight(self.fm.attention_W, l2_reg_att)
self.add_regularization_weight(self.fm.attention_W, l2=l2_reg_att)
else:
self.fm = FM()

Expand Down
2 changes: 1 addition & 1 deletion deepctr_torch/models/autoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, att_layer_num=3,
activation=dnn_activation, l2_reg=l2_reg_dnn, dropout_rate=dnn_dropout, use_bn=dnn_use_bn,
init_std=init_std, device=device)
self.add_regularization_weight(
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2_reg_dnn)
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn)
self.int_layers = nn.ModuleList(
[InteractingLayer(self.embedding_size if i == 0 else att_embedding_size * att_head_num,
att_embedding_size, att_head_num, att_res, device=device) for i in range(att_layer_num)])
Expand Down
32 changes: 18 additions & 14 deletions deepctr_torch/models/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,8 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, l2_reg_linear=1e

self.regularization_weight = []

self.add_regularization_weight(
self.embedding_dict.parameters(), l2_reg_embedding)
self.add_regularization_weight(
self.linear_model.parameters(), l2_reg_linear)
self.add_regularization_weight(self.embedding_dict.parameters(), l2=l2_reg_embedding)
self.add_regularization_weight(self.linear_model.parameters(), l2=l2_reg_linear)

self.out = PredictionLayer(task, )
self.to(device)
Expand Down Expand Up @@ -208,8 +206,10 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoc
# configure callbacks
callbacks = (callbacks or []) + [self.history] # add history callback
callbacks = CallbackList(callbacks)
callbacks.set_model(self)
callbacks.on_train_begin()
callbacks.set_model(self)
if not hasattr(callbacks, 'model'):
callbacks.__setattr__('model', self)
callbacks.model.stop_training = False

# Train
Expand Down Expand Up @@ -377,28 +377,32 @@ def compute_input_dim(self, feature_columns, include_sparse=True, include_dense=
input_dim += dense_input_dim
return input_dim

def add_regularization_weight(self, weight_list, weight_decay, p=2):
def add_regularization_weight(self, weight_list, l1=0.0, l2=0.0):
# For a Parameter, put it in a list to keep Compatible with get_regularization_loss()
if isinstance(weight_list, torch.nn.parameter.Parameter):
weight_list = [weight_list]
# For generators, filters and ParameterLists, convert them to a list of tensors to avoid bugs.
# e.g., we can't pickle generator objects when we save the model.
else:
weight_list = list(weight_list)
self.regularization_weight.append((weight_list, weight_decay, p))
self.regularization_weight.append((weight_list, l1, l2))

def get_regularization_loss(self, ):
total_reg_loss = torch.zeros((1,), device=self.device)
for weight_list, weight_decay, p in self.regularization_weight:
weight_reg_loss = torch.zeros((1,), device=self.device)
for weight_list, l1, l2 in self.regularization_weight:
for w in weight_list:
if isinstance(w, tuple):
l2_reg = torch.norm(w[1], p=p, )
parameter = w[1] # named_parameters
else:
l2_reg = torch.norm(w, p=p, )
weight_reg_loss = weight_reg_loss + l2_reg
reg_loss = weight_decay * weight_reg_loss
total_reg_loss += reg_loss
parameter = w
if l1 > 0:
total_reg_loss += torch.sum(l1 * torch.abs(parameter))
if l2 > 0:
try:
total_reg_loss += torch.sum(l2 * torch.square(parameter))
except AttributeError:
total_reg_loss += torch.sum(l2 * parameter * parameter)

return total_reg_loss

def add_auxiliary_loss(self, aux_loss, alpha):
Expand Down
4 changes: 2 additions & 2 deletions deepctr_torch/models/ccpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, conv_kernel_widt
init_std=init_std, device=device)
self.dnn_linear = nn.Linear(dnn_hidden_units[-1], 1, bias=False).to(device)
self.add_regularization_weight(
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2_reg_dnn)
self.add_regularization_weight(self.dnn_linear.weight, l2_reg_dnn)
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn)
self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_dnn)

self.to(device)

Expand Down
6 changes: 3 additions & 3 deletions deepctr_torch/models/dcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, cross_num=2, cro
self.crossnet = CrossNet(in_features=self.compute_input_dim(dnn_feature_columns),
layer_num=cross_num, parameterization=cross_parameterization, device=device)
self.add_regularization_weight(
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2_reg_dnn)
self.add_regularization_weight(self.dnn_linear.weight, l2_reg_linear)
self.add_regularization_weight(self.crossnet.kernels, l2_reg_cross)
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn)
self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_linear)
self.add_regularization_weight(self.crossnet.kernels, l2=l2_reg_cross)
self.to(device)

def forward(self, X):
Expand Down
10 changes: 5 additions & 5 deletions deepctr_torch/models/dcnmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ def __init__(self, linear_feature_columns,
low_rank=low_rank, num_experts=num_experts,
layer_num=cross_num, device=device)
self.add_regularization_weight(
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2_reg_dnn)
self.add_regularization_weight(self.dnn_linear.weight, l2_reg_linear)
self.add_regularization_weight(self.crossnet.U_list, l2_reg_cross)
self.add_regularization_weight(self.crossnet.V_list, l2_reg_cross)
self.add_regularization_weight(self.crossnet.C_list, l2_reg_cross)
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn)
self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_linear)
self.add_regularization_weight(self.crossnet.U_list, l2=l2_reg_cross)
self.add_regularization_weight(self.crossnet.V_list, l2=l2_reg_cross)
self.add_regularization_weight(self.crossnet.C_list, l2=l2_reg_cross)
self.to(device)

def forward(self, X):
Expand Down
4 changes: 2 additions & 2 deletions deepctr_torch/models/deepfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def __init__(self,
dnn_hidden_units[-1], 1, bias=False).to(device)

self.add_regularization_weight(
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2_reg_dnn)
self.add_regularization_weight(self.dnn_linear.weight, l2_reg_dnn)
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn)
self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_dnn)
self.to(device)

def forward(self, X):
Expand Down
4 changes: 2 additions & 2 deletions deepctr_torch/models/nfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def __init__(self,
self.dnn_linear = nn.Linear(
dnn_hidden_units[-1], 1, bias=False).to(device)
self.add_regularization_weight(
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2_reg_dnn)
self.add_regularization_weight(self.dnn_linear.weight, l2_reg_dnn)
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn)
self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_dnn)
self.bi_pooling = BiInteractionPooling()
self.bi_dropout = bi_dropout
if self.bi_dropout > 0:
Expand Down
7 changes: 3 additions & 4 deletions deepctr_torch/models/onn.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ def __init__(self, linear_feature_columns, dnn_feature_columns,
dnn_feature_columns, embedding_size=embedding_size, sparse=False).to(device)

# add regularization for second_order_embedding
self.add_regularization_weight(
self.second_order_embedding_dict.parameters(), l2_reg_embedding)
self.add_regularization_weight(self.second_order_embedding_dict.parameters(), l2=l2_reg_embedding)

dim = self.__compute_nffm_dnn_dim(
feature_columns=dnn_feature_columns, embedding_size=embedding_size)
Expand All @@ -82,8 +81,8 @@ def __init__(self, linear_feature_columns, dnn_feature_columns,
self.dnn_linear = nn.Linear(
dnn_hidden_units[-1], 1, bias=False).to(device)
self.add_regularization_weight(
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2_reg_dnn)
self.add_regularization_weight(self.dnn_linear.weight, l2_reg_dnn)
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn)
self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_dnn)
self.to(device)

def __compute_nffm_dnn_dim(self, feature_columns, embedding_size):
Expand Down
4 changes: 2 additions & 2 deletions deepctr_torch/models/pnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def __init__(self, dnn_feature_columns, dnn_hidden_units=(128, 128), l2_reg_embe
self.dnn_linear = nn.Linear(
dnn_hidden_units[-1], 1, bias=False).to(device)
self.add_regularization_weight(
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2_reg_dnn)
self.add_regularization_weight(self.dnn_linear.weight, l2_reg_dnn)
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn)
self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_dnn)

self.to(device)

Expand Down
4 changes: 2 additions & 2 deletions deepctr_torch/models/wdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def __init__(self,
init_std=init_std, device=device)
self.dnn_linear = nn.Linear(dnn_hidden_units[-1], 1, bias=False).to(device)
self.add_regularization_weight(
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2_reg_dnn)
self.add_regularization_weight(self.dnn_linear.weight, l2_reg_dnn)
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn)
self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_dnn)

self.to(device)

Expand Down
8 changes: 4 additions & 4 deletions deepctr_torch/models/xdeepfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, dnn_hidden_units
init_std=init_std, device=device)
self.dnn_linear = nn.Linear(dnn_hidden_units[-1], 1, bias=False).to(device)
self.add_regularization_weight(
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2_reg_dnn)
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn)

self.add_regularization_weight(self.dnn_linear.weight, l2_reg_dnn)
self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_dnn)

self.cin_layer_size = cin_layer_size
self.use_cin = len(self.cin_layer_size) > 0 and len(dnn_feature_columns) > 0
Expand All @@ -70,8 +70,8 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, dnn_hidden_units
self.cin = CIN(field_num, cin_layer_size,
cin_activation, cin_split_half, l2_reg_cin, seed, device=device)
self.cin_linear = nn.Linear(self.featuremap_num, 1, bias=False).to(device)
self.add_regularization_weight(
filter(lambda x: 'weight' in x[0], self.cin.named_parameters()), l2_reg_cin)
self.add_regularization_weight(filter(lambda x: 'weight' in x[0], self.cin.named_parameters()),
l2=l2_reg_cin)

self.to(device)

Expand Down
3 changes: 2 additions & 1 deletion docs/source/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ model = DeepFM(linear_feature_columns,dnn_feature_columns)
model.compile(Adagrad(model.parameters(),0.1024),'binary_crossentropy',metrics=['binary_crossentropy'])

es = EarlyStopping(monitor='val_binary_crossentropy', min_delta=0, verbose=1, patience=0, mode='min')
mdckpt = ModelCheckpoint(filepath='model.ckpt')
mdckpt = ModelCheckpoint(filepath = 'model.ckpt', save_best_only= True)
history = model.fit(model_input,data[target].values,batch_size=256,epochs=10,verbose=2,validation_split=0.2,callbacks=[es,mdckpt])
print(history)
```

## 3. How to add a long dense feature vector as a input to the model?
Expand Down
1 change: 1 addition & 0 deletions docs/source/History.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# History
- 12/05/2020 : [v0.2.4](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.4) released.Imporve compatibility & fix issues.Add History callback.([example](https://deepctr-torch.readthedocs.io/en/latest/FAQ.html#set-learning-rate-and-use-earlystopping)).
- 10/18/2020 : [v0.2.3](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.3) released.Add [DCN-M](./Features.html#dcn-deep-cross-network)&[DCN-Mix](./Features.html#dcn-mix-improved-deep-cross-network-with-mix-of-experts-and-matrix-kernel).Add EarlyStopping and ModelCheckpoint callbacks([example](https://deepctr-torch.readthedocs.io/en/latest/FAQ.html#set-learning-rate-and-use-earlystopping)).
- 10/09/2020 : [v0.2.2](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.2) released.Improve the reproducibility & fix some bugs.
- 03/27/2020 : [v0.2.1](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.1) released.Add [DIN](./Features.html#din-deep-interest-network) and [DIEN](./Features.html#dien-deep-interest-evolution-network) .
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# The short X.Y version
version = ''
# The full version, including alpha/beta/rc tags
release = '0.2.3'
release = '0.2.4'


# -- General configuration ---------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ You can read the latest code at https://github.com/shenweichen/DeepCTR-Torch and

News
-----
12/05/2020 : Imporve compatibility & fix issues.Add History callback(`example <https://deepctr-torch.readthedocs.io/en/latest/FAQ.html#set-learning-rate-and-use-earlystopping>`_). `Changelog <https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.3>`_

10/18/2020 : Add `DCN-M <./Features.html#dcn-deep-cross-network>`_ and `DCN-Mix <./Features.html#dcn-mix-improved-deep-cross-network-with-mix-of-experts-and-matrix-kernel>`_ . Add EarlyStopping and ModelCheckpoint callbacks(`example <https://deepctr-torch.readthedocs.io/en/latest/FAQ.html#set-learning-rate-and-use-earlystopping>`_). `Changelog <https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.3>`_

10/09/2020 : Improve the reproducibility & fix some bugs. `Changelog <https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.2>`_

03/27/2020 : Add `DIN <./Features.html#din-deep-interest-network>`_ and `DIEN <./Features.html#dien-deep-interest-evolution-network>`_ . `Changelog <https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.1>`_


DisscussionGroup
-----------------------
Expand Down
Loading

0 comments on commit 500c0a5

Please sign in to comment.