From 7d8500f3d2fb22e9e258dcef295238d9a76b4365 Mon Sep 17 00:00:00 2001 From: sgwhat Date: Thu, 25 Apr 2024 20:00:19 +0800 Subject: [PATCH 1/2] enhance streaming without any thread blocking --- modules/text_generation.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index 8072fdf987..edb498441a 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -10,6 +10,7 @@ import numpy as np import torch from tqdm import tqdm +from threading import Thread import transformers from transformers import LogitsProcessorList, is_torch_xpu_available from transformers.generation import TextIteratorStreamer @@ -400,9 +401,11 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings # Stream the reply 1 token at a time. # This is based on the trick of using 'stopping_criteria' to create an iterator. else: - with torch.no_grad(): - output = shared.model.generate(**generate_params, streamer=streamer) - + generation_kwargs = {**generate_params, "streamer": streamer} + + thread = Thread(target=shared.model.generate, kwargs=generation_kwargs) + thread.start() + cumulative_reply = '' for new_content in tqdm(streamer, "Generating Tokens", unit="token"): # check the partial unicode character @@ -412,15 +415,14 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings cumulative_reply += new_content yield cumulative_reply - output_tokens = output.shape[1] - except Exception: traceback.print_exc() finally: t1 = time.time() original_tokens = len(original_input_ids[0]) - new_tokens = output_tokens - (original_tokens if not shared.is_seq2seq else 0) - print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') + if not state['stream']: + new_tokens = output_tokens - (original_tokens if not shared.is_seq2seq else 0) + print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') return From 5517615d5fb570c58547a78b165533fd4073e666 Mon Sep 17 00:00:00 2001 From: sgwhat Date: Thu, 25 Apr 2024 20:13:41 +0800 Subject: [PATCH 2/2] set do_sample to False --- modules/presets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/presets.py b/modules/presets.py index 42ca782001..38a6e3b92d 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -31,7 +31,7 @@ def default_preset(): 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, - 'do_sample': True, + 'do_sample': False, 'encoder_repetition_penalty': 1, 'no_repeat_ngram_size': 0, 'min_length': 0,