Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for more datasets #5

Open
zzzzzx-1115 opened this issue Nov 22, 2023 · 5 comments
Open

Support for more datasets #5

zzzzzx-1115 opened this issue Nov 22, 2023 · 5 comments

Comments

@zzzzzx-1115
Copy link

Hello, I really like your awesome work!

It seems that the released codes contain only the part for Moving MNIST. Could you please share more codes (like dataset.py and training arguments for TaxiBJ) which are certain to be helpful for me to reproduce you results?

Thanks for your help!

@SongTang-x
Copy link
Owner

Thank you for your appreciation of our work. Regarding TaxiBJ:
0.Download the original dataset from here: https://gitee.com/arislee/taxi-bj
1.For the preprocessing, please refer to https://github.com/Yunbo426/MIM/blob/master/src/data_provider/taxibj.py.
2.Key training configurations are as follows:
patch_size = 4
depths = 12
heads_number = [8]
window_size = 8
drop_rate = 0.
attn_drop_rate = 0.
drop_path_rate = 0.1
batch_size = 4
lr = 1e-4

@zzzzzx-1115
Copy link
Author

Great! And I would really appreaciate it if you could share training configs for Human3.6M and KTH. BTW, how many epochs should I train SwinLSTM for the four datasets mentioned in the paper?

Thanks!

@54wb
Copy link

54wb commented Dec 15, 2023

Great! And I would really appreaciate it if you could share training configs for Human3.6M and KTH. BTW, how many epochs should I train SwinLSTM for the four datasets mentioned in the paper?

Thanks!

hi~bro, did you attempt to run TaxiBJ dataset, i noticed if patch_size==4, the input will be downsample 4 times, but PatchInflated only upsample 2 times for input, this will meet an error, do you meet this qusetion? Thanks very much.

@zzzzzx-1115
Copy link
Author

Great! And I would really appreaciate it if you could share training configs for Human3.6M and KTH. BTW, how many epochs should I train SwinLSTM for the four datasets mentioned in the paper?
Thanks!

hi~bro, did you attempt to run TaxiBJ dataset, i noticed if patch_size==4, the input will be downsample 4 times, but PatchInflated only upsample 2 times for input, this will meet an error, do you meet this qusetion? Thanks very much.

Hello! In fact, I did meet the same problem, so I just fix it myself by modifying the transposed convolution used in the PatchInflated function. I think the upsample scale should be dependent on the patchsize.

@SongTang-x
Copy link
Owner

Hi, here are the more details regarding the training configuration and the PatchInflated layer:
0. Human3.6m:
patch_size = 2
depths = 12
heads_number = [8]
window_size = 4
drop_rate = 0.05
attn_drop_rate = 0.
drop_path_rate = 0.1
batch_size = 8
1.KTH:
patch_size = 4
depths = 6
heads_number = [8]
window_size = 4
drop_rate = 0.
attn_drop_rate = 0.
drop_path_rate = 0.1
batch_size = 8
2.Suggested epoch numbers for four datasets:
MMNIST: 2000
Human3.6m: 100
KTH: 100
TaxiBJ: 100
3.For the PatchInflated Layer:
With patch_size set to 4, add an extra layer of transposed convolution in PatchInflated. For larger values, continue adding layers as needed. The code is as follows:
self.Conv_ = nn.Sequential(
nn.ConvTranspose2d(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(3, 3), stride=stride, padding=padding, output_padding=output_padding),
nn.GroupNorm(16, embed_dim),
nn.LeakyReLU(0.2, inplace=True)
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants