Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random projector? #1

Closed
yoshitomo-matsubara opened this issue Feb 5, 2024 · 28 comments
Closed

Random projector? #1

yoshitomo-matsubara opened this issue Feb 5, 2024 · 28 comments

Comments

@yoshitomo-matsubara
Copy link

Hi @roymiles

Congratulations for the paper acceptance!

For the ImageNet experiment, self.embed in OurDistillationLoss class looks like a random projector for ResNet-18's embeddings and seems not updated as it's not included in optimizer. Is it intentional? If so, why is it required?
https://github.com/roymiles/Simple-Recipe-Distillation/blob/main/imagenet/torchdistill/losses/single.py#L140

@roymiles
Copy link
Owner

roymiles commented Feb 5, 2024

Hi Yoshitomo!,

Thank you :) and yes that must have been an oversight when re-writing the code and putting it in this repo. Unfortunately, I only put time into checking the deit/ code and experiments before putting this up. I will look into fixing the resnet18 imagenet code this week and re-running the experiments. Thanks for spotting this! I will reply back on this issue then.

@roymiles
Copy link
Owner

roymiles commented Feb 8, 2024

I have fixed the the repo now. The line of interest is here:

trainable_module_list.append(self.criterion.term_dict['it_loss'][0].embed)

It is a bit slow training with just a single V100, so I figured I would give an update on the progress. Both runs are about 40% done and, as expected, the trainable projector leads to much better performance. I will upload the complete and final logs and also the checkpoints when they are both done.

Frozen Projector

2024-02-08 09:28:26,407 INFO torchdistill.misc.log Epoch: [40] [4900/5005] eta: 0:00:36 lr: 0.010000000000000002 img/s: 749.388497121768 loss: 1.6882 (1.6121) time: 0.3448 data: 0.0004 max mem: 9402
2024-02-08 09:28:43,604 INFO torchdistill.misc.log Epoch: [40] [4950/5005] eta: 0:00:18 lr: 0.010000000000000002 img/s: 754.2297413578135 loss: 1.6186 (1.6125) time: 0.3424 data: 0.0004 max mem: 9402
2024-02-08 09:29:00,909 INFO torchdistill.misc.log Epoch: [40] [5000/5005] eta: 0:00:01 lr: 0.010000000000000002 img/s: 741.8443645307605 loss: 1.5608 (1.6128) time: 0.3438 data: 0.0001 max mem: 9402
2024-02-08 09:29:02,451 INFO torchdistill.misc.log Epoch: [40] Total time: 0:28:48
2024-02-08 09:29:05,346 INFO torchdistill.misc.log Validation: [ 0/391] eta: 0:18:50 acc1: 85.9375 (85.9375) acc5: 94.5312 (94.5312) time: 2.8921 data: 2.8425 max mem: 9402
2024-02-08 09:29:08,957 INFO torchdistill.misc.log Validation: [ 50/391] eta: 0:00:43 acc1: 76.5625 (72.0435) acc5: 92.1875 (90.3799) time: 0.0847 data: 0.0391 max mem: 9402
2024-02-08 09:29:12,496 INFO torchdistill.misc.log Validation: [100/391] eta: 0:00:28 acc1: 71.0938 (71.5424) acc5: 93.7500 (91.6538) time: 0.0664 data: 0.0211 max mem: 9402
2024-02-08 09:29:16,595 INFO torchdistill.misc.log Validation: [150/391] eta: 0:00:22 acc1: 67.1875 (71.2955) acc5: 91.4062 (91.7736) time: 0.0739 data: 0.0278 max mem: 9402
2024-02-08 09:29:20,473 INFO torchdistill.misc.log Validation: [200/391] eta: 0:00:17 acc1: 51.5625 (68.2369) acc5: 79.6875 (89.7466) time: 0.0681 data: 0.0226 max mem: 9402
2024-02-08 09:29:23,996 INFO torchdistill.misc.log Validation: [250/391] eta: 0:00:12 acc1: 63.2812 (66.6802) acc5: 84.3750 (88.4120) time: 0.0645 data: 0.0190 max mem: 9402
2024-02-08 09:29:27,650 INFO torchdistill.misc.log Validation: [300/391] eta: 0:00:07 acc1: 58.5938 (65.0722) acc5: 81.2500 (87.1366) time: 0.0624 data: 0.0163 max mem: 9402
2024-02-08 09:29:31,231 INFO torchdistill.misc.log Validation: [350/391] eta: 0:00:03 acc1: 57.0312 (63.9312) acc5: 83.5938 (86.2669) time: 0.0761 data: 0.0305 max mem: 9402
2024-02-08 09:29:34,561 INFO torchdistill.misc.log Validation: Total time: 0:00:32
2024-02-08 09:29:34,561 INFO main * Acc@1 63.9700 Acc@5 86.2680

