-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathocroline-recognize
executable file
·108 lines (98 loc) · 3.39 KB
/
ocroline-recognize
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/python
import matplotlib
# matplotlib.use("GTK")
from pylab import *
rc("image", cmap="hot")
import pylab
import os
import re
import glob
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import ocropy2
import time
import resource
import psutil
import argparse
import editdistance
from contextlib import closing
from torch.autograd import Variable
import dlinputs as dli
import uuid
import codecs
parser = argparse.ArgumentParser("""Apply an RNN Recognizer""")
parser.add_argument("-m", "--model", default="/usr/local/share/ocropy2/ocr-default.pt",
help="saved model or model specification")
parser.add_argument("inputs", nargs="+",
help="input files for recognition")
parser.add_argument("-t", "--threshold", default=-1,
help="threshold for outputting errors")
parser.add_argument("-e", "--eval", action="store_true",
help="only compute error rate")
parser.add_argument("-g", "--gtext", default="gt.txt",
help="ground truth extension")
parser.add_argument("-o", "--outext", default=None,
help="extension for output")
parser.add_argument("-n", "--normalization", default="none",
help="string normalization prior to error rate computation")
parser.add_argument("-v", "--verbose", action="store_true",
help="provide additional information, e.g. about network")
args = parser.parse_args()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
net = dli.loadable.load_net(args.model)
ocr = ocropy2.SimpleOCR(net)
if args.verbose:
print net
for k, v in net.META.items():
print k, repr(v).replace("\n", " ")[:50]
if "test_loss" in net.META:
print "last test_loss", sorted(net.META["test_loss"])[-1]
ocr.gpu()
def read_image(fname):
image = imread(fname)
if image.ndim == 3:
image = mean(image, 2)
if mean(image) > 0.5:
image = 1.0 - image
image = np.expand_dims(image, 2)
return image
def normalize(s, normalization):
if normalization == "alphanum":
return re.sub(r"[^0-9A-Za-z]+", " ", s).strip()
elif normalization == "none":
return re.sub(r"[~ ]+", " ", s).strip()
else:
raise Exception("%s: unknown normalization" % normalization)
def compute_errors(gt, result, normalization="none"):
gt = normalize(gt, normalization)
result = normalize(result, normalization)
return editdistance.eval(gt, result)
num_chars = 0
num_lines = 0
num_errors = 0
for fname in args.inputs:
base = re.sub(r"\.[^/]*$", "", fname)
image = read_image(fname)
result = ocr.recognize([image])[0]
gtname = base + "." + args.gtext
# print image.shape
if args.eval and os.path.exists(gtname):
with codecs.open(gtname, "r", "utf-8") as stream:
gt = stream.read().strip()
errors = compute_errors(gt, result, args.normalization)
num_chars += len(gt)
num_lines += 1
num_errors += errors
if errors >= args.threshold:
print "%3d %30s %s" % (errors, fname, result)
else:
print fname, result
if args.outext is not None:
outname = base + "." + args.outext
with codecs.open(outname, "w", "utf-8") as stream:
stream.write(result + "\n")
if args.eval:
print num_errors / float(num_chars), num_errors, num_chars, num_lines