Skip to content

Commit

Permalink
fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
younik committed Jan 12, 2025
1 parent 5d99739 commit f4fc3ab
Show file tree
Hide file tree
Showing 11 changed files with 517 additions and 281 deletions.
51 changes: 30 additions & 21 deletions src/gfn/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import enum
from abc import ABC
from math import prod
from typing import ClassVar, Optional, Sequence
from typing import ClassVar, Sequence

import torch
from tensordict import TensorDict
Expand Down Expand Up @@ -201,11 +201,14 @@ def __init__(self, tensor: TensorDict):
assert torch.all(tensor["action_type"] != GraphActionType.ADD_EDGE)
edge_index = torch.zeros((*self.batch_shape, 2), dtype=torch.long)

self.tensor = TensorDict({
"action_type": tensor["action_type"],
"features": features,
"edge_index": edge_index,
}, batch_size=self.batch_shape)
self.tensor = TensorDict(
{
"action_type": tensor["action_type"],
"features": features,
"edge_index": edge_index,
},
batch_size=self.batch_shape,
)

def __repr__(self):
return f"""GraphAction object with {self.batch_shape} actions."""
Expand All @@ -223,7 +226,6 @@ def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActio
"""Get particular actions of the batch."""
return GraphActions(self.tensor[index])


