diff --git a/methods/imprinting.py b/methods/imprinting.py index 6e6ec0e..62fe36f 100644 --- a/methods/imprinting.py +++ b/methods/imprinting.py @@ -67,14 +67,9 @@ class WeightImprinting(Trainer): """ def __init__(self, task, device, logger, opts): super().__init__(task, device, logger, opts) - self.pixel = opts.pixel_imprinting self.masking = True - if opts.weight_mix: - self.normalize_weight = True - self.compute_score = True - else: - self.normalize_weight = False - self.compute_score = False + self.normalize_weight = False + self.compute_score = False def warm_up_(self, dataset, epochs=5): model = self.model.module if self.distributed else self.model