-
Notifications
You must be signed in to change notification settings - Fork 462
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Language Validation Test #257
Changes from all commits
8fc4b9a
9a3228b
ad9a588
7f32d7a
ff3254b
cb2d91a
f4e84d2
109eac2
2236fe0
472657b
c744030
c71aec7
c816d79
31aa252
c34279d
8b25a3d
302d068
73f873f
5e04a2a
e3552a7
1d74246
1bddd81
e4f3fd0
03284a2
43f4d1a
6ef54c2
89c9dde
d883ee0
e1e22a3
a8f4135
4cd83e8
e0be04f
9ec012f
fbbc6fc
99ebbfd
6356c4d
eac3674
d3935f9
18c1657
c26e775
023eab2
3944556
0a4d983
e5823d6
92345be
87a08e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import os | ||
import tempfile | ||
|
||
import pytest | ||
|
||
from laser_encoders.download_models import LaserModelDownloader | ||
from laser_encoders.language_list import LASER2_LANGUAGE, LASER3_LANGUAGE | ||
from laser_encoders.laser_tokenizer import initialize_tokenizer | ||
from laser_encoders.models import initialize_encoder | ||
|
||
|
||
def test_validate_achnese_models_and_tokenize_laser3(lang="acehnese"): | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
print(f"Created temporary directory for {lang}", tmp_dir) | ||
|
||
downloader = LaserModelDownloader(model_dir=tmp_dir) | ||
downloader.download_laser3(lang) | ||
encoder = initialize_encoder(lang, model_dir=tmp_dir) | ||
tokenizer = initialize_tokenizer(lang, model_dir=tmp_dir) | ||
|
||
# Test tokenization with a sample sentence | ||
tokenized = tokenizer.tokenize("This is a sample sentence.") | ||
|
||
print(f"{lang} model validated successfully") | ||
|
||
|
||
def test_validate_english_models_and_tokenize_laser2(lang="english"): | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
print(f"Created temporary directory for {lang}", tmp_dir) | ||
|
||
downloader = LaserModelDownloader(model_dir=tmp_dir) | ||
downloader.download_laser2() | ||
|
||
encoder = initialize_encoder(lang, model_dir=tmp_dir) | ||
tokenizer = initialize_tokenizer(lang, model_dir=tmp_dir) | ||
|
||
# Test tokenization with a sample sentence | ||
tokenized = tokenizer.tokenize("This is a sample sentence.") | ||
|
||
print(f"{lang} model validated successfully") | ||
|
||
|
||
def test_validate_kashmiri_models_and_tokenize_laser3(lang="kas"): | ||
avidale marked this conversation as resolved.
Show resolved
Hide resolved
|
||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
print(f"Created temporary directory for {lang}", tmp_dir) | ||
|
||
downloader = LaserModelDownloader(model_dir=tmp_dir) | ||
with pytest.raises(ValueError): | ||
downloader.download_laser3(lang) | ||
|
||
encoder = initialize_encoder(lang, model_dir=tmp_dir) | ||
tokenizer = initialize_tokenizer(lang, model_dir=tmp_dir) | ||
|
||
# Test tokenization with a sample sentence | ||
tokenized = tokenizer.tokenize("This is a sample sentence.") | ||
|
||
print(f"{lang} model validated successfully") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import os | ||
import tempfile | ||
|
||
import pytest | ||
|
||
from laser_encoders.download_models import LaserModelDownloader | ||
from laser_encoders.language_list import LASER2_LANGUAGE, LASER3_LANGUAGE | ||
from laser_encoders.laser_tokenizer import initialize_tokenizer | ||
from laser_encoders.models import initialize_encoder | ||
|
||
|
||
@pytest.mark.slow | ||
@pytest.mark.parametrize("lang", LASER3_LANGUAGE) | ||
def test_validate_language_models_and_tokenize_laser3(lang): | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
print(f"Created temporary directory for {lang}", tmp_dir) | ||
|
||
downloader = LaserModelDownloader(model_dir=tmp_dir) | ||
if lang in ["kashmiri", "kas", "central kanuri", "knc"]: | ||
with pytest.raises(ValueError) as excinfo: | ||
downloader.download_laser3(lang) | ||
assert "ValueError" in str(excinfo.value) | ||
print(f"{lang} language model raised a ValueError as expected.") | ||
else: | ||
downloader.download_laser3(lang) | ||
encoder = initialize_encoder(lang, model_dir=tmp_dir) | ||
tokenizer = initialize_tokenizer(lang, model_dir=tmp_dir) | ||
|
||
# Test tokenization with a sample sentence | ||
tokenized = tokenizer.tokenize("This is a sample sentence.") | ||
|
||
print(f"{lang} model validated successfully") | ||
|
||
|
||
@pytest.mark.slow | ||
@pytest.mark.parametrize("lang", LASER2_LANGUAGE) | ||
def test_validate_language_models_and_tokenize_laser2(lang): | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
print(f"Created temporary directory for {lang}", tmp_dir) | ||
|
||
downloader = LaserModelDownloader(model_dir=tmp_dir) | ||
downloader.download_laser2() | ||
|
||
encoder = initialize_encoder(lang, model_dir=tmp_dir) | ||
tokenizer = initialize_tokenizer(lang, model_dir=tmp_dir) | ||
|
||
# Test tokenization with a sample sentence | ||
tokenized = tokenizer.tokenize("This is a sample sentence.") | ||
|
||
print(f"{lang} model validated successfully") | ||
|
||
|
||
class MockLaserModelDownloader(LaserModelDownloader): | ||
def __init__(self, model_dir): | ||
self.model_dir = model_dir | ||
|
||
def download_laser3(self, lang): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC @avidale's suggestion for the mock downloader was just to check if the language codes exist? (and then have a real downloader for a couple of languages like you have in For example, you could parameterise it with the LASER3 langs, but the func There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (resolved in chat) |
||
lang = self.get_language_code(LASER3_LANGUAGE, lang) | ||
file_path = os.path.join(self.model_dir, f"laser3-{lang}.v1.pt") | ||
if not os.path.exists(file_path): | ||
raise FileNotFoundError(f"Could not find {file_path}.") | ||
|
||
def download_laser2(self): | ||
files = ["laser2.pt", "laser2.spm", "laser2.cvocab"] | ||
for file_name in files: | ||
file_path = os.path.join(self.model_dir, file_name) | ||
if not os.path.exists(file_path): | ||
raise FileNotFoundError(f"Could not find {file_path}.") | ||
|
||
|
||
CACHE_DIR = "/home/user/.cache/models" # Change this to the desired cache directory | ||
|
||
# This uses the mock downloader | ||
@pytest.mark.slow | ||
@pytest.mark.parametrize("lang", LASER3_LANGUAGE) | ||
def test_validate_language_models_and_tokenize_mock_laser3(lang): | ||
downloader = MockLaserModelDownloader(model_dir=CACHE_DIR) | ||
|
||
try: | ||
downloader.download_laser3(lang) | ||
except FileNotFoundError as e: | ||
raise pytest.error(str(e)) | ||
|
||
encoder = initialize_encoder(lang, model_dir=CACHE_DIR) | ||
tokenizer = initialize_tokenizer(lang, model_dir=CACHE_DIR) | ||
|
||
tokenized = tokenizer.tokenize("This is a sample sentence.") | ||
|
||
print(f"{lang} model validated successfully") | ||
|
||
|
||
# This uses the mock downloader | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't use the mock downloader? (L112) |
||
@pytest.mark.slow | ||
@pytest.mark.parametrize("lang", LASER2_LANGUAGE) | ||
def test_validate_language_models_and_tokenize_mock_laser2(lang): | ||
downloader = MockLaserModelDownloader(model_dir=CACHE_DIR) | ||
|
||
try: | ||
downloader.download_laser2() | ||
except FileNotFoundError as e: | ||
raise pytest.error(str(e)) | ||
|
||
encoder = initialize_encoder(lang, model_dir=CACHE_DIR) | ||
tokenizer = initialize_tokenizer(lang, model_dir=CACHE_DIR) | ||
|
||
tokenized = tokenizer.tokenize("This is a sample sentence.") | ||
|
||
print(f"{lang} model validated successfully") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch!