diff --git a/basenet/vgg16_bn.py b/basenet/vgg16_bn.py index f3f21a7..d207bc7 100644 --- a/basenet/vgg16_bn.py +++ b/basenet/vgg16_bn.py @@ -4,7 +4,10 @@ import torch.nn as nn import torch.nn.init as init from torchvision import models -from torchvision.models.vgg import model_urls + +model_urls = { + 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', +} def init_weights(modules): for m in modules: diff --git a/bbox.py b/bbox.py new file mode 100644 index 0000000..bfdcaf6 --- /dev/null +++ b/bbox.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- +import sys +import os +import time +import argparse + +import torch +import torch.nn as nn +import torch.backends.cudnn as cudnn +from torch.autograd import Variable + +from PIL import Image + +import cv2 +from skimage import io +import numpy as np +import craft_utils +import imgproc +import file_utils +import json +import zipfile + +from craft import CRAFT + +from collections import OrderedDict +def copyStateDict(state_dict): + if list(state_dict.keys())[0].startswith("module"): + start_idx = 1 + else: + start_idx = 0 + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = ".".join(k.split(".")[start_idx:]) + new_state_dict[name] = v + return new_state_dict + +def str2bool(v): + return v.lower() in ("yes", "y", "true", "t", "1") + +# define arguments +class Args: + test_folder = '/data/' + trained_model = 'craft_mlt_25k.pth' + text_threshold = 0.7 + low_text = 0.4 + link_threshold = 0.4 + cuda = False + canvas_size = 1280 + mag_ratio = 1.5 + poly = False + show_time = False + refiner_model = 'craft_refiner_CTW1500.pth' + refine=False + +args = Args() + + +""" For test images in a folder """ +image_list, _, _ = file_utils.get_files(args.test_folder) + +result_folder = './result/' +if not os.path.isdir(result_folder): + os.mkdir(result_folder) + +def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None): + t0 = time.time() + + # resize + img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, args.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=args.mag_ratio) + ratio_h = ratio_w = 1 / target_ratio + + # preprocessing + x = imgproc.normalizeMeanVariance(img_resized) + x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] + x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] + if cuda: + x = x.cuda() + + # forward pass + with torch.no_grad(): + y, feature = net(x) + + # make score and link map + score_text = y[0,:,:,0].cpu().data.numpy() + score_link = y[0,:,:,1].cpu().data.numpy() + + # refine link + if refine_net is not None: + with torch.no_grad(): + y_refiner = refine_net(y, feature) + score_link = y_refiner[0,:,:,0].cpu().data.numpy() + + t0 = time.time() - t0 + t1 = time.time() + + # Post-processing + boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly) + + # coordinate adjustment + boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) + polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) + for k in range(len(polys)): + if polys[k] is None: polys[k] = boxes[k] + + t1 = time.time() - t1 + + # render results (optional) + render_img = score_text.copy() + render_img = np.hstack((render_img, score_link)) + ret_score_text = imgproc.cvt2HeatmapImg(render_img) + + if args.show_time : print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) + + return boxes, polys, ret_score_text + + + +# if __name__ == '__main__': +# load net +net = CRAFT() # initialize + +print('Loading weights from checkpoint (' + args.trained_model + ')') +if args.cuda: + net.load_state_dict(copyStateDict(torch.load(args.trained_model))) +else: + net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu'))) + +if args.cuda: + net = net.cuda() + net = torch.nn.DataParallel(net) + cudnn.benchmark = False + +net.eval() + +# LinkRefiner +refine_net = None +if args.refine: + from refinenet import RefineNet + refine_net = RefineNet() + print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')') + if args.cuda: + refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model))) + refine_net = refine_net.cuda() + refine_net = torch.nn.DataParallel(refine_net) + else: + refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu'))) + + refine_net.eval() + args.poly = True + +t = time.time() + +def getbboxes(image_path): + image = imgproc.loadImage(image_path) + bboxes, p, s=test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, refine_net) + return bboxes \ No newline at end of file