From 21f71652edf6d6b90bf952091faa46489f447008 Mon Sep 17 00:00:00 2001 From: jmisilo Date: Sun, 29 Oct 2023 20:29:16 +0100 Subject: [PATCH 1/2] fix: model weights urls mismatch --- src/utils/downloads.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils/downloads.py b/src/utils/downloads.py index 9f1464e..75ef840 100644 --- a/src/utils/downloads.py +++ b/src/utils/downloads.py @@ -11,9 +11,9 @@ def download_weights(checkpoint_fpath, model_size="L"): """ download_id = ( - "1pSQruQyg8KJq6VmzhMLFbT_VaHJMdlWF" + "1Gh32arzhW06C1ZJyzcJSSfdJDi3RgWoG" if model_size.strip().upper() == "L" - else "1Gh32arzhW06C1ZJyzcJSSfdJDi3RgWoG" + else "1pSQruQyg8KJq6VmzhMLFbT_VaHJMdlWF" ) gdown.download( From 46635475cea16ffb2aeba5386fb5eafe5dec465c Mon Sep 17 00:00:00 2001 From: jmisilo Date: Sun, 29 Oct 2023 20:30:39 +0100 Subject: [PATCH 2/2] update: model weight urls definition --- src/utils/downloads.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/utils/downloads.py b/src/utils/downloads.py index 75ef840..b0ef2bc 100644 --- a/src/utils/downloads.py +++ b/src/utils/downloads.py @@ -4,17 +4,18 @@ import gdown +MODEL_WEIGHTS = { + "L": "1Gh32arzhW06C1ZJyzcJSSfdJDi3RgWoG", + "S": "1pSQruQyg8KJq6VmzhMLFbT_VaHJMdlWF", +} + def download_weights(checkpoint_fpath, model_size="L"): """ Downloads weights from Google Drive. """ - download_id = ( - "1Gh32arzhW06C1ZJyzcJSSfdJDi3RgWoG" - if model_size.strip().upper() == "L" - else "1pSQruQyg8KJq6VmzhMLFbT_VaHJMdlWF" - ) + download_id = MODEL_WEIGHTS[model_size.strip().upper()] gdown.download( f"https://drive.google.com/uc?id={download_id}", checkpoint_fpath, quiet=False