Skip to content

Commit

Permalink
minor bugs on experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
qiauil committed Nov 14, 2024
1 parent f505b9a commit ccf8fa2
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/examples/pinn_burgers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
" for i in range(n_layers):\n",
" self.net.append(nn.Sequential(nn.Linear(channel_basics, channel_basics), nn.Tanh()))\n",
" self.net=nn.Sequential(*self.net)\n",
" self.out_net=nn.Sequential(nn.Linear(channel_basics, 1))\n",
" self.out_net=nn.Linear(channel_basics, 1)\n",
" \n",
" \n",
" def forward(self, x, t):\n",
Expand Down
2 changes: 1 addition & 1 deletion experiments/PINN/lib_pinns/burgers/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self,channel_basics=50,n_layers=4, *args, **kwargs) -> None:
for i in range(n_layers):
self.net.append(nn.Sequential(nn.Linear(channel_basics, channel_basics), nn.Tanh()))
self.net=nn.Sequential(*self.net)
self.out_net=nn.Sequential(nn.Linear(channel_basics, 1))
self.out_net=nn.Linear(channel_basics, 1)

def forward(self, x, t):
ini_shape=x.shape
Expand Down
4 changes: 3 additions & 1 deletion experiments/PINN/lib_pinns/network_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@
def xavier_init_weights(m):
if type(m) == nn.Linear:
torch.nn.init.xavier_normal_(m.weight, 1)
m.bias.data.fill_(0.001)
if hasattr(m, 'bias'):
if m.bias is not None:
m.bias.data.fill_(0.001)
2 changes: 0 additions & 2 deletions experiments/PINN/lib_pinns/trainer_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,6 @@ def get_momentum_trainer(sub_trainer,operator):
class MomentumTrainer(MomentumGradVecTrainerBasis):
def initialize_momentum_handler(self,network):
self.momentum_handler=PseudoMomentumOperator(num_vectors=self.configs.n_losses,
network=network,
gradient_operator=operator,
loss_recorder=LatestLossRecorder(self.configs.n_losses))
class Trainer(sub_trainer,MomentumTrainer):
Expand All @@ -308,7 +307,6 @@ def get_separate_momentum_trainer(sub_trainer,operator):
class MomentumTrainer(MomentumGradVecTrainerBasis):
def initialize_momentum_handler(self,network):
self.momentum_handler=SeparateMomentumOperator(num_vectors=self.configs.n_losses,
network=network,
gradient_operator=operator,
loss_recorder=LatestLossRecorder(self.configs.n_losses))
class Trainer(sub_trainer,MomentumTrainer):
Expand Down

0 comments on commit ccf8fa2

Please sign in to comment.