Skip to content
This repository has been archived by the owner on Jan 16, 2022. It is now read-only.

Commit

Permalink
line break filter
Browse files Browse the repository at this point in the history
  • Loading branch information
mgrankin committed Oct 9, 2019
1 parent 6c6f47e commit 775cbca
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
23 changes: 12 additions & 11 deletions rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@

lock = threading.RLock()

def get_sample(prompt, model, tokenizer, device, length:int=5, num_samples:int=3):
def get_sample(prompt, model, tokenizer, device, length:int, num_samples:int, allow_linebreak:bool):
logger.info("*" * 200)
logger.info(prompt)

model.to(device)
model.eval()

filter_n = tokenizer.encode('\n')[-1:]
filter_single = [tokenizer.sp.unk_id()]
filter_single += [] if allow_linebreak else filter_n

context_tokens = tokenizer.encode(prompt)
out = sample_sequence(
model=model,
Expand All @@ -31,35 +34,33 @@ def get_sample(prompt, model, tokenizer, device, length:int=5, num_samples:int=3
top_k=0,
top_p=0.9,
device=device,
filter_single=filter_single,
filter_double=filter_n,
num_samples=num_samples
)
out = out.to('cpu')
print(out)
num_samples=num_samples,
).to('cpu')
replies = [out[item, len(context_tokens):].tolist() for item in range(len(out))]
text = [tokenizer.decode(item) for item in replies]
reg_text = [re.match(r'[\w\W]*[\.!?]\n', item) for item in text]
result = [reg_item[0] if reg_item else item for reg_item, item in zip(reg_text,text)]
logger.info("=" * 200)
logger.info(result)
return result

tokenizer = SPEncoder.from_pretrained(path)

tokenizer = SPEncoder.from_pretrained(path)
model = GPT2LMHeadModel.from_pretrained(path)
model.to(device)
model.eval()


from fastapi import FastAPI

app = FastAPI()
app = FastAPI(title="Russian GPT-2", version="0.1",)

@app.get("/")
def read_root():
return {"Hello": "World"}

@app.get("/gpt2-large/{prompt}")
def gen_(prompt:str, length:int=5, num_samples:int=3):
return {"replies": get_sample(prompt, model, tokenizer, device, length, num_samples)}
def gen_(prompt:str, length:int=5, num_samples:int=3, allow_linebreak:bool=False):
with lock:
return {"replies": get_sample(prompt, model, tokenizer, device, length, num_samples, allow_linebreak)}

3 changes: 2 additions & 1 deletion run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
return logits

def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0,
is_xlnet=False, device='cpu', max_input=1023, filter_double=[]):
is_xlnet=False, device='cpu', max_input=1023, filter_single=[], filter_double=[]):
context = torch.tensor(context, dtype=torch.long, device=device)
context = context.unsqueeze(0).repeat(num_samples, 1)
generated = context
Expand All @@ -125,6 +125,7 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
next_tokens = torch.zeros(num_samples, dtype=torch.long).to(device)
for isample in range(num_samples):
next_token_logits = outputs[0][isample, -1, :] / temperature
next_token_logits[filter_single] = FILTER_VALUE
# filter blank line = double \n
if generated[isample, -1] in filter_double:
next_token_logits[generated[isample, -1]] = FILTER_VALUE
Expand Down

0 comments on commit 775cbca

Please sign in to comment.