-
Notifications
You must be signed in to change notification settings - Fork 0
/
stepan-bot-v1.py
94 lines (75 loc) · 2.22 KB
/
stepan-bot-v1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModelWithLMHead
import torch
import os
import time
SP1 = '@@ПЕРВЫЙ@@'
SP2 = '@@ВТОРОЙ@@'
BASIC_PATH = './'
MODELS = {
'default': 'tinkoff-ai/ruDialoGPT-medium'
}
MODEL_NAME = MODELS['default']
def get_model_name_or_path():
if os.path.exists(f'{BASIC_PATH}/training/final/{MODEL_NAME}/default'):
return f'{BASIC_PATH}/training/final/{MODEL_NAME}/default'
return MODEL_NAME
def type_slowly(text):
for c in text:
print(c, end='')
time.sleep(0.1)
print()
CONFIGS = {
'default': dict(
top_k=15,
top_p=0.95,
num_beams=5,
num_return_sequences=1,
do_sample=True,
no_repeat_ngram_size=2,
temperature=2.1,
repetition_penalty=1.2,
length_penalty=0.5,
eos_token_id=50257,
pad_token_id=0,
max_new_tokens=40
),
'experiment': dict(
top_k=5,
top_p=0.97,
num_beams=5,
num_return_sequences=1,
do_sample=True,
no_repeat_ngram_size=2,
temperature=3.5,
repetition_penalty=1.2,
length_penalty=1.2,
eos_token_id=50257,
pad_token_id=0,
max_new_tokens=100
)
}
def run_chat(model, tokenizer):
history = f'{SP1} '
while True:
inp = input('>> User: ')
if inp == 'quit':
break
if inp == 'restart':
history = f'{SP1} '
inputs = tokenizer(history + inp + ' ' + SP2, return_tensors='pt')
print(f'>> StepanBot:', end=' ')
generated_token_ids = model.cpu().generate(
**inputs,
**CONFIGS['default']
)
while len(generated_token_ids[0]) >= 300:
generated_token_ids = generated_token_ids[:, 100:]
history = \
[tokenizer.decode(sample_token_ids, skip_special_tokens=True) for sample_token_ids in generated_token_ids][0]
type_slowly(history[history.rfind(SP2):].lstrip(SP2).rstrip(SP1).strip())
if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelWithLMHead.from_pretrained(get_model_name_or_path())
run_chat(model, tokenizer)