Skip to content

Commit

Permalink
Merge pull request #336 from alexhernandezgarcia/proxy-in-batch-buffe…
Browse files Browse the repository at this point in the history
…r-no-proxy

Prevent unnecessary recalculation of proxy values
  • Loading branch information
alexhernandezgarcia authored Jul 17, 2024
2 parents 6ac2c5f + 7cd223f commit aebc5f9
Show file tree
Hide file tree
Showing 6 changed files with 472 additions and 122 deletions.
50 changes: 25 additions & 25 deletions gflownet/gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,35 +1168,34 @@ def train(self):
all_losses.append([i.item() for i in losses])
# Buffer
t0_buffer = time.time()
# TODO: the current implementation recomputes the proxy values of the
# terminating states in order to store the proxy values in the Buffer.
# Depending on the computational cost of the proxy, this may be very
# inneficient. For example, proxy.rewards() could return the proxy values,
# which could be stored in the Batch.
if it == 0:
print(
"IMPORTANT: The current implementation recomputes the proxy "
"values of the terminating states in order to store the proxy "
"values in the Buffer. Depending on the computational cost of "
"the proxy, this may be very inneficient."
)
states_term = batch.get_terminating_states(sort_by="trajectory")
states_proxy_term = batch.get_terminating_states(
proxy=True, sort_by="trajectory"
)
proxy_vals = self.proxy(states_proxy_term)
rewards = self.proxy.proxy2reward(proxy_vals)
rewards = rewards.tolist()
proxy_vals = batch.get_terminating_proxy_values(sort_by="trajectory")
proxy_vals = proxy_vals.tolist()
# The batch will typically have the log-rewards available, since they are
# used to compute the losses. In order to avoid recalculating the proxy
# values, the natural rewards are computed by taking the exponential of the
# log-rewards. In case the rewards are available in the batch but not the
# log-rewards, the latter are computed by taking the log of the rewards.
# Numerical issues are not critical in this case, since the derived values
# are only used for reporting purposes.
if batch.rewards_available(log=False):
rewards = batch.get_terminating_rewards(sort_by="trajectory")
if batch.rewards_available(log=True):
logrewards = batch.get_terminating_rewards(
sort_by="trajectory", log=True
)
if not batch.rewards_available(log=False):
assert batch.rewards_available(log=True)
rewards = torch.exp(logrewards)
if not batch.rewards_available(log=True):
assert batch.rewards_available(log=False)
logrewards = torch.log(rewards)
rewards = rewards.tolist()
logrewards = logrewards.tolist()
actions_trajectories = batch.get_actions_trajectories()
self.buffer.add(states_term, actions_trajectories, rewards, proxy_vals, it)
self.buffer.add(states_term, actions_trajectories, logrewards, it)
self.buffer.add(
states_term,
actions_trajectories,
rewards,
proxy_vals,
it,
buffer="replay",
states_term, actions_trajectories, logrewards, it, buffer="replay"
)
t1_buffer = time.time()
times.update({"buffer": t1_buffer - t0_buffer})
Expand All @@ -1215,6 +1214,7 @@ def train(self):
self.logger.log_train(
losses=losses,
rewards=rewards,
logrewards=logrewards,
proxy_vals=proxy_vals,
states_term=states_term,
batch_size=len(batch),
Expand Down
26 changes: 20 additions & 6 deletions gflownet/proxy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,11 @@ def __call__(self, states: Union[TensorType, List, npt.NDArray]) -> TensorType:
pass

def rewards(
self, states: Union[TensorType, List, npt.NDArray], log: bool = False
) -> TensorType:
self,
states: Union[TensorType, List, npt.NDArray],
log: bool = False,
return_proxy: bool = False,
) -> Union[TensorType, Tuple[TensorType, TensorType]]:
"""
Computes the rewards of a batch of states.
Expand All @@ -120,16 +123,27 @@ def rewards(
log : bool
If True, returns the logarithm of the rewards. If False (default), returns
the natural rewards.
return_proxy : bool
If True, returns the proxy values, alongside the rewards, as the second
element in the returned tuple.
Returns
-------
tensor
The reward of all elements in the batch.
rewards : tensor
The reward or log-reward of all elements in the batch.
proxy_values : tensor (optional)
The proxy value of all elements in the batch. Included only if return_proxy
is True.
"""
proxy_values = self(states)
if log:
return self.proxy2logreward(self(states))
rewards = self.proxy2logreward(proxy_values)
else:
rewards = self.proxy2reward(proxy_values)
if return_proxy:
return rewards, proxy_values
else:
return self.proxy2reward(self(states))
return rewards

def proxy2reward(self, proxy_values: TensorType) -> TensorType:
"""
Expand Down
Loading

0 comments on commit aebc5f9

Please sign in to comment.