Skip to content

Commit

Permalink
Merge pull request #332 from alexhernandezgarcia/cleanup
Browse files Browse the repository at this point in the history
[Cleanup] Remove old, unused methods in gflownet.py
  • Loading branch information
alexhernandezgarcia authored Jul 12, 2024
2 parents 0be7346 + f280ee1 commit e2de6db
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 103 deletions.
61 changes: 0 additions & 61 deletions gflownet/gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,34 +1350,6 @@ def sample_from_reward(
samples_final.extend(samples_accepted[-(n_samples - len(samples_final)) :])
return samples_final

def get_log_corr(self, times):
data_logq = []
times.update(
{
"test_trajs": 0.0,
"test_logq": 0.0,
}
)
# TODO: this could be done just once and store it
for statestr, score in tqdm(
zip(self.buffer.test.samples, self.buffer.test["energies"]), disable=True
):
t0_test_traj = time.time()
traj_list, actions = self.env.get_trajectories(
[],
[],
[self.env.readable2state(statestr)],
[self.env.eos],
)
t1_test_traj = time.time()
times["test_trajs"] += t1_test_traj - t0_test_traj
t0_test_logq = time.time()
data_logq.append(logq(traj_list, actions, self.forward_policy, self.env))
t1_test_logq = time.time()
times["test_logq"] += t1_test_logq - t0_test_logq
corr = np.corrcoef(data_logq, self.buffer.test["energies"])
return corr, data_logq, times


def make_opt(params, logZ, config):
"""
Expand Down Expand Up @@ -1408,36 +1380,3 @@ def make_opt(params, logZ, config):
gamma=config.lr_decay_gamma,
)
return opt, lr_scheduler


def logq(traj_list, actions_list, model, env):
# TODO: this method is probably suboptimal, since it may repeat forward calls for
# the same nodes.
log_q = torch.tensor(1.0)
for traj, actions in zip(traj_list, actions_list):
traj = traj[::-1]
actions = actions[::-1]
masks = tbool(
[env.get_mask_invalid_actions_forward(state, 0) for state in traj],
device=self.device,
)
with torch.no_grad():
logits_traj = model(
tfloat(
env.states2policy(traj),
device=self.device,
float_type=self.float,
)
)
logits_traj[masks] = -torch.inf
logsoftmax = torch.nn.LogSoftmax(dim=1)
logprobs_traj = logsoftmax(logits_traj)
log_q_traj = torch.tensor(0.0)
for s, a, logprobs in zip(*[traj, actions, logprobs_traj]):
log_q_traj = log_q_traj + logprobs[a]
# Accumulate log prob of trajectory
if torch.le(log_q, 0.0):
log_q = torch.logaddexp(log_q, log_q_traj)
else:
log_q = log_q_traj
return log_q.item()
42 changes: 0 additions & 42 deletions gflownet/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,48 +254,6 @@ def log_train(
step=step,
)

def log_sampler_test(
self, corr: array, data_logq: list, step: int, use_context: bool
):
if not self.do.online:
return
if self.should_eval(step):
test_metrics = dict(
zip(
[
"test_corr_logq_score",
"test_mean_log",
],
[
corr[0, 1],
np.mean(data_logq),
],
)
)
self.log_metrics(
test_metrics,
use_context=use_context,
)

def log_losses(
self,
losses: list,
step: int,
use_context: bool,
):
if not self.do.online:
return
loss_metrics = dict(
zip(
["loss", "term_loss", "flow_loss"],
[loss.item() for loss in losses],
)
)
self.log_metrics(
loss_metrics,
use_context=use_context,
)

def save_models(
self, forward_policy, backward_policy, state_flow, step: int = 1e9, final=False
):
Expand Down

0 comments on commit e2de6db

Please sign in to comment.