Skip to content

Commit

Permalink
black, isort
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexandraVolokhova committed Nov 24, 2023
1 parent ebb2737 commit 64f7c3b
Show file tree
Hide file tree
Showing 11 changed files with 101 additions and 52 deletions.
95 changes: 72 additions & 23 deletions gflownet/gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(
print(f"\tStd score: {self.buffer.test['energies'].std()}")
print(f"\tMin score: {self.buffer.test['energies'].min()}")
print(f"\tMax score: {self.buffer.test['energies'].max()}")

# Models
self.forward_policy = forward_policy
if self.forward_policy.checkpoint is not None:
Expand Down Expand Up @@ -157,8 +157,8 @@ def __init__(

self.state_flow = state_flow
if self.state_flow is not None and self.state_flow.checkpoint is not None:
self.logger.set_state_flow_ckpt_path(self.state_flow.checkpoint)
# TODO: add the logic and conditions to reload a model
self.logger.set_state_flow_ckpt_path(self.state_flow.checkpoint)
# TODO: add the logic and conditions to reload a model
else:
self.logger.set_state_flow_ckpt_path(None)

Expand Down Expand Up @@ -202,7 +202,7 @@ def parameters(self):
parameters += list(self.backward_policy.model.parameters())
if self.state_flow is not None:
if self.loss != "forwardlooking":
raise ValueError(f"State flow cannot be trained in {self.loss} loss.")
raise ValueError(f"State flow cannot be trained in {self.loss} loss.")
parameters += list(self.state_flow.model.parameters())
return parameters

Expand Down Expand Up @@ -424,12 +424,22 @@ def sample_batch(
"actions_envs": 0.0,
}
t0_all = time.time()
batch = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards)
batch = Batch(
env=self.env,
device=self.device,
float_type=self.float,
non_terminal_rewards=self.non_terminal_rewards,
)

# ON-POLICY FORWARD trajectories
t0_forward = time.time()
envs = [self.env.copy().reset(idx) for idx in range(n_forward)]
batch_forward = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards)
batch_forward = Batch(
env=self.env,
device=self.device,
float_type=self.float,
non_terminal_rewards=self.non_terminal_rewards,
)
while envs:
# Sample actions
t0_a_envs = time.time()
Expand All @@ -451,7 +461,12 @@ def sample_batch(
# TRAIN BACKWARD trajectories
t0_train = time.time()
envs = [self.env.copy().reset(idx) for idx in range(n_train)]
batch_train = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards)
batch_train = Batch(
env=self.env,
device=self.device,
float_type=self.float,
non_terminal_rewards=self.non_terminal_rewards,
)
if n_train > 0 and self.buffer.train_pkl is not None:
with open(self.buffer.train_pkl, "rb") as f:
dict_tr = pickle.load(f)
Expand Down Expand Up @@ -482,7 +497,12 @@ def sample_batch(

# REPLAY BACKWARD trajectories
t0_replay = time.time()
batch_replay = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards)
batch_replay = Batch(
env=self.env,
device=self.device,
float_type=self.float,
non_terminal_rewards=self.non_terminal_rewards,
)
if n_replay > 0 and self.buffer.replay_pkl is not None:
with open(self.buffer.replay_pkl, "rb") as f:
dict_replay = pickle.load(f)
Expand Down Expand Up @@ -720,18 +740,17 @@ def forwardlooking_loss(self, it, batch):
masks_b = batch.get_masks_backward()
policy_output_b = self.backward_policy(states_policy)
logprobs_bkw = self.env.get_logprobs(
policy_output_b, actions, masks_b, states, is_backward=True
)
policy_output_b, actions, masks_b, states, is_backward=True
)
masks_f = batch.get_masks_forward(of_parents=True)
policy_output_f = self.forward_policy(parents_policy)
logprobs_fwd = self.env.get_logprobs(
policy_output_f, actions, masks_f, parents, is_backward=False
)

policy_output_f, actions, masks_f, parents, is_backward=False
)

states_log_flflow = self.state_flow(states_policy)
# forward-looking flow is 1 in the terminal states
states_log_flflow[done.eq(1)] = 0.
states_log_flflow[done.eq(1)] = 0.0
# Can be optimised by reusing states_log_flflow and batch.get_parent_indices
parents_log_flflow = self.state_flow(parents_policy)

Expand All @@ -741,9 +760,15 @@ def forwardlooking_loss(self, it, batch):
energies_states = -torch.log(rewards_states)
energies_parents = -torch.log(rewards_parents)

per_node_loss = (parents_log_flflow - states_log_flflow + logprobs_fwd - logprobs_bkw +
energies_states - energies_parents).pow(2)

