Skip to content

Commit

Permalink
Merge pull request #38 from sgwhat/streaming-block
Browse files Browse the repository at this point in the history
Enhance streaming-chat in case output blocking
  • Loading branch information
sgwhat authored Apr 25, 2024
2 parents ca0413d + 5517615 commit d3f9222
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
2 changes: 1 addition & 1 deletion modules/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def default_preset():
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'do_sample': True,
'do_sample': False,
'encoder_repetition_penalty': 1,
'no_repeat_ngram_size': 0,
'min_length': 0,
Expand Down
16 changes: 9 additions & 7 deletions modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import torch
from tqdm import tqdm
from threading import Thread
import transformers
from transformers import LogitsProcessorList, is_torch_xpu_available
from transformers.generation import TextIteratorStreamer
Expand Down Expand Up @@ -400,9 +401,11 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
# Stream the reply 1 token at a time.
# This is based on the trick of using 'stopping_criteria' to create an iterator.
else:
with torch.no_grad():
output = shared.model.generate(**generate_params, streamer=streamer)

generation_kwargs = {**generate_params, "streamer": streamer}

thread = Thread(target=shared.model.generate, kwargs=generation_kwargs)
thread.start()

cumulative_reply = ''
for new_content in tqdm(streamer, "Generating Tokens", unit="token"):
# check the partial unicode character
Expand All @@ -412,15 +415,14 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
cumulative_reply += new_content
yield cumulative_reply

output_tokens = output.shape[1]

except Exception:
traceback.print_exc()
finally:
t1 = time.time()
original_tokens = len(original_input_ids[0])
new_tokens = output_tokens - (original_tokens if not shared.is_seq2seq else 0)
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
if not state['stream']:
new_tokens = output_tokens - (original_tokens if not shared.is_seq2seq else 0)
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return


Expand Down

0 comments on commit d3f9222

Please sign in to comment.