diff --git a/docs/nlpir.rst b/docs/nlpir.rst index abe6d33..1ec4038 100644 --- a/docs/nlpir.rst +++ b/docs/nlpir.rst @@ -97,6 +97,13 @@ nlpir.eye_checker module :undoc-members: :show-inheritance: +nlpir.deep_classifier module +------------------------------- + +.. automodule:: nlpir.deep_classifier + :members: + :undoc-members: + :show-inheritance: nlpir.tools module -------------------- diff --git a/nlpir/deep_classifier.py b/nlpir/deep_classifier.py new file mode 100644 index 0000000..0a634c8 --- /dev/null +++ b/nlpir/deep_classifier.py @@ -0,0 +1,49 @@ +#! coding=utf-8 +""" +high-level toolbox for text classify +""" +import re +import typing +import nlpir +from nlpir import get_instance as __get_instance__ +from nlpir import native + +# class and class instance +__cls__ = native.deep_classifier.DeepClassifier +__instance__: typing.Optional[native.deep_classifier.DeepClassifier] = None +# Location of DLL +__lib__ = None +# Data directory +__data__ = None +# license_code +__license_code__ = None +# encode +__nlpir_encode__ = native.UTF8_CODE + +__handler__ = None + + +@__get_instance__ +def get_native_instance() -> native.deep_classifier.DeepClassifier: + """ + 返回原生NLPIR接口,使用更多函数 + + :return: The singleton instance + """ + return __instance__ + + +@__get_instance__ +def classify(txt: str) -> str: + """ + Text classify + + :param txt: text + :return: class + """ + global __handler__ + if __handler__ is None: + # default model + __handler__ = __instance__.new_instance(800) + __instance__.load_train_result(__handler__) + return __instance__.classify(txt, handler=__handler__) diff --git a/nlpir/eye_checker.py b/nlpir/eye_checker.py index 588c14a..f0e17e5 100644 --- a/nlpir/eye_checker.py +++ b/nlpir/eye_checker.py @@ -6,12 +6,9 @@ import os import re import typing -from enum import Enum -from pathlib import Path from pydantic import BaseModel -import nlpir from nlpir import get_instance as __get_instance__ from nlpir import native diff --git a/nlpir/native/deep_classifier.py b/nlpir/native/deep_classifier.py index 0db5e8f..053219e 100644 --- a/nlpir/native/deep_classifier.py +++ b/nlpir/native/deep_classifier.py @@ -89,7 +89,11 @@ def add_train(self, classname: str, text: str, handler: int = 0) -> bool: :param handler: classifier handler :return: add success or not """ - return self.get_func("DeepClassifier_AddTrain", [c_char_p, c_char_p, POINTER(c_int)], c_bool)(classname, text, handler) + return self.get_func( + "DeepClassifier_AddTrain", + [c_char_p, c_char_p, POINTER(c_int)], + c_bool + )(classname, text, handler) @NLPIRBase.byte_str_transform def add_train_file(self, classname: str, filename: str, handler: int = 0) -> int: diff --git a/tests/test_deep_classifier.py b/tests/test_deep_classifier.py new file mode 100644 index 0000000..83b2dfc --- /dev/null +++ b/tests/test_deep_classifier.py @@ -0,0 +1,12 @@ +# coding=utf-8 +""" +Tested function: + +- :func:`nlpir.deep_classifier.classify` +""" +from nlpir import deep_classifier + + +def test_classify(): + from tests.strings import test_str + assert deep_classifier.classify(txt=test_str) == "教育" diff --git a/tests/test_eye_checker.py b/tests/test_eye_checker.py index bc12520..f4d0cc1 100644 --- a/tests/test_eye_checker.py +++ b/tests/test_eye_checker.py @@ -1,3 +1,13 @@ +# coding=utf-8 +""" +Tested function: + +- :func:`nlpir.eye_checker.import_kgb_rules` +- :func:`nlpir.eye_checker.list_rules` +- :func:`nlpir.eye_checker.delete_rules` +- :func:`nlpir.eye_checker.extract_knowledge` +""" + import pytest from nlpir import eye_checker @@ -15,7 +25,7 @@ def test_extract(): @pytest.mark.run(order=1) def test_rule_manage(): - from tests.strings import test_kgb_test_text, test_kgb_rules + from tests.strings import test_kgb_rules rule_set = {1, 2, 3, 4, 6, 7, 9} for rule in rule_set: assert eye_checker.import_kgb_rules(rule_text=test_kgb_rules, report_type=rule, overwrite=True)