You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def vim_input():
with NamedTemporaryFile() as tempfile:
tempfile.close()
subprocess.call(['vim', '+star', tempfile.name])
text = open(tempfile.name).read()
return text
def main(stream=True):
model, tokenizer = init_model()
messages = []
while True:
prompt = input(Fore.GREEN + Style.BRIGHT + "\n用户:" + Style.NORMAL)
if prompt.strip() == "exit":
break
if prompt.strip() == "clear":
messages = clear_screen()
continue
if prompt.strip() == 'vim':
prompt = vim_input()
print(prompt)
print(Fore.CYAN + Style.BRIGHT + "\nBaichuan:" + Style.NORMAL, end='')
if prompt.strip() == "stream":
stream = not stream
print(Fore.YELLOW + "({}流式生成)\n".format("开启" if stream else "关闭"), end='')
continue
messages.append({"role": "user", "content": prompt})
if stream:
position = 0
try:
inputs = tokenizer("你好,你是谁?", return_tensors='pt')
response = model.generate(inputs['input_ids'].cuda(), max_new_tokens=512)
# for response in model.chat(tokenizer, messages, stream=True):
print(tokenizer.batch_decode(response[position:].tolist()), end='', flush=True)
position = len(response)
if torch.backends.mps.is_available():
torch.mps.empty_cache()
except KeyboardInterrupt:
pass
print()
else:
response = model.chat(tokenizer, messages)
print(response)
if torch.backends.mps.is_available():
torch.mps.empty_cache()
messages.append({"role": "assistant", "content": response})
print(Style.RESET_ALL)
if name == "main":
main()
`
The text was updated successfully, but these errors were encountered:
Hi,We I use TensorParallelPreTrainedModel, I can load it normally, But when I invoke generate API, it will segmentation fault, This is my code:
`
import os
import torch
import platform
import subprocess
from colorama import Fore, Style
from tempfile import NamedTemporaryFile
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
from tensor_parallel import TensorParallelPreTrainedModel
def init_model():
print("init model ...")
model_id = 'Baichuan2-7B-Chat'
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
# device_map="auto",
trust_remote_code=True
)
model = TensorParallelPreTrainedModel(model,["cuda:0", "cuda:1",]) #"cuda:0", "cuda:1",
model.generation_config = GenerationConfig.from_pretrained(
model_id
)
tokenizer = AutoTokenizer.from_pretrained(
model_id,
use_fast=False,
trust_remote_code=True
)
return model, tokenizer
def clear_screen():
if platform.system() == "Windows":
os.system("cls")
else:
os.system("clear")
print(Fore.YELLOW + Style.BRIGHT + "欢迎使用百川大模型,输入进行对话,vim 多行输入,clear 清空历史,CTRL+C 中断生成,stream 开关流式生成,exit 结束。")
return []
def vim_input():
with NamedTemporaryFile() as tempfile:
tempfile.close()
subprocess.call(['vim', '+star', tempfile.name])
text = open(tempfile.name).read()
return text
def main(stream=True):
model, tokenizer = init_model()
messages = []
while True:
prompt = input(Fore.GREEN + Style.BRIGHT + "\n用户:" + Style.NORMAL)
if prompt.strip() == "exit":
break
if prompt.strip() == "clear":
messages = clear_screen()
continue
if prompt.strip() == 'vim':
prompt = vim_input()
print(prompt)
print(Fore.CYAN + Style.BRIGHT + "\nBaichuan:" + Style.NORMAL, end='')
if prompt.strip() == "stream":
stream = not stream
print(Fore.YELLOW + "({}流式生成)\n".format("开启" if stream else "关闭"), end='')
continue
messages.append({"role": "user", "content": prompt})
if stream:
position = 0
try:
inputs = tokenizer("你好,你是谁?", return_tensors='pt')
response = model.generate(inputs['input_ids'].cuda(), max_new_tokens=512)
# for response in model.chat(tokenizer, messages, stream=True):
print(tokenizer.batch_decode(response[position:].tolist()), end='', flush=True)
position = len(response)
if torch.backends.mps.is_available():
torch.mps.empty_cache()
except KeyboardInterrupt:
pass
print()
else:
response = model.chat(tokenizer, messages)
print(response)
if torch.backends.mps.is_available():
torch.mps.empty_cache()
messages.append({"role": "assistant", "content": response})
print(Style.RESET_ALL)
if name == "main":
main()
`
The text was updated successfully, but these errors were encountered: