Skip to content

Commit

Permalink
Merge pull request #103 from breezedeus/dev-v1.2
Browse files Browse the repository at this point in the history
fix: allow to initialize multiple instances
  • Loading branch information
breezedeus authored May 29, 2020
2 parents 433f369 + 8905759 commit 742efd3
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 35 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ English [README](./README_en.md).



# 最近更新 【2020.05.25】:V1.2.1
# 最近更新 【2020.05.29】:V1.2.2

主要变更:

Expand All @@ -21,6 +21,7 @@ English [README](./README_en.md).
* 默认模型由之前的`conv-lite-fc`改为`densenet-lite-fc`
* 预测支持使用GPU。
* bugfixs:
* 修复同时初始化多个实例时会报错的问题;
* Web 调用时的内存泄露。感谢 [@myuanz](https://github.com/myuanz)
* 输入图片宽度很小时导致异常;
* 去掉 `f-print`
Expand Down Expand Up @@ -150,6 +151,7 @@ class CnOcr(object):
cand_alphabet=None,
root=data_dir(),
context='cpu',
name=None,
):
```

Expand All @@ -159,10 +161,12 @@ class CnOcr(object):
* `model_epoch`: 模型迭代次数。默认为 `None`,表示使用默认的迭代次数值。对于模型名称 `densenet-lite-fc`就是 `40`
* `cand_alphabet`: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围。`cnocr.consts`中内置了两个候选集合:(1) 数字和标点 `NUMBERS`;(2) 英文字母、数字和标点 `ENG_LETTERS`
* 例如对于图片 ![examples/hybrid.png](./examples/hybrid.png) ,不做约束时识别结果为 `o12345678`;如果加入数字约束时(`ocr = CnOcr(cand_alphabet=NUMBERS)`),识别结果为 `012345678`
* `cand_alphabet`也可以初始化后通过类函数 `CnOcr.set_cand_alphabet(cand_alphabet)` 进行设置。这样同一个实例也可以指定不同的`cand_alphabet`进行识别。
* `root`: 模型文件所在的根目录。
* Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/1.2.0/densenet-lite-fc`
* Windows下默认值为 `C:\Users\<username>\AppData\Roaming\cnocr`
* `context`:预测使用的机器资源,可取值为字符串`cpu``gpu`,或者 `mx.Context`实例。
* `name`:正在初始化的这个实例的名称。如果需要同时初始化多个实例,需要为不同的实例指定不同的名称。



Expand Down
15 changes: 13 additions & 2 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
# Release Notes

### Update 2020.05.29: 发布 cnocr V1.2.2

主要变更:

* `CnOcr`加入类函数 `CnOcr.set_cand_alphabet(cand_alphabet) `。可通过此类函数设置`cand_alphabet`。这样同一个实例也可以指定不同的`cand_alphabet`进行识别。
* bugfix:
* 修复同时初始化多个实例时会报错的问题。



### Update 2020.05.25: 发布 cnocr V1.2.1

bugfix:
主要变更:

* 修复了zip文件名的typo。
* bugfix:
* 修复了zip文件名的typo。