Trainable Projector

2024-02-08 08:57:20,154 INFO torchdistill.misc.log Epoch: [39] [4750/5005] eta: 0:01:27 lr: 0.010000000000000002 img/s: 746.3973317873493 loss: -0.3715 (-0.4433) time: 0.3440 data: 0.0003 max mem: 9402
2024-02-08 08:57:37,386 INFO torchdistill.misc.log Epoch: [39] [4800/5005] eta: 0:01:10 lr: 0.010000000000000002 img/s: 746.8458860041928 loss: -0.4441 (-0.4437) time: 0.3449 data: 0.0004 max mem: 9402
2024-02-08 08:57:54,584 INFO torchdistill.misc.log Epoch: [39] [4850/5005] eta: 0:00:53 lr: 0.010000000000000002 img/s: 748.0040237523999 loss: -0.4193 (-0.4434) time: 0.3434 data: 0.0004 max mem: 9402
2024-02-08 08:58:11,780 INFO torchdistill.misc.log Epoch: [39] [4900/5005] eta: 0:00:36 lr: 0.010000000000000002 img/s: 747.6076887031876 loss: -0.4346 (-0.4434) time: 0.3439 data: 0.0003 max mem: 9402
2024-02-08 08:58:29,019 INFO torchdistill.misc.log Epoch: [39] [4950/5005] eta: 0:00:18 lr: 0.010000000000000002 img/s: 747.4021349935126 loss: -0.4108 (-0.4430) time: 0.3452 data: 0.0003 max mem: 9402
2024-02-08 08:58:46,330 INFO torchdistill.misc.log Epoch: [39] [5000/5005] eta: 0:00:01 lr: 0.010000000000000002 img/s: 747.1722569610195 loss: -0.4533 (-0.4429) time: 0.3436 data: 0.0001 max mem: 9402
2024-02-08 08:58:47,852 INFO torchdistill.misc.log Epoch: [39] Total time: 0:28:46
2024-02-08 08:58:50,845 INFO torchdistill.misc.log Validation: [ 0/391] eta: 0:19:28 acc1: 79.6875 (79.6875) acc5: 92.9688 (92.9688) time: 2.9890 data: 2.9408 max mem: 9402
2024-02-08 08:58:54,453 INFO torchdistill.misc.log Validation: [ 50/391] eta: 0:00:44 acc1: 77.3438 (75.5974) acc5: 93.7500 (92.4173) time: 0.0866 data: 0.0409 max mem: 9402
2024-02-08 08:58:57,784 INFO torchdistill.misc.log Validation: [100/391] eta: 0:00:28 acc1: 75.7812 (74.8762) acc5: 94.5312 (93.1235) time: 0.0662 data: 0.0199 max mem: 9402
2024-02-08 08:59:01,960 INFO torchdistill.misc.log Validation: [150/391] eta: 0:00:22 acc1: 71.8750 (74.9638) acc5: 92.9688 (93.3775) time: 0.0862 data: 0.0403 max mem: 9402
2024-02-08 08:59:05,740 INFO torchdistill.misc.log Validation: [200/391] eta: 0:00:16 acc1: 56.2500 (72.1782) acc5: 82.0312 (91.5384) time: 0.0655 data: 0.0195 max mem: 9402
2024-02-08 08:59:09,170 INFO torchdistill.misc.log Validation: [250/391] eta: 0:00:11 acc1: 67.1875 (70.7483) acc5: 85.9375 (90.2920) time: 0.0607 data: 0.0150 max mem: 9402
2024-02-08 08:59:12,765 INFO torchdistill.misc.log Validation: [300/391] eta: 0:00:07 acc1: 63.2812 (69.5235) acc5: 82.8125 (89.3766) time: 0.0613 data: 0.0156 max mem: 9402
2024-02-08 08:59:16,248 INFO torchdistill.misc.log Validation: [350/391] eta: 0:00:03 acc1: 66.4062 (68.4606) acc5: 87.5000 (88.6040) time: 0.0699 data: 0.0237 max mem: 9402
2024-02-08 08:59:19,536 INFO torchdistill.misc.log Validation: Total time: 0:00:31
2024-02-08 08:59:19,537 INFO main * Acc@1 68.4660 Acc@5 88.6820

