diff --git a/modules.py b/modules.py index b11c059..1f351a4 100644 --- a/modules.py +++ b/modules.py @@ -119,7 +119,7 @@ def node2edge(self, x, rel_rec, rel_send): # NOTE: Assumes that we have the same graph across all samples. receivers = torch.matmul(rel_rec, x) senders = torch.matmul(rel_send, x) - edges = torch.cat([receivers, senders], dim=2) + edges = torch.cat([senders, receivers], dim=2) return edges def forward(self, inputs, rel_rec, rel_send): @@ -191,7 +191,7 @@ def node2edge_temporal(self, inputs, rel_rec, rel_send): # receivers and senders have shape: # [num_sims * num_edges, num_dims, num_timesteps] - edges = torch.cat([receivers, senders], dim=1) + edges = torch.cat([senders, receivers], dim=1) return edges def edge2node(self, x, rel_rec, rel_send): @@ -203,7 +203,7 @@ def node2edge(self, x, rel_rec, rel_send): # NOTE: Assumes that we have the same graph across all samples. receivers = torch.matmul(rel_rec, x) senders = torch.matmul(rel_send, x) - edges = torch.cat([receivers, senders], dim=2) + edges = torch.cat([senders, receivers], dim=2) return edges def forward(self, inputs, rel_rec, rel_send): @@ -445,7 +445,7 @@ def single_step_forward(self, single_timestep_inputs, rel_rec, rel_send, # Node2edge receivers = torch.matmul(rel_rec, single_timestep_inputs) senders = torch.matmul(rel_send, single_timestep_inputs) - pre_msg = torch.cat([receivers, senders], dim=-1) + pre_msg = torch.cat([senders, receivers], dim=-1) all_msgs = Variable(torch.zeros(pre_msg.size(0), pre_msg.size(1), pre_msg.size(2), self.msg_out_shape)) @@ -556,7 +556,7 @@ def single_step_forward(self, inputs, rel_rec, rel_send, # node2edge receivers = torch.matmul(rel_rec, hidden) senders = torch.matmul(rel_send, hidden) - pre_msg = torch.cat([receivers, senders], dim=-1) + pre_msg = torch.cat([senders, receivers], dim=-1) all_msgs = Variable(torch.zeros(pre_msg.size(0), pre_msg.size(1), self.msg_out_shape)) diff --git a/train.py b/train.py index ba44aab..ab321bb 100644 --- a/train.py +++ b/train.py @@ -109,8 +109,8 @@ # Generate off-diagonal interaction graph off_diag = np.ones([args.num_atoms, args.num_atoms]) - np.eye(args.num_atoms) -rel_rec = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32) -rel_send = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32) +rel_rec = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32) +rel_send = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32) rel_rec = torch.FloatTensor(rel_rec) rel_send = torch.FloatTensor(rel_send) diff --git a/train_dec.py b/train_dec.py index 16d23bb..9d32aec 100644 --- a/train_dec.py +++ b/train_dec.py @@ -95,8 +95,8 @@ # Generate fully-connected interaction graph (sparse graphs would also work) off_diag = np.ones([args.num_atoms, args.num_atoms]) - np.eye(args.num_atoms) -rel_rec = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32) -rel_send = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32) +rel_rec = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32) +rel_send = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32) rel_rec = torch.FloatTensor(rel_rec) rel_send = torch.FloatTensor(rel_send) diff --git a/train_enc.py b/train_enc.py index 1469c75..57f7642 100644 --- a/train_enc.py +++ b/train_enc.py @@ -89,8 +89,8 @@ # Generate off-diagonal interaction graph off_diag = np.ones([args.num_atoms, args.num_atoms]) - np.eye(args.num_atoms) -rel_rec = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32) -rel_send = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32) +rel_rec = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32) +rel_send = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32) rel_rec = torch.FloatTensor(rel_rec) rel_send = torch.FloatTensor(rel_send)