Expand Down
2 changes: 1 addition & 1 deletion cnocr/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.2.1'
__version__ = '1.2.2'
50 changes: 42 additions & 8 deletions cnocr/cn_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# specific language governing permissions and limitations
# under the License.
import os
import re
import logging
import mxnet as mx
import numpy as np
Expand All @@ -25,7 +26,6 @@
from cnocr.hyperparams.cn_hyperparams import CnHyperparams as Hyperparams
from cnocr.fit.lstm import init_states
from cnocr.fit.ctc_metrics import CtcMetrics
from cnocr.data_utils.data_iter import SimpleBatch
from cnocr.symbols.crnn import gen_network
from cnocr.utils import (
data_dir,
Expand Down Expand Up @@ -83,7 +83,16 @@ def lstm_init_states(batch_size, hp):
return init_names, init_arrays


def load_module(prefix, epoch, data_names, data_shapes, network=None, context='cpu'):
def load_module(
prefix,
epoch,
data_names,
data_shapes,
*,
network=None,
net_prefix=None,
context='cpu'
):
"""
Loads the model from checkpoint specified by prefix and epoch, binds it
to an executor, and sets its parameters and returns a mx.mod.Module
Expand All @@ -92,9 +101,13 @@ def load_module(prefix, epoch, data_names, data_shapes, network=None, context='c
if network is not None:
sym = network

net_prefix = net_prefix or ''
if net_prefix:
arg_params = {rename_params(k, net_prefix): v for k, v in arg_params.items()}
aux_params = {rename_params(k, net_prefix): v for k, v in aux_params.items()}
# We don't need CTC loss for prediction, just a simple softmax will suffice.
# We get the output of the layer just before the loss layer ('pred_fc') and add softmax on top
pred_fc = sym.get_internals()['pred_fc_output']
pred_fc = sym.get_internals()[net_prefix + 'pred_fc_output']
sym = mx.sym.softmax(data=pred_fc)

if not check_context(context):
Expand All @@ -110,6 +123,12 @@ def load_module(prefix, epoch, data_names, data_shapes, network=None, context='c
return mod


def rename_params(k, net_prefix):
pat = re.compile(r'^(densenet|crnn|gru|lstm)\d*_')
k = pat.sub('', k, 1)
return net_prefix + k


class CnOcr(object):
MODEL_FILE_PREFIX = 'cnocr-v{}'.format(MODEL_VERSION)

Expand All @@ -120,6 +139,7 @@ def __init__(
cand_alphabet=None,
root=data_dir(),
context='cpu',
name=None,
):
"""
Expand All @@ -130,6 +150,7 @@ def __init__(
Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/1.1.0/conv-lite-fc-0027`。
Windows下默认值为 ``。
:param context: 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为CPU。
:param name: 正在初始化的这个实例名称。如果需要同时初始化多个实例,需要为不同的实例指定不同的名称。
"""
check_model_name(model_name)
self._model_name = model_name
Expand All @@ -139,17 +160,17 @@ def __init__(
root = os.path.join(root, MODEL_VERSION)
self._model_dir = os.path.join(root, self._model_name)
self._assert_and_prepare_model_files()
self._alphabet, inv_alph_dict = read_charset(
self._alphabet, self._inv_alph_dict = read_charset(
os.path.join(self._model_dir, 'label_cn.txt')
)

self._cand_alph_idx = None
if cand_alphabet is not None:
self._cand_alph_idx = [0] + [inv_alph_dict[word] for word in cand_alphabet]
self._cand_alph_idx.sort()
self.set_cand_alphabet(cand_alphabet)

self._hp = Hyperparams()
self._hp._loss_type = None # infer mode
# 传入''的话,也改成传入None
self._net_prefix = None if name == '' else name

self._mod = self._get_module(context)

Expand All @@ -174,7 +195,7 @@ def _assert_and_prepare_model_files(self):
get_model_file(model_dir)

def _get_module(self, context):
network, self._hp = gen_network(self._model_name, self._hp)
network, self._hp = gen_network(self._model_name, self._hp, self._net_prefix)
hp = self._hp
prefix = os.path.join(self._model_dir, self._model_file_prefix)
data_names = ['data']
Expand All @@ -186,10 +207,23 @@ def _get_module(self, context):
data_names,
data_shapes,
network=network,
net_prefix=self._net_prefix,
context=context,
)
return mod

def set_cand_alphabet(self, cand_alphabet):
"""
设置待识别字符的候选集合。
:param cand_alphabet: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
:return: None
"""
if cand_alphabet is None:
self._cand_alph_idx = None
else:
self._cand_alph_idx = [0] + [self._inv_alph_dict[word] for word in cand_alphabet]
self._cand_alph_idx.sort()

def ocr(self, img_fp):
"""
:param img_fp: image file path; or color image mx.nd.NDArray or np.ndarray,
Expand Down
59 changes: 40 additions & 19 deletions cnocr/symbols/crnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..fit.ctc_loss import add_ctc_loss


def gen_network(model_name, hp):
def gen_network(model_name, hp, net_prefix=None):
hp = deepcopy(hp)
hp.seq_model_type = model_name.rsplit('-', maxsplit=1)[-1]
model_name = model_name.lower()
Expand All @@ -45,23 +45,30 @@ def gen_network(model_name, hp):
)
seq_len = hp.img_width // 8 if shorter else hp.img_width // 4
hp.set_seq_length(seq_len)
densenet = DenseNet(layer_channels, shorter=shorter)
densenet = DenseNet(layer_channels, shorter=shorter, prefix=net_prefix)
densenet.hybridize()
model = CRnn(hp, densenet)
model = CRnn(hp, densenet, prefix=net_prefix)
elif model_name.startswith('conv-lite'):
hp.seq_len_cmpr_ratio = 4
shorter = model_name.startswith('conv-lite-s-')
seq_len = hp.img_width // 8 if shorter else hp.img_width // 4 - 1
hp.set_seq_length(seq_len)
model = lambda data: crnn_lstm_lite(hp, data, shorter=shorter)

def model(data):
with mx.name.Prefix(net_prefix or ''):
return crnn_lstm_lite(hp, data, shorter=shorter)

elif model_name.startswith('conv'):
hp.seq_len_cmpr_ratio = 8
hp.set_seq_length(hp.img_width // 8)
model = lambda data: crnn_lstm(hp, data)

def model(data):
with mx.name.Prefix(net_prefix or ''):
return crnn_lstm(hp, data)
else:
raise NotImplementedError('bad model_name: %s', model_name)

return pipline(model, hp), hp
return pipline(model, hp, net_prefix=net_prefix), hp


def get_infer_shape(sym_model, hp):
Expand All @@ -75,18 +82,25 @@ def get_infer_shape(sym_model, hp):
return shape_dict


def gen_seq_model(hp):
def gen_seq_model(hp, **kw):
if hp.seq_model_type.lower() == 'lstm':
seq_model = LSTM(hp.num_hidden, hp.num_lstm_layer, bidirectional=True)
seq_model = LSTM(hp.num_hidden, hp.num_lstm_layer, bidirectional=True, **kw)
elif hp.seq_model_type.lower() == 'gru':
seq_model = GRU(hp.num_hidden, hp.num_lstm_layer, bidirectional=True)
seq_model = GRU(hp.num_hidden, hp.num_lstm_layer, bidirectional=True, **kw)
else:

def fc_seq_model(data):
fc = mx.sym.FullyConnected(
data, num_hidden=hp.num_hidden, flatten=False, name='seq-fc'
)
net = mx.sym.Activation(data=fc, act_type='relu', name='seq-relu')
if kw.get('prefix', None):
with mx.name.Prefix(kw['prefix']):
fc = mx.sym.FullyConnected(
data, num_hidden=hp.num_hidden, flatten=False, name='seq-fc'
)
net = mx.sym.Activation(data=fc, act_type='relu', name='seq-relu')
else:
fc = mx.sym.FullyConnected(
data, num_hidden=hp.num_hidden, flatten=False, name='seq-fc'
)
net = mx.sym.Activation(data=fc, act_type='relu', name='seq-relu')
return net

seq_model = fc_seq_model
Expand All @@ -100,7 +114,7 @@ def __init__(self, hp, emb_model, **kw):
self.emb_model = emb_model
self.dropout = nn.Dropout(hp.dropout)

self.seq_model = gen_seq_model(hp)
self.seq_model = gen_seq_model(hp, **kw)

def hybrid_forward(self, F, X):
embs = self.emb_model(X) # res: bz x emb_size x 1 x seq_len
Expand All @@ -114,15 +128,22 @@ def hybrid_forward(self, F, X):
return self.seq_model(embs)


def pipline(model, hp, data=None):
def pipline(model, hp, data=None, *, net_prefix=''):
# 构建用于训练的整个计算图
data = data if data is not None else mx.sym.Variable('data')

output = model(data)
output = mx.symbol.reshape(output, shape=(-3, -2)) # res: (seq_len * bz, c)
pred = mx.sym.FullyConnected(
data=output, num_hidden=hp.num_classes, name='pred_fc'
) # (bz x 35) x num_classes
if net_prefix:
with mx.name.Prefix(net_prefix):
output = mx.symbol.reshape(output, shape=(-3, -2)) # res: (seq_len * bz, c)
pred = mx.sym.FullyConnected(
data=output, num_hidden=hp.num_classes, name='pred_fc'
) # (bz x 35) x num_classes
else:
output = mx.symbol.reshape(output, shape=(-3, -2)) # res: (seq_len * bz, c)
pred = mx.sym.FullyConnected(
data=output, num_hidden=hp.num_classes, name='pred_fc'
) # (bz x 35) x num_classes
# print('pred', pred.infer_shape()[1])
# import pdb; pdb.set_trace()

Expand Down
41 changes: 37 additions & 4 deletions tests/test_cnocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
sys.path.insert(1, os.path.dirname(os.path.abspath(__file__)))

from cnocr import CnOcr
from cnocr.consts import NUMBERS, AVAILABLE_MODELS
from cnocr.line_split import line_split
from cnocr.data_utils.aug import GrayAug

Expand Down Expand Up @@ -176,19 +177,51 @@ def test_gray_aug(img_fp, expected):
print(res_img.shape, res_img.dtype)


def test_cand_alphabet():
from cnocr import NUMBERS
def test_cand_alphabet1():
img_fp = os.path.join(example_dir, 'hybrid.png')

ocr = CnOcr(name='instance1')
pred = ocr.ocr(img_fp)
pred = [''.join(line_p) for line_p in pred]
print("Predicted Chars:", pred)
assert len(pred) == 1 and pred[0] == 'o12345678'

ocr = CnOcr(name='instance2', cand_alphabet=NUMBERS)
pred = ocr.ocr(img_fp)
pred = [''.join(line_p) for line_p in pred]
print("Predicted Chars:", pred)
assert len(pred) == 1 and pred[0] == '012345678'


def test_cand_alphabet2():
img_fp = os.path.join(example_dir, 'hybrid.png')

ocr = CnOcr()
ocr = CnOcr(name='instance1')
pred = ocr.ocr(img_fp)
pred = [''.join(line_p) for line_p in pred]
print("Predicted Chars:", pred)
assert len(pred) == 1 and pred[0] == 'o12345678'

ocr = CnOcr(cand_alphabet=NUMBERS)
ocr.set_cand_alphabet(NUMBERS)
pred = ocr.ocr(img_fp)
pred = [''.join(line_p) for line_p in pred]
print("Predicted Chars:", pred)
assert len(pred) == 1 and pred[0] == '012345678'


INSTANCE_ID = 0


@pytest.mark.parametrize('model_name', AVAILABLE_MODELS.keys())
def test_multiple_instances(model_name):
global INSTANCE_ID
print('test multiple instances for model_name: %s' % model_name)
img_fp = os.path.join(example_dir, 'hybrid.png')
INSTANCE_ID += 1
print('instance id: %d' % INSTANCE_ID)
cnocr1 = CnOcr(model_name, name='instance-%d' % INSTANCE_ID)
print_preds(cnocr1.ocr(img_fp))
INSTANCE_ID += 1
print('instance id: %d' % INSTANCE_ID)
cnocr2 = CnOcr(model_name, name='instance-%d' % INSTANCE_ID, cand_alphabet=NUMBERS)
print_preds(cnocr2.ocr(img_fp))

0 comments on commit 742efd3

Please sign in to comment.