per_node_loss = (
parents_log_flflow
- states_log_flflow
+ logprobs_fwd
- logprobs_bkw
+ energies_states
- energies_parents
).pow(2)

term_loss = per_node_loss[done].mean()
nonterm_loss = per_node_loss[~done].mean()
loss = per_node_loss.mean()
Expand Down Expand Up @@ -840,7 +865,12 @@ def estimate_logprobs_data(
end_batch = min(batch_size, n_states)
pbar = tqdm(total=n_states)
while init_batch < n_states:
batch = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards)
batch = Batch(
env=self.env,
device=self.device,
float_type=self.float,
non_terminal_rewards=self.non_terminal_rewards,
)
# Create an environment for each data point and trajectory and set the state
envs = []
for state_idx in range(init_batch, end_batch):
Expand Down Expand Up @@ -939,7 +969,12 @@ def train(self):
self.logger.log_metrics(metrics, use_context=self.use_context, step=it)
self.logger.log_summary(summary)
t0_iter = time.time()
batch = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards)
batch = Batch(
env=self.env,
device=self.device,
float_type=self.float,
non_terminal_rewards=self.non_terminal_rewards,
)
for j in range(self.sttr):
sub_batch, times = self.sample_batch(
n_forward=self.batch_size.forward,
Expand Down Expand Up @@ -1021,7 +1056,9 @@ def train(self):
times.update({"log": t1_log - t0_log})
# Save intermediate models
t0_model = time.time()
self.logger.save_models(self.forward_policy, self.backward_policy, self.state_flow, step=it)
self.logger.save_models(
self.forward_policy, self.backward_policy, self.state_flow, step=it
)
t1_model = time.time()
times.update({"save_interim_model": t1_model - t0_model})

Expand Down Expand Up @@ -1050,7 +1087,9 @@ def train(self):
self.logger.log_time(times, use_context=self.use_context)

# Save final model
self.logger.save_models(self.forward_policy, self.backward_policy, self.state_flow, final=True)
self.logger.save_models(
self.forward_policy, self.backward_policy, self.state_flow, final=True
)
# Close logger
if self.use_context is False:
self.logger.end()
Expand Down Expand Up @@ -1225,7 +1264,12 @@ def test_top_k(self, it, progress=False, gfn_states=None, random_states=None):
print()
if not gfn_states:
# sample states from the current gfn
batch = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards)
batch = Batch(
env=self.env,
device=self.device,
float_type=self.float,
non_terminal_rewards=self.non_terminal_rewards,
)
self.random_action_prob = 0
t = time.time()
print("Sampling from GFN...", end="\r")
Expand All @@ -1248,7 +1292,12 @@ def test_top_k(self, it, progress=False, gfn_states=None, random_states=None):
if do_random:
# sample random states from uniform actions
if not random_states:
batch = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards)
batch = Batch(
env=self.env,
device=self.device,
float_type=self.float,
non_terminal_rewards=self.non_terminal_rewards,
)
self.random_action_prob = 1.0
print("[test_top_k] Sampling at random...", end="\r")
for b in batch_with_rest(
Expand Down
12 changes: 6 additions & 6 deletions gflownet/policy/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from abc import ABC, abstractmethod

import torch
from omegaconf import OmegaConf
from torch import nn
from abc import ABC, abstractmethod

from gflownet.utils.common import set_device, set_float_precision


Expand All @@ -14,7 +16,7 @@ def __init__(self, config, input_dim, device, float_precision, base=None):
self.input_dim = input_dim
# Must be redefined in the children classes
self.output_dim = None

# Optional base model
self.base = base

Expand Down Expand Up @@ -90,13 +92,12 @@ def make_mlp(self, activation):
)



class Policy(ModelBase):
def __init__(self, config, env, device, float_precision, base=None):
super().__init__(config, env.policy_input_dim, device, float_precision, base)

# Outputs
# Outputs

self.fixed_output = torch.tensor(env.fixed_policy_output).to(
dtype=self.float, device=self.device
)
Expand All @@ -107,7 +108,6 @@ def __init__(self, config, env, device, float_precision, base=None):

self.instantiate()


def instantiate(self):
if self.type == "fixed":
self.model = self.fixed_distribution
Expand Down
10 changes: 6 additions & 4 deletions gflownet/policy/state_flow.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import torch
from torch import nn

from gflownet.utils.common import set_device, set_float_precision
from gflownet.policy.base import ModelBase
from gflownet.utils.common import set_device, set_float_precision


class StateFlow(ModelBase):
"""
Takes state in the policy format and predicts its flow (a scalar)
"""

def __init__(self, config, env, device, float_precision, base=None):
super().__init__(config, env.policy_input_dim, device, float_precision, base)

