Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding time embedding for score / flow matching #118

Open
ListIndexOutOfRange opened this issue Dec 10, 2024 · 1 comment
Open

Adding time embedding for score / flow matching #118

ListIndexOutOfRange opened this issue Dec 10, 2024 · 1 comment

Comments

@ListIndexOutOfRange
Copy link

Hey ! Thank you for your great work ! The code is of great quality, I truly appreciate that !

I'd like to use this arch as a U-Net like model for point cloud generation using score or flow matching. Hence, I'd like to add a time embedding. There are many ways to do so, and I wanted to have your opinion.

Currently, I'm doing it inside the pooling layers. It looks like that, inside the SerializedPooling forward method:

h = self.proj(point.feat)[indices]
if 'time_embedding' in point and self.time_proj is not None:
    time_embedding_proj = self.time_proj(point.time_embedding)[indices]
    h = h + time_embedding_proj
    point_dict = Dict(
            # instead of:
            # feat=torch_scatter.segment_csr(self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce),
            feat=torch_scatter.segment_csr(h, idx_ptr, reduce=self.reduce),
            ...

where time_proj is a small sequential:

self.proj = nn.Linear(in_channels, out_channels)
self.time_proj = nn.Linear(time_embedding_dim, out_channels)

and time_embedding is the output of a small MLP passed through before the actual encoder, defined in the PTV3 init as:

self.embed_time = nn.Sequential(
    nn.Linear(enc_channels[0], time_embedding_dim),
    act_layer(),
    nn.Linear(time_embedding_dim, time_embedding_dim),
)

Does it look good to you ? I'd really appreciate your feedback.

Thanks again !

@Gofinge
Copy link
Member

Gofinge commented Jan 14, 2025

Hi, sorry for the late response. I was super busy in the past half year. Back to your question, I think injecting time embedding directly before pooling, which is relatively weak in reasoning compared with attention, might cause some issues. For example, assume 2 point in the same grid have 2 different time embedding, let's assume 1, 3 as examples. However, after mean pooling, the time embedding would be 2, which is strange. Here is two potential solutions:

  1. Only inject time embedding before the attention, let attention handle this thing.
  2. Separately pooling for points belongs to different time stamps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants