diff --git a/.github/workflows/demo.yml b/.github/workflows/demo.yml index 6baac72f40..9a399628ee 100644 --- a/.github/workflows/demo.yml +++ b/.github/workflows/demo.yml @@ -45,5 +45,5 @@ jobs: run: | streamlit --version screen -dm streamlit run demo/app.py - sleep 5 - curl http://localhost:8501 \ No newline at end of file + sleep 10 + curl http://localhost:8501 diff --git a/.github/workflows/references.yml b/.github/workflows/references.yml new file mode 100644 index 0000000000..f8be64a205 --- /dev/null +++ b/.github/workflows/references.yml @@ -0,0 +1,77 @@ +name: references + +on: + push: + branches: main + pull_request: + branches: main + +jobs: + test-classification-tf: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: [3.8] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('references/requirements.txt') }}-${{ hashFiles('**/*.py') }} + restore-keys: | + ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('references/requirements.txt') }}- + ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}- + ${{ runner.os }}-pkg-deps-${{ matrix.python }}- + ${{ runner.os }}-pkg-deps- + ${{ runner.os }}- + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[tf] --upgrade + pip install -r references/requirements.txt + sudo apt-get update && sudo apt-get install fonts-freefont-ttf -y + - name: Train for a short epoch + run: python references/classification/train_tensorflow.py mobilenet_v3_small -b 32 --val-samples 1 --train-samples 1 --epochs 1 + + test-classification-torch: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: [3.8] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-${{ hashFiles('references/requirements.txt') }}-${{ hashFiles('**/*.py') }} + restore-keys: | + ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-${{ hashFiles('references/requirements.txt') }}- + ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}- + ${{ runner.os }}-pkg-deps-${{ matrix.python }}- + ${{ runner.os }}-pkg-deps- + ${{ runner.os }}- + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[torch] --upgrade + pip install -r references/requirements.txt + pip install contiguous-params + sudo apt-get update && sudo apt-get install fonts-freefont-ttf -y + - name: Train for a short epoch + run: python references/classification/train_pytorch.py mobilenet_v3_small -b 32 --val-samples 1 --train-samples 1 --epochs 1 diff --git a/references/classification/train_pytorch.py b/references/classification/train_pytorch.py index 0c33793938..c265fc9e51 100644 --- a/references/classification/train_pytorch.py +++ b/references/classification/train_pytorch.py @@ -85,10 +85,10 @@ def main(args): st = time.time() val_set = CharacterGenerator( vocab=vocab, - num_samples=20 * len(vocab), + num_samples=args.val_samples * len(vocab), cache_samples=True, sample_transforms=T.Resize((args.input_size, args.input_size)), - font_family="FreeMono.ttf", + font_family=args.font, ) val_loader = DataLoader( val_set, @@ -123,7 +123,7 @@ def main(args): # Load train data generator train_set = CharacterGenerator( vocab=vocab, - num_samples=1000 * len(vocab), + num_samples=args.train_samples * len(vocab), cache_samples=True, sample_transforms=Compose([ T.Resize((args.input_size, args.input_size)), @@ -132,7 +132,7 @@ def main(args): T.RandomApply(T.ColorInversion(), .7), ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02), ]), - font_family="FreeMono.ttf", + font_family=args.font, ) train_loader = DataLoader( @@ -224,6 +224,21 @@ def parse_args(): parser.add_argument('--wd', '--weight-decay', default=0, type=float, help='weight decay', dest='weight_decay') parser.add_argument('-j', '--workers', type=int, default=None, help='number of workers used for dataloading') parser.add_argument('--resume', type=str, default=None, help='Path to your checkpoint') + parser.add_argument('--font', type=str, default="FreeMono.ttf", help='Font family to be used') + parser.add_argument( + '--train-samples', + dest='train_samples', + type=int, + default=1000, + help='Multiplied by the vocab length gets you the number of training samples that will be used.' + ) + parser.add_argument( + '--val-samples', + dest='val_samples', + type=int, + default=20, + help='Multiplied by the vocab length gets you the number of validation samples that will be used.' + ) parser.add_argument("--test-only", dest='test_only', action='store_true', help="Run the validation loop") parser.add_argument('--show-samples', dest='show_samples', action='store_true', help='Display unormalized training samples') diff --git a/references/classification/train_tensorflow.py b/references/classification/train_tensorflow.py index c52847c953..10148da4c2 100644 --- a/references/classification/train_tensorflow.py +++ b/references/classification/train_tensorflow.py @@ -81,10 +81,10 @@ def main(args): st = time.time() val_set = CharacterGenerator( vocab=vocab, - num_samples=20 * len(vocab), + num_samples=args.val_samples * len(vocab), cache_samples=True, sample_transforms=T.Resize((args.input_size, args.input_size)), - font_family="FreeMono.ttf", + font_family=args.font, ) val_loader = DataLoader( val_set, @@ -123,7 +123,7 @@ def main(args): # Load train data generator train_set = CharacterGenerator( vocab=vocab, - num_samples=1000 * len(vocab), + num_samples=args.train_samples * len(vocab), cache_samples=True, sample_transforms=T.Compose([ T.Resize((args.input_size, args.input_size)), @@ -134,7 +134,7 @@ def main(args): T.RandomContrast(.3), T.RandomBrightness(.3), ]), - font_family="FreeMono.ttf", + font_family=args.font, ) train_loader = DataLoader( train_set, @@ -228,6 +228,21 @@ def parse_args(): parser.add_argument('--lr', type=float, default=0.001, help='learning rate for the optimizer (Adam)') parser.add_argument('-j', '--workers', type=int, default=4, help='number of workers used for dataloading') parser.add_argument('--resume', type=str, default=None, help='Path to your checkpoint') + parser.add_argument('--font', type=str, default="FreeMono.ttf", help='Font family to be used') + parser.add_argument( + '--train-samples', + dest='train_samples', + type=int, + default=1000, + help='Multiplied by the vocab length gets you the number of training samples that will be used.' + ) + parser.add_argument( + '--val-samples', + dest='val_samples', + type=int, + default=20, + help='Multiplied by the vocab length gets you the number of validation samples that will be used.' + ) parser.add_argument("--test-only", dest='test_only', action='store_true', help="Run the validation loop") parser.add_argument('--show-samples', dest='show_samples', action='store_true', help='Display unormalized training samples')