From 3a0a0ddd9745076342b34f6d6c50cd7ebf76c91e Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Thu, 4 Jul 2024 14:43:32 -0400 Subject: [PATCH 1/2] Remove unused methods in gflownet.py --- gflownet/gflownet.py | 61 -------------------------------------------- 1 file changed, 61 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 81c881076..4e5d38605 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -1350,34 +1350,6 @@ def sample_from_reward( samples_final.extend(samples_accepted[-(n_samples - len(samples_final)) :]) return samples_final - def get_log_corr(self, times): - data_logq = [] - times.update( - { - "test_trajs": 0.0, - "test_logq": 0.0, - } - ) - # TODO: this could be done just once and store it - for statestr, score in tqdm( - zip(self.buffer.test.samples, self.buffer.test["energies"]), disable=True - ): - t0_test_traj = time.time() - traj_list, actions = self.env.get_trajectories( - [], - [], - [self.env.readable2state(statestr)], - [self.env.eos], - ) - t1_test_traj = time.time() - times["test_trajs"] += t1_test_traj - t0_test_traj - t0_test_logq = time.time() - data_logq.append(logq(traj_list, actions, self.forward_policy, self.env)) - t1_test_logq = time.time() - times["test_logq"] += t1_test_logq - t0_test_logq - corr = np.corrcoef(data_logq, self.buffer.test["energies"]) - return corr, data_logq, times - def make_opt(params, logZ, config): """ @@ -1408,36 +1380,3 @@ def make_opt(params, logZ, config): gamma=config.lr_decay_gamma, ) return opt, lr_scheduler - - -def logq(traj_list, actions_list, model, env): - # TODO: this method is probably suboptimal, since it may repeat forward calls for - # the same nodes. - log_q = torch.tensor(1.0) - for traj, actions in zip(traj_list, actions_list): - traj = traj[::-1] - actions = actions[::-1] - masks = tbool( - [env.get_mask_invalid_actions_forward(state, 0) for state in traj], - device=self.device, - ) - with torch.no_grad(): - logits_traj = model( - tfloat( - env.states2policy(traj), - device=self.device, - float_type=self.float, - ) - ) - logits_traj[masks] = -torch.inf - logsoftmax = torch.nn.LogSoftmax(dim=1) - logprobs_traj = logsoftmax(logits_traj) - log_q_traj = torch.tensor(0.0) - for s, a, logprobs in zip(*[traj, actions, logprobs_traj]): - log_q_traj = log_q_traj + logprobs[a] - # Accumulate log prob of trajectory - if torch.le(log_q, 0.0): - log_q = torch.logaddexp(log_q, log_q_traj) - else: - log_q = log_q_traj - return log_q.item() From f280ee1c71f7ee2fb6777fda6b7cb6e981ca13c0 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 10 Jul 2024 17:16:30 +0200 Subject: [PATCH 2/2] Remove unused methods in logger.py --- gflownet/utils/logger.py | 42 ---------------------------------------- 1 file changed, 42 deletions(-) diff --git a/gflownet/utils/logger.py b/gflownet/utils/logger.py index 9fe982b30..b62597ceb 100644 --- a/gflownet/utils/logger.py +++ b/gflownet/utils/logger.py @@ -254,48 +254,6 @@ def log_train( step=step, ) - def log_sampler_test( - self, corr: array, data_logq: list, step: int, use_context: bool - ): - if not self.do.online: - return - if self.should_eval(step): - test_metrics = dict( - zip( - [ - "test_corr_logq_score", - "test_mean_log", - ], - [ - corr[0, 1], - np.mean(data_logq), - ], - ) - ) - self.log_metrics( - test_metrics, - use_context=use_context, - ) - - def log_losses( - self, - losses: list, - step: int, - use_context: bool, - ): - if not self.do.online: - return - loss_metrics = dict( - zip( - ["loss", "term_loss", "flow_loss"], - [loss.item() for loss in losses], - ) - ) - self.log_metrics( - loss_metrics, - use_context=use_context, - ) - def save_models( self, forward_policy, backward_policy, state_flow, step: int = 1e9, final=False ):