Skip to content

Commit

Permalink
add device to init args
Browse files Browse the repository at this point in the history
  • Loading branch information
Rose Lightheart committed Oct 10, 2022
1 parent aca8c42 commit 15d3cc0
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions hgraph/inc_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def add_edge(self, i, j, feature=None):

class IncTree(IncBase):

def __init__(self, batch_size, node_fdim, edge_fdim, max_nodes=100, max_edges=200, max_nb=12, max_sub_nodes=20):
super(IncTree, self).__init__(batch_size, node_fdim, edge_fdim, max_nodes, max_edges, max_nb)
def __init__(self, batch_size, node_fdim, edge_fdim, max_nodes=100, max_edges=200, max_nb=12, max_sub_nodes=20, device=''):
super(IncTree, self).__init__(batch_size, node_fdim, edge_fdim, max_nodes, max_edges, max_nb, device=device)
self.cgraph = self.fnode.new_zeros(max_nodes * batch_size, max_sub_nodes)

def get_tensors(self):
Expand Down Expand Up @@ -88,8 +88,8 @@ def get_cluster_edges(self, node_list):

class IncGraph(IncBase):

def __init__(self, avocab, batch_size, node_fdim, edge_fdim, max_nodes=100, max_edges=300, max_nb=10):
super(IncGraph, self).__init__(batch_size, node_fdim, edge_fdim, max_nodes, max_edges, max_nb)
def __init__(self, avocab, batch_size, node_fdim, edge_fdim, max_nodes=100, max_edges=300, max_nb=10, device=''):
super(IncGraph, self).__init__(batch_size, node_fdim, edge_fdim, max_nodes, max_edges, max_nb, device=device)
self.avocab = avocab
self.mol = Chem.RWMol()
self.mol.AddAtom( Chem.Atom('C') ) #make sure node is 1 index, consistent to self.graph
Expand Down

0 comments on commit 15d3cc0

Please sign in to comment.