diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 4d8d9985e..bb4bcb712 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -1,12 +1,8 @@ FROM tensorflow/tensorflow:1.12.0-py3 ENV LANG=C.UTF-8 -RUN mkdir /gpt-2 +RUN mkdir /gpt-2 WORKDIR /gpt-2 -COPY requirements.txt download_model.sh /gpt-2/ -RUN apt-get update && \ - apt-get install -y curl && \ - sh download_model.sh 117M -RUN pip3 install -r requirements.txt - ADD . /gpt-2 +RUN pip3 install -r requirements.txt +RUN python3 download_model.py 117M diff --git a/Dockerfile.gpu b/Dockerfile.gpu index 4ea7d3701..b7b013bd3 100644 --- a/Dockerfile.gpu +++ b/Dockerfile.gpu @@ -12,10 +12,6 @@ ENV NVIDIA_VISIBLE_DEVICES=all \ RUN mkdir /gpt-2 WORKDIR /gpt-2 -COPY requirements.txt download_model.sh /gpt-2/ -RUN apt-get update && \ - apt-get install -y curl && \ - sh download_model.sh 117M -RUN pip3 install -r requirements.txt - ADD . /gpt-2 +RUN pip3 install -r requirements.txt +RUN python3 download_model.py 117M diff --git a/README.md b/README.md index 8ae00301f..390032625 100644 --- a/README.md +++ b/README.md @@ -17,12 +17,7 @@ Then, follow instructions for either native or Docker installation. ### Native Installation -Download the model data -``` -sh download_model.sh 117M -``` - -The remaining steps can optionally be done in a virtual environment using tools such as `virtualenv` or `conda`. +All steps can optionally be done in a virtual environment using tools such as `virtualenv` or `conda`. Install tensorflow 1.12 (with GPU support, if you have a GPU and want everything to run faster) ``` @@ -38,6 +33,11 @@ Install other python packages: pip3 install -r requirements.txt ``` +Download the model data +``` +python3 download_model.py 117M +``` + ### Docker Installation Build the Dockerfile and tag the created image as `gpt-2`: diff --git a/download_model.py b/download_model.py index 79c960f39..2a3829431 100644 --- a/download_model.py +++ b/download_model.py @@ -1,24 +1,27 @@ -#!/usr/bin/env python import os import sys import requests from tqdm import tqdm -if len(sys.argv)!=2: +if len(sys.argv) != 2: print('You must enter the model name as a parameter, e.g.: download_model.py 117M') sys.exit(1) + model = sys.argv[1] -#Create directory if it does not exist already, then do nothing -if not os.path.exists('models/'+model): - os.makedirs('models/'+model) -#download all the files + +subdir = os.path.join('models', model) +if not os.path.exists(subdir): + os.makedirs(subdir) + for filename in ['checkpoint','encoder.json','hparams.json','model.ckpt.data-00000-of-00001', 'model.ckpt.index', 'model.ckpt.meta', 'vocab.bpe']: - r = requests.get("https://storage.googleapis.com/gpt-2/models/"+model+"/"+filename,stream=True) - #wb flag required for windows - with open('models/'+model+'/'+filename,'wb') as currentFile: - fileSize = int(r.headers["content-length"]) - with tqdm(ncols=100,desc="Fetching "+filename,total=fileSize,unit_scale=True) as pbar: - #went for 1k for chunk_size. Motivation -> Ethernet packet size is around 1500 bytes. - for chunk in r.iter_content(chunk_size=1000): - currentFile.write(chunk) - pbar.update(1000) + + r = requests.get("https://storage.googleapis.com/gpt-2/" + subdir + "/" + filename, stream=True) + + with open(os.path.join(subdir, filename), 'wb') as f: + file_size = int(r.headers["content-length"]) + chunk_size = 1000 + with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar: + # 1k for chunk_size, since Ethernet packet size is around 1500 bytes + for chunk in r.iter_content(chunk_size=chunk_size): + f.write(chunk) + pbar.update(chunk_size) diff --git a/download_model.sh b/download_model.sh deleted file mode 100755 index 690463f2e..000000000 --- a/download_model.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/sh - -if [ "$#" -ne 1 ]; then - echo "You must enter the model name as a parameter, e.g.: sh download_model.sh 117M" - exit 1 -fi - -model=$1 - -mkdir -p models/$model - -# TODO: gsutil rsync -r gs://gpt-2/models/ models/ -for filename in checkpoint encoder.json hparams.json model.ckpt.data-00000-of-00001 model.ckpt.index model.ckpt.meta vocab.bpe; do - fetch=$model/$filename - echo "Fetching $fetch" - curl --output models/$fetch https://storage.googleapis.com/gpt-2/models/$fetch -done