diff --git a/laser_encoders/download_models.py b/laser_encoders/download_models.py index 17a5db35..452501d3 100644 --- a/laser_encoders/download_models.py +++ b/laser_encoders/download_models.py @@ -17,7 +17,9 @@ import argparse import logging import os +import shutil import sys +import tempfile from pathlib import Path import requests @@ -46,20 +48,27 @@ def __init__(self, model_dir: str = None): def download(self, filename: str): url = os.path.join(self.base_url, filename) - local_file_path = self.model_dir / filename + local_file_path = os.path.join(self.model_dir, filename) - if local_file_path.exists(): + if os.path.exists(local_file_path): logger.info(f" - {filename} already downloaded") else: logger.info(f" - Downloading {filename}") - response = requests.get(url, stream=True) - total_size = int(response.headers.get("Content-Length", 0)) - progress_bar = tqdm(total=total_size, unit_scale=True, unit="B") - with open(local_file_path, "wb") as f: + + tf = tempfile.NamedTemporaryFile(delete=False) + temp_file_path = tf.name + + with tf: + response = requests.get(url, stream=True) + total_size = int(response.headers.get("Content-Length", 0)) + progress_bar = tqdm(total=total_size, unit_scale=True, unit="B") + for chunk in response.iter_content(chunk_size=1024): - f.write(chunk) + tf.write(chunk) progress_bar.update(len(chunk)) - progress_bar.close() + progress_bar.close() + + shutil.move(temp_file_path, local_file_path) def get_language_code(self, language_list: dict, lang: str) -> str: try: