Skip to content

Commit

Permalink
Various small improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhernandezgarcia committed Feb 9, 2024
1 parent a05a4d6 commit 6bb3654
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
11 changes: 10 additions & 1 deletion gflownet/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,16 @@ def state2proxy(
A state
"""
state = self._get_state(state)
return torch.squeeze(self.states2proxy([state]), dim=0)
state_proxy = self.states2proxy([state])
if isinstance(state_proxy, list):
return state_proxy[0]
elif torch.is_tensor(state_proxy):
return torch.squeeze(state_proxy, dim=0)
else:
raise NotImplementedError(
"The output of states2proxy must be either a list or a tensor. "
f"Got {type(state_proxy)}."
)

def states2policy(
self, states: Union[List, TensorType["batch", "state_dim"]]
Expand Down
8 changes: 6 additions & 2 deletions gflownet/proxy/uniform.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Union

import torch
from torchtyping import TensorType

Expand All @@ -8,8 +10,10 @@ class Uniform(Proxy):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def __call__(self, states: TensorType["batch", "state_dim"]) -> TensorType["batch"]:
return -1.0 * torch.ones(states.shape[0]).to(states)
def __call__(
self, states: Union[List, TensorType["batch", "state_dim"]]
) -> TensorType["batch"]:
return -1.0 * torch.ones(len(states), device=self.device, dtype=self.float)

@property
def min(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/gflownet/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def test__trajectories_are_reversible(env):
while not env.done:
state, action, valid = env.step_random(backward=False)
if valid:
states_trajectory_fw.append(state)
states_trajectory_fw.append(copy(state))
actions_trajectory_fw.append(action)

# Sample backward trajectory with actions in forward trajectory
Expand All @@ -402,7 +402,7 @@ def test__trajectories_are_reversible(env):
while not env.equal(env.state, env.source) or env.done:
state, action, valid = env.step_backwards(actions_trajectory_fw_copy.pop())
if valid:
states_trajectory_bw.append(state)
states_trajectory_bw.append(copy(state))
actions_trajectory_bw.append(action)

assert all(
Expand Down

0 comments on commit 6bb3654

Please sign in to comment.