diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 81c88107..4e5d3860 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -1350,34 +1350,6 @@ def sample_from_reward( samples_final.extend(samples_accepted[-(n_samples - len(samples_final)) :]) return samples_final - def get_log_corr(self, times): - data_logq = [] - times.update( - { - "test_trajs": 0.0, - "test_logq": 0.0, - } - ) - # TODO: this could be done just once and store it - for statestr, score in tqdm( - zip(self.buffer.test.samples, self.buffer.test["energies"]), disable=True - ): - t0_test_traj = time.time() - traj_list, actions = self.env.get_trajectories( - [], - [], - [self.env.readable2state(statestr)], - [self.env.eos], - ) - t1_test_traj = time.time() - times["test_trajs"] += t1_test_traj - t0_test_traj - t0_test_logq = time.time() - data_logq.append(logq(traj_list, actions, self.forward_policy, self.env)) - t1_test_logq = time.time() - times["test_logq"] += t1_test_logq - t0_test_logq - corr = np.corrcoef(data_logq, self.buffer.test["energies"]) - return corr, data_logq, times - def make_opt(params, logZ, config): """ @@ -1408,36 +1380,3 @@ def make_opt(params, logZ, config): gamma=config.lr_decay_gamma, ) return opt, lr_scheduler - - -def logq(traj_list, actions_list, model, env): - # TODO: this method is probably suboptimal, since it may repeat forward calls for - # the same nodes. - log_q = torch.tensor(1.0) - for traj, actions in zip(traj_list, actions_list): - traj = traj[::-1] - actions = actions[::-1] - masks = tbool( - [env.get_mask_invalid_actions_forward(state, 0) for state in traj], - device=self.device, - ) - with torch.no_grad(): - logits_traj = model( - tfloat( - env.states2policy(traj), - device=self.device, - float_type=self.float, - ) - ) - logits_traj[masks] = -torch.inf - logsoftmax = torch.nn.LogSoftmax(dim=1) - logprobs_traj = logsoftmax(logits_traj) - log_q_traj = torch.tensor(0.0) - for s, a, logprobs in zip(*[traj, actions, logprobs_traj]): - log_q_traj = log_q_traj + logprobs[a] - # Accumulate log prob of trajectory - if torch.le(log_q, 0.0): - log_q = torch.logaddexp(log_q, log_q_traj) - else: - log_q = log_q_traj - return log_q.item() diff --git a/gflownet/utils/logger.py b/gflownet/utils/logger.py index 9fe982b3..b62597ce 100644 --- a/gflownet/utils/logger.py +++ b/gflownet/utils/logger.py @@ -254,48 +254,6 @@ def log_train( step=step, ) - def log_sampler_test( - self, corr: array, data_logq: list, step: int, use_context: bool - ): - if not self.do.online: - return - if self.should_eval(step): - test_metrics = dict( - zip( - [ - "test_corr_logq_score", - "test_mean_log", - ], - [ - corr[0, 1], - np.mean(data_logq), - ], - ) - ) - self.log_metrics( - test_metrics, - use_context=use_context, - ) - - def log_losses( - self, - losses: list, - step: int, - use_context: bool, - ): - if not self.do.online: - return - loss_metrics = dict( - zip( - ["loss", "term_loss", "flow_loss"], - [loss.item() for loss in losses], - ) - ) - self.log_metrics( - loss_metrics, - use_context=use_context, - ) - def save_models( self, forward_policy, backward_policy, state_flow, step: int = 1e9, final=False ):