From c96905dba5cc3c1943c5360cecf70400ced5a139 Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Tue, 10 Dec 2024 14:04:18 -0500 Subject: [PATCH] select device --- services/finetuning/tsfmfinetuning/finetuning.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/services/finetuning/tsfmfinetuning/finetuning.py b/services/finetuning/tsfmfinetuning/finetuning.py index 374048a2..4b63e7d5 100644 --- a/services/finetuning/tsfmfinetuning/finetuning.py +++ b/services/finetuning/tsfmfinetuning/finetuning.py @@ -7,6 +7,7 @@ from typing import Any, Dict, Tuple, Union import pandas as pd +import torch from fastapi import APIRouter, HTTPException from starlette import status from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed @@ -215,7 +216,7 @@ def _finetuning_common( metric_for_best_model="eval_loss", # Metric to monitor for early stopping greater_is_better=False, # For loss label_names=["future_values"], - use_cpu=True, # only needed for testing on Mac :( + use_cpu=not torch.cuda.is_available(), ) callbacks = []