Skip to content

Commit

Permalink
move self.is_model to instantiate and add super().parse_config(config…
Browse files Browse the repository at this point in the history
…) in the parse_config
  • Loading branch information
engmubarak48 committed Jul 9, 2024
1 parent 2ac9af5 commit 7e02200
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
3 changes: 2 additions & 1 deletion gflownet/policy/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def __init__(self, config, env, device, float_precision, base=None):
float_precision=float_precision,
base=base,
)
self.is_model = True

def make_cnn(self):
"""
Expand Down Expand Up @@ -73,6 +72,7 @@ def make_cnn(self):
return model.to(self.device)

def parse_config(self, config):
super().parse_config(config)
if config is None:
config = OmegaConf.create()
self.checkpoint = config.get("checkpoint", None)
Expand All @@ -85,6 +85,7 @@ def parse_config(self, config):

def instantiate(self):
self.model = self.make_cnn()
self.is_model = True

def __call__(self, states):
states = states.unsqueeze(1) # (batch_size, channels, height, width)
Expand Down
4 changes: 2 additions & 2 deletions gflownet/policy/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def __init__(self, config, env, device, float_precision, base=None):
float_precision=float_precision,
base=base,
)
self.is_model = True

def make_mlp(self, activation):
"""
Expand Down Expand Up @@ -61,9 +60,9 @@ def make_mlp(self, activation):
)

def parse_config(self, config):
super().parse_config(config)
if config is None:
config = OmegaConf.create()
config.type = "mlp"
self.checkpoint = config.get("checkpoint", None)
self.shared_weights = config.get("shared_weights", False)
self.n_hid = config.get("n_hid", 128)
Expand All @@ -73,6 +72,7 @@ def parse_config(self, config):

def instantiate(self):
self.model = self.make_mlp(nn.LeakyReLU()).to(self.device)
self.is_model = True

def __call__(self, states):
return self.model(states)

0 comments on commit 7e02200

Please sign in to comment.