Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Where can I find the code of Eq. 12 in the paper?? #30

Open
hoon0528 opened this issue Mar 10, 2021 · 2 comments
Open

Where can I find the code of Eq. 12 in the paper?? #30

hoon0528 opened this issue Mar 10, 2021 · 2 comments

Comments

@hoon0528
Copy link

Below is the code snippet of MLPDecoder.
I think prediction is ended with Eq. 11 in the paper.
I can't find the code of Eq. 12.
Am I missing something in this code??

Thanks in advance.

    def single_step_forward(self, single_timestep_inputs, rel_rec, rel_send,
                            single_timestep_rel_type):

        # single_timestep_inputs has shape
        # [batch_size, num_timesteps, num_atoms, num_dims]

        # single_timestep_rel_type has shape:
        # [batch_size, num_timesteps, num_atoms*(num_atoms-1), num_edge_types]

        # Node2edge 
        receivers = torch.matmul(rel_rec, single_timestep_inputs)
        senders = torch.matmul(rel_send, single_timestep_inputs)
        # Eq 10 [x_i^t, x_j^t] [#sims(batch_size), #tsteps_indexed, #edges, #dims*2]
        pre_msg = torch.cat([senders, receivers], dim=-1)
        # self.msg_out_shape = #node_features
        all_msgs = Variable(torch.zeros(pre_msg.size(0), pre_msg.size(1),
                                        pre_msg.size(2), self.msg_out_shape))
        if single_timestep_inputs.is_cuda:
            all_msgs = all_msgs.cuda()

        if self.skip_first_edge_type:
            start_idx = 1
        else:
            start_idx = 0

        # Run separate MLP for every edge type
        # NOTE: To exlude one edge type, simply offset range by 1
        # Eq 10 MLP
        for i in range(start_idx, len(self.msg_fc2)):
            msg = F.relu(self.msg_fc1[i](pre_msg))
            msg = F.dropout(msg, p=self.dropout_prob)
            msg = F.relu(self.msg_fc2[i](msg))
            msg = msg * single_timestep_rel_type[:, :, :, i:i + 1] #element-wise product with broadcast
            all_msgs += msg

        # Aggregate all msgs to receiver
        # Eq 11 / rel_rec [#edges, #nodes]
        agg_msgs = all_msgs.transpose(-2, -1).matmul(rel_rec).transpose(-2, -1)
        agg_msgs = agg_msgs.contiguous()

        # Skip connection
        aug_inputs = torch.cat([single_timestep_inputs, agg_msgs], dim=-1)

        # Output MLP
        pred = F.dropout(F.relu(self.out_fc1(aug_inputs)), p=self.dropout_prob)
        pred = F.dropout(F.relu(self.out_fc2(pred)), p=self.dropout_prob)
        pred = self.out_fc3(pred)

        # Predict position/velocity difference / Eq 11 >> Where is Eq 12??
        return single_timestep_inputs + pred

    def forward(self, inputs, rel_type, rel_rec, rel_send, pred_steps=1):
        # NOTE: Assumes that we have the same graph across all samples.
        # Input shape: [num_sims, num_atoms, num_timesteps, num_dims] > [#sims, #tsteps, #nodes, #dims]
        inputs = inputs.transpose(1, 2).contiguous()

        sizes = [rel_type.size(0), inputs.size(1), rel_type.size(1),
                 rel_type.size(2)]
        rel_type = rel_type.unsqueeze(1).expand(sizes)

        time_steps = inputs.size(1)
        assert (pred_steps <= time_steps)
        preds = []

        # Only take n-th timesteps as starting points (n: pred_steps)
        last_pred = inputs[:, 0::pred_steps, :, :]
        curr_rel_type = rel_type[:, 0::pred_steps, :, :]
        # NOTE: Assumes rel_type is constant (i.e. same across all time steps).

        # Run n prediction steps / Eq 10~11
        for step in range(0, pred_steps):
            last_pred = self.single_step_forward(last_pred, rel_rec, rel_send,
                                                 curr_rel_type)
            preds.append(last_pred)

        sizes = [preds[0].size(0), preds[0].size(1) * pred_steps,
                 preds[0].size(2), preds[0].size(3)]

        output = Variable(torch.zeros(sizes))
        if inputs.is_cuda:
            output = output.cuda()

        # Re-assemble correct timeline
        for i in range(len(preds)):
            output[:, i::pred_steps, :, :] = preds[i]
        # last prediction is one step beyond input
        pred_all = output[:, :(inputs.size(1) - 1), :, :]

        return pred_all.transpose(1, 2).contiguous()
@djoad001
Copy link

djoad001 commented Mar 11, 2021

Hi
I have maybe the same problem. I am looking for a way to predict new unseen trajectory steps but seems that something is missing.

Briefly, given x1, x2, ... xT with estimated edges and I would like to predict xT+1, ...xT+N.

  • As in eq 12, the x(t+1) follows a normal distribution with mean mu(t+1), did you decided to not draw the value of x(t+1) from this distribution ? Indeed we guess that the most probable value is mu(t+1)... so did you set x(t+1) = mu(t+1) ? In this case equation 12 can be skipped.

  • Concerning the prediction on unseen trajectories, do we have to call the MLPDecoder many times to iteratively predict all x(t+k) value using x(t).... from my point of its seems that it is the only solution. The idea here is to reproduce the non transparent lines on Figure 9.

Thank a lot for helping.

@paulinesert
Copy link

Hi,
@djoad001 how did you end up doing? I also don't see how to predict iteratively different timesteps when I don't have access to the ground truth (and thus can not do the teacher forcing every N timesteps). So I guess the solution is to modify the code to take into account that we are in test mode and use the last mu (or draw a sample from the distribution) to predict the next one, all of it iteratively until we get the desired number of timesteps ?

Thanks for your help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants