Skip to content

Commit

Permalink
estimator_outputs can be passed around
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Nov 21, 2023
1 parent 8999dd3 commit 450ebf0
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions src/gfn/gflownet/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ def __init__(
self.logZ = nn.Parameter(torch.tensor(init_logZ))
self.log_reward_clip_min = log_reward_clip_min

def loss(self, env: Env, trajectories: Trajectories) -> TT[0, float]:
def loss(
self,
env: Env,
trajectories: Trajectories,
estimator_outputs: torch.Tensor = None,
) -> TT[0, float]:
"""Trajectory balance loss.
The trajectory balance loss is described in 2.3 of
Expand All @@ -51,7 +56,7 @@ def loss(self, env: Env, trajectories: Trajectories) -> TT[0, float]:
ValueError: if the loss is NaN.
"""
del env # unused
_, _, scores = self.get_trajectories_scores(trajectories)
_, _, scores = self.get_trajectories_scores(trajectories, estimator_outputs)
loss = (scores + self.logZ).pow(2).mean()
if torch.isnan(loss):
raise ValueError("loss is nan")
Expand Down Expand Up @@ -80,14 +85,19 @@ def __init__(

self.log_reward_clip_min = log_reward_clip_min

def loss(self, env: Env, trajectories: Trajectories) -> TT[0, float]:
def loss(
self,
env: Env,
trajectories: Trajectories,
estimator_outputs: torch.Tensor = None,
) -> TT[0, float]:
"""Log Partition Variance loss.
This method is described in section 3.2 of
[ROBUST SCHEDULING WITH GFLOWNETS](https://arxiv.org/abs/2302.05446))
"""
del env # unused
_, _, scores = self.get_trajectories_scores(trajectories)
_, _, scores = self.get_trajectories_scores(trajectories, estimator_outputs)
loss = (scores - scores.mean()).pow(2).mean()
if torch.isnan(loss):
raise ValueError("loss is NaN.")
Expand Down

0 comments on commit 450ebf0

Please sign in to comment.