Skip to content
This repository has been archived by the owner on Jan 16, 2022. It is now read-only.

Commit

Permalink
deploy
Browse files Browse the repository at this point in the history
  • Loading branch information
mgrankin committed Oct 9, 2019
1 parent 775cbca commit 81f570e
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 17 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,5 @@ apex/
dataset/
.vscode/tags
corpus/Untitled.ipynb
gpt2/
Untitled.ipynb
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ done

```

### 8. Deploy your model
### 8. Save trained model

# upload
``` bash
aws s3 cp output_s/config.json s3://models.dobro.ai/gpt2/ru/small/
aws s3 cp output_s/encoder.model s3://models.dobro.ai/gpt2/ru/small/
aws s3 cp output_s/pytorch_model.bin s3://models.dobro.ai/gpt2/ru/small/
Expand All @@ -143,6 +143,15 @@ aws s3 cp output_m/pytorch_model.bin s3://models.dobro.ai/gpt2/ru/medium/
aws s3 cp output_l/config.json s3://models.dobro.ai/gpt2/ru/large/
aws s3 cp output_l/encoder.model s3://models.dobro.ai/gpt2/ru/large/
aws s3 cp output_l/pytorch_model.bin s3://models.dobro.ai/gpt2/ru/large/
```

# download
### 9. Deploy the model

``` bash
git clone https://github.com/mgrankin/ru_transformers.git
cd ru_transformers
aws s3 sync --no-sign-request s3://models.dobro.ai/gpt2/ru gpt2
conda env create -f environment.yml
conda activate gpt
uvicorn rest:app --reload --host 0.0.0.0
```
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ dependencies:
- tendo
- schedule
- fastapi>=0.41
- awscli
- awscli
28 changes: 15 additions & 13 deletions rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,20 @@

import logging

logging.basicConfig(filename="stihbot.log", level=logging.INFO)
logging.basicConfig(filename="rest.log", level=logging.INFO)
logger = logging.getLogger(__name__)

device="cuda"
path = 'gpt2/medium'
import yaml
cfg = yaml.safe_load(open('rest_config.yaml'))
device = cfg['device']
model_path = cfg['model_path']

lock = threading.RLock()
tokenizer = SPEncoder.from_pretrained(model_path)
model = GPT2LMHeadModel.from_pretrained(model_path)
model.to(device)
model.eval()

def get_sample(prompt, model, tokenizer, device, length:int, num_samples:int, allow_linebreak:bool):
def get_sample(prompt, length:int, num_samples:int, allow_linebreak:bool):
logger.info("*" * 200)
logger.info(prompt)

Expand Down Expand Up @@ -46,11 +51,6 @@ def get_sample(prompt, model, tokenizer, device, length:int, num_samples:int, al
logger.info(result)
return result

tokenizer = SPEncoder.from_pretrained(path)
model = GPT2LMHeadModel.from_pretrained(path)
model.to(device)
model.eval()

from fastapi import FastAPI

app = FastAPI(title="Russian GPT-2", version="0.1",)
Expand All @@ -59,8 +59,10 @@ def get_sample(prompt, model, tokenizer, device, length:int, num_samples:int, al
def read_root():
return {"Hello": "World"}

@app.get("/gpt2-large/{prompt}")
def gen_(prompt:str, length:int=5, num_samples:int=3, allow_linebreak:bool=False):
lock = threading.RLock()

@app.get("/" + model_path + "/{prompt}")
def gen_sample(prompt:str, length:int=10, num_samples:int=3, allow_linebreak:bool=False):
with lock:
return {"replies": get_sample(prompt, model, tokenizer, device, length, num_samples, allow_linebreak)}
return {"replies": get_sample(prompt, length, num_samples, allow_linebreak)}

2 changes: 2 additions & 0 deletions rest_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model_path: gpt2/medium
device: cuda

0 comments on commit 81f570e

Please sign in to comment.