Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Isn't the reduction="mean" in the loss function supposed to be changed to "batchmean"? #41

Open
mystic-square opened this issue Jan 17, 2025 · 0 comments

Comments

@mystic-square
Copy link

Hello, I really appreciate you sharing this work.

I have a question:

During my De-KD experiments, I noticed that the default reduction value in the code D_KL = nn.KLDivLoss()(F.log_softmax(outputs / T, dim=1), F.softmax(teacher_outputs / T, dim=1)) * (T * T) is set to default value "mean". Using "batchmean" might be more appropriate for computing the loss. Otherwise, the KL loss value could be quite small because it gets averaged instead of summed across the batch. This might not be ideal for knowledge distillation.

I printed the values of loss_CE and D_KL in the code, and I found that even with the factor alpha, it doesn’t quite balance the two losses.

def loss_kd(outputs, labels, teacher_outputs, params):
  
    alpha = params.alpha
    T = params.temperature
    loss_CE = F.cross_entropy(outputs, labels)
    D_KL = nn.KLDivLoss()(F.log_softmax(outputs / T, dim=1), F.softmax(teacher_outputs / T, dim=1)) * (T * T)
    KD_loss = (1. - alpha) * loss_CE + alpha * D_KL

    print("loss_ce:" + str(np.array(((1. - alpha) * loss_CE).cpu().detach().numpy())))
    print("loss_kd:" + str(np.array((alpha * D_KL).cpu().detach().numpy())))

    return KD_loss`

Here’s some of the output:

Student model: mobilenet_v2_distill
Teacher model: resnet18
Starting training for 200 epoch(s)
- Eval metrics, acc:15.4800, loss: 0.0000
>>>>>>>>>The teacher accuracy: 15.48>>>>>>>>>
Epoch 1/200, lr:0.0
  0%|                                                                                                                                       | 0/391 [00:00<?, ?it/s]loss_ce:0.23258547
loss_kd:0.012982355
  0%|▎                                                                                                     | 1/391 [00:01<08:37,  1.33s/it, loss=0.246, lr=0.000256]
loss_ce:0.2298743
loss_kd:0.011845248
  0%|▎                                                                                                     | 1/391 [00:01<08:37,  1.33s/it, loss=0.244, lr=0.000512]
loss_ce:0.23207653
loss_kd:0.011403086
  1%|▊                                                                                                     | 3/391 [00:01<02:28,  2.61it/s, loss=0.244, lr=0.000767]
loss_ce:0.23141216
loss_kd:0.012453535
  1%|▊                                                                                                     | 3/391 [00:01<02:28,  2.61it/s, loss=0.244, lr=0.001023]
loss_ce:0.23330784
loss_kd:0.012101171
  1%|█▎                                                                                                    | 5/391 [00:01<01:22,  4.66it/s, loss=0.244, lr=0.001279]
loss_ce:0.23280807
loss_kd:0.011920918
  1%|█▎                                                                                                    | 5/391 [00:01<01:22,  4.66it/s, loss=0.244, lr=0.001535]
loss_ce:0.23340328
loss_kd:0.0119095445
  2%|█▊                                                                                                    | 7/391 [00:01<00:56,  6.78it/s, loss=0.244, lr=0.001790]
loss_ce:0.23054479
loss_kd:0.013966557
  2%|█▊                                                                                                    | 7/391 [00:01<00:56,  6.78it/s, loss=0.244, lr=0.002046]
loss_ce:0.23117931
loss_kd:0.011783687
  2%|██▎                                                                                                   | 9/391 [00:01<00:44,  8.54it/s, loss=0.244, lr=0.002302]
loss_ce:0.23241459
loss_kd:0.013681319
  2%|██▎                                                                                                   | 9/391 [00:01<00:44,  8.54it/s, loss=0.244, lr=0.002558]
loss_ce:0.23257616
loss_kd:0.011883538
  3%|██▊                                                                                                  | 11/391 [00:01<00:36, 10.37it/s, loss=0.244, lr=0.002813]
loss_ce:0.23141861
loss_kd:0.013294127
  3%|██▊                                                                                                  | 11/391 [00:01<00:36, 10.37it/s, loss=0.244, lr=0.003069]
loss_ce:0.23048294
loss_kd:0.012409204
  3%|███▎                                                                                                 | 13/391 [00:02<00:31, 11.92it/s, loss=0.244, lr=0.003325]
loss_ce:0.23034087
loss_kd:0.0133417435
  3%|███▎                                                                                                 | 13/391 [00:02<00:31, 11.92it/s, loss=0.244, lr=0.003581]
loss_ce:0.22938477
loss_kd:0.011549198
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant