Skip to content

Commit

Permalink
Merge pull request #20 from loeweX/master
Browse files Browse the repository at this point in the history
Inconsistent usage of receiver/sender
  • Loading branch information
tkipf authored Nov 19, 2019
2 parents 4915806 + 415ef47 commit e63fcb0
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
10 changes: 5 additions & 5 deletions modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def node2edge(self, x, rel_rec, rel_send):
# NOTE: Assumes that we have the same graph across all samples.
receivers = torch.matmul(rel_rec, x)
senders = torch.matmul(rel_send, x)
edges = torch.cat([receivers, senders], dim=2)
edges = torch.cat([senders, receivers], dim=2)
return edges

def forward(self, inputs, rel_rec, rel_send):
Expand Down Expand Up @@ -191,7 +191,7 @@ def node2edge_temporal(self, inputs, rel_rec, rel_send):

# receivers and senders have shape:
# [num_sims * num_edges, num_dims, num_timesteps]
edges = torch.cat([receivers, senders], dim=1)
edges = torch.cat([senders, receivers], dim=1)
return edges

def edge2node(self, x, rel_rec, rel_send):
Expand All @@ -203,7 +203,7 @@ def node2edge(self, x, rel_rec, rel_send):
# NOTE: Assumes that we have the same graph across all samples.
receivers = torch.matmul(rel_rec, x)
senders = torch.matmul(rel_send, x)
edges = torch.cat([receivers, senders], dim=2)
edges = torch.cat([senders, receivers], dim=2)
return edges

def forward(self, inputs, rel_rec, rel_send):
Expand Down Expand Up @@ -445,7 +445,7 @@ def single_step_forward(self, single_timestep_inputs, rel_rec, rel_send,
# Node2edge
receivers = torch.matmul(rel_rec, single_timestep_inputs)
senders = torch.matmul(rel_send, single_timestep_inputs)
pre_msg = torch.cat([receivers, senders], dim=-1)
pre_msg = torch.cat([senders, receivers], dim=-1)

all_msgs = Variable(torch.zeros(pre_msg.size(0), pre_msg.size(1),
pre_msg.size(2), self.msg_out_shape))
Expand Down Expand Up @@ -556,7 +556,7 @@ def single_step_forward(self, inputs, rel_rec, rel_send,
# node2edge
receivers = torch.matmul(rel_rec, hidden)
senders = torch.matmul(rel_send, hidden)
pre_msg = torch.cat([receivers, senders], dim=-1)
pre_msg = torch.cat([senders, receivers], dim=-1)

all_msgs = Variable(torch.zeros(pre_msg.size(0), pre_msg.size(1),
self.msg_out_shape))
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@
# Generate off-diagonal interaction graph
off_diag = np.ones([args.num_atoms, args.num_atoms]) - np.eye(args.num_atoms)

rel_rec = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32)
rel_send = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32)
rel_rec = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32)
rel_send = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32)
rel_rec = torch.FloatTensor(rel_rec)
rel_send = torch.FloatTensor(rel_send)

Expand Down
4 changes: 2 additions & 2 deletions train_dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@
# Generate fully-connected interaction graph (sparse graphs would also work)
off_diag = np.ones([args.num_atoms, args.num_atoms]) - np.eye(args.num_atoms)

rel_rec = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32)
rel_send = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32)
rel_rec = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32)
rel_send = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32)
rel_rec = torch.FloatTensor(rel_rec)
rel_send = torch.FloatTensor(rel_send)

Expand Down
4 changes: 2 additions & 2 deletions train_enc.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@
# Generate off-diagonal interaction graph
off_diag = np.ones([args.num_atoms, args.num_atoms]) - np.eye(args.num_atoms)

rel_rec = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32)
rel_send = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32)
rel_rec = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32)
rel_send = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32)
rel_rec = torch.FloatTensor(rel_rec)
rel_send = torch.FloatTensor(rel_send)

Expand Down

0 comments on commit e63fcb0

Please sign in to comment.