# output dim
self.output_dim = 1

# Instantiate neural network
self.instantiate()

Expand All @@ -22,4 +24,4 @@ def instantiate(self):
self.model = self.make_mlp(nn.LeakyReLU()).to(self.device)
self.is_model = True
else:
raise "StateFlow model type not defined"
raise "StateFlow model type not defined"
24 changes: 13 additions & 11 deletions gflownet/utils/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
env: Optional[GFlowNetEnv] = None,
device: Union[str, torch.device] = "cpu",
float_type: Union[int, torch.dtype] = 32,
non_terminal_rewards: bool = False
non_terminal_rewards: bool = False,
):
"""
env : GFlowNetEnv
Expand Down Expand Up @@ -567,8 +567,10 @@ def _compute_parents(self):
# Sort parents list in the same order as states
# TODO: check if tensor and sort without iter
self.parents = [self.parents[indices.index(idx)] for idx in range(len(self))]
self.parents_indices = tlong([self.parents_indices[indices.index(idx)] for idx in range(len(self))],
device=self.device)
self.parents_indices = tlong(
[self.parents_indices[indices.index(idx)] for idx in range(len(self))],
device=self.device,
)
self.parents_available = True

# TODO: consider converting directly from self.parents
Expand Down Expand Up @@ -853,19 +855,17 @@ def get_rewards(
if self.rewards_available is False or force_recompute is True:
self._compute_rewards()
return self.rewards

def _compute_rewards(self):
"""
Computes rewards for all self.states by first converting the states into proxy
format. The result is stored in self.rewards as a torch.tensor
"""

self.rewards = torch.zeros(len(self), dtype=self.float, device=self.device)
done = self.get_done()
if self.non_terminal_rewards:
self.rewards = self.env.proxy2reward(
self.env.proxy(self.states2proxy())
)
self.rewards = self.env.proxy2reward(self.env.proxy(self.states2proxy()))
elif len(done) > 0:
states_proxy_done = self.get_terminating_states(proxy=True)
self.rewards[done] = self.env.proxy2reward(
Expand All @@ -884,13 +884,15 @@ def get_rewards_parents(self) -> TensorType["n_states"]:
def _compute_rewards_parents(self):
"""
Computes rewards of the self.parents by reusing rewards of the states (i.e. self.rewards).
Stores the result in self.rewards_parents
Stores the result in self.rewards_parents
"""
state_rewards = self.get_rewards()
self.rewards_parents = torch.zeros_like(state_rewards)
parent_is_source = self.get_parent_is_source()
parent_indices = self.get_parents_indices()
self.rewards_parents[~parent_is_source] = self.rewards[parent_indices[~parent_is_source]]
self.rewards_parents[~parent_is_source] = self.rewards[
parent_indices[~parent_is_source]
]
rewards_source = self.get_rewards_source()
self.rewards_parents[parent_is_source] = rewards_source[parent_is_source]
self.rewards_parents_available = True
Expand All @@ -902,7 +904,7 @@ def get_rewards_source(self) -> TensorType["n_states"]:
if not self.rewards_source_available:
self._compute_rewards_source()
return self.rewards_source

def _compute_rewards_source(self):
"""
Computes a tensor of length len(self.states) with rewards of the corresponding source states.
Expand Down
3 changes: 0 additions & 3 deletions gflownet/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,6 @@ def save_models(
path = self.sf_ckpt_path.parent / stem
torch.save(state_flow.model.state_dict(), path)




def log_time(self, times: dict, use_context: bool):
if self.do.times:
times = {"time_{}".format(k): v for k, v in times.items()}
Expand Down
2 changes: 1 addition & 1 deletion gflownet/utils/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ def parse_policy_config(config: DictConfig, kind: str) -> Optional[DictConfig]:
del policy_config.backward
del policy_config.shared

return policy_config
return policy_config
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def main(config):
device=config.device,
float_precision=config.float_precision,
base=forward_policy,
)
)
else:
state_flow = None

Expand Down
2 changes: 1 addition & 1 deletion mila/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from os.path import expandvars
from pathlib import Path
from textwrap import dedent
from git import Repo

from git import Repo
from yaml import safe_load

ROOT = Path(__file__).resolve().parent.parent
Expand Down
1 change: 1 addition & 0 deletions scripts/dav_mp20_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from collections import Counter

from external.repos.ActiveLearningMaterials.dave.utils.loaders import make_loaders

from gflownet.proxy.crystals.dave import DAVE
from gflownet.utils.common import load_gflow_net_from_run_path, resolve_path

Expand Down
Loading

0 comments on commit 64f7c3b

Please sign in to comment.