diff --git a/src/gfn/actions.py b/src/gfn/actions.py index d2e9b3b0..c80eb2d3 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -3,7 +3,7 @@ import enum from abc import ABC from math import prod -from typing import ClassVar, Optional, Sequence +from typing import ClassVar, Sequence import torch from tensordict import TensorDict @@ -201,11 +201,14 @@ def __init__(self, tensor: TensorDict): assert torch.all(tensor["action_type"] != GraphActionType.ADD_EDGE) edge_index = torch.zeros((*self.batch_shape, 2), dtype=torch.long) - self.tensor = TensorDict({ - "action_type": tensor["action_type"], - "features": features, - "edge_index": edge_index, - }, batch_size=self.batch_shape) + self.tensor = TensorDict( + { + "action_type": tensor["action_type"], + "features": features, + "edge_index": edge_index, + }, + batch_size=self.batch_shape, + ) def __repr__(self): return f"""GraphAction object with {self.batch_shape} actions.""" @@ -223,7 +226,6 @@ def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActio """Get particular actions of the batch.""" return GraphActions(self.tensor[index]) - def __setitem__( self, index: int | Sequence[int] | Sequence[bool], action: GraphActions ) -> None: @@ -239,9 +241,14 @@ def compare(self, other: GraphActions) -> torch.Tensor: Returns: boolean tensor of shape batch_shape indicating whether the actions are equal. """ compare = torch.all(self.tensor == other.tensor, dim=-1) - return compare["action_type"] & \ - (compare["action_type"] == GraphActionType.EXIT | compare["features"]) & \ - (compare["action_type"] != GraphActionType.ADD_EDGE | compare["edge_index"]) + return ( + compare["action_type"] + & (compare["action_type"] == GraphActionType.EXIT | compare["features"]) + & ( + compare["action_type"] + != GraphActionType.ADD_EDGE | compare["edge_index"] + ) + ) @property def is_exit(self) -> torch.Tensor: @@ -257,25 +264,28 @@ def action_type(self) -> torch.Tensor: def features(self) -> torch.Tensor: """Returns the features tensor.""" return self.tensor["features"] - + @property def edge_index(self) -> torch.Tensor: """Returns the edge index tensor.""" return self.tensor["edge_index"] @classmethod - def make_dummy_actions( - cls, batch_shape: tuple[int] - ) -> GraphActions: + def make_dummy_actions(cls, batch_shape: tuple[int]) -> GraphActions: """Creates an Actions object of dummy actions with the given batch shape.""" return cls( - TensorDict({ - "action_type": torch.full(batch_shape, fill_value=GraphActionType.EXIT), - # "features": torch.zeros((*batch_shape, 0, cls.nodes_features_dim)), - # "edge_index": torch.zeros((2, *batch_shape, 0)), - }, batch_size=batch_shape) + TensorDict( + { + "action_type": torch.full( + batch_shape, fill_value=GraphActionType.EXIT + ), + # "features": torch.zeros((*batch_shape, 0, cls.nodes_features_dim)), + # "edge_index": torch.zeros((2, *batch_shape, 0)), + }, + batch_size=batch_shape, + ) ) - + @classmethod def stack(cls, actions_list: list[GraphActions]) -> GraphActions: """Stacks a list of GraphActions objects into a single GraphActions object.""" @@ -283,4 +293,3 @@ def stack(cls, actions_list: list[GraphActions]) -> GraphActions: [actions.tensor for actions in actions_list], dim=0 ) return cls(actions_tensor) - diff --git a/src/gfn/env.py b/src/gfn/env.py index 38884c97..28734926 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch from tensordict import TensorDict @@ -219,7 +219,7 @@ def reset( batch_shape = (1,) if isinstance(batch_shape, int): batch_shape = (batch_shape,) - + return self.States.from_batch_shape( batch_shape=batch_shape, random=random, sink=sink ) @@ -266,7 +266,7 @@ def _step( not_done_actions = actions[~new_sink_states_idx] new_not_done_states_tensor = self.step(not_done_states, not_done_actions) - + if not isinstance(new_not_done_states_tensor, (torch.Tensor, TensorDict)): raise Exception( "User implemented env.step function *must* return a torch.Tensor!" diff --git a/src/gfn/gym/__init__.py b/src/gfn/gym/__init__.py index 20490566..ebec6f20 100644 --- a/src/gfn/gym/__init__.py +++ b/src/gfn/gym/__init__.py @@ -1,4 +1,4 @@ from gfn.gym.box import Box from gfn.gym.discrete_ebm import DiscreteEBM +from gfn.gym.graph_building import GraphBuilding from gfn.gym.hypergrid import HyperGrid -from gfn.gym.graph_building import GraphBuilding \ No newline at end of file diff --git a/src/gfn/gym/graph_building.py b/src/gfn/gym/graph_building.py index 9a15d468..415c4ec1 100644 --- a/src/gfn/gym/graph_building.py +++ b/src/gfn/gym/graph_building.py @@ -2,8 +2,8 @@ from typing import Callable, Literal, Tuple import torch -from torch_geometric.nn import GCNConv from tensordict import TensorDict +from torch_geometric.nn import GCNConv from gfn.actions import GraphActions, GraphActionType from gfn.env import GraphEnv, NonValidActionsError @@ -17,16 +17,24 @@ def __init__( state_evaluator: Callable[[GraphStates], torch.Tensor] | None = None, device_str: Literal["cpu", "cuda"] = "cpu", ): - s0 = TensorDict({ - "node_feature": torch.zeros((0, feature_dim), dtype=torch.float32), - "edge_feature": torch.zeros((0, feature_dim), dtype=torch.float32), - "edge_index": torch.zeros((0, 2), dtype=torch.long), - }, device=device_str) - sf = TensorDict({ - "node_feature": torch.ones((1, feature_dim), dtype=torch.float32) * float("inf"), - "edge_feature": torch.ones((1, feature_dim), dtype=torch.float32) * float("inf"), - "edge_index": torch.zeros((0, 2), dtype=torch.long), - }, device=device_str) + s0 = TensorDict( + { + "node_feature": torch.zeros((0, feature_dim), dtype=torch.float32), + "edge_feature": torch.zeros((0, feature_dim), dtype=torch.float32), + "edge_index": torch.zeros((0, 2), dtype=torch.long), + }, + device=device_str, + ) + sf = TensorDict( + { + "node_feature": torch.ones((1, feature_dim), dtype=torch.float32) + * float("inf"), + "edge_feature": torch.ones((1, feature_dim), dtype=torch.float32) + * float("inf"), + "edge_index": torch.zeros((0, 2), dtype=torch.long), + }, + device=device_str, + ) if state_evaluator is None: state_evaluator = GCNConvEvaluator(feature_dim) @@ -59,15 +67,22 @@ def step(self, states: GraphStates, actions: GraphActions) -> TensorDict: return self.States.make_sink_states_tensor(states.batch_shape) if action_type == GraphActionType.ADD_NODE: - batch_indices = torch.arange(len(states))[actions.action_type == GraphActionType.ADD_NODE] + batch_indices = torch.arange(len(states))[ + actions.action_type == GraphActionType.ADD_NODE + ] state_tensor = self._add_node(state_tensor, batch_indices, actions.features) if action_type == GraphActionType.ADD_EDGE: - state_tensor["edge_feature"] = torch.cat([state_tensor["edge_feature"], actions.features], dim=0) - state_tensor["edge_index"] = torch.cat([ - state_tensor["edge_index"], - actions.edge_index + state_tensor["batch_ptr"][:-1][:, None] - ], dim=0) + state_tensor["edge_feature"] = torch.cat( + [state_tensor["edge_feature"], actions.features], dim=0 + ) + state_tensor["edge_index"] = torch.cat( + [ + state_tensor["edge_index"], + actions.edge_index + state_tensor["batch_ptr"][:-1][:, None], + ], + dim=0, + ) return state_tensor @@ -90,13 +105,17 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> TensorDic assert torch.all(actions.action_type == action_type) if action_type == GraphActionType.ADD_NODE: is_equal = torch.any( - torch.all(state_tensor["node_feature"][:, None] == actions.features, dim=-1), - dim=-1 + torch.all( + state_tensor["node_feature"][:, None] == actions.features, dim=-1 + ), + dim=-1, ) state_tensor["node_feature"] = state_tensor["node_feature"][~is_equal] elif action_type == GraphActionType.ADD_EDGE: assert actions.edge_index is not None - global_edge_index = actions.edge_index + state_tensor["batch_ptr"][:-1][:, None] + global_edge_index = ( + actions.edge_index + state_tensor["batch_ptr"][:-1][:, None] + ) is_equal = torch.all( state_tensor["edge_index"] == global_edge_index[:, None], dim=-1 ) @@ -121,67 +140,84 @@ def is_action_valid( add_node_out = torch.all(equal_nodes_per_batch == 1) else: add_node_out = torch.all(equal_nodes_per_batch == 0) - + add_edge_mask = actions.action_type == GraphActionType.ADD_EDGE if not torch.any(add_edge_mask): add_edge_out = True else: add_edge_states = states[add_edge_mask].tensor add_edge_actions = actions[add_edge_mask] - if torch.any(add_edge_actions.edge_index[:, 0] == add_edge_actions.edge_index[:, 1]): + if torch.any( + add_edge_actions.edge_index[:, 0] == add_edge_actions.edge_index[:, 1] + ): return False if add_edge_states["node_feature"].shape[0] == 0: return False - if torch.any(add_edge_actions.edge_index > add_edge_states["node_feature"].shape[0]): + if torch.any( + add_edge_actions.edge_index > add_edge_states["node_feature"].shape[0] + ): return False - global_edge_index = add_edge_actions.edge_index + add_edge_states["batch_ptr"][:-1][:, None] + global_edge_index = ( + add_edge_actions.edge_index + add_edge_states["batch_ptr"][:-1][:, None] + ) equal_edges_per_batch = torch.all( add_edge_states["edge_index"] == global_edge_index[:, None], dim=-1 ).sum(dim=-1) - + if backward: add_edge_out = torch.all(equal_edges_per_batch == 1) else: add_edge_out = torch.all(equal_edges_per_batch == 0) - + return bool(add_node_out) and bool(add_edge_out) - - def _add_node(self, tensor_dict: TensorDict, batch_indices: torch.Tensor, nodes_to_add: torch.Tensor) -> TensorDict: + + def _add_node( + self, + tensor_dict: TensorDict, + batch_indices: torch.Tensor, + nodes_to_add: torch.Tensor, + ) -> TensorDict: if isinstance(batch_indices, list): batch_indices = torch.tensor(batch_indices) if len(batch_indices) != len(nodes_to_add): - raise ValueError("Number of batch indices must match number of node feature lists") - + raise ValueError( + "Number of batch indices must match number of node feature lists" + ) + modified_dict = tensor_dict.clone() - node_feature_dim = modified_dict['node_feature'].shape[1] - + node_feature_dim = modified_dict["node_feature"].shape[1] + for graph_idx, new_nodes in zip(batch_indices, nodes_to_add): - start_ptr = tensor_dict['batch_ptr'][graph_idx] - end_ptr = tensor_dict['batch_ptr'][graph_idx + 1] + tensor_dict["batch_ptr"][graph_idx] + end_ptr = tensor_dict["batch_ptr"][graph_idx + 1] if new_nodes.ndim == 1: new_nodes = new_nodes.unsqueeze(0) if new_nodes.shape[1] != node_feature_dim: - raise ValueError(f"Node features must have dimension {node_feature_dim}") - + raise ValueError( + f"Node features must have dimension {node_feature_dim}" + ) + # Update batch pointers for subsequent graphs shift = new_nodes.shape[0] - modified_dict['batch_ptr'][graph_idx + 1:] += shift - + modified_dict["batch_ptr"][graph_idx + 1 :] += shift + # Expand node features - modified_dict['node_feature'] = torch.cat([ - modified_dict['node_feature'][:end_ptr], - new_nodes, - modified_dict['node_feature'][end_ptr:] - ]) - + modified_dict["node_feature"] = torch.cat( + [ + modified_dict["node_feature"][:end_ptr], + new_nodes, + modified_dict["node_feature"][end_ptr:], + ] + ) + # Update edge indices # Increment indices for edges after the current graph - edge_mask_0 = modified_dict['edge_index'][:, 0] >= end_ptr - edge_mask_1 = modified_dict['edge_index'][:, 1] >= end_ptr - modified_dict['edge_index'][edge_mask_0, 0] += shift - modified_dict['edge_index'][edge_mask_1, 1] += shift - + edge_mask_0 = modified_dict["edge_index"][:, 0] >= end_ptr + edge_mask_1 = modified_dict["edge_index"][:, 1] >= end_ptr + modified_dict["edge_index"][edge_mask_0, 0] += shift + modified_dict["edge_index"][edge_mask_1, 1] += shift + return modified_dict def reward(self, final_states: GraphStates) -> torch.Tensor: diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 169a1f57..e8058e65 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -9,7 +9,12 @@ from gfn.actions import GraphActionType from gfn.preprocessors import IdentityPreprocessor, Preprocessor from gfn.states import DiscreteStates, GraphStates, States -from gfn.utils.distributions import CategoricalActionType, CategoricalIndexes, ComposedDistribution, UnsqueezedCategorical +from gfn.utils.distributions import ( + CategoricalActionType, + CategoricalIndexes, + ComposedDistribution, + UnsqueezedCategorical, +) REDUCTION_FXNS = { "mean": torch.mean, @@ -235,7 +240,9 @@ def forward(self, states: DiscreteStates) -> torch.Tensor: Returns the output of the module, as a tensor of shape (*batch_shape, output_dim). """ out = super().forward(states) - assert out.shape[-1] == self.expected_output_dim, f"Expected output dim: {self.expected_output_dim}, got: {out.shape[-1]}" + assert ( + out.shape[-1] == self.expected_output_dim + ), f"Expected output dim: {self.expected_output_dim}, got: {out.shape[-1]}" return out def to_probability_distribution( @@ -517,7 +524,9 @@ def to_probability_distribution( dists["action_type"] = CategoricalActionType(probs=action_type_probs) edge_index_logits = module_output["edge_index"] - if states.tensor["node_feature"].shape[0] > 1 and torch.any(edge_index_logits != -float("inf")): + if states.tensor["node_feature"].shape[0] > 1 and torch.any( + edge_index_logits != -float("inf") + ): edge_index_probs = torch.softmax(edge_index_logits / temperature, dim=-1) uniform_dist_probs = ( torch.ones_like(edge_index_probs) / edge_index_probs.shape[-1] diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index d0f580fd..17352f22 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -167,7 +167,7 @@ def sample_trajectories( step = 0 all_estimator_outputs = [] - + while not all(dones): actions = env.actions_from_batch_shape((n_trajectories,)) # Dummy actions. log_probs = torch.full( diff --git a/src/gfn/states.py b/src/gfn/states.py index 7875038c..d7cc96f6 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -293,7 +293,7 @@ def log_rewards(self, log_rewards: torch.Tensor) -> None: def sample(self, n_samples: int) -> States: """Samples a subset of the States object.""" return self[torch.randperm(len(self))[:n_samples]] - + @classmethod def stack(cls, states: List[States]): """Given a list of states, stacks them along a new dimension (0).""" @@ -526,17 +526,23 @@ def __init__(self, tensor: TensorDict): self._log_rewards: Optional[float] = None # TODO logic repeated from env.is_valid_action not_empty = self.tensor["batch_ptr"][:-1] + 1 < self.tensor["batch_ptr"][1:] - self.forward_masks = torch.ones((np.prod(self.batch_shape), 3), dtype=torch.bool) + self.forward_masks = torch.ones( + (np.prod(self.batch_shape), 3), dtype=torch.bool + ) self.forward_masks[..., GraphActionType.ADD_EDGE] = not_empty self.forward_masks[..., GraphActionType.EXIT] = not_empty self.forward_masks = self.forward_masks.view(*self.batch_shape, 3) - self.backward_masks = torch.ones((np.prod(self.batch_shape), 3), dtype=torch.bool) + self.backward_masks = torch.ones( + (np.prod(self.batch_shape), 3), dtype=torch.bool + ) self.backward_masks[..., GraphActionType.ADD_NODE] = not_empty - self.backward_masks[..., GraphActionType.ADD_EDGE] = not_empty # TODO: check at least one edge is present + self.backward_masks[ + ..., GraphActionType.ADD_EDGE + ] = not_empty # TODO: check at least one edge is present self.backward_masks[..., GraphActionType.EXIT] = not_empty self.backward_masks = self.backward_masks.view(*self.batch_shape, 3) - + @property def batch_shape(self) -> tuple: return tuple(self.tensor["batch_shape"].tolist()) @@ -559,13 +565,16 @@ def from_batch_shape( def make_initial_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) - return TensorDict({ - "node_feature": cls.s0["node_feature"].repeat(np.prod(batch_shape), 1), - "edge_feature": cls.s0["edge_feature"].repeat(np.prod(batch_shape), 1), - "edge_index": cls.s0["edge_index"].repeat(np.prod(batch_shape), 1), - "batch_ptr": torch.arange(np.prod(batch_shape) + 1) * cls.s0["node_feature"].shape[0], - "batch_shape": batch_shape - }) + return TensorDict( + { + "node_feature": cls.s0["node_feature"].repeat(np.prod(batch_shape), 1), + "edge_feature": cls.s0["edge_feature"].repeat(np.prod(batch_shape), 1), + "edge_index": cls.s0["edge_index"].repeat(np.prod(batch_shape), 1), + "batch_ptr": torch.arange(np.prod(batch_shape) + 1) + * cls.s0["node_feature"].shape[0], + "batch_shape": batch_shape, + } + ) @classmethod def make_sink_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: @@ -573,13 +582,16 @@ def make_sink_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: raise NotImplementedError("Sink state is not defined") batch_shape = batch_shape if isinstance(batch_shape, Tuple) else (batch_shape,) - return TensorDict({ - "node_feature": cls.sf["node_feature"].repeat(np.prod(batch_shape), 1), - "edge_feature": cls.sf["edge_feature"].repeat(np.prod(batch_shape), 1), - "edge_index": cls.sf["edge_index"].repeat(np.prod(batch_shape), 1), - "batch_ptr": torch.arange(np.prod(batch_shape) + 1) * cls.sf["node_feature"].shape[0], - "batch_shape": batch_shape - }) + return TensorDict( + { + "node_feature": cls.sf["node_feature"].repeat(np.prod(batch_shape), 1), + "edge_feature": cls.sf["edge_feature"].repeat(np.prod(batch_shape), 1), + "edge_index": cls.sf["edge_index"].repeat(np.prod(batch_shape), 1), + "batch_ptr": torch.arange(np.prod(batch_shape) + 1) + * cls.sf["node_feature"].shape[0], + "batch_shape": batch_shape, + } + ) @classmethod def make_random_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: @@ -589,13 +601,21 @@ def make_random_states_tensor(cls, batch_shape: int | Tuple) -> TensorDict: num_edges = np.random.randint(num_nodes * (num_nodes - 1) // 2) node_features_dim = cls.s0["node_feature"].shape[-1] edge_features_dim = cls.s0["edge_feature"].shape[-1] - return TensorDict({ - "node_feature": torch.rand(np.prod(batch_shape) * num_nodes, node_features_dim), - "edge_feature": torch.rand(np.prod(batch_shape) * num_edges, edge_features_dim), - "edge_index": torch.randint(num_nodes, size=(np.prod(batch_shape) * num_edges, 2)), - "batch_ptr": torch.arange(np.prod(batch_shape) + 1) * num_nodes, - "batch_shape": batch_shape - }) + return TensorDict( + { + "node_feature": torch.rand( + np.prod(batch_shape) * num_nodes, node_features_dim + ), + "edge_feature": torch.rand( + np.prod(batch_shape) * num_edges, edge_features_dim + ), + "edge_index": torch.randint( + num_nodes, size=(np.prod(batch_shape) * num_edges, 2) + ), + "batch_ptr": torch.arange(np.prod(batch_shape) + 1) * num_nodes, + "batch_shape": batch_shape, + } + ) def __len__(self) -> int: return int(np.prod(self.batch_shape)) @@ -611,44 +631,48 @@ def __getitem__( ) -> GraphStates: tensor_idx = torch.arange(len(self)).view(*self.batch_shape) index = tensor_idx[index].flatten() - - if torch.any(index >= len(self.tensor['batch_ptr']) - 1): + + if torch.any(index >= len(self.tensor["batch_ptr"]) - 1): raise ValueError("Graph index out of bounds") - - start_ptrs = self.tensor['batch_ptr'][:-1][index] - end_ptrs = self.tensor['batch_ptr'][1:][index] - + + start_ptrs = self.tensor["batch_ptr"][:-1][index] + end_ptrs = self.tensor["batch_ptr"][1:][index] + node_features = [torch.empty(0, self.node_features_dim)] edge_features = [torch.empty(0, self.edge_features_dim)] edge_indices = [torch.empty(0, 2, dtype=torch.long)] batch_ptr = [0] - + for start, end in zip(start_ptrs, end_ptrs): - graph_nodes = self.tensor['node_feature'][start:end] + graph_nodes = self.tensor["node_feature"][start:end] node_features.append(graph_nodes) # Find edges for this graph - edge_mask = ((self.tensor['edge_index'][:, 0] >= start) & - (self.tensor['edge_index'][:, 0] < end)) - graph_edges = self.tensor['edge_feature'][edge_mask] + edge_mask = (self.tensor["edge_index"][:, 0] >= start) & ( + self.tensor["edge_index"][:, 0] < end + ) + graph_edges = self.tensor["edge_feature"][edge_mask] edge_features.append(graph_edges) - + # Adjust edge indices to be local to this graph - graph_edge_index = self.tensor['edge_index'][edge_mask] - graph_edge_index[:, 0] -= (start - batch_ptr[-1]) - graph_edge_index[:, 1] -= (start - batch_ptr[-1]) + graph_edge_index = self.tensor["edge_index"][edge_mask] + graph_edge_index[:, 0] -= start - batch_ptr[-1] + graph_edge_index[:, 1] -= start - batch_ptr[-1] edge_indices.append(graph_edge_index) batch_ptr.append(batch_ptr[-1] + len(graph_nodes)) - out = self.__class__(TensorDict({ - 'node_feature': torch.cat(node_features), - 'edge_feature': torch.cat(edge_features), - 'edge_index': torch.cat(edge_indices), - 'batch_ptr': torch.tensor(batch_ptr), - 'batch_shape': (len(index),) - })) + out = self.__class__( + TensorDict( + { + "node_feature": torch.cat(node_features), + "edge_feature": torch.cat(edge_features), + "edge_index": torch.cat(edge_indices), + "batch_ptr": torch.tensor(batch_ptr), + "batch_shape": (len(index),), + } + ) + ) - if self._log_rewards is not None: out._log_rewards = self._log_rewards[index] @@ -660,64 +684,88 @@ def __setitem__(self, index: int | Sequence[int], graph: GraphStates): """ tensor_idx = torch.arange(len(self)).view(*self.batch_shape) index = tensor_idx[index].flatten() - + # Validate indices - if torch.any(index >= len(self.tensor['batch_ptr']) - 1): + if torch.any(index >= len(self.tensor["batch_ptr"]) - 1): raise ValueError("Target graph index out of bounds") - + # Source graph details source_tensor_dict = graph.tensor - source_num_graphs = torch.prod(source_tensor_dict['batch_shape']) - + source_num_graphs = torch.prod(source_tensor_dict["batch_shape"]) + # Validate source and target indices match if len(index) != source_num_graphs: - raise ValueError("Number of source graphs must match number of target indices") - + raise ValueError( + "Number of source graphs must match number of target indices" + ) + for i, graph_idx in enumerate(index): # Get start and end pointers for the current graph - start_ptr = self.tensor['batch_ptr'][graph_idx] - end_ptr = self.tensor['batch_ptr'][graph_idx + 1] - source_start_ptr = source_tensor_dict['batch_ptr'][i] - source_end_ptr = source_tensor_dict['batch_ptr'][i + 1] + start_ptr = self.tensor["batch_ptr"][graph_idx] + end_ptr = self.tensor["batch_ptr"][graph_idx + 1] + source_start_ptr = source_tensor_dict["batch_ptr"][i] + source_end_ptr = source_tensor_dict["batch_ptr"][i + 1] + + new_nodes = source_tensor_dict["node_feature"][ + source_start_ptr:source_end_ptr + ] - new_nodes = source_tensor_dict['node_feature'][source_start_ptr:source_end_ptr] - # Ensure new nodes have correct feature dimension if new_nodes.ndim == 1: new_nodes = new_nodes.unsqueeze(0) - + if new_nodes.shape[1] != self.node_features_dim: - raise ValueError(f"Node features must have dimension {node_feature_dim}") - + raise ValueError( + f"Node features must have dimension {node_feature_dim}" + ) + # Number of new nodes to add shift = new_nodes.shape[0] - (end_ptr - start_ptr) - + # Concatenate node features - self.tensor['node_feature'] = torch.cat([ - self.tensor['node_feature'][:start_ptr], # Nodes before the current graph - new_nodes, # New nodes to add - self.tensor['node_feature'][end_ptr:] # Nodes after the current graph - ]) - + self.tensor["node_feature"] = torch.cat( + [ + self.tensor["node_feature"][ + :start_ptr + ], # Nodes before the current graph + new_nodes, # New nodes to add + self.tensor["node_feature"][ + end_ptr: + ], # Nodes after the current graph + ] + ) + # Update edge indices for subsequent graphs - edge_mask = self.tensor['edge_index'] >= end_ptr + edge_mask = self.tensor["edge_index"] >= end_ptr assert torch.all(edge_mask[..., 0] == edge_mask[..., 1]) edge_mask = torch.all(edge_mask, dim=-1) - self.tensor['edge_index'][edge_mask] += shift - edge_mask |= torch.all(self.tensor['edge_index'] < start_ptr, dim=-1) - edge_to_add_mask = torch.all(source_tensor_dict['edge_index'] >= source_start_ptr, dim=-1) - edge_to_add_mask &= torch.all(source_tensor_dict['edge_index'] < source_end_ptr, dim=-1) - self.tensor['edge_index'] = torch.cat([ - self.tensor['edge_index'][edge_mask], - source_tensor_dict['edge_index'][edge_to_add_mask] - source_start_ptr + start_ptr, - ], dim=0) - self.tensor['edge_feature'] = torch.cat([ - self.tensor['edge_feature'][edge_mask], - source_tensor_dict['edge_feature'][edge_to_add_mask], - ], dim=0) + self.tensor["edge_index"][edge_mask] += shift + edge_mask |= torch.all(self.tensor["edge_index"] < start_ptr, dim=-1) + edge_to_add_mask = torch.all( + source_tensor_dict["edge_index"] >= source_start_ptr, dim=-1 + ) + edge_to_add_mask &= torch.all( + source_tensor_dict["edge_index"] < source_end_ptr, dim=-1 + ) + self.tensor["edge_index"] = torch.cat( + [ + self.tensor["edge_index"][edge_mask], + source_tensor_dict["edge_index"][edge_to_add_mask] + - source_start_ptr + + start_ptr, + ], + dim=0, + ) + self.tensor["edge_feature"] = torch.cat( + [ + self.tensor["edge_feature"][edge_mask], + source_tensor_dict["edge_feature"][edge_to_add_mask], + ], + dim=0, + ) # Update batch pointers - self.tensor['batch_ptr'][graph_idx + 1:] += shift + self.tensor["batch_ptr"][graph_idx + 1 :] += shift @property def device(self) -> torch.device: @@ -736,13 +784,33 @@ def clone(self) -> GraphStates: def extend(self, other: GraphStates): """Concatenates to another GraphStates object along the batch dimension""" - self.tensor["node_feature"] = torch.cat([self.tensor["node_feature"], other.tensor["node_feature"]], dim=0) - self.tensor["edge_feature"] = torch.cat([self.tensor["edge_feature"], other.tensor["edge_feature"]], dim=0) + self.tensor["node_feature"] = torch.cat( + [self.tensor["node_feature"], other.tensor["node_feature"]], dim=0 + ) + self.tensor["edge_feature"] = torch.cat( + [self.tensor["edge_feature"], other.tensor["edge_feature"]], dim=0 + ) # TODO: fix indices - self.tensor["edge_index"] = torch.cat([self.tensor["edge_index"], other.tensor["edge_index"] + self.tensor["batch_ptr"][-1]], dim=0) - self.tensor["batch_ptr"] = torch.cat([self.tensor["batch_ptr"], other.tensor["batch_ptr"][1:] + self.tensor["batch_ptr"][-1]], dim=0) - assert torch.all(self.tensor["batch_shape"][1:] == other.tensor["batch_shape"][1:]) - self.tensor["batch_shape"] = (self.tensor["batch_shape"][0] + other.tensor["batch_shape"][0],) + self.batch_shape[1:] + self.tensor["edge_index"] = torch.cat( + [ + self.tensor["edge_index"], + other.tensor["edge_index"] + self.tensor["batch_ptr"][-1], + ], + dim=0, + ) + self.tensor["batch_ptr"] = torch.cat( + [ + self.tensor["batch_ptr"], + other.tensor["batch_ptr"][1:] + self.tensor["batch_ptr"][-1], + ], + dim=0, + ) + assert torch.all( + self.tensor["batch_shape"][1:] == other.tensor["batch_shape"][1:] + ) + self.tensor["batch_shape"] = ( + self.tensor["batch_shape"][0] + other.tensor["batch_shape"][0], + ) + self.batch_shape[1:] @property def log_rewards(self) -> torch.Tensor: @@ -759,12 +827,22 @@ def _compare(self, other: TensorDict) -> torch.Tensor: if end - start != len(other["node_feature"]): out[i] = False else: - out[i] = torch.all(self.tensor["node_feature"][start:end] == other["node_feature"]) - edge_mask = torch.all((self.tensor["edge_index"] >= start) & (self.tensor["edge_index"] < end), dim=-1) + out[i] = torch.all( + self.tensor["node_feature"][start:end] == other["node_feature"] + ) + edge_mask = torch.all( + (self.tensor["edge_index"] >= start) + & (self.tensor["edge_index"] < end), + dim=-1, + ) edge_index = self.tensor["edge_index"][edge_mask] - start - out[i] &= len(edge_index) == len(other["edge_index"]) and torch.all(edge_index == other["edge_index"]) + out[i] &= len(edge_index) == len(other["edge_index"]) and torch.all( + edge_index == other["edge_index"] + ) edge_feature = self.tensor["edge_feature"][edge_mask] - out[i] &= len(edge_feature) == len(other["edge_feature"]) and torch.all(edge_feature == other["edge_feature"]) + out[i] &= len(edge_feature) == len(other["edge_feature"]) and torch.all( + edge_feature == other["edge_feature"] + ) return out.view(self.batch_shape) @property diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index e6f6beaa..250a275b 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -91,7 +91,6 @@ def log_prob(self, value): class CategoricalActionType(Categorical): # TODO: remove, just to sample 1 action_type - def __init__(self, probs: torch.Tensor): self.batch_len = len(probs) super().__init__(probs[0]) @@ -99,6 +98,6 @@ def __init__(self, probs: torch.Tensor): def sample(self, sample_shape=torch.Size()) -> torch.Tensor: samples = super().sample(sample_shape) return samples.repeat(self.batch_len) - + def log_prob(self, value): - return super().log_prob(value[0]).repeat(self.batch_len) \ No newline at end of file + return super().log_prob(value[0]).repeat(self.batch_len) diff --git a/testing/test_environments.py b/testing/test_environments.py index a2919d50..45768c2e 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -332,54 +332,88 @@ def test_graph_env(): action_cls = env.make_actions_class() with pytest.raises(NonValidActionsError): - actions = action_cls(TensorDict({ - "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), - "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), - "edge_index": torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long), - }, batch_size=BATCH_SIZE)) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.randint( + 0, 10, (BATCH_SIZE, 2), dtype=torch.long + ), + }, + batch_size=BATCH_SIZE, + ) + ) states = env.step(states, actions) for _ in range(NUM_NODES): - actions = action_cls(TensorDict({ - "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), - "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), - }, batch_size=BATCH_SIZE)) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + }, + batch_size=BATCH_SIZE, + ) + ) states = env.step(states, actions) states = env.States(states) assert states.tensor["node_feature"].shape == (BATCH_SIZE * NUM_NODES, FEATURE_DIM) with pytest.raises(NonValidActionsError): - first_node_mask = torch.arange(len(states.tensor["node_feature"])) // BATCH_SIZE == 0 - actions = action_cls(TensorDict({ - "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), - "features": states.tensor["node_feature"][first_node_mask], - }, batch_size=BATCH_SIZE)) + first_node_mask = ( + torch.arange(len(states.tensor["node_feature"])) // BATCH_SIZE == 0 + ) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": states.tensor["node_feature"][first_node_mask], + }, + batch_size=BATCH_SIZE, + ) + ) states = env.step(states, actions) with pytest.raises(NonValidActionsError): edge_index = torch.randint(0, 3, (BATCH_SIZE,), dtype=torch.long) - actions = action_cls(TensorDict({ - "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), - "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), - "edge_index": torch.stack([edge_index, edge_index], dim=1), - }, batch_size=BATCH_SIZE)) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.stack([edge_index, edge_index], dim=1), + }, + batch_size=BATCH_SIZE, + ) + ) states = env.step(states, actions) for i in range(NUM_NODES - 1): node_is = states.tensor["batch_ptr"][:-1] + i node_js = states.tensor["batch_ptr"][:-1] + i + 1 - actions = action_cls(TensorDict({ - "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), - "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), - "edge_index": torch.stack([node_is, node_js], dim=1), - }, batch_size=BATCH_SIZE)) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.stack([node_is, node_js], dim=1), + }, + batch_size=BATCH_SIZE, + ) + ) states = env.step(states, actions) states = env.States(states) - actions = action_cls(TensorDict({ - "action_type": torch.full((BATCH_SIZE,), GraphActionType.EXIT), - }, batch_size=BATCH_SIZE)) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.EXIT), + }, + batch_size=BATCH_SIZE, + ) + ) sf_states = env.step(states, actions) sf_states = env.States(sf_states) assert torch.all(sf_states.is_sink_state) @@ -396,36 +430,58 @@ def test_graph_env(): num_edges_per_batch = len(states.tensor["edge_feature"]) // BATCH_SIZE for i in reversed(range(num_edges_per_batch)): edge_idx = torch.arange(i * BATCH_SIZE, (i + 1) * BATCH_SIZE) - actions = action_cls(TensorDict({ - "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), - "features": states.tensor["edge_feature"][edge_idx], - "edge_index": states.tensor["edge_index"][edge_idx], - }, batch_size=BATCH_SIZE)) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": states.tensor["edge_feature"][edge_idx], + "edge_index": states.tensor["edge_index"][edge_idx], + }, + batch_size=BATCH_SIZE, + ) + ) states = env.backward_step(states, actions) states = env.States(states) with pytest.raises(NonValidActionsError): - actions = action_cls(TensorDict({ - "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), - "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), - "edge_index": torch.randint(0, 10, (BATCH_SIZE, 2), dtype=torch.long), - }, batch_size=BATCH_SIZE)) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_EDGE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + "edge_index": torch.randint( + 0, 10, (BATCH_SIZE, 2), dtype=torch.long + ), + }, + batch_size=BATCH_SIZE, + ) + ) states = env.backward_step(states, actions) for i in reversed(range(1, NUM_NODES + 1)): edge_idx = torch.arange(BATCH_SIZE) * i - actions = action_cls(TensorDict({ - "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), - "features": states.tensor["node_feature"][edge_idx], - }, batch_size=BATCH_SIZE)) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": states.tensor["node_feature"][edge_idx], + }, + batch_size=BATCH_SIZE, + ) + ) states = env.backward_step(states, actions) states = env.States(states) assert states.tensor["node_feature"].shape == (0, FEATURE_DIM) with pytest.raises(NonValidActionsError): - actions = action_cls(TensorDict({ - "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), - "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), - }, batch_size=BATCH_SIZE)) + actions = action_cls( + TensorDict( + { + "action_type": torch.full((BATCH_SIZE,), GraphActionType.ADD_NODE), + "features": torch.rand((BATCH_SIZE, FEATURE_DIM)), + }, + batch_size=BATCH_SIZE, + ) + ) states = env.backward_step(states, actions) diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 470c8b09..d23dbd61 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -225,6 +225,7 @@ def test_replay_buffer( # ------ GRAPH TESTS ------ + class GraphActionNet(nn.Module): def __init__(self, feature_dim: int): super().__init__() @@ -243,7 +244,9 @@ def forward(self, states: GraphStates) -> TensorDict: features = torch.zeros((len(states), self.feature_dim)) else: action_type = self.action_type_conv(node_feature, edge_index) - action_type = action_type.reshape(len(states), -1, action_type.shape[-1]).mean(dim=1) + action_type = action_type.reshape( + len(states), -1, action_type.shape[-1] + ).mean(dim=1) features = self.features_conv(node_feature, edge_index) features = features.reshape(len(states), -1, features.shape[-1]).mean(dim=1) @@ -277,3 +280,5 @@ def test_graph_building(): save_logprobs=True, save_estimator_outputs=False, ) + + assert len(trajectories) == 7 diff --git a/tutorials/examples/test_graph_ring.py b/tutorials/examples/test_graph_ring.py index c6b591c2..b6d401f7 100644 --- a/tutorials/examples/test_graph_ring.py +++ b/tutorials/examples/test_graph_ring.py @@ -3,17 +3,19 @@ import math import time from typing import Optional + +import matplotlib.pyplot as plt import torch +from tensordict import TensorDict from torch import nn -from gfn.actions import Actions, GraphActionType, GraphActions +from torch_geometric.nn import GCNConv + +from gfn.actions import Actions, GraphActions, GraphActionType from gfn.gflownet.flow_matching import FMGFlowNet from gfn.gym import GraphBuilding from gfn.modules import DiscretePolicyEstimator from gfn.preprocessors import Preprocessor from gfn.states import GraphStates -from tensordict import TensorDict -from torch_geometric.nn import GCNConv -import matplotlib.pyplot as plt def state_evaluator(states: GraphStates) -> torch.Tensor: @@ -26,11 +28,15 @@ def state_evaluator(states: GraphStates) -> torch.Tensor: out = torch.zeros(len(states)) for i in range(len(states)): start, end = states.tensor["batch_ptr"][i], states.tensor["batch_ptr"][i + 1] - edge_index_mask = torch.all(states.tensor["edge_index"] >= start, dim=-1) & torch.all(states.tensor["edge_index"] < end, dim=-1) + edge_index_mask = torch.all( + states.tensor["edge_index"] >= start, dim=-1 + ) & torch.all(states.tensor["edge_index"] < end, dim=-1) edge_index = states.tensor["edge_index"][edge_index_mask] arange = torch.arange(start, end) # TODO: not correct, accepts multiple rings - if torch.all(torch.sort(edge_index[:, 0])[0] == arange) and torch.all(torch.sort(edge_index[:, 1])[0] == arange): + if torch.all(torch.sort(edge_index[:, 0])[0] == arange) and torch.all( + torch.sort(edge_index[:, 1])[0] == arange + ): out[i] = 1 else: out[i] = eps @@ -46,7 +52,11 @@ def __init__(self, n_nodes: int, edge_hidden_dim: int = 128): self.edge_hidden_dim = edge_hidden_dim def _group_sum(self, tensor: torch.Tensor, batch_ptr: torch.Tensor) -> torch.Tensor: - cumsum = torch.zeros((len(tensor) + 1, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) + cumsum = torch.zeros( + (len(tensor) + 1, *tensor.shape[1:]), + dtype=tensor.dtype, + device=tensor.device, + ) cumsum[1:] = torch.cumsum(tensor, dim=0) return cumsum[batch_ptr[1:]] - cumsum[batch_ptr[:-1]] @@ -59,82 +69,102 @@ def forward(self, states_tensor: TensorDict) -> torch.Tensor: action_type = self._group_sum(action_type, batch_ptr) edge_index = self.edge_index_conv(node_feature, edge_index) - edge_index = edge_index.reshape(*states_tensor["batch_shape"], self.n_nodes, self.edge_hidden_dim) + edge_index = edge_index.reshape( + *states_tensor["batch_shape"], self.n_nodes, self.edge_hidden_dim + ) edge_index = torch.einsum("bnf,bmf->bnm", edge_index, edge_index) - edge_actions = edge_index.reshape(*states_tensor["batch_shape"], self.n_nodes * self.n_nodes) + edge_actions = edge_index.reshape( + *states_tensor["batch_shape"], self.n_nodes * self.n_nodes + ) return torch.cat([action_type, edge_actions], dim=-1) + class RingGraphBuilding(GraphBuilding): def __init__(self, n_nodes: int = 10): self.n_nodes = n_nodes self.n_actions = 1 + n_nodes * n_nodes super().__init__(feature_dim=1, state_evaluator=state_evaluator) - def make_actions_class(self) -> type[Actions]: env = self + class RingActions(Actions): action_shape = (1,) dummy_action = torch.tensor([env.n_actions]) - exit_action = torch.zeros(1,) + exit_action = torch.zeros( + 1, + ) return RingActions - def make_states_class(self) -> type[GraphStates]: env = self class RingStates(GraphStates): - s0 = TensorDict({ - "node_feature": torch.arange(env.n_nodes).unsqueeze(-1), - "edge_feature": torch.ones((0, 1)), - "edge_index": torch.ones((0, 2), dtype=torch.long), - }, batch_size=()) - sf = TensorDict({ - "node_feature": torch.zeros((env.n_nodes, 1)), - "edge_feature": torch.zeros((0, 1)), - "edge_index": torch.zeros((0, 2), dtype=torch.long), - }, batch_size=()) + s0 = TensorDict( + { + "node_feature": torch.arange(env.n_nodes).unsqueeze(-1), + "edge_feature": torch.ones((0, 1)), + "edge_index": torch.ones((0, 2), dtype=torch.long), + }, + batch_size=(), + ) + sf = TensorDict( + { + "node_feature": torch.zeros((env.n_nodes, 1)), + "edge_feature": torch.zeros((0, 1)), + "edge_index": torch.zeros((0, 2), dtype=torch.long), + }, + batch_size=(), + ) def __init__(self, tensor: TensorDict): self.tensor = tensor self.node_features_dim = tensor["node_feature"].shape[-1] self.edge_features_dim = tensor["edge_feature"].shape[-1] self._log_rewards: Optional[float] = None - + self.n_nodes = env.n_nodes self.n_actions = env.n_actions @property def forward_masks(self): forward_masks = torch.ones(len(self), self.n_actions, dtype=torch.bool) - forward_masks[:, 1::self.n_nodes + 1] = False + forward_masks[:, 1 :: self.n_nodes + 1] = False for i in range(len(self)): existing_edges = self[i].tensor["edge_index"] - forward_masks[i, 1 + existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1]] = False - + forward_masks[ + i, + 1 + existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1], + ] = False + return forward_masks.view(*self.batch_shape, self.n_actions) - + @forward_masks.setter def forward_masks(self, value: torch.Tensor): - pass # fwd masks is computed on the fly + pass # fwd masks is computed on the fly @property def backward_masks(self): - backward_masks = torch.zeros(len(self), self.n_actions, dtype=torch.bool) + backward_masks = torch.zeros( + len(self), self.n_actions, dtype=torch.bool + ) for i in range(len(self)): existing_edges = self[i].tensor["edge_index"] - backward_masks[i, 1 + existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1]] = True - + backward_masks[ + i, + 1 + existing_edges[:, 0] * self.n_nodes + existing_edges[:, 1], + ] = True + return backward_masks.view(*self.batch_shape, self.n_actions) - + @backward_masks.setter def backward_masks(self, value: torch.Tensor): - pass # bwd masks is computed on the fly - + pass # bwd masks is computed on the fly + return RingStates - + def _step(self, states: GraphStates, actions: Actions) -> GraphStates: actions = self.convert_actions(actions) return super()._step(states, actions) @@ -145,20 +175,26 @@ def _backward_step(self, states: GraphStates, actions: Actions) -> GraphStates: def convert_actions(self, actions: Actions) -> GraphActions: action_tensor = actions.tensor.squeeze(-1) - action_type = torch.where(action_tensor == 0, GraphActionType.EXIT, GraphActionType.ADD_EDGE) + action_type = torch.where( + action_tensor == 0, GraphActionType.EXIT, GraphActionType.ADD_EDGE + ) edge_index_i0 = (action_tensor - 1) // (self.n_nodes) edge_index_i1 = (action_tensor - 1) % (self.n_nodes) edge_index = torch.stack([edge_index_i0, edge_index_i1], dim=-1) - return GraphActions(TensorDict({ - "action_type": action_type, - "features": torch.ones(action_tensor.shape + (1,)), - "edge_index": edge_index, - }, batch_size=action_tensor.shape)) + return GraphActions( + TensorDict( + { + "action_type": action_type, + "features": torch.ones(action_tensor.shape + (1,)), + "edge_index": edge_index, + }, + batch_size=action_tensor.shape, + ) + ) class GraphPreprocessor(Preprocessor): - def __init__(self, feature_dim: int = 1): super().__init__(output_dim=feature_dim) @@ -184,28 +220,36 @@ def render_states(states: GraphStates): y = radius * math.sin(angle) xs.append(x) ys.append(y) - current_ax.add_patch(plt.Circle((x, y), 0.5, facecolor='none', edgecolor='black')) - + current_ax.add_patch( + plt.Circle((x, y), 0.5, facecolor="none", edgecolor="black") + ) + for edge in state.tensor["edge_index"]: start_x, start_y = xs[edge[0]], ys[edge[0]] end_x, end_y = xs[edge[1]], ys[edge[1]] dx = end_x - start_x dy = end_y - start_y length = math.sqrt(dx**2 + dy**2) - dx, dy = dx/length, dy/length - + dx, dy = dx / length, dy / length + circle_radius = 0.5 head_thickness = 0.2 start_x += dx * (circle_radius) start_y += dy * (circle_radius) end_x -= dx * (circle_radius + head_thickness) end_y -= dy * (circle_radius + head_thickness) - - current_ax.arrow(start_x, start_y, - end_x - start_x, end_y - start_y, - head_width=head_thickness, head_length=head_thickness, - fc='black', ec='black') - + + current_ax.arrow( + start_x, + start_y, + end_x - start_x, + end_y - start_y, + head_width=head_thickness, + head_length=head_thickness, + fc="black", + ec="black", + ) + current_ax.set_title(f"State {i}, $r={rewards[i]:.2f}$") current_ax.set_xlim(-(radius + 1), radius + 1) current_ax.set_ylim(-(radius + 1), radius + 1) @@ -213,7 +257,6 @@ def render_states(states: GraphStates): current_ax.set_xticks([]) current_ax.set_yticks([]) - plt.show() @@ -224,11 +267,15 @@ def render_states(states: GraphStates): env = RingGraphBuilding(n_nodes=N_NODES) module = RingPolicyEstimator(env.n_nodes) - pf_estimator = DiscretePolicyEstimator(module=module, n_actions=env.n_actions, preprocessor=GraphPreprocessor()) + pf_estimator = DiscretePolicyEstimator( + module=module, n_actions=env.n_actions, preprocessor=GraphPreprocessor() + ) gflownet = FMGFlowNet(pf_estimator) optimizer = torch.optim.Adam(gflownet.parameters(), lr=1e-2) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=N_ITERATIONS, eta_min=1e-4) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=N_ITERATIONS, eta_min=1e-4 + ) losses = [] @@ -246,6 +293,3 @@ def render_states(states: GraphStates): t2 = time.time() print("Time:", t2 - t1) render_states(trajectories.last_states[:8]) - - -