diff --git a/laser_encoders/download_models.py b/laser_encoders/download_models.py index 1167d7c1..fbd731db 100644 --- a/laser_encoders/download_models.py +++ b/laser_encoders/download_models.py @@ -71,10 +71,10 @@ def download(self, filename: str): def get_language_code(self, language_list: dict, lang: str) -> str: try: lang_3_4 = language_list[lang] - if isinstance(lang_3_4, tuple): + if isinstance(lang_3_4, list): options = ", ".join(f"'{opt}'" for opt in lang_3_4) raise ValueError( - f"Language '{lang_3_4}' has multiple options: {options}. Please specify using --lang." + f"Language '{lang}' has multiple options: {options}. Please specify using the 'lang' argument." ) return lang_3_4 except KeyError: @@ -88,7 +88,14 @@ def download_laser2(self): self.download("laser2.cvocab") def download_laser3(self, lang: str, spm: bool = False): - lang = self.get_language_code(LASER3_LANGUAGE, lang) + result = self.get_language_code(LASER3_LANGUAGE, lang) + + if isinstance(result, list): + raise ValueError( + f"There are script-specific models available for {lang}. Please choose one from the following: {result}" + ) + + lang = result self.download(f"laser3-{lang}.v1.pt") if spm: if lang in SPM_LANGUAGE: diff --git a/laser_encoders/laser_tokenizer.py b/laser_encoders/laser_tokenizer.py index 0488cb2c..763a5aae 100644 --- a/laser_encoders/laser_tokenizer.py +++ b/laser_encoders/laser_tokenizer.py @@ -153,12 +153,14 @@ def initialize_tokenizer(lang: str = None, model_dir: str = None, laser: str = N f"Unsupported laser model: {laser}. Choose either laser2 or laser3." ) else: - if lang in LASER3_LANGUAGE or lang in LASER2_LANGUAGE: + if lang in LASER3_LANGUAGE: lang = downloader.get_language_code(LASER3_LANGUAGE, lang) if lang in SPM_LANGUAGE: filename = f"laser3-{lang}.v1.spm" else: filename = "laser2.spm" + elif lang in LASER2_LANGUAGE: + filename = "laser2.spm" else: raise ValueError( f"Unsupported language name: {lang}. Please specify a supported language name." diff --git a/laser_encoders/models.py b/laser_encoders/models.py index 037a4f9f..d1617a53 100644 --- a/laser_encoders/models.py +++ b/laser_encoders/models.py @@ -350,8 +350,8 @@ def initialize_encoder( f"Unsupported laser model: {laser}. Choose either laser2 or laser3." ) else: - lang = downloader.get_language_code(LASER3_LANGUAGE, lang) if lang in LASER3_LANGUAGE: + lang = downloader.get_language_code(LASER3_LANGUAGE, lang) downloader.download_laser3(lang=lang, spm=spm) file_path = f"laser3-{lang}.v1" elif lang in LASER2_LANGUAGE: diff --git a/laser_encoders/test_models_initialization.py b/laser_encoders/test_models_initialization.py new file mode 100644 index 00000000..88e898fa --- /dev/null +++ b/laser_encoders/test_models_initialization.py @@ -0,0 +1,57 @@ +import os +import tempfile + +import pytest + +from laser_encoders.download_models import LaserModelDownloader +from laser_encoders.language_list import LASER2_LANGUAGE, LASER3_LANGUAGE +from laser_encoders.laser_tokenizer import initialize_tokenizer +from laser_encoders.models import initialize_encoder + + +def test_validate_achnese_models_and_tokenize_laser3(lang="acehnese"): + with tempfile.TemporaryDirectory() as tmp_dir: + print(f"Created temporary directory for {lang}", tmp_dir) + + downloader = LaserModelDownloader(model_dir=tmp_dir) + downloader.download_laser3(lang) + encoder = initialize_encoder(lang, model_dir=tmp_dir) + tokenizer = initialize_tokenizer(lang, model_dir=tmp_dir) + + # Test tokenization with a sample sentence + tokenized = tokenizer.tokenize("This is a sample sentence.") + + print(f"{lang} model validated successfully") + + +def test_validate_english_models_and_tokenize_laser2(lang="english"): + with tempfile.TemporaryDirectory() as tmp_dir: + print(f"Created temporary directory for {lang}", tmp_dir) + + downloader = LaserModelDownloader(model_dir=tmp_dir) + downloader.download_laser2() + + encoder = initialize_encoder(lang, model_dir=tmp_dir) + tokenizer = initialize_tokenizer(lang, model_dir=tmp_dir) + + # Test tokenization with a sample sentence + tokenized = tokenizer.tokenize("This is a sample sentence.") + + print(f"{lang} model validated successfully") + + +def test_validate_kashmiri_models_and_tokenize_laser3(lang="kas"): + with tempfile.TemporaryDirectory() as tmp_dir: + print(f"Created temporary directory for {lang}", tmp_dir) + + downloader = LaserModelDownloader(model_dir=tmp_dir) + with pytest.raises(ValueError): + downloader.download_laser3(lang) + + encoder = initialize_encoder(lang, model_dir=tmp_dir) + tokenizer = initialize_tokenizer(lang, model_dir=tmp_dir) + + # Test tokenization with a sample sentence + tokenized = tokenizer.tokenize("This is a sample sentence.") + + print(f"{lang} model validated successfully") diff --git a/laser_encoders/validate_models.py b/laser_encoders/validate_models.py new file mode 100644 index 00000000..0748dfee --- /dev/null +++ b/laser_encoders/validate_models.py @@ -0,0 +1,108 @@ +import os +import tempfile + +import pytest + +from laser_encoders.download_models import LaserModelDownloader +from laser_encoders.language_list import LASER2_LANGUAGE, LASER3_LANGUAGE +from laser_encoders.laser_tokenizer import initialize_tokenizer +from laser_encoders.models import initialize_encoder + + +@pytest.mark.slow +@pytest.mark.parametrize("lang", LASER3_LANGUAGE) +def test_validate_language_models_and_tokenize_laser3(lang): + with tempfile.TemporaryDirectory() as tmp_dir: + print(f"Created temporary directory for {lang}", tmp_dir) + + downloader = LaserModelDownloader(model_dir=tmp_dir) + if lang in ["kashmiri", "kas", "central kanuri", "knc"]: + with pytest.raises(ValueError) as excinfo: + downloader.download_laser3(lang) + assert "ValueError" in str(excinfo.value) + print(f"{lang} language model raised a ValueError as expected.") + else: + downloader.download_laser3(lang) + encoder = initialize_encoder(lang, model_dir=tmp_dir) + tokenizer = initialize_tokenizer(lang, model_dir=tmp_dir) + + # Test tokenization with a sample sentence + tokenized = tokenizer.tokenize("This is a sample sentence.") + + print(f"{lang} model validated successfully") + + +@pytest.mark.slow +@pytest.mark.parametrize("lang", LASER2_LANGUAGE) +def test_validate_language_models_and_tokenize_laser2(lang): + with tempfile.TemporaryDirectory() as tmp_dir: + print(f"Created temporary directory for {lang}", tmp_dir) + + downloader = LaserModelDownloader(model_dir=tmp_dir) + downloader.download_laser2() + + encoder = initialize_encoder(lang, model_dir=tmp_dir) + tokenizer = initialize_tokenizer(lang, model_dir=tmp_dir) + + # Test tokenization with a sample sentence + tokenized = tokenizer.tokenize("This is a sample sentence.") + + print(f"{lang} model validated successfully") + + +class MockLaserModelDownloader(LaserModelDownloader): + def __init__(self, model_dir): + self.model_dir = model_dir + + def download_laser3(self, lang): + lang = self.get_language_code(LASER3_LANGUAGE, lang) + file_path = os.path.join(self.model_dir, f"laser3-{lang}.v1.pt") + if not os.path.exists(file_path): + raise FileNotFoundError(f"Could not find {file_path}.") + + def download_laser2(self): + files = ["laser2.pt", "laser2.spm", "laser2.cvocab"] + for file_name in files: + file_path = os.path.join(self.model_dir, file_name) + if not os.path.exists(file_path): + raise FileNotFoundError(f"Could not find {file_path}.") + + +CACHE_DIR = "/home/user/.cache/models" # Change this to the desired cache directory + +# This uses the mock downloader +@pytest.mark.slow +@pytest.mark.parametrize("lang", LASER3_LANGUAGE) +def test_validate_language_models_and_tokenize_mock_laser3(lang): + downloader = MockLaserModelDownloader(model_dir=CACHE_DIR) + + try: + downloader.download_laser3(lang) + except FileNotFoundError as e: + raise pytest.error(str(e)) + + encoder = initialize_encoder(lang, model_dir=CACHE_DIR) + tokenizer = initialize_tokenizer(lang, model_dir=CACHE_DIR) + + tokenized = tokenizer.tokenize("This is a sample sentence.") + + print(f"{lang} model validated successfully") + + +# This uses the mock downloader +@pytest.mark.slow +@pytest.mark.parametrize("lang", LASER2_LANGUAGE) +def test_validate_language_models_and_tokenize_mock_laser2(lang): + downloader = MockLaserModelDownloader(model_dir=CACHE_DIR) + + try: + downloader.download_laser2() + except FileNotFoundError as e: + raise pytest.error(str(e)) + + encoder = initialize_encoder(lang, model_dir=CACHE_DIR) + tokenizer = initialize_tokenizer(lang, model_dir=CACHE_DIR) + + tokenized = tokenizer.tokenize("This is a sample sentence.") + + print(f"{lang} model validated successfully")