diff --git a/pyiqa/models/inference_model.py b/pyiqa/models/inference_model.py index 8dcc090..171dc34 100644 --- a/pyiqa/models/inference_model.py +++ b/pyiqa/models/inference_model.py @@ -56,6 +56,7 @@ def __init__( self.net = self.net.to(self.device) self.net.eval() + self.seed = seed if not as_loss: set_random_seed(seed) @@ -63,6 +64,8 @@ def __init__( def forward(self, target, ref=None, **kwargs): device = self.dummy_param.device + if not self.as_loss: + set_random_seed(self.seed) with torch.set_grad_enabled(self.as_loss):