Skip to content

Commit

Permalink
Added iteration for generation in FSL step
Browse files Browse the repository at this point in the history
  • Loading branch information
fcdl94 authored and fcdl94 committed Mar 8, 2021
1 parent 884dd83 commit 8560088
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions methods/generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(self, task, device, logger, opts):

if task.step > 0:
self.gen_weight = 1.
self.gen_factor = 5
if self.gen_mib:
self.generated_criterion = BinaryCrossEntropy()
else:
Expand Down Expand Up @@ -123,16 +124,19 @@ def warm_up_(self, dataset, epochs=1):
model.cls.imprint_weights_step(step=self.task.step, features=features)

def generative_loss(self, images=None, labels=None):
with torch.no_grad():
gen_feat, gen_target = self.generate_synth_feat(images, labels)
score = self.model(gen_feat, only_head=True)
loss = self.generated_criterion(score, gen_target)
loss = 0.
for _ in range(self.gen_factor):
with torch.no_grad():
gen_feat, gen_target = self.generate_synth_feat(images, labels)
score = self.model(gen_feat, only_head=True)
loss += self.generated_criterion(score, gen_target)

if self.model_old is not None:
score_old = self.model_old(gen_feat, only_head=True)
if self.kd_criterion is not None:
loss += self.kd_loss * self.kd_criterion(score, score_old)
if self.model_old is not None:
score_old = self.model_old(gen_feat, only_head=True)
if self.kd_criterion is not None:
loss += self.kd_loss * self.kd_criterion(score, score_old)

loss = loss / self.gen_factor
return self.gen_weight * loss

def generate_synth_feat(self, images=None, labels=None):
Expand Down

0 comments on commit 8560088

Please sign in to comment.