You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Below is the code snippet of MLPDecoder.
I think prediction is ended with Eq. 11 in the paper.
I can't find the code of Eq. 12.
Am I missing something in this code??
Thanks in advance.
def single_step_forward(self, single_timestep_inputs, rel_rec, rel_send,
single_timestep_rel_type):
# single_timestep_inputs has shape
# [batch_size, num_timesteps, num_atoms, num_dims]
# single_timestep_rel_type has shape:
# [batch_size, num_timesteps, num_atoms*(num_atoms-1), num_edge_types]
# Node2edge
receivers = torch.matmul(rel_rec, single_timestep_inputs)
senders = torch.matmul(rel_send, single_timestep_inputs)
# Eq 10 [x_i^t, x_j^t] [#sims(batch_size), #tsteps_indexed, #edges, #dims*2]
pre_msg = torch.cat([senders, receivers], dim=-1)
# self.msg_out_shape = #node_features
all_msgs = Variable(torch.zeros(pre_msg.size(0), pre_msg.size(1),
pre_msg.size(2), self.msg_out_shape))
if single_timestep_inputs.is_cuda:
all_msgs = all_msgs.cuda()
if self.skip_first_edge_type:
start_idx = 1
else:
start_idx = 0
# Run separate MLP for every edge type
# NOTE: To exlude one edge type, simply offset range by 1
# Eq 10 MLP
for i in range(start_idx, len(self.msg_fc2)):
msg = F.relu(self.msg_fc1[i](pre_msg))
msg = F.dropout(msg, p=self.dropout_prob)
msg = F.relu(self.msg_fc2[i](msg))
msg = msg * single_timestep_rel_type[:, :, :, i:i + 1] #element-wise product with broadcast
all_msgs += msg
# Aggregate all msgs to receiver
# Eq 11 / rel_rec [#edges, #nodes]
agg_msgs = all_msgs.transpose(-2, -1).matmul(rel_rec).transpose(-2, -1)
agg_msgs = agg_msgs.contiguous()
# Skip connection
aug_inputs = torch.cat([single_timestep_inputs, agg_msgs], dim=-1)
# Output MLP
pred = F.dropout(F.relu(self.out_fc1(aug_inputs)), p=self.dropout_prob)
pred = F.dropout(F.relu(self.out_fc2(pred)), p=self.dropout_prob)
pred = self.out_fc3(pred)
# Predict position/velocity difference / Eq 11 >> Where is Eq 12??
return single_timestep_inputs + pred
def forward(self, inputs, rel_type, rel_rec, rel_send, pred_steps=1):
# NOTE: Assumes that we have the same graph across all samples.
# Input shape: [num_sims, num_atoms, num_timesteps, num_dims] > [#sims, #tsteps, #nodes, #dims]
inputs = inputs.transpose(1, 2).contiguous()
sizes = [rel_type.size(0), inputs.size(1), rel_type.size(1),
rel_type.size(2)]
rel_type = rel_type.unsqueeze(1).expand(sizes)
time_steps = inputs.size(1)
assert (pred_steps <= time_steps)
preds = []
# Only take n-th timesteps as starting points (n: pred_steps)
last_pred = inputs[:, 0::pred_steps, :, :]
curr_rel_type = rel_type[:, 0::pred_steps, :, :]
# NOTE: Assumes rel_type is constant (i.e. same across all time steps).
# Run n prediction steps / Eq 10~11
for step in range(0, pred_steps):
last_pred = self.single_step_forward(last_pred, rel_rec, rel_send,
curr_rel_type)
preds.append(last_pred)
sizes = [preds[0].size(0), preds[0].size(1) * pred_steps,
preds[0].size(2), preds[0].size(3)]
output = Variable(torch.zeros(sizes))
if inputs.is_cuda:
output = output.cuda()
# Re-assemble correct timeline
for i in range(len(preds)):
output[:, i::pred_steps, :, :] = preds[i]
# last prediction is one step beyond input
pred_all = output[:, :(inputs.size(1) - 1), :, :]
return pred_all.transpose(1, 2).contiguous()
The text was updated successfully, but these errors were encountered:
Hi
I have maybe the same problem. I am looking for a way to predict new unseen trajectory steps but seems that something is missing.
Briefly, given x1, x2, ... xT with estimated edges and I would like to predict xT+1, ...xT+N.
As in eq 12, the x(t+1) follows a normal distribution with mean mu(t+1), did you decided to not draw the value of x(t+1) from this distribution ? Indeed we guess that the most probable value is mu(t+1)... so did you set x(t+1) = mu(t+1) ? In this case equation 12 can be skipped.
Concerning the prediction on unseen trajectories, do we have to call the MLPDecoder many times to iteratively predict all x(t+k) value using x(t).... from my point of its seems that it is the only solution. The idea here is to reproduce the non transparent lines on Figure 9.
Hi, @djoad001 how did you end up doing? I also don't see how to predict iteratively different timesteps when I don't have access to the ground truth (and thus can not do the teacher forcing every N timesteps). So I guess the solution is to modify the code to take into account that we are in test mode and use the last mu (or draw a sample from the distribution) to predict the next one, all of it iteratively until we get the desired number of timesteps ?
Below is the code snippet of MLPDecoder.
I think prediction is ended with Eq. 11 in the paper.
I can't find the code of Eq. 12.
Am I missing something in this code??
Thanks in advance.
The text was updated successfully, but these errors were encountered: