Skip to content

Commit

Permalink
Merge pull request #1468 from zhengbw0324/master
Browse files Browse the repository at this point in the history
FIX: update scipy version
  • Loading branch information
zhengbw0324 authored Oct 4, 2022
2 parents b616af6 + 629b915 commit 83e931c
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 20 deletions.
12 changes: 6 additions & 6 deletions conda/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ requirements:
host:
- python
- numpy >=1.17.2
- scipy ==1.6.0
- scipy >=1.6.0
- pandas >=1.0.5
- tqdm >=4.48.2
- pyyaml >=5.1.0
Expand All @@ -20,12 +20,12 @@ requirements:
- colorlog==4.7.2
- colorama==0.4.4
- tensorboard >=2.5.0
- tabulate>=0.8.10
- plotly>=4.0.0
- tabulate >=0.8.10
- plotly >=4.0.0
run:
- python
- numpy >=1.17.2
- scipy ==1.6.0
- scipy >=1.6.0
- pandas >=1.0.5
- tqdm >=4.48.2
- pyyaml >=5.1.0
Expand All @@ -34,8 +34,8 @@ requirements:
- colorlog==4.7.2
- colorama==0.4.4
- tensorboard >=2.5.0
- tabulate>=0.8.10
- plotly>=4.0.0
- tabulate >=0.8.10
- plotly >=4.0.0
test:
imports:
- recbole
Expand Down
16 changes: 4 additions & 12 deletions recbole/model/context_aware_recommender/fwfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def __init__(self, config, dataset):
self._get_feature2field()
self.num_fields = len(set(self.feature2field.values())) # the number of fields
self.num_pair = self.num_fields * self.num_fields

self.weight = torch.randn(
self.num_fields, self.num_fields, 1, requires_grad=True, device=self.device
)
self.loss = nn.BCEWithLogitsLoss()

# parameters initialization
Expand Down Expand Up @@ -106,17 +108,7 @@ def fwfm_layer(self, infeature):
"""
# get r(Fi, Fj)
batch_size = infeature.shape[0]
para = (
torch.randn(self.num_fields * self.num_fields * self.embedding_size)
.expand(batch_size, self.num_fields * self.num_fields * self.embedding_size)
.to(self.device)
) # [batch_size*num_pairs*emb_dim]
para = para.reshape(
batch_size, self.num_fields, self.num_fields, self.embedding_size
)
r = nn.Parameter(
para, requires_grad=True
) # [batch_size, num_fields, num_fields, emb_dim]
weight = self.weight.expand(batch_size, -1, -1, -1)

fwfm_inter = list() # [batch_size, num_fields, emb_dim]
for i in range(self.num_features - 1):
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
torch>=1.10.0
numpy>=1.17.2
scipy==1.6.0
scipy>=1.6.0
hyperopt==0.2.5
pandas>=1.0.5
tqdm>=4.48.2
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
install_requires = [
"torch>=1.10.0",
"numpy>=1.17.2",
"scipy==1.6.0",
"scipy>=1.6.0",
"pandas>=1.0.5",
"tqdm>=4.48.2",
"colorlog==4.7.2",
Expand Down

0 comments on commit 83e931c

Please sign in to comment.