Skip to content

Commit

Permalink
Merge pull request #1229 from 545999961/master
Browse files Browse the repository at this point in the history
update adjust batch size
  • Loading branch information
545999961 authored Nov 15, 2024
2 parents fde4abd + 3e9c603 commit 3c223e5
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 54 deletions.
19 changes: 8 additions & 11 deletions FlagEmbedding/inference/embedder/decoder_only/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,19 +243,16 @@ def encode_single_device(

# adjust batch size
flag = False
max_length_inputs = self.tokenizer.pad(
all_inputs_sorted[:1],
padding=True,
return_tensors='pt',
**kwargs
).to(device)
while flag is False:
try:
test_inputs_batch = {}
for k, v in max_length_inputs.items():
test_inputs_batch[k] = v.repeat(batch_size, 1)
last_hidden_state = self.model(**test_inputs_batch, return_dict=True).last_hidden_state
embeddings = last_token_pool(last_hidden_state, test_inputs_batch['attention_mask'])
inputs_batch = self.tokenizer.pad(
all_inputs_sorted[: batch_size],
padding=True,
return_tensors='pt',
**kwargs
).to(device)
last_hidden_state = self.model(**inputs_batch, return_dict=True).last_hidden_state
embeddings = last_token_pool(last_hidden_state, inputs_batch['attention_mask'])
flag = True
except RuntimeError as e:
batch_size = batch_size * 3 // 4
Expand Down
38 changes: 16 additions & 22 deletions FlagEmbedding/inference/embedder/decoder_only/icl.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,19 +393,16 @@ def encode_queries_single_device(

# adjust batch size
flag = False
max_length_inputs = self.tokenizer.pad(
all_inputs_sorted[:1],
padding=True,
return_tensors='pt',
**kwargs
).to(device)
while flag is False:
try:
test_inputs_batch = {}
for k, v in max_length_inputs.items():
test_inputs_batch[k] = v.repeat(batch_size, 1)
last_hidden_state = self.model(**test_inputs_batch, return_dict=True).last_hidden_state
embeddings = last_token_pool(last_hidden_state, test_inputs_batch['attention_mask'])
inputs_batch = self.tokenizer.pad(
all_inputs_sorted[: batch_size],
padding=True,
return_tensors='pt',
**kwargs
).to(device)
last_hidden_state = self.model(**inputs_batch, return_dict=True).last_hidden_state
embeddings = last_token_pool(last_hidden_state, inputs_batch['attention_mask'])
flag = True
except RuntimeError as e:
batch_size = batch_size * 3 // 4
Expand Down Expand Up @@ -505,19 +502,16 @@ def encode_single_device(

# adjust batch size
flag = False
max_length_inputs = self.tokenizer.pad(
all_inputs_sorted[:1],
padding=True,
return_tensors='pt',
**kwargs
).to(device)
while flag is False:
try:
test_inputs_batch = {}
for k, v in max_length_inputs.items():
test_inputs_batch[k] = v.repeat(batch_size, 1)
last_hidden_state = self.model(**test_inputs_batch, return_dict=True).last_hidden_state
embeddings = last_token_pool(last_hidden_state, test_inputs_batch['attention_mask'])
inputs_batch = self.tokenizer.pad(
all_inputs_sorted[: batch_size],
padding=True,
return_tensors='pt',
**kwargs
).to(device)
last_hidden_state = self.model(**inputs_batch, return_dict=True).last_hidden_state
embeddings = last_token_pool(last_hidden_state, inputs_batch['attention_mask'])
flag = True
except RuntimeError as e:
batch_size = batch_size * 3 // 4
Expand Down
19 changes: 8 additions & 11 deletions FlagEmbedding/inference/embedder/encoder_only/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,19 +224,16 @@ def encode_single_device(

# adjust batch size
flag = False
max_length_inputs = self.tokenizer.pad(
all_inputs_sorted[:1],
padding=True,
return_tensors='pt',
**kwargs
).to(device)
while flag is False:
try:
test_inputs_batch = {}
for k, v in max_length_inputs.items():
test_inputs_batch[k] = v.repeat(batch_size, 1)
last_hidden_state = self.model(**test_inputs_batch, return_dict=True).last_hidden_state
embeddings = self.pooling(last_hidden_state, test_inputs_batch['attention_mask'])
inputs_batch = self.tokenizer.pad(
all_inputs_sorted[: batch_size],
padding=True,
return_tensors='pt',
**kwargs
).to(device)
last_hidden_state = self.model(**inputs_batch, return_dict=True).last_hidden_state
embeddings = self.pooling(last_hidden_state, inputs_batch['attention_mask'])
flag = True
except RuntimeError as e:
batch_size = batch_size * 3 // 4
Expand Down
17 changes: 7 additions & 10 deletions FlagEmbedding/inference/embedder/encoder_only/m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,19 +388,16 @@ def _process_colbert_vecs(colbert_vecs: np.ndarray, attention_mask: list):

# adjust batch size
flag = False
max_length_inputs = self.tokenizer.pad(
all_inputs_sorted[:1],
padding=True,
return_tensors='pt',
**kwargs
).to(device)
while flag is False:
try:
test_inputs_batch = {}
for k, v in max_length_inputs.items():
test_inputs_batch[k] = v.repeat(batch_size, 1)
inputs_batch = self.tokenizer.pad(
all_inputs_sorted[: batch_size],
padding=True,
return_tensors='pt',
**kwargs
).to(device)
outputs = self.model(
test_inputs_batch,
inputs_batch,
return_dense=return_dense,
return_sparse=return_sparse,
return_colbert_vecs=return_colbert_vecs
Expand Down

0 comments on commit 3c223e5

Please sign in to comment.