From 17d312e4c882f15e0f638eae43e70a4cc5e2211c Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 12 Jul 2024 09:37:05 +0200 Subject: [PATCH 01/11] Remove energies (proxy values) from buffer data frames --- gflownet/utils/buffer.py | 34 ++++++---------------------------- 1 file changed, 6 insertions(+), 28 deletions(-) diff --git a/gflownet/utils/buffer.py b/gflownet/utils/buffer.py index 1593c48ee..770b02f4f 100644 --- a/gflownet/utils/buffer.py +++ b/gflownet/utils/buffer.py @@ -36,13 +36,12 @@ def __init__( self.env = env self.proxy = proxy self.replay_capacity = replay_capacity - self.main = pd.DataFrame(columns=["state", "traj", "reward", "energy", "iter"]) + self.main = pd.DataFrame(columns=["state", "traj", "reward", "iter"]) self.replay = pd.DataFrame( np.empty((self.replay_capacity, 5), dtype=object), - columns=["state", "traj", "reward", "energy", "iter"], + columns=["state", "traj", "reward", "iter"], ) self.replay.reward = pd.to_numeric(self.replay.reward) - self.replay.energy = pd.to_numeric(self.replay.energy) self.replay.reward = [-1 for _ in range(self.replay_capacity)] self.replay_states = {} self.replay_trajs = {} @@ -127,16 +126,7 @@ def save_replay(self): f, ) - def add( - self, - states, - trajs, - rewards, - energies, - it, - buffer="main", - criterion="greater", - ): + def add(self, states, trajs, rewards, it, buffer="main", criterion="greater"): if buffer == "main": self.main = pd.concat( [ @@ -146,7 +136,6 @@ def add( "state": [self.env.state2readable(s) for s in states], "traj": [self.env.traj2readable(p) for p in trajs], "reward": rewards, - "energy": energies, "iter": it, } ), @@ -156,20 +145,10 @@ def add( ) elif buffer == "replay" and self.replay_capacity > 0: if criterion == "greater": - self.replay = self._add_greater(states, trajs, rewards, energies, it) + self.replay = self._add_greater(states, trajs, rewards, it) - def _add_greater( - self, - states, - trajs, - rewards, - energies, - it, - allow_duplicate_states=False, - ): - for idx, (state, traj, reward, energy) in enumerate( - zip(states, trajs, rewards, energies) - ): + def _add_greater(self, states, trajs, rewards, it, allow_duplicate_states=False): + for idx, (state, traj, reward) in enumerate(zip(states, trajs, rewards)): if not allow_duplicate_states: if isinstance(state, torch.Tensor): is_duplicate = False @@ -189,7 +168,6 @@ def _add_greater( "state": self.env.state2readable(state), "traj": self.env.traj2readable(traj), "reward": reward, - "energy": energy, "iter": it, } self.replay_states[(idx, it)] = state From 27bdd780d69e048b1e447ce495b34e8185fbdbf7 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 12 Jul 2024 12:32:39 +0200 Subject: [PATCH 02/11] Proxy base: rewards() can return the proxy values too --- gflownet/proxy/base.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index d7be5bf40..efe85e9e9 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -105,8 +105,11 @@ def __call__(self, states: Union[TensorType, List, npt.NDArray]) -> TensorType: pass def rewards( - self, states: Union[TensorType, List, npt.NDArray], log: bool = False - ) -> TensorType: + self, + states: Union[TensorType, List, npt.NDArray], + log: bool = False, + return_proxy: bool = False, + ) -> Union[TensorType, Tuple[TensorType, TensorType]]: """ Computes the rewards of a batch of states. @@ -120,16 +123,27 @@ def rewards( log : bool If True, returns the logarithm of the rewards. If False (default), returns the natural rewards. + return_proxy : bool + If True, returns the proxy values, alongside the rewards, as the second + element in the returned tuple. Returns ------- - tensor - The reward of all elements in the batch. + rewards + The reward or log-reward of all elements in the batch. + tensor (optional) + The proxy value of all elements in the batch. Included only if return_proxy + is True. """ + proxy_values = self(states) if log: - return self.proxy2logreward(self(states)) + rewards = self.proxy2logreward(proxy_values) + else: + rewards = self.proxy2reward(proxy_values) + if return_proxy: + return rewards, proxy_values else: - return self.proxy2reward(self(states)) + return rewards def proxy2reward(self, proxy_values: TensorType) -> TensorType: """ From ecf6b9d374fb89ab60f8ac8b13f7f08cc53c3684 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 12 Jul 2024 12:33:23 +0200 Subject: [PATCH 03/11] Buffer: docstring and fix --- gflownet/utils/buffer.py | 53 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/gflownet/utils/buffer.py b/gflownet/utils/buffer.py index 770b02f4f..e54c314fe 100644 --- a/gflownet/utils/buffer.py +++ b/gflownet/utils/buffer.py @@ -38,7 +38,7 @@ def __init__( self.replay_capacity = replay_capacity self.main = pd.DataFrame(columns=["state", "traj", "reward", "iter"]) self.replay = pd.DataFrame( - np.empty((self.replay_capacity, 5), dtype=object), + np.empty((self.replay_capacity, 4), dtype=object), columns=["state", "traj", "reward", "iter"], ) self.replay.reward = pd.to_numeric(self.replay.reward) @@ -127,6 +127,26 @@ def save_replay(self): ) def add(self, states, trajs, rewards, it, buffer="main", criterion="greater"): + """ + Adds a batch of states (with the trajectory actions and rewards) to the buffer. + + Note that the rewards may be log-rewards. + + Parameters + ---------- + states : list + A batch of terminating states. + trajs : list + The list of trajectory actions of each terminating state. + rewards : list + The reward or log-reward of each terminating state. + it : int + Iteration number. + buffer : str + Identifier of the buffer: main or replay + criterion : str + Identifier of the criterion. Currently, only greater is implemented. + """ if buffer == "main": self.main = pd.concat( [ @@ -146,8 +166,39 @@ def add(self, states, trajs, rewards, it, buffer="main", criterion="greater"): elif buffer == "replay" and self.replay_capacity > 0: if criterion == "greater": self.replay = self._add_greater(states, trajs, rewards, it) + else: + raise ValueError( + f"Unknown criterion identifier. Received {buffer}, expected greater" + ) + else: + raise ValueError( + f"Unknown buffer identifier. Received {buffer}, expected main or replay" + ) def _add_greater(self, states, trajs, rewards, it, allow_duplicate_states=False): + """ + Adds a batch of states (with the trajectory actions and rewards) to the buffer + if the state reward is larger than the minimum reward in the buffer and the + trajectory is not yet in the buffer. + + Note that the rewards may be log-rewards. The reward is only used to check the + inclusion criterion. Since the logarithm is a monotonic function, using the log + or natural rewards is equivalent for this purpose. + + Parameters + ---------- + states : list + A batch of terminating states. + trajs : list + The list of trajectory actions of each terminating state. + rewards : list + The reward or log-reward of each terminating state. + it : int + Iteration number. + allow_duplicate_states : bool + If True, terminating states already present in the buffer will be added + provided the trajectory is different and the reward criterion is satisfied. + """ for idx, (state, traj, reward) in enumerate(zip(states, trajs, rewards)): if not allow_duplicate_states: if isinstance(state, torch.Tensor): From 5c789facf5197b754ba79a82d12ed756f31917c7 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 12 Jul 2024 12:33:54 +0200 Subject: [PATCH 04/11] Batch: functionality to store and retrieve proxy values --- gflownet/utils/batch.py | 70 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 66 insertions(+), 4 deletions(-) diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index adbe869f2..8d756623e 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -105,6 +105,7 @@ def __init__( self.logrewards_available = False self.logrewards_parents_available = False self.logrewards_source_available = False + self.proxy_values_available = False def __len__(self): return self.size @@ -881,7 +882,7 @@ def get_rewards( force_recompute : bool If True, the rewards are recomputed even if they are available. do_non_terminating : bool - If True, compute the actual rewards of the non-terminating states. If + If True, return the actual rewards of the non-terminating states. If False, non-terminating states will be assigned reward 0. """ if self.rewards_available is False or force_recompute is True: @@ -891,6 +892,26 @@ def get_rewards( else: return self.rewards + def get_proxy_values( + self, + force_recompute: Optional[bool] = False, + do_non_terminating: Optional[bool] = False, + ) -> TensorType["n_states"]: + """ + Returns the proxy values of all states in the batch (including not done). + + Parameters + ---------- + force_recompute : bool + If True, the proxy values are recomputed even if they are available. + do_non_terminating : bool + If True, return the actual proxy values of the non-terminating states. If + False, non-terminating states will be assigned value inf. + """ + if self.proxy_values_available is False or force_recompute is True: + self._compute_rewards(do_non_terminating=do_non_terminating) + return self.proxy_values + def _compute_rewards( self, log: bool = False, do_non_terminating: Optional[bool] = False ): @@ -904,19 +925,27 @@ def _compute_rewards( If True, compute the logarithm of the rewards. do_non_terminating : bool If True, compute the rewards of the non-terminating states instead of - assigning reward 0. + assigning reward 0 and proxy value inf. """ if do_non_terminating: - rewards = self.proxy.rewards(self.states2proxy(), log) + rewards, proxy_values = self.proxy.rewards( + self.states2proxy(), log, return_proxy=True + ) else: rewards = self.proxy.get_min_reward(log) * torch.ones( len(self), dtype=self.float, device=self.device ) + proxy_values = torch.full_like(rewards, torch.inf) done = self.get_done() if len(done) > 0: states_proxy_done = self.get_terminating_states(proxy=True) - rewards[done] = self.proxy.rewards(states_proxy_done, log) + rewards[done], proxy_values[done] = self.proxy.rewards( + states_proxy_done, log, return_proxy=True + ) + + self.proxy_values = proxy_values + self.proxy_values_available = True if log: self.logrewards = rewards self.logrewards_available = True @@ -1125,6 +1154,39 @@ def get_terminating_rewards( else: return self.rewards[indices][done] + def get_terminating_proxy_values( + self, + sort_by: str = "insertion", + force_recompute: Optional[bool] = False, + ) -> TensorType["n_trajectories"]: + """ + Returns the proxy values of the terminating states in the batch, that is all + states with done = True. The returned proxy values may be sorted by order of + insertion (sort_by = "insert[ion]", default) or by trajectory index (sort_by = + "traj[ectory]". + + Parameters + ---------- + sort_by : str + Indicates how to sort the output: + - insert[ion]: sort by order of insertion (proxy values of trajectories + that reached the terminating state first come first) + - traj[ectory]: sort by trajectory index (the order in the ordered + dict self.trajectories) + force_recompute : bool + If True, the proxy_values are recomputed even if they are available. + """ + if sort_by == "insert" or sort_by == "insertion": + indices = np.arange(len(self)) + elif sort_by == "traj" or sort_by == "trajectory": + indices = np.argsort(self.traj_indices) + else: + raise ValueError("sort_by must be either insert[ion] or traj[ectory]") + if self.proxy_values_available is False or force_recompute is True: + self._compute_rewards(log, do_non_terminating=False) + done = self.get_done()[indices] + return self.proxy_values[indices][done] + def get_actions_trajectories(self) -> List[List[Tuple]]: """ Returns the actions corresponding to all trajectories in the batch, sorted by From 5f1c7a408e90ba8b085693024c0fdfd6238f6d48 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 12 Jul 2024 12:34:22 +0200 Subject: [PATCH 05/11] Batch: tests for new functionality about proxy values --- tests/gflownet/utils/test_batch.py | 190 ++++++++++++++++++++++++++++- 1 file changed, 184 insertions(+), 6 deletions(-) diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index 5c0ca8628..4ef944deb 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -340,6 +340,7 @@ def test__forward_sampling_multiple_envs_all_as_expected(env, proxy, batch, requ parents_all = [] parents_all_a = [] rewards = [] + proxy_values = [] traj_indices = [] state_indices = [] states_term_sorted = [None for _ in range(batch_size)] @@ -372,7 +373,11 @@ def test__forward_sampling_multiple_envs_all_as_expected(env, proxy, batch, requ parents_all.extend(env_parents) parents_all_a.extend(env_parents_a) if env.done: - rewards.append(proxy.rewards(env.state2proxy())[0]) + reward, proxy_value = proxy.rewards( + env.state2proxy(), return_proxy=True + ) + rewards.append(reward[0]) + proxy_values.append(proxy_value[0]) else: rewards.append( tfloat( @@ -381,6 +386,9 @@ def test__forward_sampling_multiple_envs_all_as_expected(env, proxy, batch, requ device=batch.device, ) ) + proxy_values.append( + tfloat(torch.inf, float_type=batch.float, device=batch.device) + ) traj_indices.append(env.id) state_indices.append(env.n_actions) if env.done: @@ -457,6 +465,15 @@ def test__forward_sampling_multiple_envs_all_as_expected(env, proxy, batch, requ tfloat(rewards, device=batch.device, float_type=batch.float), ) ), (rewards, rewards_batch) + # Check proxy values + proxy_values_batch = batch.get_proxy_values() + proxy_values = torch.stack(proxy_values) + assert torch.all( + torch.isclose( + proxy_values_batch, + tfloat(proxy_values, device=batch.device, float_type=batch.float), + ) + ), (proxy_values, proxy_values_batch) # Check terminating states (sorted by trajectory) states_term_batch = batch.get_terminating_states(sort_by="traj") states_term_policy_batch = batch.get_terminating_states(sort_by="traj", policy=True) @@ -503,6 +520,7 @@ def test__backward_sampling_multiple_envs_all_as_expected(env, proxy, batch, req parents_all = [] parents_all_a = [] rewards = [] + proxy_values = [] traj_indices = [] state_indices = [] states_term_sorted = [copy(x) for x in x_batch] @@ -521,11 +539,18 @@ def test__backward_sampling_multiple_envs_all_as_expected(env, proxy, batch, req if not env.continuous: env_parents, env_parents_a = env.get_parents() if env.done: - reward = proxy.rewards(env.state2proxy())[0] + reward, proxy_value = proxy.rewards( + env.state2proxy(), return_proxy=True + ) + reward = reward[0] + proxy_value = proxy_value[0] else: reward = tfloat( proxy.get_min_reward(), float_type=batch.float, device=batch.device ) + proxy_value = tfloat( + torch.inf, float_type=batch.float, device=batch.device + ) if env.done: states_term_sorted[env.id] = env.state # Sample random action @@ -546,6 +571,7 @@ def test__backward_sampling_multiple_envs_all_as_expected(env, proxy, batch, req parents_all.extend(env_parents) parents_all_a.extend(env_parents_a) rewards.append(reward) + proxy_values.append(proxy_value) traj_indices.append(env.id) state_indices.append(env.n_actions) # Add all envs, actions and valids to batch @@ -620,6 +646,15 @@ def test__backward_sampling_multiple_envs_all_as_expected(env, proxy, batch, req tfloat(rewards, device=batch.device, float_type=batch.float), ) ), (rewards, rewards_batch) + # Check proxy values + proxy_values_batch = batch.get_proxy_values() + proxy_values = torch.stack(proxy_values) + assert torch.all( + torch.isclose( + proxy_values_batch, + tfloat(proxy_values, device=batch.device, float_type=batch.float), + ) + ), (proxy_values, proxy_values_batch) # Check terminating states (sorted by trajectory) states_term_batch = batch.get_terminating_states(sort_by="traj") states_term_policy_batch = batch.get_terminating_states(sort_by="traj", policy=True) @@ -657,6 +692,7 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques parents_all = [] parents_all_a = [] rewards = [] + proxy_values = [] traj_indices = [] state_indices = [] states_term_sorted = [] @@ -700,7 +736,11 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques parents_all.extend(env_parents) parents_all_a.extend(env_parents_a) if env.done: - rewards.append(proxy.rewards(env.state2proxy())[0]) + reward, proxy_value = proxy.rewards( + env.state2proxy(), return_proxy=True + ) + rewards.append(reward[0]) + proxy_values.append(proxy_value[0]) else: rewards.append( tfloat( @@ -709,6 +749,9 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques device=batch.device, ) ) + proxy_values.append( + tfloat(torch.inf, float_type=batch.float, device=batch.device) + ) traj_indices.append(env.id) state_indices.append(env.n_actions) if env.done: @@ -746,11 +789,18 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques if not env.continuous: env_parents, env_parents_a = env.get_parents() if env.done: - reward = proxy.rewards(env.state2proxy())[0] + reward, proxy_value = proxy.rewards( + env.state2proxy(), return_proxy=True + ) + reward = reward[0] + proxy_value = proxy_value[0] else: reward = tfloat( proxy.get_min_reward(), float_type=batch.float, device=batch.device ) + proxy_value = tfloat( + torch.inf, float_type=batch.float, device=batch.device + ) if env.done: states_term_sorted[env.id] = env.state # Sample random action @@ -771,6 +821,7 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques parents_all.extend(env_parents) parents_all_a.extend(env_parents_a) rewards.append(reward) + proxy_values.append(proxy_value) traj_indices.append(env.id) state_indices.append(env.n_actions) # Add all envs, actions and valids to batch @@ -847,6 +898,15 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques tfloat(rewards, device=batch.device, float_type=batch.float), ) ), (rewards, rewards_batch) + # Check proxy values + proxy_values_batch = batch.get_proxy_values() + proxy_values = torch.stack(proxy_values) + assert torch.all( + torch.isclose( + proxy_values_batch, + tfloat(proxy_values, device=batch.device, float_type=batch.float), + ) + ), (proxy_values, proxy_values_batch) # Check terminating states (sorted by trajectory) states_term_batch = batch.get_terminating_states(sort_by="traj") states_term_policy_batch = batch.get_terminating_states(sort_by="traj", policy=True) @@ -884,6 +944,7 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): parents_all = [] parents_all_a = [] rewards = [] + proxy_values = [] traj_indices = [] state_indices = [] states_term_sorted = [] @@ -927,7 +988,11 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): parents_all.extend(env_parents) parents_all_a.extend(env_parents_a) if env.done: - rewards.append(proxy.rewards(env.state2proxy())[0]) + reward, proxy_value = proxy.rewards( + env.state2proxy(), return_proxy=True + ) + rewards.append(reward[0]) + proxy_values.append(proxy_value[0]) else: rewards.append( tfloat( @@ -936,6 +1001,11 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): device=batch_fw.device, ) ) + proxy_values.append( + tfloat( + torch.inf, float_type=batch_fw.float, device=batch_fw.device + ) + ) traj_indices.append(env.id) state_indices.append(env.n_actions) if env.done: @@ -973,13 +1043,20 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): if not env.continuous: env_parents, env_parents_a = env.get_parents() if env.done: - reward = proxy.rewards(env.state2proxy())[0] + reward, proxy_value = proxy.rewards( + env.state2proxy(), return_proxy=True + ) + reward = reward[0] + proxy_value = proxy_value[0] else: reward = tfloat( proxy.get_min_reward(), float_type=batch_bw.float, device=batch_bw.device, ) + proxy_value = tfloat( + torch.inf, float_type=batch_bw.float, device=batch_bw.device + ) if env.done: states_term_sorted[env.id + batch_size_forward] = env.state # Sample random action @@ -1000,6 +1077,7 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): parents_all.extend(env_parents) parents_all_a.extend(env_parents_a) rewards.append(reward) + proxy_values.append(proxy_value) traj_indices.append(env.id + batch_size_forward) state_indices.append(env.n_actions) # Add all envs, actions and valids to batch @@ -1081,6 +1159,15 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): tfloat(rewards, device=batch.device, float_type=batch.float), ) ), (rewards, rewards_batch) + # Check proxy values + proxy_values_batch = batch.get_proxy_values() + proxy_values = torch.stack(proxy_values) + assert torch.all( + torch.isclose( + proxy_values_batch, + tfloat(proxy_values, device=batch.device, float_type=batch.float), + ) + ), (proxy_values, proxy_values_batch) # Check terminating states (sorted by trajectory) states_term_batch = batch.get_terminating_states(sort_by="traj") states_term_policy_batch = batch.get_terminating_states(sort_by="traj", policy=True) @@ -1319,6 +1406,40 @@ def test__get_rewards__single_env_returns_expected_non_terminating( ), (rewards, rewards_batch) +@pytest.mark.repeat(N_REPETITIONS) +@pytest.mark.parametrize( + "env, proxy", + [("grid2d", "corners"), ("tetris6x4", "tetris_score"), ("ctorus2d5l", "corners")], +) +def test__get_proxy_values__single_env_returns_expected_non_terminating( + env, proxy, batch, request +): + env = request.getfixturevalue(env) + proxy = request.getfixturevalue(proxy) + proxy.setup(env) + env = env.reset() + batch.set_env(env) + batch.set_proxy(proxy) + + proxy_values = [] + while not env.done: + parent = env.state + # Sample random action + _, action, valid = env.step_random() + # Add to batch + batch.add_to_batch([env], [action], [valid]) + if valid: + proxy_values.append(proxy(env.state2proxy())[0]) + proxy_values_batch = batch.get_proxy_values(do_non_terminating=True) + proxy_values = torch.stack(proxy_values) + assert torch.all( + torch.isclose( + proxy_values_batch, + tfloat(proxy_values, device=batch.device, float_type=batch.float), + ) + ), (proxy_values, proxy_values_batch) + + @pytest.mark.repeat(N_REPETITIONS) @pytest.mark.parametrize( "env, proxy", @@ -1410,6 +1531,63 @@ def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminating( ), rewards_batch +@pytest.mark.repeat(N_REPETITIONS) +@pytest.mark.parametrize( + "env, proxy", + [("grid2d", "corners"), ("tetris6x4", "tetris_score_norm")], +) +def test__get_proxy_values_multiple_env_returns_expected_non_zero_non_terminating( + env, proxy, batch, request +): + batch_size = BATCH_SIZE + env_ref = request.getfixturevalue(env) + proxy = request.getfixturevalue(proxy) + proxy.setup(env_ref) + env_ref = env_ref.reset() + + batch.set_env(env_ref) + batch.set_proxy(proxy) + + # Make list of envs + envs = [] + for idx in range(batch_size): + env_aux = env_ref.copy().reset(idx) + envs.append(env_aux) + + proxy_values = [] + + # Iterate until envs is empty + while envs: + actions_iter = [] + valids_iter = [] + # Make step env by env (different to GFN Agent) to have full control + for env in envs: + parent = copy(env.state) + # Sample random action + state, action, valid = env.step_random() + if valid: + # Add to iter lists + actions_iter.append(action) + valids_iter.append(valid) + proxy_values.append(proxy(env.state2proxy())[0]) + # Add all envs, actions and valids to batch + batch.add_to_batch(envs, actions_iter, valids_iter) + # Remove done envs + envs = [env for env in envs if not env.done] + + proxy_values_batch = batch.get_proxy_values(do_non_terminating=True) + proxy_values = torch.stack(proxy_values) + assert torch.all( + torch.isclose( + proxy_values_batch, + tfloat(proxy_values, device=batch.device, float_type=batch.float), + ) + ), (proxy_values, proxy_values_batch) + assert ~torch.any( + torch.isclose(proxy_values_batch, torch.zeros_like(proxy_values_batch)) + ), proxy_values_batch + + @pytest.mark.repeat(N_REPETITIONS) @pytest.mark.parametrize( "env, proxy", From 902a7b5f947b516c6854d88c5a4230d551fd970e Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 12 Jul 2024 12:36:12 +0200 Subject: [PATCH 06/11] GFlowNet Agent: adapt train() to take proxy values and rewards from the Batch to avoid recomputing --- gflownet/gflownet.py | 49 ++++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 81c881076..b29de1a47 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -1168,35 +1168,34 @@ def train(self): all_losses.append([i.item() for i in losses]) # Buffer t0_buffer = time.time() - # TODO: the current implementation recomputes the proxy values of the - # terminating states in order to store the proxy values in the Buffer. - # Depending on the computational cost of the proxy, this may be very - # inneficient. For example, proxy.rewards() could return the proxy values, - # which could be stored in the Batch. - if it == 0: - print( - "IMPORTANT: The current implementation recomputes the proxy " - "values of the terminating states in order to store the proxy " - "values in the Buffer. Depending on the computational cost of " - "the proxy, this may be very inneficient." - ) states_term = batch.get_terminating_states(sort_by="trajectory") - states_proxy_term = batch.get_terminating_states( - proxy=True, sort_by="trajectory" - ) - proxy_vals = self.proxy(states_proxy_term) - rewards = self.proxy.proxy2reward(proxy_vals) - rewards = rewards.tolist() + proxy_vals = batch.get_terminating_proxy_values(sort_by="trajectory") proxy_vals = proxy_vals.tolist() + # The batch will typically have the log-rewards available, since they are + # used to compute the losses. In order to avoid recalculating the proxy + # values, the natural rewards are computed by taking the exponential of the + # log-rewards. In case the rewards are available in the batch but not the + # log-rewards, the latter are computed by taking the log of the rewards. + # Numerical issues are not critical in this case, since the derived values + # are only used for reporting purposes. + if batch.rewards_available: + rewards = batch.get_terminating_rewards(sort_by="trajectory") + if batch.logrewards_available: + logrewards = batch.get_terminating_rewards( + sort_by="trajectory", log=True + ) + if not batch.rewards_available: + assert batch.logrewards_available + rewards = torch.exp(logrewards) + if batch.logrewards_available: + assert batch.rewards_available + logrewards = torch.log(rewards) + rewards = rewards.tolist() + logrewards = logrewards.tolist() actions_trajectories = batch.get_actions_trajectories() - self.buffer.add(states_term, actions_trajectories, rewards, proxy_vals, it) + self.buffer.add(states_term, actions_trajectories, logrewards, it) self.buffer.add( - states_term, - actions_trajectories, - rewards, - proxy_vals, - it, - buffer="replay", + states_term, actions_trajectories, logrewards, it, buffer="replay" ) t1_buffer = time.time() times.update({"buffer": t1_buffer - t0_buffer}) From 6fe0e8e2d7c58f6ece264c5dff86fe6a8bb4bd4d Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 12 Jul 2024 12:36:45 +0200 Subject: [PATCH 07/11] Log stats of logrewards too --- gflownet/gflownet.py | 1 + gflownet/utils/logger.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index b29de1a47..1d9e94a6d 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -1214,6 +1214,7 @@ def train(self): self.logger.log_train( losses=losses, rewards=rewards, + logrewards=logrewards, proxy_vals=proxy_vals, states_term=states_term, batch_size=len(batch), diff --git a/gflownet/utils/logger.py b/gflownet/utils/logger.py index 9fe982b30..c5129bc2c 100644 --- a/gflownet/utils/logger.py +++ b/gflownet/utils/logger.py @@ -193,6 +193,7 @@ def log_train( self, losses, rewards: list, + logrewards: list, proxy_vals: array, states_term: list, batch_size: int, @@ -212,6 +213,8 @@ def log_train( [ "mean_reward", "max_reward", + "mean_logreward", + "max_logreward", "mean_proxy", "min_proxy", "max_proxy", @@ -225,6 +228,8 @@ def log_train( [ np.mean(rewards), np.max(rewards), + np.mean(logrewards), + np.max(logrewards), np.mean(proxy_vals), np.min(proxy_vals), np.max(proxy_vals), From 19b1f27d87a92b326109c7706ffd78e055c52448 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 12 Jul 2024 12:41:49 +0200 Subject: [PATCH 08/11] Fix bugs --- gflownet/gflownet.py | 2 +- gflownet/utils/buffer.py | 16 +++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 1d9e94a6d..d464e6c7d 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -1187,7 +1187,7 @@ def train(self): if not batch.rewards_available: assert batch.logrewards_available rewards = torch.exp(logrewards) - if batch.logrewards_available: + if not batch.logrewards_available: assert batch.rewards_available logrewards = torch.log(rewards) rewards = rewards.tolist() diff --git a/gflownet/utils/buffer.py b/gflownet/utils/buffer.py index e54c314fe..3b12ee4d8 100644 --- a/gflownet/utils/buffer.py +++ b/gflownet/utils/buffer.py @@ -163,13 +163,15 @@ def add(self, states, trajs, rewards, it, buffer="main", criterion="greater"): axis=0, join="outer", ) - elif buffer == "replay" and self.replay_capacity > 0: - if criterion == "greater": - self.replay = self._add_greater(states, trajs, rewards, it) - else: - raise ValueError( - f"Unknown criterion identifier. Received {buffer}, expected greater" - ) + elif buffer == "replay": + if self.replay_capacity > 0: + if criterion == "greater": + self.replay = self._add_greater(states, trajs, rewards, it) + else: + raise ValueError( + f"Unknown criterion identifier. Received {buffer}, " + "expected greater" + ) else: raise ValueError( f"Unknown buffer identifier. Received {buffer}, expected main or replay" From 95bd066cde9a492dfb83f10cfd561ef4b5ff2014 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 12 Jul 2024 13:23:41 +0200 Subject: [PATCH 09/11] Fix bug in checking whether (log)rewards are available --- gflownet/gflownet.py | 12 ++++++------ gflownet/utils/batch.py | 40 ++++++++++++++++++++++++++++++---------- 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index d464e6c7d..e7adcc6d2 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -1178,17 +1178,17 @@ def train(self): # log-rewards, the latter are computed by taking the log of the rewards. # Numerical issues are not critical in this case, since the derived values # are only used for reporting purposes. - if batch.rewards_available: + if batch.rewards_available(log=False): rewards = batch.get_terminating_rewards(sort_by="trajectory") - if batch.logrewards_available: + if batch.rewards_available(log=True): logrewards = batch.get_terminating_rewards( sort_by="trajectory", log=True ) - if not batch.rewards_available: - assert batch.logrewards_available + if not batch.rewards_available(log=False): + assert batch.rewards_available(log=True) rewards = torch.exp(logrewards) - if not batch.logrewards_available: - assert batch.rewards_available + if not batch.rewards_available(log=True): + assert batch.rewards_available(log=False) logrewards = torch.log(rewards) rewards = rewards.tolist() logrewards = logrewards.tolist() diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index 8d756623e..bd319ba7b 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -99,10 +99,10 @@ def __init__( self.parents_all_available = False self.masks_forward_available = False self.masks_backward_available = False - self.rewards_available = False + self._rewards_available = False self.rewards_parents_available = False self.rewards_source_available = False - self.logrewards_available = False + self._logrewards_available = False self.logrewards_parents_available = False self.logrewards_source_available = False self.proxy_values_available = False @@ -139,6 +139,26 @@ def traj_idx_action_idx_to_batch_idx( def idx2state_idx(self, idx: int): return self.trajectories[self.traj_indices[idx]].index(idx) + def rewards_available(self, log: bool = False) -> bool: + """ + Returns True if the (log)rewards are available. + + Parameters + ---------- + log : bool + If True, check self._logrewards_available. Otherwise (default), check + self._rewards_available. + + Returns + ------- + bool + True if the (log)rewards are available, False otherwise. + """ + if log: + return self._logrewards_available + else: + return self._rewards_available + def set_env(self, env: GFlowNetEnv): """ Sets the generic environment passed as an argument and initializes the @@ -256,8 +276,8 @@ def add_to_batch( self.masks_backward_available = False self.parents_policy_available = False self.parents_all_available = False - self.rewards_available = False - self.logrewards_available = False + self._rewards_available = False + self._logrewards_available = False def get_n_trajectories(self) -> int: """ @@ -885,7 +905,7 @@ def get_rewards( If True, return the actual rewards of the non-terminating states. If False, non-terminating states will be assigned reward 0. """ - if self.rewards_available is False or force_recompute is True: + if self.rewards_available(log) is False or force_recompute is True: self._compute_rewards(log, do_non_terminating) if log: return self.logrewards @@ -948,10 +968,10 @@ def _compute_rewards( self.proxy_values_available = True if log: self.logrewards = rewards - self.logrewards_available = True + self._logrewards_available = True else: self.rewards = rewards - self.rewards_available = True + self._rewards_available = True def get_rewards_parents(self, log: bool = False) -> TensorType["n_states"]: """ @@ -1146,7 +1166,7 @@ def get_terminating_rewards( indices = np.argsort(self.traj_indices) else: raise ValueError("sort_by must be either insert[ion] or traj[ectory]") - if self.rewards_available is False or force_recompute is True: + if self.rewards_available(log) is False or force_recompute is True: self._compute_rewards(log, do_non_terminating=False) done = self.get_done()[indices] if log: @@ -1305,11 +1325,11 @@ def merge(self, batches: List): self.parents_all = extend(self.parents_all, batch.parents_all) else: self.parents_all = None - if self.rewards_available and batch.rewards_available: + if self._rewards_available and batch._rewards_available: self.rewards = extend(self.rewards, batch.rewards) else: self.rewards = None - if self.logrewards_available and batch.logrewards_available: + if self._logrewards_available and batch._logrewards_available: self.logrewards = extend(self.logrewards, batch.logrewards) else: self.logrewards = None From 0c3d8aac4d2dc5737de789f181be45430f7817b9 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 12 Jul 2024 13:49:39 +0200 Subject: [PATCH 10/11] Rename variables and create additional methods for consistency --- gflownet/utils/batch.py | 128 ++++++++++++++++++++++++++-------------- 1 file changed, 84 insertions(+), 44 deletions(-) diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index bd319ba7b..83fc83aa4 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -94,18 +94,18 @@ def __init__( self.states_policy = None self.parents_policy = None # Flags for available items - self.parents_available = False - self.parents_policy_available = False - self.parents_all_available = False - self.masks_forward_available = False - self.masks_backward_available = False + self._parents_available = False + self._parents_policy_available = False + self._parents_all_available = False + self._masks_forward_available = False + self._masks_backward_available = False self._rewards_available = False - self.rewards_parents_available = False - self.rewards_source_available = False + self._rewards_parents_available = False + self._rewards_source_available = False self._logrewards_available = False - self.logrewards_parents_available = False - self.logrewards_source_available = False - self.proxy_values_available = False + self._logrewards_parents_available = False + self._logrewards_source_available = False + self._proxy_values_available = False def __len__(self): return self.size @@ -159,6 +159,46 @@ def rewards_available(self, log: bool = False) -> bool: else: return self._rewards_available + def rewards_parents_available(self, log: bool = False) -> bool: + """ + Returns True if the (log)rewards of the parents are available. + + Parameters + ---------- + log : bool + If True, check self._logrewards_parents_available. Otherwise (default), + check self._rewards_parents_available. + + Returns + ------- + bool + True if the (log)rewards of the parents are available, False otherwise. + """ + if log: + return self._logrewards_parents_available + else: + return self._rewards_parents_available + + def rewards_source_available(self, log: bool = False) -> bool: + """ + Returns True if the (log)rewards of the source are available. + + Parameters + ---------- + log : bool + If True, check self._logrewards_source_available. Otherwise (default), + check self._rewards_source_available. + + Returns + ------- + bool + True if the (log)rewards of the source are available, False otherwise. + """ + if log: + return self._logrewards_source_available + else: + return self._rewards_source_available + def set_env(self, env: GFlowNetEnv): """ Sets the generic environment passed as an argument and initializes the @@ -272,10 +312,10 @@ def add_to_batch( # Increment size of batch self.size += 1 # Other variables are not available after new items were added to the batch - self.masks_forward_available = False - self.masks_backward_available = False - self.parents_policy_available = False - self.parents_all_available = False + self._masks_forward_available = False + self._masks_backward_available = False + self._parents_policy_available = False + self._parents_all_available = False self._rewards_available = False self._logrewards_available = False @@ -546,10 +586,10 @@ def get_parents( self.parents or self.parents_policy : torch.tensor The parent of all states in the batch. """ - if self.parents_available is False or force_recompute is True: + if self._parents_available is False or force_recompute is True: self._compute_parents() if policy: - if self.parents_policy_available is False or force_recompute is True: + if self._parents_policy_available is False or force_recompute is True: self._compute_parents_policy() return self.parents_policy else: @@ -568,7 +608,7 @@ def get_parents_indices(self): self.parents_indices The indices in self.states of the parents of self.states. """ - if self.parents_available is False: + if self._parents_available is False: self._compute_parents() return self.parents_indices @@ -589,7 +629,7 @@ def _compute_parents(self): parent is not present in self.states (i.e. it is source), the corresponding index is -1. - self.parents_available is set to True. + self._parents_available is set to True. """ self.parents = [] self.parents_indices = [] @@ -621,7 +661,7 @@ def _compute_parents(self): [self.parents_indices[indices_dict[idx]] for idx in range(len(self))], device=self.device, ) - self.parents_available = True + self._parents_available = True # TODO: consider converting directly from self.parents def _compute_parents_policy(self): @@ -636,7 +676,7 @@ def _compute_parents_policy(self): Shape: [n_states, state_policy_dims] self.parents_policy is stored as a torch tensor and - self.parents_policy_available is set to True. + self._parents_policy_available is set to True. """ self.states_policy = self.get_states(policy=True) self.parents_policy = torch.zeros_like(self.states_policy) @@ -652,7 +692,7 @@ def _compute_parents_policy(self): self.parents_policy[batch_indices[1:]] = self.states_policy[ batch_indices[:-1] ] - self.parents_policy_available = True + self._parents_policy_available = True def get_parents_all( self, policy: bool = False, force_recompute: bool = False @@ -664,7 +704,7 @@ def get_parents_all( """ Returns the whole set of parents, their corresponding actions and indices of all states in the batch. If the parents are not available - (self.parents_all_available is False) or if force_recompute is True, then + (self._parents_all_available is False) or if force_recompute is True, then self._compute_parents_all() is called to compute the required components. The parents are returned in "policy format" if policy is True, otherwise they @@ -696,7 +736,7 @@ def get_parents_all( """ if self.continuous: raise Exception("get_parents() is ill-defined for continuous environments!") - if self.parents_all_available is False or force_recompute is True: + if self._parents_all_available is False or force_recompute is True: self._compute_parents_all() if policy: return ( @@ -726,7 +766,7 @@ def _compute_parents_all(self): Shape: [n_parents, state_policy_dims] All the above components are stored as torch tensors and - self.parents_all_available is set to True. + self._parents_all_available is set to True. """ # Iterate over the trajectories to obtain all parents self.parents_all = [] @@ -763,7 +803,7 @@ def _compute_parents_all(self): device=self.device, ) self.parents_all_policy = torch.cat(self.parents_all_policy) - self.parents_all_available = True + self._parents_all_available = True # TODO: opportunity to improve efficiency by caching. def get_masks_forward( @@ -791,7 +831,7 @@ def get_masks_forward( self.masks_invalid_actions_forward : torch.tensor The forward mask of all states in the batch. """ - if self.masks_forward_available is False or force_recompute is True: + if self._masks_forward_available is False or force_recompute is True: self._compute_masks_forward() # Make tensor masks_invalid_actions_forward = tbool( @@ -826,8 +866,8 @@ def get_masks_forward( def _compute_masks_forward(self): """ Computes the forward mask of invalid actions of all states in the batch, by - calling env.get_mask_invalid_actions_forward(). self.masks_forward_available is - set to True. + calling env.get_mask_invalid_actions_forward(). self._masks_forward_available + is set to True. """ # Iterate over the trajectories to compute all forward masks for idx, mask in enumerate(self.masks_invalid_actions_forward): @@ -839,7 +879,7 @@ def _compute_masks_forward(self): self.masks_invalid_actions_forward[idx] = self.envs[ traj_idx ].get_mask_invalid_actions_forward(state, done) - self.masks_forward_available = True + self._masks_forward_available = True # TODO: opportunity to improve efficiency by caching. Note that # env.get_masks_invalid_actions_backward() may be expensive because it calls @@ -863,14 +903,14 @@ def get_masks_backward( self.masks_invalid_actions_backward : torch.tensor The backward mask of all states in the batch. """ - if self.masks_backward_available is False or force_recompute is True: + if self._masks_backward_available is False or force_recompute is True: self._compute_masks_backward() return tbool(self.masks_invalid_actions_backward, device=self.device) def _compute_masks_backward(self): """ Computes the backward mask of invalid actions of all states in the batch, by - calling env.get_mask_invalid_actions_backward(). self.masks_backward_available + calling env.get_mask_invalid_actions_backward(). self._masks_backward_available is set to True. """ # Iterate over the trajectories to compute all backward masks @@ -883,7 +923,7 @@ def _compute_masks_backward(self): self.masks_invalid_actions_backward[idx] = self.envs[ traj_idx ].get_mask_invalid_actions_backward(state, done) - self.masks_backward_available = True + self._masks_backward_available = True # TODO: better handling of availability of rewards, logrewards, proxy_values. def get_rewards( @@ -928,7 +968,7 @@ def get_proxy_values( If True, return the actual proxy values of the non-terminating states. If False, non-terminating states will be assigned value inf. """ - if self.proxy_values_available is False or force_recompute is True: + if self._proxy_values_available is False or force_recompute is True: self._compute_rewards(do_non_terminating=do_non_terminating) return self.proxy_values @@ -965,7 +1005,7 @@ def _compute_rewards( ) self.proxy_values = proxy_values - self.proxy_values_available = True + self._proxy_values_available = True if log: self.logrewards = rewards self._logrewards_available = True @@ -987,7 +1027,7 @@ def get_rewards_parents(self, log: bool = False) -> TensorType["n_states"]: self.rewards_parents or self.logrewards_parents A tensor containing the rewards of the parents of self.states. """ - if not self.rewards_parents_available: + if not self.rewards_parents_available(log): self._compute_rewards_parents(log) if log: return self.logrewards_parents @@ -1019,10 +1059,10 @@ def _compute_rewards_parents(self, log: bool = False): rewards_parents[parent_is_source] = rewards_source[parent_is_source] if log: self.logrewards_parents = rewards_parents - self.logrewards_parents_available = True + self._logrewards_parents_available = True else: self.rewards_parents = rewards_parents - self.rewards_parents_available = True + self._rewards_parents_available = True def get_rewards_source(self, log: bool = False) -> TensorType["n_states"]: """ @@ -1038,7 +1078,7 @@ def get_rewards_source(self, log: bool = False) -> TensorType["n_states"]: self.rewards_source or self.logrewards_source A tensor containing the rewards the source states. """ - if not self.rewards_source_available: + if not self.rewards_source_available(log): self._compute_rewards_source(log) if log: return self.logrewards_source @@ -1066,10 +1106,10 @@ def _compute_rewards_source(self, log: bool = False): raise NotImplementedError if log: self.logrewards_source = rewards_source - self.logrewards_source_available = True + self._logrewards_source_available = True else: self.rewards_source = rewards_source - self.rewards_source_available = True + self._rewards_source_available = True def get_terminating_states( self, @@ -1202,7 +1242,7 @@ def get_terminating_proxy_values( indices = np.argsort(self.traj_indices) else: raise ValueError("sort_by must be either insert[ion] or traj[ectory]") - if self.proxy_values_available is False or force_recompute is True: + if self._proxy_values_available is False or force_recompute is True: self._compute_rewards(log, do_non_terminating=False) done = self.get_done()[indices] return self.proxy_values[indices][done] @@ -1313,15 +1353,15 @@ def merge(self, batches: List): self.states_policy = extend(self.states_policy, batch.states_policy) else: self.states_policy = None - if self.parents_available and batch.parents_available: + if self._parents_available and batch._parents_available: self.parents = extend(self.parents, batch.parents) else: self.parents = None - if self.parents_policy_available and batch.parents_policy_available: + if self._parents_policy_available and batch._parents_policy_available: self.parents_policy = extend(self.parents_policy, batch.parents_policy) else: self.parents_policy = None - if self.parents_all_available and batch.parents_all_available: + if self._parents_all_available and batch._parents_all_available: self.parents_all = extend(self.parents_all, batch.parents_all) else: self.parents_all = None From 7cd223fa4db330c39746551eecd2eb48f4c38150 Mon Sep 17 00:00:00 2001 From: Alex Date: Sun, 14 Jul 2024 08:52:17 +0200 Subject: [PATCH 11/11] Small update in docstring --- gflownet/proxy/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index efe85e9e9..2c9b29a58 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -129,9 +129,9 @@ def rewards( Returns ------- - rewards + rewards : tensor The reward or log-reward of all elements in the batch. - tensor (optional) + proxy_values : tensor (optional) The proxy value of all elements in the batch. Included only if return_proxy is True. """