Skip to content

Commit

Permalink
[misc] Change prefered backend from tf to torch (#1779)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Nov 18, 2024
1 parent 83f1bc5 commit c56bf41
Show file tree
Hide file tree
Showing 167 changed files with 431 additions and 1,016 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[tf,viz,html]
pip install -e .[torch,viz,html]
pip install -e .[docs]
- name: Build documentation
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ repos:
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.1
rev: v0.7.4
hooks:
- id: ruff
args: [ --fix ]
Expand Down
2 changes: 1 addition & 1 deletion api/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
.PHONY: lock run stop test
# Pin the dependencies
lock:
pip install poetry>=1.0
pip install poetry>=1.0 poetry-plugin-export
poetry lock
poetry export -f requirements.txt --without-hashes --output requirements.txt
poetry export -f requirements.txt --without-hashes --with dev --output requirements-dev.txt
Expand Down
2 changes: 0 additions & 2 deletions api/app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@ async def get_documents(files: List[UploadFile]) -> Tuple[List[np.ndarray], List
"""Convert a list of UploadFile objects to lists of numpy arrays and their corresponding filenames
Args:
----
files: list of UploadFile objects
Returns:
-------
Tuple[List[np.ndarray], List[str]]: list of numpy arrays and their corresponding filenames
"""
Expand Down
30 changes: 18 additions & 12 deletions api/app/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,34 @@
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.


import tensorflow as tf

gpu_devices = tf.config.list_physical_devices("GPU")
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)

from typing import Callable, Union

import torch

from doctr.models import kie_predictor, ocr_predictor

from .schemas import DetectionIn, KIEIn, OCRIn, RecognitionIn


def _move_to_device(predictor: Callable) -> Callable:
"""Move the predictor to the desired device
Args:
predictor: the predictor to move
Returns:
Callable: the predictor moved to the desired device
"""
return predictor.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))


def init_predictor(request: Union[KIEIn, OCRIn, RecognitionIn, DetectionIn]) -> Callable:
"""Initialize the predictor based on the request
Args:
----
request: input request
Returns:
-------
Callable: the predictor
"""
params = request.model_dump()
Expand All @@ -36,12 +42,12 @@ def init_predictor(request: Union[KIEIn, OCRIn, RecognitionIn, DetectionIn]) ->
predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
predictor.det_predictor.model.postprocessor.box_thresh = box_thresh
if isinstance(request, DetectionIn):
return predictor.det_predictor
return _move_to_device(predictor.det_predictor)
elif isinstance(request, RecognitionIn):
return predictor.reco_predictor
return predictor
return _move_to_device(predictor.reco_predictor)
return _move_to_device(predictor)
elif isinstance(request, KIEIn):
predictor = kie_predictor(pretrained=True, **params)
predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
predictor.det_predictor.model.postprocessor.box_thresh = box_thresh
return predictor
return _move_to_device(predictor)
2 changes: 0 additions & 2 deletions api/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
version: '3.8'

services:
web:
container_name: api_web
Expand Down
2 changes: 1 addition & 1 deletion api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ license = "Apache-2.0"

[tool.poetry.dependencies]
python = ">=3.10,<3.13"
python-doctr = {git = "https://github.com/mindee/doctr.git", extras = ['tf'], branch = "main" }
python-doctr = {git = "https://github.com/mindee/doctr.git", extras = ['torch'], branch = "main" }
# Fastapi: minimum version required to avoid pydantic error
# cf. https://github.com/tiangolo/fastapi/issues/4168
fastapi = ">=0.73.0"
Expand Down
201 changes: 103 additions & 98 deletions api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,32 +37,32 @@ def mock_detection_response():
"box": {
"name": "117319856-fc35bf00-ae8b-11eb-9b51-ca5aba673466.jpg",
"geometries": [
[0.8176307908857315, 0.1787109375, 0.9101580212741838, 0.2080078125],
[0.7471996155154171, 0.1796875, 0.8272978149561669, 0.20703125],
[0.8203927977629988, 0.181640625, 0.9087770178355502, 0.2041015625],
[0.7471996155154171, 0.1806640625, 0.8245358080788996, 0.2060546875],
],
},
"poly": {
"name": "117319856-fc35bf00-ae8b-11eb-9b51-ca5aba673466.jpg",
"geometries": [
[
0.9063061475753784,
0.17740710079669952,
0.9078840017318726,
0.20474515855312347,
0.8173396587371826,
0.20735852420330048,
0.8157618045806885,
0.18002046644687653,
0.8203927977629988,
0.181640625,
0.906015010958283,
0.181640625,
0.906015010958283,
0.2021484375,
0.8203927977629988,
0.2021484375,
],
[
0.8233299851417542,
0.17740298807621002,
0.8250390291213989,
0.2027825564146042,
0.7470247745513916,
0.20540954172611237,
0.7453157305717468,
0.1800299733877182,
0.7482568619833604,
0.17938309907913208,
0.8208542842026056,
0.1819499135017395,
0.8193355512950555,
0.2034294307231903,
0.7467381290758103,
0.20086261630058289,
],
],
},
Expand All @@ -82,17 +82,17 @@ def mock_kie_response():
"class_name": "words",
"items": [
{
"value": "Hello",
"geometry": [0.7471996155154171, 0.1796875, 0.8272978149561669, 0.20703125],
"objectness_score": 0.39,
"confidence": 1,
"value": "world!",
"geometry": [0.8203927977629988, 0.181640625, 0.9087770178355502, 0.2041015625],
"objectness_score": 0.46,
"confidence": 0.94,
"crop_orientation": {"value": 0, "confidence": None},
},
{
"value": "world!",
"geometry": [0.8176307908857315, 0.1787109375, 0.9101580212741838, 0.2080078125],
"objectness_score": 0.39,
"confidence": 1,
"value": "Hello",
"geometry": [0.7471996155154171, 0.1806640625, 0.8245358080788996, 0.2060546875],
"objectness_score": 0.46,
"confidence": 0.66,
"crop_orientation": {"value": 0, "confidence": None},
},
],
Expand All @@ -109,35 +109,35 @@ def mock_kie_response():
"class_name": "words",
"items": [
{
"value": "Hello",
"value": "world!",
"geometry": [
0.7453157305717468,
0.1800299733877182,
0.8233299851417542,
0.17740298807621002,
0.8250390291213989,
0.2027825564146042,
0.7470247745513916,
0.20540954172611237,
0.8203927977629988,
0.181640625,
0.906015010958283,
0.181640625,
0.906015010958283,
0.2021484375,
0.8203927977629988,
0.2021484375,
],
"objectness_score": 0.5,
"confidence": 0.99,
"objectness_score": 0.52,
"confidence": 1,
"crop_orientation": {"value": 0, "confidence": 1},
},
{
"value": "world!",
"value": "Hello",
"geometry": [
0.8157618045806885,
0.18002046644687653,
0.9063061475753784,
0.17740710079669952,
0.9078840017318726,
0.20474515855312347,
0.8173396587371826,
0.20735852420330048,
0.7482568619833604,
0.17938309907913208,
0.8208542842026056,
0.1819499135017395,
0.8193355512950555,
0.2034294307231903,
0.7467381290758103,
0.20086261630058289,
],
"objectness_score": 0.5,
"confidence": 1,
"objectness_score": 0.57,
"confidence": 0.65,
"crop_orientation": {"value": 0, "confidence": 1},
},
],
Expand All @@ -159,30 +159,35 @@ def mock_ocr_response():
{
"blocks": [
{
"geometry": [0.7471996155154171, 0.1787109375, 0.9101580212741838, 0.2080078125],
"objectness_score": 0.39,
"geometry": [0.7471996155154171, 0.1806640625, 0.9087770178355502, 0.2060546875],
"objectness_score": 0.46,
"lines": [
{
"geometry": [0.7471996155154171, 0.1787109375, 0.9101580212741838, 0.2080078125],
"objectness_score": 0.39,
"geometry": [0.7471996155154171, 0.1806640625, 0.9087770178355502, 0.2060546875],
"objectness_score": 0.46,
"words": [
{
"value": "Hello",
"geometry": [0.7471996155154171, 0.1796875, 0.8272978149561669, 0.20703125],
"objectness_score": 0.39,
"confidence": 1,
"geometry": [
0.7471996155154171,
0.1806640625,
0.8245358080788996,
0.2060546875,
],
"objectness_score": 0.46,
"confidence": 0.66,
"crop_orientation": {"value": 0, "confidence": None},
},
{
"value": "world!",
"geometry": [
0.8176307908857315,
0.1787109375,
0.9101580212741838,
0.2080078125,
0.8203927977629988,
0.181640625,
0.9087770178355502,
0.2041015625,
],
"objectness_score": 0.39,
"confidence": 1,
"objectness_score": 0.46,
"confidence": 0.94,
"crop_orientation": {"value": 0, "confidence": None},
},
],
Expand All @@ -203,59 +208,59 @@ def mock_ocr_response():
"blocks": [
{
"geometry": [
0.7451040148735046,
0.17927837371826172,
0.9062581658363342,
0.17407986521720886,
0.9072266221046448,
0.2041015625,
0.7460724711418152,
0.20930007100105286,
0.7460642457008362,
0.2017778754234314,
0.7464945912361145,
0.17868199944496155,
0.9056554436683655,
0.18164771795272827,
0.9052250981330872,
0.20474359393119812,
],
"objectness_score": 0.5,
"objectness_score": 0.54,
"lines": [
{
"geometry": [
0.7451040148735046,
0.17927837371826172,
0.9062581658363342,
0.17407986521720886,
0.9072266221046448,
0.2041015625,
0.7460724711418152,
0.20930007100105286,
0.7460642457008362,
0.2017778754234314,
0.7464945912361145,
0.17868199944496155,
0.9056554436683655,
0.18164771795272827,
0.9052250981330872,
0.20474359393119812,
],
"objectness_score": 0.5,
"objectness_score": 0.54,
"words": [
{
"value": "Hello",
"geometry": [
0.7453157305717468,
0.1800299733877182,
0.8233299851417542,
0.17740298807621002,
0.8250390291213989,
0.2027825564146042,
0.7470247745513916,
0.20540954172611237,
0.7482568619833604,
0.17938309907913208,
0.8208542842026056,
0.1819499135017395,
0.8193355512950555,
0.2034294307231903,
0.7467381290758103,
0.20086261630058289,
],
"objectness_score": 0.5,
"confidence": 0.99,
"objectness_score": 0.57,
"confidence": 0.65,
"crop_orientation": {"value": 0, "confidence": 1},
},
{
"value": "world!",
"geometry": [
0.8157618045806885,
0.18002046644687653,
0.9063061475753784,
0.17740710079669952,
0.9078840017318726,
0.20474515855312347,
0.8173396587371826,
0.20735852420330048,
0.8203927977629988,
0.181640625,
0.906015010958283,
0.181640625,
0.906015010958283,
0.2021484375,
0.8203927977629988,
0.2021484375,
],
"objectness_score": 0.5,
"objectness_score": 0.52,
"confidence": 1,
"crop_orientation": {"value": 0, "confidence": 1},
},
Expand Down
Loading

0 comments on commit c56bf41

Please sign in to comment.