From 3f4c553f74c658f1d78e6fbdcdb239d43ff3d801 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Rana <91743459+NIXBLACK11@users.noreply.github.com> Date: Sun, 19 Nov 2023 21:55:51 +0530 Subject: [PATCH] add tests to test_laser_tokenizer.py --- laser_encoders/test_laser_tokenizer.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/laser_encoders/test_laser_tokenizer.py b/laser_encoders/test_laser_tokenizer.py index 08ad789e..5a60c7e3 100644 --- a/laser_encoders/test_laser_tokenizer.py +++ b/laser_encoders/test_laser_tokenizer.py @@ -20,6 +20,7 @@ import numpy as np import pytest +import warnings from laser_encoders import ( LaserEncoderPipeline, @@ -285,3 +286,18 @@ def test_encoder_non_normalization(tmp_path: Path, test_readme_params: dict): norm = np.linalg.norm(non_normalized_embeddings[0]) assert not np.isclose(norm, 1) + + +def test_optional_lang_with_laser2(tmp_path: Path): + with pytest.warns(UserWarning, match="The 'lang' parameter is optional when using 'laser2'. It will be ignored."): + encoder = LaserEncoderPipeline(lang="en", laser="laser2", model_dir=tmp_path) + + +def test_required_lang_with_laser3(tmp_path: Path): + with pytest.raises(ValueError, match="For 'laser3', the 'lang' parameter is required."): + encoder = LaserEncoderPipeline(laser="laser3", model_dir=tmp_path) + + +def test_missing_lang_and_laser(tmp_path: Path): + with pytest.raises(ValueError, match="Either 'laser' or 'lang' should be provided."): + encoder = LaserEncoderPipeline(model_dir=tmp_path)