def __setitem__(
self, index: int | Sequence[int] | Sequence[bool], action: GraphActions
) -> None:
Expand All @@ -239,9 +241,14 @@ def compare(self, other: GraphActions) -> torch.Tensor:
Returns: boolean tensor of shape batch_shape indicating whether the actions are equal.
"""
compare = torch.all(self.tensor == other.tensor, dim=-1)
return compare["action_type"] & \
(compare["action_type"] == GraphActionType.EXIT | compare["features"]) & \
(compare["action_type"] != GraphActionType.ADD_EDGE | compare["edge_index"])
return (
compare["action_type"]
& (compare["action_type"] == GraphActionType.EXIT | compare["features"])
& (
compare["action_type"]
!= GraphActionType.ADD_EDGE | compare["edge_index"]
)
)

@property
def is_exit(self) -> torch.Tensor:
Expand All @@ -257,30 +264,32 @@ def action_type(self) -> torch.Tensor:
def features(self) -> torch.Tensor:
"""Returns the features tensor."""
return self.tensor["features"]

@property
def edge_index(self) -> torch.Tensor:
"""Returns the edge index tensor."""
return self.tensor["edge_index"]

@classmethod
def make_dummy_actions(
cls, batch_shape: tuple[int]
) -> GraphActions:
def make_dummy_actions(cls, batch_shape: tuple[int]) -> GraphActions:
"""Creates an Actions object of dummy actions with the given batch shape."""
return cls(
TensorDict({
"action_type": torch.full(batch_shape, fill_value=GraphActionType.EXIT),
# "features": torch.zeros((*batch_shape, 0, cls.nodes_features_dim)),
# "edge_index": torch.zeros((2, *batch_shape, 0)),
}, batch_size=batch_shape)
TensorDict(
{
"action_type": torch.full(
batch_shape, fill_value=GraphActionType.EXIT
),
# "features": torch.zeros((*batch_shape, 0, cls.nodes_features_dim)),
# "edge_index": torch.zeros((2, *batch_shape, 0)),
},
batch_size=batch_shape,
)
)

@classmethod
def stack(cls, actions_list: list[GraphActions]) -> GraphActions:
"""Stacks a list of GraphActions objects into a single GraphActions object."""
actions_tensor = torch.stack(
[actions.tensor for actions in actions_list], dim=0
)
return cls(actions_tensor)

6 changes: 3 additions & 3 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Dict, Optional, Tuple, Union
from typing import Optional, Tuple, Union

import torch
from tensordict import TensorDict
Expand Down Expand Up @@ -219,7 +219,7 @@ def reset(
batch_shape = (1,)
if isinstance(batch_shape, int):
batch_shape = (batch_shape,)

return self.States.from_batch_shape(
batch_shape=batch_shape, random=random, sink=sink
)
Expand Down Expand Up @@ -266,7 +266,7 @@ def _step(
not_done_actions = actions[~new_sink_states_idx]

new_not_done_states_tensor = self.step(not_done_states, not_done_actions)

if not isinstance(new_not_done_states_tensor, (torch.Tensor, TensorDict)):
raise Exception(
"User implemented env.step function *must* return a torch.Tensor!"
Expand Down
2 changes: 1 addition & 1 deletion src/gfn/gym/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from gfn.gym.box import Box
from gfn.gym.discrete_ebm import DiscreteEBM
from gfn.gym.graph_building import GraphBuilding
from gfn.gym.hypergrid import HyperGrid
from gfn.gym.graph_building import GraphBuilding
134 changes: 85 additions & 49 deletions src/gfn/gym/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import Callable, Literal, Tuple

import torch
from torch_geometric.nn import GCNConv
from tensordict import TensorDict
from torch_geometric.nn import GCNConv

from gfn.actions import GraphActions, GraphActionType
from gfn.env import GraphEnv, NonValidActionsError
Expand All @@ -17,16 +17,24 @@ def __init__(
state_evaluator: Callable[[GraphStates], torch.Tensor] | None = None,
device_str: Literal["cpu", "cuda"] = "cpu",
):
s0 = TensorDict({
"node_feature": torch.zeros((0, feature_dim), dtype=torch.float32),
"edge_feature": torch.zeros((0, feature_dim), dtype=torch.float32),
"edge_index": torch.zeros((0, 2), dtype=torch.long),
}, device=device_str)
sf = TensorDict({
"node_feature": torch.ones((1, feature_dim), dtype=torch.float32) * float("inf"),
"edge_feature": torch.ones((1, feature_dim), dtype=torch.float32) * float("inf"),
"edge_index": torch.zeros((0, 2), dtype=torch.long),
}, device=device_str)
s0 = TensorDict(
{
"node_feature": torch.zeros((0, feature_dim), dtype=torch.float32),
"edge_feature": torch.zeros((0, feature_dim), dtype=torch.float32),
"edge_index": torch.zeros((0, 2), dtype=torch.long),
},
device=device_str,
)
sf = TensorDict(
{
"node_feature": torch.ones((1, feature_dim), dtype=torch.float32)
* float("inf"),
"edge_feature": torch.ones((1, feature_dim), dtype=torch.float32)
* float("inf"),
"edge_index": torch.zeros((0, 2), dtype=torch.long),
},
device=device_str,
)

if state_evaluator is None:
state_evaluator = GCNConvEvaluator(feature_dim)
Expand Down Expand Up @@ -59,15 +67,22 @@ def step(self, states: GraphStates, actions: GraphActions) -> TensorDict:
return self.States.make_sink_states_tensor(states.batch_shape)

if action_type == GraphActionType.ADD_NODE:
batch_indices = torch.arange(len(states))[actions.action_type == GraphActionType.ADD_NODE]
batch_indices = torch.arange(len(states))[
actions.action_type == GraphActionType.ADD_NODE
]
state_tensor = self._add_node(state_tensor, batch_indices, actions.features)

if action_type == GraphActionType.ADD_EDGE:
state_tensor["edge_feature"] = torch.cat([state_tensor["edge_feature"], actions.features], dim=0)
state_tensor["edge_index"] = torch.cat([
state_tensor["edge_index"],
actions.edge_index + state_tensor["batch_ptr"][:-1][:, None]
], dim=0)
state_tensor["edge_feature"] = torch.cat(
[state_tensor["edge_feature"], actions.features], dim=0
)
state_tensor["edge_index"] = torch.cat(
[
state_tensor["edge_index"],
actions.edge_index + state_tensor["batch_ptr"][:-1][:, None],
],
dim=0,
)

return state_tensor

Expand All @@ -90,13 +105,17 @@ def backward_step(self, states: GraphStates, actions: GraphActions) -> TensorDic
assert torch.all(actions.action_type == action_type)
if action_type == GraphActionType.ADD_NODE:
is_equal = torch.any(
torch.all(state_tensor["node_feature"][:, None] == actions.features, dim=-1),
dim=-1
torch.all(
state_tensor["node_feature"][:, None] == actions.features, dim=-1
),
dim=-1,
)
state_tensor["node_feature"] = state_tensor["node_feature"][~is_equal]
elif action_type == GraphActionType.ADD_EDGE:
assert actions.edge_index is not None
global_edge_index = actions.edge_index + state_tensor["batch_ptr"][:-1][:, None]
global_edge_index = (
actions.edge_index + state_tensor["batch_ptr"][:-1][:, None]
)
is_equal = torch.all(
state_tensor["edge_index"] == global_edge_index[:, None], dim=-1
)
Expand All @@ -121,67 +140,84 @@ def is_action_valid(
add_node_out = torch.all(equal_nodes_per_batch == 1)
else:
add_node_out = torch.all(equal_nodes_per_batch == 0)

add_edge_mask = actions.action_type == GraphActionType.ADD_EDGE
if not torch.any(add_edge_mask):
add_edge_out = True
else:
add_edge_states = states[add_edge_mask].tensor
add_edge_actions = actions[add_edge_mask]
if torch.any(add_edge_actions.edge_index[:, 0] == add_edge_actions.edge_index[:, 1]):
if torch.any(
add_edge_actions.edge_index[:, 0] == add_edge_actions.edge_index[:, 1]
):
return False
if add_edge_states["node_feature"].shape[0] == 0:
return False
if torch.any(add_edge_actions.edge_index > add_edge_states["node_feature"].shape[0]):
if torch.any(
add_edge_actions.edge_index > add_edge_states["node_feature"].shape[0]
):
return False
global_edge_index = add_edge_actions.edge_index + add_edge_states["batch_ptr"][:-1][:, None]
global_edge_index = (
add_edge_actions.edge_index + add_edge_states["batch_ptr"][:-1][:, None]
)
equal_edges_per_batch = torch.all(
add_edge_states["edge_index"] == global_edge_index[:, None], dim=-1
).sum(dim=-1)

if backward:
add_edge_out = torch.all(equal_edges_per_batch == 1)
else:
add_edge_out = torch.all(equal_edges_per_batch == 0)

return bool(add_node_out) and bool(add_edge_out)

def _add_node(self, tensor_dict: TensorDict, batch_indices: torch.Tensor, nodes_to_add: torch.Tensor) -> TensorDict:

def _add_node(
self,
tensor_dict: TensorDict,
batch_indices: torch.Tensor,
nodes_to_add: torch.Tensor,
) -> TensorDict:
if isinstance(batch_indices, list):
batch_indices = torch.tensor(batch_indices)
if len(batch_indices) != len(nodes_to_add):
raise ValueError("Number of batch indices must match number of node feature lists")

raise ValueError(
"Number of batch indices must match number of node feature lists"
)

modified_dict = tensor_dict.clone()
node_feature_dim = modified_dict['node_feature'].shape[1]
node_feature_dim = modified_dict["node_feature"].shape[1]

for graph_idx, new_nodes in zip(batch_indices, nodes_to_add):
start_ptr = tensor_dict['batch_ptr'][graph_idx]
end_ptr = tensor_dict['batch_ptr'][graph_idx + 1]
tensor_dict["batch_ptr"][graph_idx]
end_ptr = tensor_dict["batch_ptr"][graph_idx + 1]

if new_nodes.ndim == 1:
new_nodes = new_nodes.unsqueeze(0)
if new_nodes.shape[1] != node_feature_dim:
raise ValueError(f"Node features must have dimension {node_feature_dim}")

raise ValueError(
f"Node features must have dimension {node_feature_dim}"
)

# Update batch pointers for subsequent graphs
shift = new_nodes.shape[0]
modified_dict['batch_ptr'][graph_idx + 1:] += shift
modified_dict["batch_ptr"][graph_idx + 1 :] += shift

# Expand node features
modified_dict['node_feature'] = torch.cat([
modified_dict['node_feature'][:end_ptr],
new_nodes,
modified_dict['node_feature'][end_ptr:]
])

modified_dict["node_feature"] = torch.cat(
[
modified_dict["node_feature"][:end_ptr],
new_nodes,
modified_dict["node_feature"][end_ptr:],
]
)

# Update edge indices
# Increment indices for edges after the current graph
edge_mask_0 = modified_dict['edge_index'][:, 0] >= end_ptr
edge_mask_1 = modified_dict['edge_index'][:, 1] >= end_ptr
modified_dict['edge_index'][edge_mask_0, 0] += shift
modified_dict['edge_index'][edge_mask_1, 1] += shift
edge_mask_0 = modified_dict["edge_index"][:, 0] >= end_ptr
edge_mask_1 = modified_dict["edge_index"][:, 1] >= end_ptr
modified_dict["edge_index"][edge_mask_0, 0] += shift
modified_dict["edge_index"][edge_mask_1, 1] += shift

return modified_dict

def reward(self, final_states: GraphStates) -> torch.Tensor:
Expand Down
15 changes: 12 additions & 3 deletions src/gfn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from gfn.actions import GraphActionType
from gfn.preprocessors import IdentityPreprocessor, Preprocessor
from gfn.states import DiscreteStates, GraphStates, States
from gfn.utils.distributions import CategoricalActionType, CategoricalIndexes, ComposedDistribution, UnsqueezedCategorical
from gfn.utils.distributions import (
CategoricalActionType,
CategoricalIndexes,
ComposedDistribution,
UnsqueezedCategorical,
)

REDUCTION_FXNS = {
"mean": torch.mean,
Expand Down Expand Up @@ -235,7 +240,9 @@ def forward(self, states: DiscreteStates) -> torch.Tensor:
Returns the output of the module, as a tensor of shape (*batch_shape, output_dim).
"""
out = super().forward(states)
assert out.shape[-1] == self.expected_output_dim, f"Expected output dim: {self.expected_output_dim}, got: {out.shape[-1]}"
assert (
out.shape[-1] == self.expected_output_dim
), f"Expected output dim: {self.expected_output_dim}, got: {out.shape[-1]}"
return out

def to_probability_distribution(
Expand Down Expand Up @@ -517,7 +524,9 @@ def to_probability_distribution(
dists["action_type"] = CategoricalActionType(probs=action_type_probs)

edge_index_logits = module_output["edge_index"]
if states.tensor["node_feature"].shape[0] > 1 and torch.any(edge_index_logits != -float("inf")):
if states.tensor["node_feature"].shape[0] > 1 and torch.any(
edge_index_logits != -float("inf")
):
edge_index_probs = torch.softmax(edge_index_logits / temperature, dim=-1)
uniform_dist_probs = (
torch.ones_like(edge_index_probs) / edge_index_probs.shape[-1]
Expand Down
2 changes: 1 addition & 1 deletion src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def sample_trajectories(

step = 0
all_estimator_outputs = []

while not all(dones):
actions = env.actions_from_batch_shape((n_trajectories,)) # Dummy actions.
log_probs = torch.full(
Expand Down
Loading

0 comments on commit f4fc3ab

Please sign in to comment.