Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent unnecessary recalculation of proxy values #336

Merged
merged 11 commits into from
Jul 17, 2024
Merged
50 changes: 25 additions & 25 deletions gflownet/gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,35 +1168,34 @@ def train(self):
all_losses.append([i.item() for i in losses])
# Buffer
t0_buffer = time.time()
# TODO: the current implementation recomputes the proxy values of the
# terminating states in order to store the proxy values in the Buffer.
# Depending on the computational cost of the proxy, this may be very
# inneficient. For example, proxy.rewards() could return the proxy values,
# which could be stored in the Batch.
if it == 0:
print(
"IMPORTANT: The current implementation recomputes the proxy "
"values of the terminating states in order to store the proxy "
"values in the Buffer. Depending on the computational cost of "
"the proxy, this may be very inneficient."
)
states_term = batch.get_terminating_states(sort_by="trajectory")
states_proxy_term = batch.get_terminating_states(
proxy=True, sort_by="trajectory"
)
proxy_vals = self.proxy(states_proxy_term)
rewards = self.proxy.proxy2reward(proxy_vals)
rewards = rewards.tolist()
proxy_vals = batch.get_terminating_proxy_values(sort_by="trajectory")
proxy_vals = proxy_vals.tolist()
# The batch will typically have the log-rewards available, since they are
# used to compute the losses. In order to avoid recalculating the proxy
# values, the natural rewards are computed by taking the exponential of the
# log-rewards. In case the rewards are available in the batch but not the
# log-rewards, the latter are computed by taking the log of the rewards.
# Numerical issues are not critical in this case, since the derived values
# are only used for reporting purposes.
if batch.rewards_available(log=False):
rewards = batch.get_terminating_rewards(sort_by="trajectory")
if batch.rewards_available(log=True):
logrewards = batch.get_terminating_rewards(
sort_by="trajectory", log=True
)
if not batch.rewards_available(log=False):
assert batch.rewards_available(log=True)
rewards = torch.exp(logrewards)
if not batch.rewards_available(log=True):
assert batch.rewards_available(log=False)
logrewards = torch.log(rewards)
rewards = rewards.tolist()
logrewards = logrewards.tolist()
actions_trajectories = batch.get_actions_trajectories()
self.buffer.add(states_term, actions_trajectories, rewards, proxy_vals, it)
self.buffer.add(states_term, actions_trajectories, logrewards, it)
self.buffer.add(
states_term,
actions_trajectories,
rewards,
proxy_vals,
it,
buffer="replay",
states_term, actions_trajectories, logrewards, it, buffer="replay"
)
t1_buffer = time.time()
times.update({"buffer": t1_buffer - t0_buffer})
Expand All @@ -1215,6 +1214,7 @@ def train(self):
self.logger.log_train(
losses=losses,
rewards=rewards,
logrewards=logrewards,
proxy_vals=proxy_vals,
states_term=states_term,
batch_size=len(batch),
Expand Down
26 changes: 20 additions & 6 deletions gflownet/proxy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,11 @@ def __call__(self, states: Union[TensorType, List, npt.NDArray]) -> TensorType:
pass

def rewards(
self, states: Union[TensorType, List, npt.NDArray], log: bool = False
) -> TensorType:
self,
states: Union[TensorType, List, npt.NDArray],
log: bool = False,
return_proxy: bool = False,
) -> Union[TensorType, Tuple[TensorType, TensorType]]:
"""
Computes the rewards of a batch of states.

Expand All @@ -120,16 +123,27 @@ def rewards(
log : bool
If True, returns the logarithm of the rewards. If False (default), returns
the natural rewards.
return_proxy : bool
If True, returns the proxy values, alongside the rewards, as the second
element in the returned tuple.

Returns
-------
tensor
The reward of all elements in the batch.
rewards : tensor
The reward or log-reward of all elements in the batch.
proxy_values : tensor (optional)
The proxy value of all elements in the batch. Included only if return_proxy
is True.
"""
proxy_values = self(states)
if log:
return self.proxy2logreward(self(states))
rewards = self.proxy2logreward(proxy_values)
else:
rewards = self.proxy2reward(proxy_values)
if return_proxy:
return rewards, proxy_values
else:
return self.proxy2reward(self(states))
return rewards

def proxy2reward(self, proxy_values: TensorType) -> TensorType:
"""
Expand Down
Loading
Loading