Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use native PIL/torch in autotagger #5

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

from os import getenv
from dotenv import load_dotenv
from autotagger import Autotagger
from autotagger import Autotagger, read_image
from base64 import b64encode
from fastai.vision.core import PILImage
from flask import Flask, request, render_template, jsonify, abort
from werkzeug.exceptions import HTTPException
import torch
Expand Down Expand Up @@ -33,7 +32,7 @@ def evaluate():
output = request.values.get("format", "html")
limit = int(request.values.get("limit", 50))

images = [PILImage.create(file) for file in files]
images = [read_image(file) for file in files]
predictions = autotagger.predict(images, threshold=threshold, limit=limit)

if output == "html":
Expand Down
5 changes: 2 additions & 3 deletions autotag
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ import click
import itertools
import logging
import PIL
from fastai.vision.core import PILImage
from autotagger import Autotagger
from autotagger import Autotagger, read_image
from pathlib import Path
from more_itertools import ichunked

Expand Down Expand Up @@ -68,7 +67,7 @@ def recurse_dir(directory):
def open_image(filepath):
try:
with click.open_file(filepath, "rb") as file:
return (filepath, PILImage.create(file))
return (filepath, read_image(file))
except PIL.UnidentifiedImageError as err:
logging.warning(f"Skipped {filepath} (not an image)")
return None
Expand Down
2 changes: 1 addition & 1 deletion autotagger/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .autotagger import Autotagger
from .autotagger import Autotagger, read_image
139 changes: 103 additions & 36 deletions autotagger/autotagger.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,112 @@
from fastbook import *
from pandas import DataFrame, read_csv
from fastai.imports import noop
from fastai.callback.progress import ProgressCallback
from fastbook import create_timm_model
import pandas as pd
from pandas import DataFrame
import timm
import sys
import torch
from PIL import Image
import torchvision.transforms as transforms

class Autotagger:
def __init__(self, model_path="models/model.pth", data_path="test/tags.csv.gz", tags_path="data/tags.json"):
self.model_path = model_path
self.learn = self.init_model(data_path=data_path, tags_path=tags_path, model_path=model_path)

def init_model(self, model_path="model/model.pth", data_path="test/tags.csv.gz", tags_path="data/tags.json"):
df = read_csv(data_path)
vocab = json.load(open(tags_path))
# https://github.com/fastai/fastai/blob/176accfd5ae929d73d183d596c7155d3a9401f2f/fastai/vision/core.py#L96
# load image and copy to new PIL Image object
# allows removal of fastai dep
def read_image(file):
im = Image.open(file)
im.load()
im = im._new(im.im)
return im

dblock = DataBlock(
blocks=(ImageBlock, MultiCategoryBlock(vocab=vocab)),
get_x = lambda df: Path("test") / df["filename"],
get_y = lambda df: df["tags"].split(" "),
item_tfms = Resize(224, method = ResizeMethod.Squish),
batch_tfms = [RandomErasing()]
)

dls = dblock.dataloaders(df)
learn = vision_learner(dls, "resnet152", pretrained=False)
model_file = open(model_path, "rb")
learn.load(model_file, with_opt=False)
learn.remove_cb(ProgressCallback)
learn.logger = noop
# take in a single string denoting file path, a single PIL Image instance,
# or a list of either or a combination and handle them using a map-style dataset
class InferenceDataset(torch.utils.data.Dataset):
def __init__(self, files, transform=None):
if isinstance(files, (list, tuple)):
self.files = files
else:
self.files = [files]

self.transform = transform

def __len__(self):
return len(self.files)

def __getitem__(self, index):
image = self.files[index]

# file path case
if isinstance(image, str):
image = Image.open(image)

assert isinstance(image, Image.Image), "Dataset got invalid type, supported types: singular or list of the following: path as a string, PIL Image"

# check if file valid
image.load()

# fill transparent backgorunds with white and convert to RGB
image = image.convert("RGBA")

# may not replicate behavior of old impl
color = (255,255,255)
background = Image.new('RGB', image.size, color)
background.paste(image, mask=image.split()[3])
image = background

if self.transform: image = self.transform(image)

return image

class Autotagger:
def __init__(self, model_path = "models/model.pth", tags_path="data/tags.json"):

# load tags
self.classes = pd.read_json(tags_path)

# instantiate fastai model
self.model,_ = create_timm_model("resnet152", len(self.classes), pretrained=False)

return learn
# load weights
self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

# set to eval, script and optimize for inference (~2.5x speedup)
# trade off init time for faster inference, scripting/tracing is slow
self.model = self.model.eval()

# depending on what models are used in the future, use either script or trace
# can't script due to fastai model defn, need to use trace
#self.model = torch.jit.script(self.model)
self.model = torch.jit.trace(self.model, torch.randn(1, 3, 224, 224))
self.model = torch.jit.optimize_for_inference(self.model)

def predict(self, files, threshold=0.01, limit=50, bs=64):
if not files:
return

dl = self.learn.dls.test_dl(files, bs=bs)
batch, _ = self.learn.get_preds(dl=dl)

for scores in batch:
df = DataFrame({ "tag": self.learn.dls.vocab, "score": scores })
df = df[df.score >= threshold].sort_values("score", ascending=False).head(limit)
tags = dict(zip(df.tag, df.score))
yield tags

# instantiate dataset using files
dataset = InferenceDataset(
files,
transform=transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
])
)

# create a dataloader, if calling predict with a large batch,
# the input is already split into bs chunks, may make more sense to
# call create a dl with bs of 1, may save memory/reduce latency
# depending on inputs and use case (autotag with 1 file vs list of files)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size = bs,
shuffle=False,
drop_last=False
)

for batch in dataloader:
preds = self.model(batch).sigmoid()
for scores in preds:
df = DataFrame({ "tag": self.classes[0], "score": scores })
df = df[df.score >= threshold].sort_values("score", ascending=False).head(limit)
tags = dict(zip(df.tag, df.score))
yield tags