Skip to content

Commit

Permalink
Add pytorch demo (mindee#1008)
Browse files Browse the repository at this point in the history
  • Loading branch information
odulcy-mindee authored Aug 26, 2022
1 parent 7c73cdf commit 61d0f1c
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 44 deletions.
41 changes: 34 additions & 7 deletions .github/workflows/demo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
matrix:
os: [ubuntu-latest]
python: [3.8]
framework: [tensorflow, pytorch]
steps:
- if: matrix.os == 'macos-latest'
name: Install MacOS prerequisites
Expand All @@ -24,22 +25,48 @@ jobs:
with:
python-version: ${{ matrix.python }}
architecture: x64
- name: Cache python modules
- if: matrix.framework == 'tensorflow'
name: Cache python modules (TF)
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('demo/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}-
- name: Install dependencies
key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('demo/tf-requirements.txt') }}
- if: matrix.framework == 'pytorch'
name: Cache python modules (PT)
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('demo/pt-requirements.txt') }}

- if: matrix.framework == 'tensorflow'
name: Install dependencies (TF)
run: |
python -m pip install --upgrade pip
pip install -e .[tf] --upgrade
pip install -r demo/requirements.txt
pip install -r demo/tf-requirements.txt
- if: matrix.framework == 'pytorch'
name: Install dependencies (PT)
run: |
python -m pip install --upgrade pip
pip install -e .[torch] --upgrade
pip install -r demo/pt-requirements.txt
- name: Run demo
- if: matrix.framework == 'tensorflow'
name: Run demo (TF)
env:
USE_TF: 1
run: |
streamlit --version
screen -dm streamlit run demo/app.py
sleep 10
curl http://localhost:8501/docs
- if: matrix.framework == 'pytorch'
name: Run demo (PT)
env:
USE_TORCH: 1
run: |
streamlit --version
screen -dm streamlit run demo/app.py
sleep 10
curl http://localhost:8501/docs
18 changes: 15 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,26 @@ Check it out [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%2

#### Running it locally

If you prefer to use it locally, there is an extra dependency ([Streamlit](https://streamlit.io/)) that is required:
If you prefer to use it locally, there is an extra dependency ([Streamlit](https://streamlit.io/)) that is required.

##### Tensorflow version
```shell
pip install -r demo/tf-requirements.txt
```
Then run your app in your default browser with:

```shell
USE_TF=1 streamlit run demo/app.py
```

##### PyTorch version
```shell
pip install -r demo/requirements.txt
pip install -r demo/pt-requirements.txt
```
Then run your app in your default browser with:

```shell
streamlit run demo/app.py
USE_TORCH=1 streamlit run demo/app.py
```

#### TensorFlow.js
Expand Down
68 changes: 34 additions & 34 deletions demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,35 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import os

import cv2
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
from doctr.file_utils import is_tf_available
from doctr.io import DocumentFile
from doctr.utils.visualization import visualize_page

import cv2
import tensorflow as tf
if is_tf_available():
import tensorflow as tf

gpu_devices = tf.config.experimental.list_physical_devices("GPU")
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)
from backend.tensorflow import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor

from doctr.io import DocumentFile
from doctr.models import ocr_predictor
from doctr.utils.visualization import visualize_page
if any(tf.config.experimental.list_physical_devices("gpu")):
forward_device = tf.device("/gpu:0")
else:
forward_device = tf.device("/cpu:0")

else:
import torch

DET_ARCHS = ["db_resnet50", "db_mobilenet_v3_large", "linknet_resnet18_rotation"]
RECO_ARCHS = ["crnn_vgg16_bn", "crnn_mobilenet_v3_small", "master", "sar_resnet31"]
from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor

forward_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def main():

def main(det_archs, reco_archs):
"""Build a streamlit layout"""

# Wide mode
st.set_page_config(layout="wide")
Expand Down Expand Up @@ -56,12 +62,14 @@ def main():
else:
doc = DocumentFile.from_images(uploaded_file.read())
page_idx = st.sidebar.selectbox("Page selection", [idx + 1 for idx in range(len(doc))]) - 1
cols[0].image(doc[page_idx])
page = doc[page_idx]
cols[0].image(page)

# Model selection
st.sidebar.title("Model selection")
det_arch = st.sidebar.selectbox("Text detection model", DET_ARCHS)
reco_arch = st.sidebar.selectbox("Text recognition model", RECO_ARCHS)
st.sidebar.markdown("**Backend**: " + ("TensorFlow" if is_tf_available() else "PyTorch"))
det_arch = st.sidebar.selectbox("Text detection model", det_archs)
reco_arch = st.sidebar.selectbox("Text recognition model", reco_archs)

# For newline
st.sidebar.write("\n")
Expand All @@ -73,37 +81,29 @@ def main():

else:
with st.spinner("Loading model..."):
predictor = ocr_predictor(
det_arch,
reco_arch,
pretrained=True,
assume_straight_pages=(det_arch != "linknet_resnet18_rotation"),
)
predictor = load_predictor(det_arch, reco_arch, forward_device)

with st.spinner("Analyzing..."):

# Forward the image to the model
processed_batches = predictor.det_predictor.pre_processor([doc[page_idx]])
out = predictor.det_predictor.model(processed_batches[0], return_model_output=True)
seg_map = out["out_map"]
seg_map = tf.squeeze(seg_map[0, ...], axis=[2])
seg_map = cv2.resize(
seg_map.numpy(), (doc[page_idx].shape[1], doc[page_idx].shape[0]), interpolation=cv2.INTER_LINEAR
)
seg_map = forward_image(predictor, page, forward_device)
seg_map = np.squeeze(seg_map)
seg_map = cv2.resize(seg_map, (page.shape[1], page.shape[0]), interpolation=cv2.INTER_LINEAR)

# Plot the raw heatmap
fig, ax = plt.subplots()
ax.imshow(seg_map)
ax.axis("off")
cols[1].pyplot(fig)

# Plot OCR output
out = predictor([doc[page_idx]])
fig = visualize_page(out.pages[0].export(), doc[page_idx], interactive=False)
out = predictor([page])
fig = visualize_page(out.pages[0].export(), page, interactive=False)
cols[2].pyplot(fig)

# Page reconsitution under input page
page_export = out.pages[0].export()
if det_arch != "linknet_resnet18_rotation":
if "rotation" not in det_arch:
img = out.pages[0].synthesize()
cols[3].image(img, clamp=True)

Expand All @@ -113,4 +113,4 @@ def main():


if __name__ == "__main__":
main()
main(DET_ARCHS, RECO_ARCHS)
37 changes: 37 additions & 0 deletions demo/backend/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (C) 2021-2022, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import numpy as np
import torch

from doctr.models import ocr_predictor
from doctr.models.predictor import OCRPredictor

DET_ARCHS = ["db_resnet50", "db_mobilenet_v3_large", "linknet_resnet50_rotation"]
RECO_ARCHS = ["crnn_vgg16_bn", "crnn_mobilenet_v3_small", "master", "sar_resnet31"]


def load_predictor(det_arch: str, reco_arch: str, device) -> OCRPredictor:
"""
Args:
device is torch.device
"""
predictor = ocr_predictor(
det_arch, reco_arch, pretrained=True, assume_straight_pages=("rotation" not in det_arch)
).to(device)
return predictor


def forward_image(predictor: OCRPredictor, image: np.ndarray, device) -> np.ndarray:
"""
Args:
device is torch.device
"""
with torch.no_grad():
processed_batches = predictor.det_predictor.pre_processor([image])
out = predictor.det_predictor.model(processed_batches[0].to(device), return_model_output=True)
seg_map = out["out_map"].to("cpu").numpy()

return seg_map
41 changes: 41 additions & 0 deletions demo/backend/tensorflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (C) 2021-2022, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import numpy as np
import tensorflow as tf

from doctr.models import ocr_predictor
from doctr.models.predictor import OCRPredictor

DET_ARCHS = ["db_resnet50", "db_mobilenet_v3_large", "linknet_resnet18_rotation"]
RECO_ARCHS = ["crnn_vgg16_bn", "crnn_mobilenet_v3_small", "master", "sar_resnet31"]


def load_predictor(det_arch: str, reco_arch: str, device) -> OCRPredictor:
"""
Args:
device is tf.device
"""
with device:
predictor = ocr_predictor(
det_arch, reco_arch, pretrained=True, assume_straight_pages=("rotation" not in det_arch)
)
return predictor


def forward_image(predictor: OCRPredictor, image: np.ndarray, device) -> np.ndarray:
"""
Args:
device is tf.device
"""
with device:
processed_batches = predictor.det_predictor.pre_processor([image])
out = predictor.det_predictor.model(processed_batches[0], return_model_output=True)
seg_map = out["out_map"]

with tf.device("/cpu:0"):
seg_map = tf.identity(seg_map).numpy()

return seg_map
2 changes: 2 additions & 0 deletions demo/pt-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-e git+https://github.com/mindee/doctr.git#egg=python-doctr[torch]
streamlit>=1.0.0
File renamed without changes.

0 comments on commit 61d0f1c

Please sign in to comment.