From c2f66cd8e39f4b29fd5ccc5c4387be5c6acac75c Mon Sep 17 00:00:00 2001 From: CaptainVee Date: Thu, 21 Sep 2023 20:45:51 +0100 Subject: [PATCH] added value error for when there is no spm_model --- laser_encoders/models.py | 4 ++++ laser_encoders/test_laser_tokenizer.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/laser_encoders/models.py b/laser_encoders/models.py index 678efc5d..e2a81ef9 100644 --- a/laser_encoders/models.py +++ b/laser_encoders/models.py @@ -95,6 +95,10 @@ def __call__(self, sentences): if self.spm_model: sentences = self.tokenizer(sentences) return self.encode_sentences(sentences) + else: + raise ValueError( + "Either initialize the encoder with an spm_model or pre-tokenize and use the encode_sentences method." + ) def _process_batch(self, batch): tokens = batch.tokens diff --git a/laser_encoders/test_laser_tokenizer.py b/laser_encoders/test_laser_tokenizer.py index cd36182d..867111cf 100644 --- a/laser_encoders/test_laser_tokenizer.py +++ b/laser_encoders/test_laser_tokenizer.py @@ -173,5 +173,5 @@ def test_sentence_encoder( sentence_embedding = sentence_encoder.encode_sentences([tokenized_text]) assert isinstance(sentence_embedding, np.ndarray) - # assert sentence_embedding.shape == (1, 1024) + assert sentence_embedding.shape == (1, 1024) assert np.allclose(expected_array, sentence_embedding[:, :10], atol=1e-3)