diff --git a/FlagEmbedding/inference/embedder/decoder_only/base.py b/FlagEmbedding/inference/embedder/decoder_only/base.py index 930d3ab3..9c3deb6d 100644 --- a/FlagEmbedding/inference/embedder/decoder_only/base.py +++ b/FlagEmbedding/inference/embedder/decoder_only/base.py @@ -243,19 +243,16 @@ def encode_single_device( # adjust batch size flag = False - max_length_inputs = self.tokenizer.pad( - all_inputs_sorted[:1], - padding=True, - return_tensors='pt', - **kwargs - ).to(device) while flag is False: try: - test_inputs_batch = {} - for k, v in max_length_inputs.items(): - test_inputs_batch[k] = v.repeat(batch_size, 1) - last_hidden_state = self.model(**test_inputs_batch, return_dict=True).last_hidden_state - embeddings = last_token_pool(last_hidden_state, test_inputs_batch['attention_mask']) + inputs_batch = self.tokenizer.pad( + all_inputs_sorted[: batch_size], + padding=True, + return_tensors='pt', + **kwargs + ).to(device) + last_hidden_state = self.model(**inputs_batch, return_dict=True).last_hidden_state + embeddings = last_token_pool(last_hidden_state, inputs_batch['attention_mask']) flag = True except RuntimeError as e: batch_size = batch_size * 3 // 4 diff --git a/FlagEmbedding/inference/embedder/decoder_only/icl.py b/FlagEmbedding/inference/embedder/decoder_only/icl.py index 57e6d2cf..790829a1 100644 --- a/FlagEmbedding/inference/embedder/decoder_only/icl.py +++ b/FlagEmbedding/inference/embedder/decoder_only/icl.py @@ -393,19 +393,16 @@ def encode_queries_single_device( # adjust batch size flag = False - max_length_inputs = self.tokenizer.pad( - all_inputs_sorted[:1], - padding=True, - return_tensors='pt', - **kwargs - ).to(device) while flag is False: try: - test_inputs_batch = {} - for k, v in max_length_inputs.items(): - test_inputs_batch[k] = v.repeat(batch_size, 1) - last_hidden_state = self.model(**test_inputs_batch, return_dict=True).last_hidden_state - embeddings = last_token_pool(last_hidden_state, test_inputs_batch['attention_mask']) + inputs_batch = self.tokenizer.pad( + all_inputs_sorted[: batch_size], + padding=True, + return_tensors='pt', + **kwargs + ).to(device) + last_hidden_state = self.model(**inputs_batch, return_dict=True).last_hidden_state + embeddings = last_token_pool(last_hidden_state, inputs_batch['attention_mask']) flag = True except RuntimeError as e: batch_size = batch_size * 3 // 4 @@ -505,19 +502,16 @@ def encode_single_device( # adjust batch size flag = False - max_length_inputs = self.tokenizer.pad( - all_inputs_sorted[:1], - padding=True, - return_tensors='pt', - **kwargs - ).to(device) while flag is False: try: - test_inputs_batch = {} - for k, v in max_length_inputs.items(): - test_inputs_batch[k] = v.repeat(batch_size, 1) - last_hidden_state = self.model(**test_inputs_batch, return_dict=True).last_hidden_state - embeddings = last_token_pool(last_hidden_state, test_inputs_batch['attention_mask']) + inputs_batch = self.tokenizer.pad( + all_inputs_sorted[: batch_size], + padding=True, + return_tensors='pt', + **kwargs + ).to(device) + last_hidden_state = self.model(**inputs_batch, return_dict=True).last_hidden_state + embeddings = last_token_pool(last_hidden_state, inputs_batch['attention_mask']) flag = True except RuntimeError as e: batch_size = batch_size * 3 // 4 diff --git a/FlagEmbedding/inference/embedder/encoder_only/base.py b/FlagEmbedding/inference/embedder/encoder_only/base.py index 43c0595a..fe11b228 100644 --- a/FlagEmbedding/inference/embedder/encoder_only/base.py +++ b/FlagEmbedding/inference/embedder/encoder_only/base.py @@ -224,19 +224,16 @@ def encode_single_device( # adjust batch size flag = False - max_length_inputs = self.tokenizer.pad( - all_inputs_sorted[:1], - padding=True, - return_tensors='pt', - **kwargs - ).to(device) while flag is False: try: - test_inputs_batch = {} - for k, v in max_length_inputs.items(): - test_inputs_batch[k] = v.repeat(batch_size, 1) - last_hidden_state = self.model(**test_inputs_batch, return_dict=True).last_hidden_state - embeddings = self.pooling(last_hidden_state, test_inputs_batch['attention_mask']) + inputs_batch = self.tokenizer.pad( + all_inputs_sorted[: batch_size], + padding=True, + return_tensors='pt', + **kwargs + ).to(device) + last_hidden_state = self.model(**inputs_batch, return_dict=True).last_hidden_state + embeddings = self.pooling(last_hidden_state, inputs_batch['attention_mask']) flag = True except RuntimeError as e: batch_size = batch_size * 3 // 4 diff --git a/FlagEmbedding/inference/embedder/encoder_only/m3.py b/FlagEmbedding/inference/embedder/encoder_only/m3.py index 64a31a69..8a8aa041 100644 --- a/FlagEmbedding/inference/embedder/encoder_only/m3.py +++ b/FlagEmbedding/inference/embedder/encoder_only/m3.py @@ -388,19 +388,16 @@ def _process_colbert_vecs(colbert_vecs: np.ndarray, attention_mask: list): # adjust batch size flag = False - max_length_inputs = self.tokenizer.pad( - all_inputs_sorted[:1], - padding=True, - return_tensors='pt', - **kwargs - ).to(device) while flag is False: try: - test_inputs_batch = {} - for k, v in max_length_inputs.items(): - test_inputs_batch[k] = v.repeat(batch_size, 1) + inputs_batch = self.tokenizer.pad( + all_inputs_sorted[: batch_size], + padding=True, + return_tensors='pt', + **kwargs + ).to(device) outputs = self.model( - test_inputs_batch, + inputs_batch, return_dense=return_dense, return_sparse=return_sparse, return_colbert_vecs=return_colbert_vecs