@yoshitomo-matsubara
Copy link
Author

Hi @roymiles

Thanks for the update! Let me know once you finalize the code and config.
I'm reimplementing your method in a unified way used for torchdistill, and if you're interested in contributing to torchdistill like this, I can help you do that and advertise your work

@roymiles
Copy link
Owner

That sounds great thanks! I'll just do a small loop through some hyperparameters first to get the best results. A proper implementation in torchdistill would be really great, thank you! I'll let you know when I have the results and code to share.

@roymiles
Copy link
Owner

@yoshitomo-matsubara I have just pushed now and it should all be good. I have also put the logs and model checkpoints in the README.md.

@yoshitomo-matsubara
Copy link
Author

@roymiles
Great! Do you want to make a PR for torchdistill repo? Or I can do it for you

@roymiles
Copy link
Owner

@yoshitomo-matsubara
Yea I think it would be really great to have this as part of the torchdistill repo. I do realise my current implementation doesn't quite fit the torchdistill template and I would have to look through all the other implementations you have to see how it should be done properly.

If it is not too much work, it would be really nice if you could implement this in torchdistill for me. That would be really helpful, thanks! 🙂

@yoshitomo-matsubara
Copy link
Author

No problem, I can do it for you.

What name would you pick for your method? If you don't have any preference, I would use MilesMikolajczyk2024 as part of module name e.g., OurDistillationLoss -> MilesMikolajczyk2024Loss

@roymiles
Copy link
Owner

Thanks! I think I'd prefer BNLogSumLoss as that summarises what it does a bit more clearly.

@yoshitomo-matsubara
Copy link
Author

BNLogSum may be good for the loss module, but I need to add a wrapper to include an auxiliary trainable module for your method
So I want to use a unique name of the method as part of the wrapper class name
https://github.com/yoshitomo-matsubara/torchdistill/blob/main/torchdistill/models/wrapper.py#L210

Similarly, I need the name at other places as well e.g., https://github.com/yoshitomo-matsubara/torchdistill/tree/main/configs/sample/ilsvrc2012

@roymiles
Copy link
Owner

For the wrapper class, perhaps Linear4BNLogSum. For the folder I'm not too sure ha. I think log_sum is unique but you might have a better idea than me. It is quite hard to think of an acronym that describes all the components (linear + BN + log sum). I don't mind too much :) I'm cool with whatever works.

@yoshitomo-matsubara
Copy link
Author

What about srd from this repository name (Simple-Recipe-Distillation) ?

@roymiles
Copy link
Owner

ah yea that sounds good 👍

@yoshitomo-matsubara
Copy link
Author

Hi @roymiles

I added your method as SRD to torchdistill repo

