diff --git a/services/finetuning/tsfmfinetuning/finetuning.py b/services/finetuning/tsfmfinetuning/finetuning.py index 374048a..4b63e7d 100644 --- a/services/finetuning/tsfmfinetuning/finetuning.py +++ b/services/finetuning/tsfmfinetuning/finetuning.py @@ -7,6 +7,7 @@ from typing import Any, Dict, Tuple, Union import pandas as pd +import torch from fastapi import APIRouter, HTTPException from starlette import status from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed @@ -215,7 +216,7 @@ def _finetuning_common( metric_for_best_model="eval_loss", # Metric to monitor for early stopping greater_is_better=False, # For loss label_names=["future_values"], - use_cpu=True, # only needed for testing on Mac :( + use_cpu=not torch.cuda.is_available(), ) callbacks = []