Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for including Graphs as States of GFlowNet #183

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 140 additions & 1 deletion src/gfn/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from abc import ABC
from copy import deepcopy
from math import prod
from typing import Callable, ClassVar, List, Optional, Sequence, cast
from typing import Callable, ClassVar, List, Optional, Sequence, cast, Tuple

import torch
from torchtyping import TensorType as TT
from torch_geometric.data import Batch, Data


class States(ABC):
Expand Down Expand Up @@ -492,3 +493,141 @@ def stack_states(states: List[States]):
) + state_example.batch_shape

return stacked_states


class GraphStates(ABC):
"""
Base class for Graph as a state representation. The `GraphStates` object is a batched collection of
multiple graph objects. The `Batch` object from PyTorch Geometric is used to represent the batch of
graph objects as states.
"""


s0: ClassVar[Data]
sf: ClassVar[Data]
node_feature_dim: ClassVar[int]
edge_feature_dim: ClassVar[int]
make_random_states_graph: Callable = lambda x: (_ for _ in ()).throw(
NotImplementedError(
"The environment does not support initialization of random Graph states."
)
)

def __init__(self, graphs: Batch):
self.data: Batch = graphs
self.batch_shape: int = self.data.num_graphs
self._log_rewards: float = None

@classmethod
def from_batch_shape(cls, batch_shape: int, random: bool = False, sink: bool=False) -> GraphStates:
if random and sink:
raise ValueError("Only one of `random` and `sink` should be True.")
if random:
data = cls.make_random_states_graph(batch_shape)
elif sink:
data = cls.make_sink_states_graph(batch_shape)
else:
data = cls.make_initial_states_graph(batch_shape)
return cls(data)

@classmethod
def make_initial_states_graph(cls, batch_shape: int) -> Batch:
data = Batch.from_data_list([cls.s0 for _ in range(batch_shape)])
return data

@classmethod
def make_sink_states_graph(cls, batch_shape: int) -> Batch:
data = Batch.from_data_list([cls.sf for _ in range(batch_shape)])
return data

@classmethod
def make_random_states_graph(cls, batch_shape: int) -> Batch:
data = Batch.from_data_list([cls.make_random_states_graph() for _ in range(batch_shape)])
return data

def __len__(self):
return self.data.batch_size

def __repr__(self):
return (f"{self.__class__.__name__} object of batch shape {self.batch_shape} and "
f"node feature dim {self.node_feature_dim} and edge feature dim {self.edge_feature_dim}")

def __getitem__(self, index: int | Sequence[int] | slice) -> GraphStates:
if isinstance(index, int):
out = self.__class__(Batch.from_data_list([self.data[index]]))
elif isinstance(index, (Sequence, slice)):
out = self.__class__(Batch.from_data_list(self.data.index_select(index)))
else:
raise NotImplementedError("Indexing with type {} is not implemented".format(type(index)))

if self._log_rewards is not None:
out._log_rewards = self._log_rewards[index]

return out

def __setitem__(self, index: int | Sequence[int], graph: GraphStates):
"""
Set particular states of the Batch
"""
data_list = self.data.to_data_list()
if isinstance(index, int):
assert len(graph) == 1, "GraphStates must have a batch size of 1 for single index assignment"
data_list[index] = graph.data[0]
self.data = Batch.from_data_list(data_list)
elif isinstance(index, Sequence):
assert len(index) == len(graph), "Index and GraphState must have the same length"
for i, idx in enumerate(index):
data_list[idx] = graph.data[i]
self.data = Batch.from_data_list(data_list)
elif isinstance(index, slice):
assert index.stop - index.start == len(graph), "Index slice and GraphStates must have the same length"
data_list[index] = graph.data.to_data_list()
self.data = Batch.from_data_list(data_list)
else:
raise NotImplementedError("Setters with type {} is not implemented".format(type(index)))

@property
def device(self) -> torch.device:
return self.data.get_example(0).x.device

def to(self, device: torch.device) -> GraphStates:
"""
Moves and/or casts the graph states to the specified device
"""
if self.device != device:
self.data = self.data.to(device)
return self

def clone(self) -> States:
"""Returns a *detached* clone of the current instance using deepcopy."""
return deepcopy(self)

def extend(self, other: GraphStates):
"""Concatenates to another GraphStates object along the batch dimension"""
self.data = Batch.from_data_list(self.data.to_data_list() + other.data.to_data_list())
if self._log_rewards is not None:
assert other._log_rewards is not None
self._log_rewards = torch.cat(
(self._log_rewards, other._log_rewards), dim=0
)


@property
def log_rewards(self) -> TT["batch_shape", torch.float]:
return self._log_rewards

@log_rewards.setter
def log_rewards(self, log_rewards: TT["batch_shape", torch.float]) -> None:
self._log_rewards = log_rewards












Loading