diff --git a/utmos/lightning_module.py b/utmos/lightning_module.py index 0f20a11..4ae8884 100644 --- a/utmos/lightning_module.py +++ b/utmos/lightning_module.py @@ -12,6 +12,12 @@ def __init__(self, cfg): self.cfg = cfg self.construct_model() self.save_hyperparameters() + device = 'cpu' + if torch.cuda.is_available(): + device = 'cuda' + if torch.backends.mps.is_available(): + device = 'mps' + self.device = device def construct_model(self): self.feature_extractors = nn.ModuleList([ @@ -32,7 +38,7 @@ def construct_model(self): def forward(self, inputs): outputs = {} - inputs = {key: value.to('mps') for key, value in inputs.items()} + inputs = {key: value.to(self.device) for key, value in inputs.items()} for feature_extractor in self.feature_extractors: outputs.update(feature_extractor(inputs)) x = outputs