From 3fbde19010ffaa3009a79c8874fb801f7fed9789 Mon Sep 17 00:00:00 2001 From: Aasheesh Singh Date: Tue, 20 Aug 2024 16:43:16 -0400 Subject: [PATCH] add GraphState class with State containing Timesteps as Batch --- src/gfn/states.py | 141 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 140 insertions(+), 1 deletion(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index f4fa1a20..167c59de 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -3,10 +3,11 @@ from abc import ABC from copy import deepcopy from math import prod -from typing import Callable, ClassVar, List, Optional, Sequence, cast +from typing import Callable, ClassVar, List, Optional, Sequence, cast, Tuple import torch from torchtyping import TensorType as TT +from torch_geometric.data import Batch, Data class States(ABC): @@ -492,3 +493,141 @@ def stack_states(states: List[States]): ) + state_example.batch_shape return stacked_states + + +class GraphStates(ABC): + """ + Base class for Graph as a state representation. The `GraphStates` object is a batched collection of + multiple graph objects. The `Batch` object from PyTorch Geometric is used to represent the batch of + graph objects as states. + """ + + + s0: ClassVar[Data] + sf: ClassVar[Data] + node_feature_dim: ClassVar[int] + edge_feature_dim: ClassVar[int] + make_random_states_graph: Callable = lambda x: (_ for _ in ()).throw( + NotImplementedError( + "The environment does not support initialization of random Graph states." + ) + ) + + def __init__(self, graphs: Batch): + self.data: Batch = graphs + self.batch_shape: int = self.data.num_graphs + self._log_rewards: float = None + + @classmethod + def from_batch_shape(cls, batch_shape: int, random: bool = False, sink: bool=False) -> GraphStates: + if random and sink: + raise ValueError("Only one of `random` and `sink` should be True.") + if random: + data = cls.make_random_states_graph(batch_shape) + elif sink: + data = cls.make_sink_states_graph(batch_shape) + else: + data = cls.make_initial_states_graph(batch_shape) + return cls(data) + + @classmethod + def make_initial_states_graph(cls, batch_shape: int) -> Batch: + data = Batch.from_data_list([cls.s0 for _ in range(batch_shape)]) + return data + + @classmethod + def make_sink_states_graph(cls, batch_shape: int) -> Batch: + data = Batch.from_data_list([cls.sf for _ in range(batch_shape)]) + return data + + @classmethod + def make_random_states_graph(cls, batch_shape: int) -> Batch: + data = Batch.from_data_list([cls.make_random_states_graph() for _ in range(batch_shape)]) + return data + + def __len__(self): + return self.data.batch_size + + def __repr__(self): + return (f"{self.__class__.__name__} object of batch shape {self.batch_shape} and " + f"node feature dim {self.node_feature_dim} and edge feature dim {self.edge_feature_dim}") + + def __getitem__(self, index: int | Sequence[int] | slice) -> GraphStates: + if isinstance(index, int): + out = self.__class__(Batch.from_data_list([self.data[index]])) + elif isinstance(index, (Sequence, slice)): + out = self.__class__(Batch.from_data_list(self.data.index_select(index))) + else: + raise NotImplementedError("Indexing with type {} is not implemented".format(type(index))) + + if self._log_rewards is not None: + out._log_rewards = self._log_rewards[index] + + return out + + def __setitem__(self, index: int | Sequence[int], graph: GraphStates): + """ + Set particular states of the Batch + """ + data_list = self.data.to_data_list() + if isinstance(index, int): + assert len(graph) == 1, "GraphStates must have a batch size of 1 for single index assignment" + data_list[index] = graph.data[0] + self.data = Batch.from_data_list(data_list) + elif isinstance(index, Sequence): + assert len(index) == len(graph), "Index and GraphState must have the same length" + for i, idx in enumerate(index): + data_list[idx] = graph.data[i] + self.data = Batch.from_data_list(data_list) + elif isinstance(index, slice): + assert index.stop - index.start == len(graph), "Index slice and GraphStates must have the same length" + data_list[index] = graph.data.to_data_list() + self.data = Batch.from_data_list(data_list) + else: + raise NotImplementedError("Setters with type {} is not implemented".format(type(index))) + + @property + def device(self) -> torch.device: + return self.data.get_example(0).x.device + + def to(self, device: torch.device) -> GraphStates: + """ + Moves and/or casts the graph states to the specified device + """ + if self.device != device: + self.data = self.data.to(device) + return self + + def clone(self) -> States: + """Returns a *detached* clone of the current instance using deepcopy.""" + return deepcopy(self) + + def extend(self, other: GraphStates): + """Concatenates to another GraphStates object along the batch dimension""" + self.data = Batch.from_data_list(self.data.to_data_list() + other.data.to_data_list()) + if self._log_rewards is not None: + assert other._log_rewards is not None + self._log_rewards = torch.cat( + (self._log_rewards, other._log_rewards), dim=0 + ) + + + @property + def log_rewards(self) -> TT["batch_shape", torch.float]: + return self._log_rewards + + @log_rewards.setter + def log_rewards(self, log_rewards: TT["batch_shape", torch.float]) -> None: + self._log_rewards = log_rewards + + + + + + + + + + + +