diff --git a/predict.py b/predict.py index 36e4d65..ccfe2bc 100644 --- a/predict.py +++ b/predict.py @@ -8,9 +8,10 @@ def predict_image( image_path: str = typer.Argument(help="image path", show_default=True), - model_path: str = typer.Argument("checkpoint/checkpoint_notebook.pth", help="model path (pth)", show_default=True) + model_path: str = typer.Argument("checkpoint/checkpoint_notebook.pth", help="model path (pth)", show_default=True), + device: str = typer.Argument("cpu", help="use cuda if your device has cuda", show_default=True) ): - predictor = ImageRecognition(model_path=model_path) + predictor = ImageRecognition(model_path=model_path, device=device) result = predictor.predict(image=image_path) typer.echo(f"Prediction: {result}")