From 81094a5b362e8cfa80fb5775883dec8deb87aa3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sindy=20L=C3=B6we?= Date: Tue, 19 Nov 2019 10:42:21 +0100 Subject: [PATCH 1/5] Update train.py --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index ba44aab..ab321bb 100644 --- a/train.py +++ b/train.py @@ -109,8 +109,8 @@ # Generate off-diagonal interaction graph off_diag = np.ones([args.num_atoms, args.num_atoms]) - np.eye(args.num_atoms) -rel_rec = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32) -rel_send = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32) +rel_rec = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32) +rel_send = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32) rel_rec = torch.FloatTensor(rel_rec) rel_send = torch.FloatTensor(rel_send) From a0f4dfbe9c35f603da3ae858c80916a267997e05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sindy=20L=C3=B6we?= Date: Tue, 19 Nov 2019 10:46:15 +0100 Subject: [PATCH 2/5] Update modules.py --- modules.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modules.py b/modules.py index b11c059..a7aba10 100644 --- a/modules.py +++ b/modules.py @@ -119,7 +119,7 @@ def node2edge(self, x, rel_rec, rel_send): # NOTE: Assumes that we have the same graph across all samples. receivers = torch.matmul(rel_rec, x) senders = torch.matmul(rel_send, x) - edges = torch.cat([receivers, senders], dim=2) + edges = torch.cat([senders, receivers], dim=2) return edges def forward(self, inputs, rel_rec, rel_send): @@ -191,7 +191,7 @@ def node2edge_temporal(self, inputs, rel_rec, rel_send): # receivers and senders have shape: # [num_sims * num_edges, num_dims, num_timesteps] - edges = torch.cat([receivers, senders], dim=1) + edges = torch.cat([senders, receivers], dim=1) return edges def edge2node(self, x, rel_rec, rel_send): @@ -203,7 +203,7 @@ def node2edge(self, x, rel_rec, rel_send): # NOTE: Assumes that we have the same graph across all samples. receivers = torch.matmul(rel_rec, x) senders = torch.matmul(rel_send, x) - edges = torch.cat([receivers, senders], dim=2) + edges = torch.cat([senders, receivers], dim=2) return edges def forward(self, inputs, rel_rec, rel_send): @@ -445,7 +445,7 @@ def single_step_forward(self, single_timestep_inputs, rel_rec, rel_send, # Node2edge receivers = torch.matmul(rel_rec, single_timestep_inputs) senders = torch.matmul(rel_send, single_timestep_inputs) - pre_msg = torch.cat([receivers, senders], dim=-1) + pre_msg = torch.cat([senders, receivers], dim=-1) all_msgs = Variable(torch.zeros(pre_msg.size(0), pre_msg.size(1), pre_msg.size(2), self.msg_out_shape)) From a00e33dc15fb66773815a0139a41d2afa8507f81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sindy=20L=C3=B6we?= Date: Tue, 19 Nov 2019 10:47:36 +0100 Subject: [PATCH 3/5] Update modules.py --- modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules.py b/modules.py index a7aba10..1f351a4 100644 --- a/modules.py +++ b/modules.py @@ -556,7 +556,7 @@ def single_step_forward(self, inputs, rel_rec, rel_send, # node2edge receivers = torch.matmul(rel_rec, hidden) senders = torch.matmul(rel_send, hidden) - pre_msg = torch.cat([receivers, senders], dim=-1) + pre_msg = torch.cat([senders, receivers], dim=-1) all_msgs = Variable(torch.zeros(pre_msg.size(0), pre_msg.size(1), self.msg_out_shape)) From 6a461eda2583336a0afcde91d0385e92e300988a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sindy=20L=C3=B6we?= Date: Tue, 19 Nov 2019 10:51:42 +0100 Subject: [PATCH 4/5] Update train_enc.py --- train_enc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_enc.py b/train_enc.py index 1469c75..57f7642 100644 --- a/train_enc.py +++ b/train_enc.py @@ -89,8 +89,8 @@ # Generate off-diagonal interaction graph off_diag = np.ones([args.num_atoms, args.num_atoms]) - np.eye(args.num_atoms) -rel_rec = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32) -rel_send = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32) +rel_rec = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32) +rel_send = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32) rel_rec = torch.FloatTensor(rel_rec) rel_send = torch.FloatTensor(rel_send) From 415ef47e58b56aa364701e82b69dd62c85711f80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sindy=20L=C3=B6we?= Date: Tue, 19 Nov 2019 10:52:25 +0100 Subject: [PATCH 5/5] Update train_dec.py --- train_dec.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_dec.py b/train_dec.py index 16d23bb..9d32aec 100644 --- a/train_dec.py +++ b/train_dec.py @@ -95,8 +95,8 @@ # Generate fully-connected interaction graph (sparse graphs would also work) off_diag = np.ones([args.num_atoms, args.num_atoms]) - np.eye(args.num_atoms) -rel_rec = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32) -rel_send = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32) +rel_rec = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32) +rel_send = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32) rel_rec = torch.FloatTensor(rel_rec) rel_send = torch.FloatTensor(rel_send)