Skip to content

Commit

Permalink
explicitly set device
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed Dec 10, 2024
1 parent 0afbf8d commit 5330e1e
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions services/inference/tsfminference/hf_service_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Dict, Optional, Union

import pandas as pd
import torch
import transformers
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel

Expand Down Expand Up @@ -353,12 +354,14 @@ def _run(
total_periods=model_prediction_length,
)

device = "cpu" if not torch.cuda.is_available() else "cuda"
forecast_pipeline = TimeSeriesForecastingPipeline(
model=self.model,
explode_forecasts=True,
feature_extractor=self.preprocessor,
add_known_ground_truth=False,
freq=self.preprocessor.freq,
device=device,
)
forecasts = forecast_pipeline(data, future_time_series=future_data, inverse_scale_outputs=True)

Expand Down

0 comments on commit 5330e1e

Please sign in to comment.