Skip to content

Commit

Permalink
chore: Added CI jobs to check classification training (mindee#457)
Browse files Browse the repository at this point in the history
* feat: Added options in classification script

Added options to change the font family and the number of generated samples

* chore: Added CI job to check the classification training script

* chore: Increased sleep for CI job checking the demo

* chore: Ensured the training is run in CI for only a single epoch

* chore: Installed free fonts before CI checks

* chore: Making sure to avoid permission denied
  • Loading branch information
fg-mindee authored Sep 3, 2021
1 parent 64c7864 commit 3e7e9de
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 10 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/demo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ jobs:
run: |
streamlit --version
screen -dm streamlit run demo/app.py
sleep 5
curl http://localhost:8501
sleep 10
curl http://localhost:8501
77 changes: 77 additions & 0 deletions .github/workflows/references.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
name: references

on:
push:
branches: main
pull_request:
branches: main

jobs:
test-classification-tf:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python: [3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python }}
architecture: x64
- name: Cache python modules
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('references/requirements.txt') }}-${{ hashFiles('**/*.py') }}
restore-keys: |
${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('references/requirements.txt') }}-
${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}-
${{ runner.os }}-pkg-deps-${{ matrix.python }}-
${{ runner.os }}-pkg-deps-
${{ runner.os }}-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[tf] --upgrade
pip install -r references/requirements.txt
sudo apt-get update && sudo apt-get install fonts-freefont-ttf -y
- name: Train for a short epoch
run: python references/classification/train_tensorflow.py mobilenet_v3_small -b 32 --val-samples 1 --train-samples 1 --epochs 1

test-classification-torch:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python: [3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python }}
architecture: x64
- name: Cache python modules
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-${{ hashFiles('references/requirements.txt') }}-${{ hashFiles('**/*.py') }}
restore-keys: |
${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-${{ hashFiles('references/requirements.txt') }}-
${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-
${{ runner.os }}-pkg-deps-${{ matrix.python }}-
${{ runner.os }}-pkg-deps-
${{ runner.os }}-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[torch] --upgrade
pip install -r references/requirements.txt
pip install contiguous-params
sudo apt-get update && sudo apt-get install fonts-freefont-ttf -y
- name: Train for a short epoch
run: python references/classification/train_pytorch.py mobilenet_v3_small -b 32 --val-samples 1 --train-samples 1 --epochs 1
23 changes: 19 additions & 4 deletions references/classification/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def main(args):
st = time.time()
val_set = CharacterGenerator(
vocab=vocab,
num_samples=20 * len(vocab),
num_samples=args.val_samples * len(vocab),
cache_samples=True,
sample_transforms=T.Resize((args.input_size, args.input_size)),
font_family="FreeMono.ttf",
font_family=args.font,
)
val_loader = DataLoader(
val_set,
Expand Down Expand Up @@ -123,7 +123,7 @@ def main(args):
# Load train data generator
train_set = CharacterGenerator(
vocab=vocab,
num_samples=1000 * len(vocab),
num_samples=args.train_samples * len(vocab),
cache_samples=True,
sample_transforms=Compose([
T.Resize((args.input_size, args.input_size)),
Expand All @@ -132,7 +132,7 @@ def main(args):
T.RandomApply(T.ColorInversion(), .7),
ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02),
]),
font_family="FreeMono.ttf",
font_family=args.font,
)

train_loader = DataLoader(
Expand Down Expand Up @@ -224,6 +224,21 @@ def parse_args():
parser.add_argument('--wd', '--weight-decay', default=0, type=float, help='weight decay', dest='weight_decay')
parser.add_argument('-j', '--workers', type=int, default=None, help='number of workers used for dataloading')
parser.add_argument('--resume', type=str, default=None, help='Path to your checkpoint')
parser.add_argument('--font', type=str, default="FreeMono.ttf", help='Font family to be used')
parser.add_argument(
'--train-samples',
dest='train_samples',
type=int,
default=1000,
help='Multiplied by the vocab length gets you the number of training samples that will be used.'
)
parser.add_argument(
'--val-samples',
dest='val_samples',
type=int,
default=20,
help='Multiplied by the vocab length gets you the number of validation samples that will be used.'
)
parser.add_argument("--test-only", dest='test_only', action='store_true', help="Run the validation loop")
parser.add_argument('--show-samples', dest='show_samples', action='store_true',
help='Display unormalized training samples')
Expand Down
23 changes: 19 additions & 4 deletions references/classification/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def main(args):
st = time.time()
val_set = CharacterGenerator(
vocab=vocab,
num_samples=20 * len(vocab),
num_samples=args.val_samples * len(vocab),
cache_samples=True,
sample_transforms=T.Resize((args.input_size, args.input_size)),
font_family="FreeMono.ttf",
font_family=args.font,
)
val_loader = DataLoader(
val_set,
Expand Down Expand Up @@ -123,7 +123,7 @@ def main(args):
# Load train data generator
train_set = CharacterGenerator(
vocab=vocab,
num_samples=1000 * len(vocab),
num_samples=args.train_samples * len(vocab),
cache_samples=True,
sample_transforms=T.Compose([
T.Resize((args.input_size, args.input_size)),
Expand All @@ -134,7 +134,7 @@ def main(args):
T.RandomContrast(.3),
T.RandomBrightness(.3),
]),
font_family="FreeMono.ttf",
font_family=args.font,
)
train_loader = DataLoader(
train_set,
Expand Down Expand Up @@ -228,6 +228,21 @@ def parse_args():
parser.add_argument('--lr', type=float, default=0.001, help='learning rate for the optimizer (Adam)')
parser.add_argument('-j', '--workers', type=int, default=4, help='number of workers used for dataloading')
parser.add_argument('--resume', type=str, default=None, help='Path to your checkpoint')
parser.add_argument('--font', type=str, default="FreeMono.ttf", help='Font family to be used')
parser.add_argument(
'--train-samples',
dest='train_samples',
type=int,
default=1000,
help='Multiplied by the vocab length gets you the number of training samples that will be used.'
)
parser.add_argument(
'--val-samples',
dest='val_samples',
type=int,
default=20,
help='Multiplied by the vocab length gets you the number of validation samples that will be used.'
)
parser.add_argument("--test-only", dest='test_only', action='store_true', help="Run the validation loop")
parser.add_argument('--show-samples', dest='show_samples', action='store_true',
help='Display unormalized training samples')
Expand Down

0 comments on commit 3e7e9de

Please sign in to comment.