Can you fork the current torchdistill repo and use this config to reproduce the number?
resnet18_from_resnet34.txt
(tentatively using .txt as .yaml file cannot be uploaded here)

Once you confirm the reproducibility, keep the log file and checkpoint file and submit a PR with the yaml file + README.md at configs/official/ilsvrc2012/roymiles/aaai2024/

@roymiles
Copy link
Owner

Hi @yoshitomo-matsubara

Thanks so much for doing this! I'll give this a go sometime this/next week once I have a few GPUs free.

@yoshitomo-matsubara
Copy link
Author

yoshitomo-matsubara commented Feb 28, 2024

@roymiles no problem! Let me know then
I hope to include the official config, log, and checkpoint files when releasing the next version (soon)

EDIT: resnet18_from_resnet34.txt

@yoshitomo-matsubara
Copy link
Author

Hi @roymiles

How's the experiment going?

@roymiles
Copy link
Owner

Hi @yoshitomo-matsubara

I am really sorry for the late reply. I had some issues before with training, though it seems going to DataParallel (as you suggested) has fixed it a lot more cleanly :D

I then got a bit bogged down with other work/personal events, but I have since started the run now and it seems to be training well. I'll have the results in the next few days.

@roymiles
Copy link
Owner

I finished the run but the results were a bit lower than I expected. Though I have just realised that this may be due to having a projector on the teacher side i.e. teacher: auxiliary_model_wrapper:kwargs:linear_kwargs. I am re-running the experiment now, sorry about that! ha

@yoshitomo-matsubara
Copy link
Author

Hi @roymiles
Good catch! I forgot to remove linear_kwargs from teacher wrapper. Thanks for pointing it out

@roymiles
Copy link
Owner

Hi @yoshitomo-matsubara

Hopefully this is the final update before I push the log, checkpoint file, and yaml. I was getting poor results because my runs were automatically loading the optimiser and checkpoints from my previous run with the same dst_ckpt. This completely crossed my mind but the log loss is now on par with this repo implementation too and it will likely be finished in a day or so.

@yoshitomo-matsubara
Copy link
Author

Hi @roymiles

Does dst_ckpt work as a file path to load the checkpoint?

dst_ckpt is literally a file path to store the checkpoint, but not to load the checkpoint (src_ckpt is a file path for loading the checkpoint)

@roymiles
Copy link
Owner

This is what I found when trying to debug. The optimiser ended up starting at a much lower lr (in fact the final lr) than the config specified. It ended up being after this line: https://github.com/yoshitomo-matsubara/torchdistill/blob/3799847d0e24b89d22801f75f36f0d075906f928/examples/torchvision/image_classification.py#L132

This was with an empty src_ckpt in the yaml.

@yoshitomo-matsubara
Copy link
Author

ah src_ckpt should be used there. I will update the scripts soon. Thanks for pointing it out!

@yoshitomo-matsubara
Copy link
Author

By the way, you're also welcomed to advertise your BMVC'22 and AAAI'24 papers at
https://yoshitomo-matsubara.net/torchdistill/projects.html#papers
as those papers use torchdistill :)

@yoshitomo-matsubara
Copy link
Author

Hi @roymiles

It loos like the previous config did not use the normalized representations for computing a loss.
Can you run it again with srd-resnet18_from_resnet34.txt and confirm the reproducibility?

@yoshitomo-matsubara
Copy link
Author

Never mind, I reran the experiment with the fixed official config.
With the fixed official config, I achieved 71.93%, which is even better than the previous numbers (71.63% and 71.65% in your paper and previous torchdistill number)

yoshitomo-matsubara/torchdistill#473

@roymiles
Copy link
Owner

Hi @yoshitomo-matsubara

Sorry for the late reply and ah that's a complete oversight on my part. It is really great that you spotted this and even better you got better results 😂!

I have only just seen your previous post now, but I would definitely like to put a link/description to these papers on the project page. Thanks so much for this :D I will add a "show and tell discussion" this week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants