Skip to content

Commit

Permalink
Update lightning_module.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fakerybakery authored May 8, 2024
1 parent 96e5219 commit 5598ab1
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion utmos/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ def __init__(self, cfg):
self.cfg = cfg
self.construct_model()
self.save_hyperparameters()
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
if torch.backends.mps.is_available():
device = 'mps'
self.device = device

def construct_model(self):
self.feature_extractors = nn.ModuleList([
Expand All @@ -32,7 +38,7 @@ def construct_model(self):

def forward(self, inputs):
outputs = {}
inputs = {key: value.to('mps') for key, value in inputs.items()}
inputs = {key: value.to(self.device) for key, value in inputs.items()}
for feature_extractor in self.feature_extractors:
outputs.update(feature_extractor(inputs))
x = outputs
Expand Down

0 comments on commit 5598ab1

Please sign in to comment.