diff --git a/laser_encoders/download_models.py b/laser_encoders/download_models.py index c4ace448..17a5db35 100644 --- a/laser_encoders/download_models.py +++ b/laser_encoders/download_models.py @@ -117,7 +117,7 @@ def initialize_encoder( model_dir: str = None, spm: bool = True, laser: str = None, - tokenize: bool = None, + tokenize: bool = False, ): downloader = LaserModelDownloader(model_dir) if laser is not None: @@ -147,17 +147,18 @@ def initialize_encoder( model_dir = downloader.model_dir model_path = os.path.join(model_dir, f"{file_path}.pt") - spm_path = os.path.join(model_dir, f"{file_path}.cvocab") + spm_vocab = os.path.join(model_dir, f"{file_path}.cvocab") spm_model = None + if not os.path.exists(spm_vocab): + # if there is no cvocab for the laser3 lang use laser2 cvocab + spm_vocab = os.path.join(model_dir, "laser2.cvocab") if tokenize: spm_model = os.path.join(model_dir, f"{file_path}.spm") + if not os.path.exists(spm_model): + spm_model = os.path.join(model_dir, "laser2.spm") - if not os.path.exists(spm_path): - # if there is no cvocab for the laser3 lang use laser2 cvocab - spm_path = os.path.join(model_dir, "laser2.cvocab") - spm_model = os.path.join(model_dir, "laser2.spm") return SentenceEncoder( - model_path=model_path, spm_vocab=spm_path, spm_model=spm_model + model_path=model_path, spm_vocab=spm_vocab, spm_model=spm_model ) diff --git a/laser_encoders/models.py b/laser_encoders/models.py index 0a36d49b..678efc5d 100644 --- a/laser_encoders/models.py +++ b/laser_encoders/models.py @@ -55,6 +55,9 @@ def __init__( if verbose: logger.info(f"loading encoder: {model_path}") self.spm_model = spm_model + if self.spm_model: + self.tokenizer = LaserTokenizer(spm_model=Path(self.spm_model)) + self.use_cuda = torch.cuda.is_available() and not cpu self.max_sentences = max_sentences self.max_tokens = max_tokens @@ -88,6 +91,11 @@ def __init__( self.encoder.eval() self.sort_kind = sort_kind + def __call__(self, sentences): + if self.spm_model: + sentences = self.tokenizer(sentences) + return self.encode_sentences(sentences) + def _process_batch(self, batch): tokens = batch.tokens lengths = batch.lengths @@ -153,10 +161,6 @@ def batch(tokens, lengths, indices): yield batch(batch_tokens, batch_lengths, batch_indices) def encode_sentences(self, sentences): - if self.spm_model: - tokenizer = LaserTokenizer(spm_model=Path(self.spm_model)) - sentences = tokenizer(sentences) - indices = [] results = [] for batch, batch_indices in self._make_batches(sentences):