Skip to content

Commit

Permalink
Docstring of base Proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhernandezgarcia committed May 17, 2024
1 parent 361cdc0 commit 4c1b1de
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions gflownet/proxy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@


class Proxy(ABC):
"""
Generic proxy class
"""

def __init__(
self,
device,
Expand All @@ -31,6 +27,38 @@ def __init__(
do_clip_rewards: bool = False,
**kwargs,
):
r"""
Base Proxy class for GFlowNet proxies.
A proxy is the input to a reward function. Depending on the
``reward_function``, the reward may be directly the output of the proxy or a
function of it.
Arguments
---------
device : str or torch.device
The device to be passed to torch tensors.
float_precision : int or torch.dtype
The floating point precision to be passed to torch tensors.
reward_function : str or Callable
The transformation applied to the proxy outputs to obtain a GFlowNet
reward. See :py:meth:`Proxy._get_reward_functions`.
logreward_function : Callable
The transformation applied to the proxy outputs to obtain a GFlowNet
log reward. See :meth:`Proxy._get_reward_functions`. If None (default), the
log of the reward function is used. The Callable may be used to improve the
numerical stability of the transformation.
reward_function_kwargs : dict
A dictionary of arguments to be passed to the reward function.
reward_min : float
The minimum value allowed for rewards, 0.0 by default, which results in a
minimum log reward of :py:const:`LOGZERO`. Note that certain loss
functions, for example the Forward Looking loss may not work as desired if
the minimum reward is 0.0. It may be set to a small (positive) value close
to zero in order to prevent numerical stability issues.
do_clip_rewards : bool
Whether to clip the rewards according to the minimum value.
"""
# Proxy to reward function
self.reward_function = reward_function
self.logreward_function = logreward_function
Expand Down

0 comments on commit 4c1b1de

Please sign in to comment.