-
Notifications
You must be signed in to change notification settings - Fork 26
/
predict.py
72 lines (61 loc) · 2.7 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#-*- coding:utf-8 -*-
import os
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from ctpn import config
from ctpn.ctpn import CTPN_Model
from ctpn.utils import gen_anchor, transform_bbox, clip_bbox, filter_bbox, nms, TextProposalConnectorOriented
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weights = './weights/ctpn.pth'
model = CTPN_Model().to(device)
model.load_state_dict(torch.load(weights, map_location=device)['model_state_dict'])
model.eval()
def get_text_boxes(image, display = True, prob_thresh = 0.5):
h, w= image.shape[:2]
rescale_fac = max(h, w) / 1000
if rescale_fac > 1.0:
h = int(h / rescale_fac)
w = int(w / rescale_fac)
image = cv2.resize(image, (w,h))
h, w = image.shape[:2]
image_c = image.copy()
image = image.astype(np.float32) - config.IMAGE_MEAN
image = torch.from_numpy(image.transpose(2, 0, 1)).unsqueeze(0).float().to(device)
with torch.no_grad():
cls, regr = model(image)
cls_prob = F.softmax(cls, dim=-1).cpu().numpy()
regr = regr.cpu().numpy()
anchor = gen_anchor((int(h / 16), int(w / 16)), 16)
bbox = transform_bbox(anchor, regr)
bbox = clip_bbox(bbox, [h, w])
fg = np.where(cls_prob[0, :, 1] > prob_thresh)[0]
select_anchor = bbox[fg, :]
select_score = cls_prob[0, fg, 1]
select_anchor = select_anchor.astype(np.int32)
keep_index = filter_bbox(select_anchor, 16)
select_anchor = select_anchor[keep_index]
select_score = select_score[keep_index]
select_score = np.reshape(select_score, (select_score.shape[0], 1))
nmsbox = np.hstack((select_anchor, select_score))
keep = nms(nmsbox, 0.3)
select_anchor = select_anchor[keep]
select_score = select_score[keep]
textConn = TextProposalConnectorOriented()
text = textConn.get_text_lines(select_anchor, select_score, [h, w])
if display:
for i in text:
s = str(round(i[-1] * 100, 2)) + '%'
i = [int(j) for j in i]
cv2.line(image_c, (i[0], i[1]), (i[2], i[3]), (0, 0, 255), 2)
cv2.line(image_c, (i[0], i[1]), (i[4], i[5]), (0, 0, 255), 2)
cv2.line(image_c, (i[6], i[7]), (i[2], i[3]), (0, 0, 255), 2)
cv2.line(image_c, (i[4], i[5]), (i[6], i[7]), (0, 0, 255), 2)
cv2.putText(image_c, s, (i[0]+13, i[1]+13), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,0,0), 2, cv2.LINE_AA)
return text, image_c
if __name__ == '__main__':
img_path = 'images/21.jpg'
input_img = cv2.imread(img_path)
text, out_img = get_text_boxes(input_img)
cv2.imwrite('results/21.jpg', out_img)