From 0e337c742292b8fca1b56e5777b6a8a3cb1a0ebf Mon Sep 17 00:00:00 2001 From: Chielo Newctle Date: Fri, 13 Oct 2023 22:28:08 +0800 Subject: [PATCH 1/2] add config to release-plz and git-cliff --- Cargo.toml | 1 + cliff.toml | 2 ++ release-plz.toml | 3 +++ 3 files changed, 6 insertions(+) create mode 100644 cliff.toml create mode 100644 release-plz.toml diff --git a/Cargo.toml b/Cargo.toml index b17e222..ebed18d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ homepage = "https://github.com/ModelTC/general-sam" documentation = "https://docs.rs/general-sam" readme = "README.md" authors = ["Chielo Newctle "] +exclude = ["release-plz.toml", "cliff.tolm"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lib] diff --git a/cliff.toml b/cliff.toml new file mode 100644 index 0000000..66621b2 --- /dev/null +++ b/cliff.toml @@ -0,0 +1,2 @@ +[git] +filter_unconventional = true diff --git a/release-plz.toml b/release-plz.toml new file mode 100644 index 0000000..26ffbf1 --- /dev/null +++ b/release-plz.toml @@ -0,0 +1,3 @@ +[workspace] +changelog_config = "cliff.toml" +dependencies_update = true From 7f74fb51edc4a39cff2845a35edb8c035ddc60bf Mon Sep 17 00:00:00 2001 From: Chielo Newctle Date: Fri, 13 Oct 2023 22:30:00 +0800 Subject: [PATCH 2/2] chore: move out pybind to an independent repo New pybind repository: https://github.com/ModelTC/general-sam-py --- pybind/.gitignore | 72 ----- pybind/Cargo.lock | 311 --------------------- pybind/Cargo.toml | 24 -- pybind/LICENSE-APACHE | 201 -------------- pybind/LICENSE-MIT | 21 -- pybind/README.md | 164 ----------- pybind/general_sam/__init__.py | 35 --- pybind/general_sam/general_sam.pyi | 165 ----------- pybind/general_sam/trie_utils.py | 89 ------ pybind/general_sam/vocab_prefix.py | 148 ---------- pybind/pyproject.toml | 16 -- pybind/src/lib.rs | 429 ----------------------------- pybind/tests/test_general_sam.py | 47 ---- pybind/tests/test_token_healing.py | 134 --------- pybind/tests/test_vocab_prefix.py | 74 ----- 15 files changed, 1930 deletions(-) delete mode 100644 pybind/.gitignore delete mode 100644 pybind/Cargo.lock delete mode 100644 pybind/Cargo.toml delete mode 100644 pybind/LICENSE-APACHE delete mode 100644 pybind/LICENSE-MIT delete mode 100644 pybind/README.md delete mode 100644 pybind/general_sam/__init__.py delete mode 100644 pybind/general_sam/general_sam.pyi delete mode 100644 pybind/general_sam/trie_utils.py delete mode 100644 pybind/general_sam/vocab_prefix.py delete mode 100644 pybind/pyproject.toml delete mode 100644 pybind/src/lib.rs delete mode 100644 pybind/tests/test_general_sam.py delete mode 100644 pybind/tests/test_token_healing.py delete mode 100644 pybind/tests/test_vocab_prefix.py diff --git a/pybind/.gitignore b/pybind/.gitignore deleted file mode 100644 index af3ca5e..0000000 --- a/pybind/.gitignore +++ /dev/null @@ -1,72 +0,0 @@ -/target - -# Byte-compiled / optimized / DLL files -__pycache__/ -.pytest_cache/ -*.py[cod] - -# C extensions -*.so - -# Distribution / packaging -.Python -.venv/ -env/ -bin/ -build/ -develop-eggs/ -dist/ -eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -include/ -man/ -venv/ -*.egg-info/ -.installed.cfg -*.egg - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt -pip-selfcheck.json - -# Unit test / coverage reports -htmlcov/ -.tox/ -.coverage -.cache -nosetests.xml -coverage.xml - -# Translations -*.mo - -# Mr Developer -.mr.developer.cfg -.project -.pydevproject - -# Rope -.ropeproject - -# Django stuff: -*.log -*.pot - -.DS_Store - -# Sphinx documentation -docs/_build/ - -# PyCharm -.idea/ - -# VSCode -.vscode/ - -# Pyenv -.python-version \ No newline at end of file diff --git a/pybind/Cargo.lock b/pybind/Cargo.lock deleted file mode 100644 index 813c92a..0000000 --- a/pybind/Cargo.lock +++ /dev/null @@ -1,311 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - -[[package]] -name = "autocfg" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" - -[[package]] -name = "bitflags" -version = "1.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - -[[package]] -name = "cc" -version = "1.0.83" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" -dependencies = [ - "libc", -] - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "either" -version = "1.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" - -[[package]] -name = "general-sam" -version = "0.1.0" - -[[package]] -name = "general-sam-py" -version = "0.1.0" -dependencies = [ - "either", - "general-sam", - "pyo3", -] - -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" - -[[package]] -name = "indoc" -version = "2.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" - -[[package]] -name = "libc" -version = "0.2.149" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a08173bc88b7955d1b3145aa561539096c421ac8debde8cbc3612ec635fee29b" - -[[package]] -name = "lock_api" -version = "0.4.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" -dependencies = [ - "autocfg", - "scopeguard", -] - -[[package]] -name = "memoffset" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" -dependencies = [ - "autocfg", -] - -[[package]] -name = "once_cell" -version = "1.18.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" - -[[package]] -name = "parking_lot" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" -dependencies = [ - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-targets", -] - -[[package]] -name = "proc-macro2" -version = "1.0.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "pyo3" -version = "0.20.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04e8453b658fe480c3e70c8ed4e3d3ec33eb74988bd186561b0cc66b85c3bc4b" -dependencies = [ - "cfg-if", - "indoc", - "libc", - "memoffset", - "parking_lot", - "pyo3-build-config", - "pyo3-ffi", - "pyo3-macros", - "unindent", -] - -[[package]] -name = "pyo3-build-config" -version = "0.20.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a96fe70b176a89cff78f2fa7b3c930081e163d5379b4dcdf993e3ae29ca662e5" -dependencies = [ - "once_cell", - "python3-dll-a", - "target-lexicon", -] - -[[package]] -name = "pyo3-ffi" -version = "0.20.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "214929900fd25e6604661ed9cf349727c8920d47deff196c4e28165a6ef2a96b" -dependencies = [ - "libc", - "pyo3-build-config", -] - -[[package]] -name = "pyo3-macros" -version = "0.20.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dac53072f717aa1bfa4db832b39de8c875b7c7af4f4a6fe93cdbf9264cf8383b" -dependencies = [ - "proc-macro2", - "pyo3-macros-backend", - "quote", - "syn", -] - -[[package]] -name = "pyo3-macros-backend" -version = "0.20.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7774b5a8282bd4f25f803b1f0d945120be959a36c72e08e7cd031c792fdfd424" -dependencies = [ - "heck", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "python3-dll-a" -version = "0.2.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5f07cd4412be8fa09a721d40007c483981bbe072cd6a21f2e83e04ec8f8343f" -dependencies = [ - "cc", -] - -[[package]] -name = "quote" -version = "1.0.33" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "redox_syscall" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" -dependencies = [ - "bitflags", -] - -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - -[[package]] -name = "smallvec" -version = "1.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "942b4a808e05215192e39f4ab80813e599068285906cc91aa64f923db842bd5a" - -[[package]] -name = "syn" -version = "2.0.38" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "target-lexicon" -version = "0.12.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d0e916b1148c8e263850e1ebcbd046f333e0683c724876bb0da63ea4373dc8a" - -[[package]] -name = "unicode-ident" -version = "1.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" - -[[package]] -name = "unindent" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" - -[[package]] -name = "windows-targets" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" -dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", -] - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" - -[[package]] -name = "windows_i686_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" - -[[package]] -name = "windows_i686_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" diff --git a/pybind/Cargo.toml b/pybind/Cargo.toml deleted file mode 100644 index 9e9da66..0000000 --- a/pybind/Cargo.toml +++ /dev/null @@ -1,24 +0,0 @@ -[package] -name = "general-sam-py" -version = "0.1.0" -edition = "2021" -license = "MIT OR Apache-2.0" -description = "Python bindings for general-sam and some utilities" -repository = "https://github.com/ModelTC/general-sam" -homepage = "https://github.com/ModelTC/general-sam/tree/main/pybind" -readme = "README.md" -authors = ["Chielo Newctle "] - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html -[lib] -name = "general_sam" -crate-type = ["cdylib"] - -[dependencies] -general-sam = { path = ".." } -pyo3 = { version = "0.20.0", features = [ - "extension-module", - "abi3-py38", - "generate-import-lib", -] } -either = "1.9.0" diff --git a/pybind/LICENSE-APACHE b/pybind/LICENSE-APACHE deleted file mode 100644 index c98d27d..0000000 --- a/pybind/LICENSE-APACHE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - https://www.apache.org/licenses/LICENSE-2.0 - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/pybind/LICENSE-MIT b/pybind/LICENSE-MIT deleted file mode 100644 index dd66111..0000000 --- a/pybind/LICENSE-MIT +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2023 Chielo Newctle -Copyright (c) 2023 ModelTC Team - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is furnished -to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS -FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS -OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF -OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/pybind/README.md b/pybind/README.md deleted file mode 100644 index 7e1167b..0000000 --- a/pybind/README.md +++ /dev/null @@ -1,164 +0,0 @@ -# general-sam-py - -![License](https://img.shields.io/badge/license-MIT%2FApache--2.0-informational?style=flat-square) - -Python bindings for [`general-sam`](https://github.com/ModelTC/general-sam) -and some utilities. - -| [![the suffix automaton of abcbc][sam-of-abcbc]][sam-oi-wiki] | -| :----------------------------------------------------------------------------: | -| The suffix automaton of abcbc, image from [后缀自动机 - OI Wiki][sam-oi-wiki]. | - -[sam-of-abcbc]: https://oi-wiki.org/string/images/SAM/SA_suffix_links.svg -[sam-oi-wiki]: https://oi-wiki.org/string/sam/ - -## Usage - -### `GeneralSAM` - -```python -from general_sam import GeneralSAM - - -sam = GeneralSAM.construct_from_bytes(b'abcbc') - -state = sam.get_root_state() -state.feed_bytes(b'cbc') -assert state.is_accepting() - -state = sam.get_root_state() -state.feed_bytes(b'bcb') -assert not state.is_accepting() -``` - -```python -from general_sam import GeneralSAM - - -sam = GeneralSAM.construct_from_chars('abcbc') -state = sam.get_root_state() - -state.feed_chars('b') -assert not state.is_accepting() -state.feed_chars('c') -assert state.is_accepting() -state.feed_chars('bc') -assert state.is_accepting() -state.feed_chars('bc') -assert not state.is_accepting() and state.is_nil() -``` - -```python -from general_sam import GeneralSAM, GeneralSAMState, construct_trie_from_chars - - -trie, _ = construct_trie_from_chars(['hello', 'Chielo']) -sam = GeneralSAM.construct_from_trie(trie) - -def fetch_state(s: str) -> GeneralSAMState: - state = sam.get_root_state() - state.feed_chars(s) - return state - -assert fetch_state('lo').is_accepting() -assert fetch_state('ello').is_accepting() -assert fetch_state('elo').is_accepting() - -state = fetch_state('el') -assert not state.is_accepting() and not state.is_nil() - -state = fetch_state('bye') -assert not state.is_accepting() and state.is_nil() -``` - -### `VocabPrefixAutomaton` - -```python -from general_sam import VocabPrefixAutomaton, CountInfo - - -vocab = ['歌曲', '聆听歌曲', '播放歌曲', '歌词', '查看歌词'] -automaton = VocabPrefixAutomaton(vocab, bytes_or_chars='chars') - -# NOTE: CountInfo is related to the sorted vocab: -_ = ['播放歌曲', '查看歌词', '歌曲', '歌词', '聆听歌曲'] - -# 一起 | 聆 | 听 | 歌 -state = automaton.get_root_state() - -# feed 歌 -cnt_info = automaton.prepend_feed(state, '歌') -assert cnt_info is not None and cnt_info == CountInfo( - str_cnt=2, tot_cnt_lower=2, tot_cnt_upper=4 -) - -selected_idx = automaton.get_order_slice(cnt_info) -assert frozenset(selected_idx) == {0, 3} -selected_vocab = [vocab[i] for i in selected_idx] -assert frozenset(selected_vocab) == {'歌曲', '歌词'} - -# feed 听 -cnt_info = automaton.prepend_feed(state, '听') -assert cnt_info is None -assert not state.is_nil() - -# feed 聆 -cnt_info = automaton.prepend_feed(state, '聆') -assert cnt_info is not None and cnt_info == CountInfo( - str_cnt=1, tot_cnt_lower=4, tot_cnt_upper=5 -) - -selected_idx = automaton.get_order_slice(cnt_info) -assert frozenset(selected_idx) == {1} -selected_vocab = [vocab[i] for i in selected_idx] -assert frozenset(selected_vocab) == {'聆听歌曲'} - -# feed 一起 -assert not state.is_nil() -cnt_info = automaton.prepend_feed(state, '一起') -assert state.is_nil() - -# 来 | 查看 | 歌词 -state = automaton.get_root_state() - -# feed 歌词 -cnt_info = automaton.prepend_feed(state, '歌词') -assert cnt_info is not None and cnt_info == CountInfo( - str_cnt=1, tot_cnt_lower=3, tot_cnt_upper=4 -) - -selected_idx = automaton.get_order_slice(cnt_info) -assert frozenset(selected_idx) == {3} -selected_vocab = [vocab[i] for i in selected_idx] -assert frozenset(selected_vocab) == {'歌词'} - -# feed 查看 -cnt_info = automaton.prepend_feed(state, '查看') -assert cnt_info is not None and cnt_info == CountInfo( - str_cnt=1, tot_cnt_lower=1, tot_cnt_upper=2 -) - -selected_idx = automaton.get_order_slice(cnt_info) -assert frozenset(selected_idx) == {4} -selected_vocab = [vocab[i] for i in selected_idx] -assert frozenset(selected_vocab) == {'查看歌词'} - -# feed 来 -assert not state.is_nil() -cnt_info = automaton.prepend_feed(state, '来') -assert state.is_nil() -``` - -## License - -- © 2023 Chielo Newctle \ -- © 2023 ModelTC Team - -This project is licensed under either of - -- [Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0) ([`LICENSE-APACHE`](LICENSE-APACHE)) -- [MIT license](https://opensource.org/licenses/MIT) ([`LICENSE-MIT`](LICENSE-MIT)) - -at your option. - -The [SPDX](https://spdx.dev) license identifier for this project is `MIT OR Apache-2.0`. diff --git a/pybind/general_sam/__init__.py b/pybind/general_sam/__init__.py deleted file mode 100644 index f4259ff..0000000 --- a/pybind/general_sam/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -from .general_sam import ( - GeneralSAM, - GeneralSAMState, - Trie, - TrieNode, -) -from .trie_utils import ( - CountInfo, - SortResult, - construct_trie_from_bytes, - construct_trie_from_chars, - sort_bytes, - sort_chars, - sort_seq_via_trie, -) -from .vocab_prefix import ( - VocabPrefixAutomaton, - VocabPrefixBytesOrChars, -) - -__all__ = [ - 'GeneralSAM', - 'GeneralSAMState', - 'Trie', - 'TrieNode', - 'CountInfo', - 'SortResult', - 'construct_trie_from_chars', - 'construct_trie_from_bytes', - 'sort_chars', - 'sort_bytes', - 'sort_seq_via_trie', - 'VocabPrefixAutomaton', - 'VocabPrefixBytesOrChars', -] diff --git a/pybind/general_sam/general_sam.pyi b/pybind/general_sam/general_sam.pyi deleted file mode 100644 index 7181913..0000000 --- a/pybind/general_sam/general_sam.pyi +++ /dev/null @@ -1,165 +0,0 @@ -from typing import Callable, Mapping, Optional, Sequence, Union - - -class TrieNode: - def is_in_chars(self) -> bool: - ... - - def is_in_bytes(self) -> bool: - ... - - def get_node_id(self) -> int: - ... - - def is_accepting(self) -> bool: - ... - - def get_trans(self) -> Mapping[Union[str, int], int]: - ... - - def get_parent(self) -> int: - ... - - -class Trie: - @staticmethod - def in_chars() -> 'Trie': - ... - - @staticmethod - def in_bytes() -> 'Trie': - ... - - def is_in_chars(self) -> bool: - ... - - def is_in_bytes(self) -> bool: - ... - - def num_of_nodes(self) -> int: - ... - - def insert_chars(self, s: str) -> int: - ... - - def insert_bytes(self, s: bytes) -> int: - ... - - def get_bfs_order(self) -> Sequence[int]: - ... - - def get_root(self) -> TrieNode: - ... - - def get_node(self, node_id: int) -> Optional[TrieNode]: - ... - - def dfs_travel( - self, - in_stack_callback: Callable[[int, Optional[str]], None], - out_stack_callback: Callable[[int], None], - root_node_id: Optional[int] = None, - ) -> TrieNode: - ... - - def bfs_travel( - self, - in_queue_callback: Callable[[int, Optional[str]], None], - out_queue_callback: Callable[[int], None], - root_node_id: Optional[int] = None, - ) -> TrieNode: - ... - - -class GeneralSAMState: - def is_in_str(self) -> bool: - ... - - def is_in_bytes(self) -> bool: - ... - - def get_node_id(self) -> int: - ... - - def is_nil(self) -> bool: - ... - - def is_root(self) -> bool: - ... - - def is_accepting(self) -> bool: - ... - - def get_trans(self) -> Mapping[Union[str, int], int]: - ... - - def get_suffix_parent_id(self) -> int: - ... - - def copy(self) -> 'GeneralSAMState': - ... - - def goto_suffix_parent(self): - ... - - def goto_char(self, t: str): - ... - - def goto_byte(self, t: int): - ... - - def feed_chars(self, s: str): - ... - - def feed_bytes(self, s: bytes): - ... - - def dfs_along( - self, - trie: Trie, - in_stack_callback: Callable[['GeneralSAMState', int, Optional[str]], None], - out_stack_callback: Callable[['GeneralSAMState', int], None], - trie_node_id: Optional[int] = None, - ) -> TrieNode: - ... - - def bfs_along( - self, - trie: Trie, - in_queue_callback: Callable[['GeneralSAMState', int, Optional[str]], None], - out_queue_callback: Callable[['GeneralSAMState', int], None], - trie_node_id: Optional[int] = None, - ) -> TrieNode: - ... - - -class GeneralSAM: - @staticmethod - def construct_from_chars(s: str) -> 'GeneralSAM': - ... - - @staticmethod - def construct_from_bytes(s: bytes) -> 'GeneralSAM': - ... - - @staticmethod - def construct_from_trie(trie: Trie) -> 'GeneralSAM': - ... - - def is_in_str(self) -> bool: - ... - - def is_in_bytes(self) -> bool: - ... - - def num_of_nodes(self) -> int: - ... - - def get_root_state(self) -> GeneralSAMState: - ... - - def get_state(self, node_id: int) -> GeneralSAMState: - ... - - def get_topo_order(self) -> Sequence[GeneralSAMState]: - ... diff --git a/pybind/general_sam/trie_utils.py b/pybind/general_sam/trie_utils.py deleted file mode 100644 index 1c3469d..0000000 --- a/pybind/general_sam/trie_utils.py +++ /dev/null @@ -1,89 +0,0 @@ -from dataclasses import dataclass -from typing import Collection, Sequence, Tuple - -from .general_sam import Trie - - -def construct_trie_from_chars( - strings: Collection[str], -) -> Tuple[Trie, Sequence[int]]: - trie = Trie.in_chars() - node_ids = [trie.insert_chars(s) for s in strings] - return trie, node_ids - - -def construct_trie_from_bytes( - strings: Collection[bytes], -) -> Tuple[Trie, Sequence[int]]: - trie = Trie.in_bytes() - node_ids = [trie.insert_bytes(s) for s in strings] - return trie, node_ids - - -@dataclass -class CountInfo: - str_cnt: int - tot_cnt_lower: int - tot_cnt_upper: int - - -@dataclass -class SortResult: - trie: Trie - node_ids: Sequence[int] - cnt_info_on_nodes: Sequence[CountInfo] - cnt_info_on_strings: Sequence[CountInfo] - order: Sequence[int] - rank: Sequence[int] - - -def sort_chars(strings: Collection[str]) -> SortResult: - trie, node_ids = construct_trie_from_chars(strings) - return sort_seq_via_trie(trie, node_ids) - - -def sort_bytes(strings: Collection[bytes]) -> SortResult: - trie, node_ids = construct_trie_from_bytes(strings) - return sort_seq_via_trie(trie, node_ids) - - -def sort_seq_via_trie(trie: Trie, node_ids: Sequence[int]) -> SortResult: - num_of_seq = len(node_ids) - - cnt_info_on_nodes = [CountInfo(0, 0, 0) for _ in range(trie.num_of_nodes())] - for k in node_ids: - cnt_info_on_nodes[k].str_cnt += 1 - - tot_str_cnt = 0 - - def in_stack(node_id: int, _): - nonlocal tot_str_cnt - node_info = cnt_info_on_nodes[node_id] - node_info.tot_cnt_lower = tot_str_cnt - tot_str_cnt += node_info.str_cnt - - def out_stack(node_id: int): - nonlocal tot_str_cnt - node_info = cnt_info_on_nodes[node_id] - node_info.tot_cnt_upper = tot_str_cnt - - trie.dfs_travel(in_stack, out_stack) - - cnt_info_on_strings = [cnt_info_on_nodes[node_ids[i]] for i in range(num_of_seq)] - - order = sorted( - range(num_of_seq), - key=lambda i: cnt_info_on_strings[i].tot_cnt_lower, - ) - rank = [0] * num_of_seq - for k, i in enumerate(order): - rank[i] = k - - return SortResult( - trie=trie, - node_ids=node_ids, - cnt_info_on_nodes=cnt_info_on_nodes, - cnt_info_on_strings=cnt_info_on_strings, - order=order, - rank=rank, - ) diff --git a/pybind/general_sam/vocab_prefix.py b/pybind/general_sam/vocab_prefix.py deleted file mode 100644 index dcf98af..0000000 --- a/pybind/general_sam/vocab_prefix.py +++ /dev/null @@ -1,148 +0,0 @@ -import enum -from dataclasses import replace -from typing import ( - Callable, - Iterable, - List, - Optional, - Sequence, - Tuple, - Union, - cast, -) - -from .general_sam import GeneralSAM, GeneralSAMState, Trie -from .trie_utils import ( - CountInfo, - SortResult, - construct_trie_from_bytes, - construct_trie_from_chars, - sort_bytes, - sort_chars, -) - - -class VocabPrefixBytesOrChars(enum.Enum): - BYTES = enum.auto() - CHARS = enum.auto() - - -class VocabPrefixAutomaton(object): - def __init__( - self, - vocab: Iterable[Union[str, bytes]], - bytes_or_chars: Union[ - str, VocabPrefixBytesOrChars - ] = VocabPrefixBytesOrChars.CHARS, - ) -> None: - if isinstance(bytes_or_chars, str): - bytes_or_chars = getattr(VocabPrefixBytesOrChars, bytes_or_chars.upper()) - - self.bytes_or_chars = cast(VocabPrefixBytesOrChars, bytes_or_chars) - - self.vocab: Sequence[Union[str, bytes]] = list(vocab) - - if self.bytes_or_chars == VocabPrefixBytesOrChars.BYTES and isinstance( - self.vocab[0], str - ): - self.vocab = list(cast(str, i).encode() for i in self.vocab) - if self.bytes_or_chars == VocabPrefixBytesOrChars.CHARS and isinstance( - self.vocab[0], bytes - ): - self.vocab = list(cast(bytes, i).decode() for i in self.vocab) - - self.vocab_rev: Sequence[Union[str, bytes]] = list(s[::-1] for s in vocab) - - sort_seq, construct_trie = { - VocabPrefixBytesOrChars.BYTES: (sort_bytes, construct_trie_from_bytes), - VocabPrefixBytesOrChars.CHARS: (sort_chars, construct_trie_from_chars), - }[self.bytes_or_chars] - self.vocab_sort_res = cast(SortResult, sort_seq(self.vocab)) - self.trie_rev, self.trie_rev_node_ids = cast( - Tuple[Trie, Sequence[int]], - construct_trie(self.vocab_rev), - ) - - self.sam_rev = GeneralSAM.construct_from_trie(self.trie_rev) - self._gen_cnt_info_in_sam() - - @property - def _state_feed_fn(self) -> Callable[[GeneralSAMState, Union[bytes, str]], None]: - return { - VocabPrefixBytesOrChars.BYTES: GeneralSAMState.feed_bytes, - VocabPrefixBytesOrChars.CHARS: GeneralSAMState.feed_chars, - }[self.bytes_or_chars] - - def _gen_cnt_info_in_sam(self): - self.cnt_info_in_sam: List[Optional[CountInfo]] = [ - None for _ in range(self.sam_rev.num_of_nodes()) - ] - - for token_rev, cnt_info in zip( - self.vocab_rev, self.vocab_sort_res.cnt_info_on_strings - ): - state = self.sam_rev.get_root_state() - self._state_feed_fn(state, token_rev) - state_id = state.get_node_id() - self.cnt_info_in_sam[state_id] = replace(cnt_info, str_cnt=1) - - for sam_state in reversed(self.sam_rev.get_topo_order()): - assert not sam_state.is_nil() - if sam_state.is_root(): - continue - - state_id = sam_state.get_node_id() - state_cnt_info = self.cnt_info_in_sam[state_id] - if state_cnt_info is None: - continue - - link_id = sam_state.get_suffix_parent_id() - link_cnt_info = self.cnt_info_in_sam[link_id] - - if link_cnt_info is None: - self.cnt_info_in_sam[link_id] = replace(state_cnt_info) - continue - - link_cnt_info.str_cnt += state_cnt_info.str_cnt - link_cnt_info.tot_cnt_lower = min( - link_cnt_info.tot_cnt_lower, - state_cnt_info.tot_cnt_lower, - ) - link_cnt_info.tot_cnt_upper = max( - link_cnt_info.tot_cnt_upper, - state_cnt_info.tot_cnt_upper, - ) - - for state_id in range(self.sam_rev.num_of_nodes()): - sam_state = self.sam_rev.get_state(state_id) - state_cnt_info = self.cnt_info_in_sam[state_id] - if sam_state.is_nil() or sam_state.is_root() or state_cnt_info is None: - continue - - link_id = sam_state.get_suffix_parent_id() - link_cnt_info = self.cnt_info_in_sam[link_id] - - assert link_cnt_info is not None - assert link_cnt_info.tot_cnt_lower <= state_cnt_info.tot_cnt_lower - assert link_cnt_info.tot_cnt_upper >= state_cnt_info.tot_cnt_upper - - def get_root_state(self) -> GeneralSAMState: - return self.sam_rev.get_root_state() - - def prepend_feed( - self, state: GeneralSAMState, token: Union[str, bytes] - ) -> Optional[CountInfo]: - if self.bytes_or_chars == VocabPrefixBytesOrChars.BYTES and isinstance( - token, str - ): - token = token.encode() - self._state_feed_fn(state, token[::-1]) - return self.cnt_info_in_sam[state.get_node_id()] - - def get_order(self) -> Sequence[int]: - return self.vocab_sort_res.order - - def get_order_slice(self, cnt_info: CountInfo) -> Sequence[int]: - return self.vocab_sort_res.order[ - cnt_info.tot_cnt_lower : cnt_info.tot_cnt_upper - ] diff --git a/pybind/pyproject.toml b/pybind/pyproject.toml deleted file mode 100644 index d33302d..0000000 --- a/pybind/pyproject.toml +++ /dev/null @@ -1,16 +0,0 @@ -[build-system] -requires = ["maturin>=1.3,<2.0"] -build-backend = "maturin" - -[project] -name = "general_sam" -requires-python = ">=3.7" -classifiers = [ - "Programming Language :: Rust", - "Programming Language :: Python :: Implementation :: CPython", - "Programming Language :: Python :: Implementation :: PyPy", -] -dynamic = ["version"] - -[tool.maturin] -features = ["pyo3/extension-module"] diff --git a/pybind/src/lib.rs b/pybind/src/lib.rs deleted file mode 100644 index 5e8d8f9..0000000 --- a/pybind/src/lib.rs +++ /dev/null @@ -1,429 +0,0 @@ -extern crate general_sam as general_sam_rs; - -use std::{convert::Infallible, str::from_utf8, sync::Arc}; - -use either::{for_both, Either}; -use pyo3::{prelude::*, types::PyDict}; - -use general_sam_rs::{ - sam, trie, - trie_alike::{TravelEvent, TrieNodeAlike}, -}; - -#[pyclass] -struct Trie(Either, trie::Trie>); - -#[pyclass] -struct TrieNode(usize, Either, trie::Node>); - -#[pymethods] -impl TrieNode { - fn is_in_chars(&self) -> bool { - self.1.is_left() - } - - fn is_in_bytes(&self) -> bool { - self.1.is_right() - } - - fn get_node_id(&self) -> usize { - self.0 - } - - fn is_accepting(&self) -> bool { - for_both!(self.1.as_ref(), x => x.accept) - } - - fn get_trans(&self) -> PyObject { - Python::with_gil(|py| { - for_both!(self.1.as_ref(), x => { - x.get_trans().clone().into_py(py) - }) - }) - } - - fn get_parent(&self) -> usize { - for_both!(self.1.as_ref(), x => x.get_parent()) - } -} - -#[pymethods] -impl Trie { - #[staticmethod] - fn in_chars() -> Self { - Trie(Either::Left(trie::Trie::default())) - } - - #[staticmethod] - fn in_bytes() -> Self { - Trie(Either::Right(trie::Trie::default())) - } - - fn is_in_chars(&self) -> bool { - self.0.is_left() - } - - fn is_in_bytes(&self) -> bool { - self.0.is_right() - } - - fn num_of_nodes(&self) -> usize { - for_both!(self.0.as_ref(), x => x.num_of_nodes()) - } - - fn insert_chars(&mut self, s: &str) -> usize { - match self.0.as_mut() { - Either::Left(trie_chars) => trie_chars.insert_iter(s.chars()), - Either::Right(trie_bytes) => trie_bytes.insert_ref_iter(s.as_bytes().iter()), - } - } - - fn insert_bytes(&mut self, b: &[u8]) -> usize { - match self.0.as_mut() { - Either::Left(trie_chars) => trie_chars.insert_iter(from_utf8(b).unwrap().chars()), - Either::Right(trie_bytes) => trie_bytes.insert_ref_iter(b.iter()), - } - } - - fn get_bfs_order(&self) -> Vec { - for_both!(self.0.as_ref(), trie => { - let state = trie.get_root_state(); - let mut res = Vec::new(); - state - .bfs_travel(|event| -> Result<(), Infallible> { - if let TravelEvent::Push(s, _) = event { - res.push(s.node_id); - } - Ok(()) - }) - .unwrap(); - res - }) - } - - fn get_root(&self) -> TrieNode { - self.get_node(trie::TRIE_ROOT_NODE_ID).unwrap() - } - - fn get_node(&self, node_id: usize) -> Option { - match self.0.as_ref() { - Either::Left(trie) => trie - .get_node(node_id) - .map(|node| TrieNode(node_id, Either::Left(node.clone()))), - Either::Right(trie) => trie - .get_node(node_id) - .map(|node| TrieNode(node_id, Either::Right(node.clone()))), - } - } - - #[pyo3(signature = (in_stack_callback, out_stack_callback, root_node_id=None))] - fn dfs_travel( - &self, - in_stack_callback: PyObject, - out_stack_callback: PyObject, - root_node_id: Option, - ) -> Result<(), PyErr> { - for_both!(self.0.as_ref(), trie => { - let root_state = trie.get_state(root_node_id.unwrap_or(trie::TRIE_ROOT_NODE_ID)); - if root_state.is_nil() { - return Ok(()); - } - root_state.dfs_travel(|event| match event { - TravelEvent::Push(tn, key_opt) => Python::with_gil(|py| { - in_stack_callback.call1(py, (tn.node_id, key_opt.copied())) - }) - .map(|_| ()), - TravelEvent::Pop(tn) => { - Python::with_gil(|py| out_stack_callback.call1(py, (tn.node_id,))).map(|_| ()) - } - }) - }) - } - - #[pyo3(signature = (in_stack_callback, out_stack_callback, root_node_id=None))] - fn bfs_travel( - &self, - in_stack_callback: PyObject, - out_stack_callback: PyObject, - root_node_id: Option, - ) -> Result<(), PyErr> { - for_both!(self.0.as_ref(), trie => { - let root_state = trie.get_state(root_node_id.unwrap_or(trie::TRIE_ROOT_NODE_ID)); - if root_state.is_nil() { - return Ok(()); - } - root_state.bfs_travel(|event| match event { - TravelEvent::Push(tn, key_opt) => Python::with_gil(|py| { - in_stack_callback.call1(py, (tn.node_id, key_opt.copied())) - }) - .map(|_| ()), - TravelEvent::Pop(tn) => { - Python::with_gil(|py| out_stack_callback.call1(py, (tn.node_id,))).map(|_| ()) - } - }) - }) - } -} - -#[pyclass] -struct GeneralSAM(Arc, sam::GeneralSAM>>); - -#[pyclass] -#[derive(Clone)] -struct GeneralSAMState( - Arc, sam::GeneralSAM>>, - usize, -); - -impl GeneralSAMState { - fn get_state(&self) -> Either, sam::State> { - self.0 - .as_ref() - .as_ref() - .map_either(|x| x.get_state(self.1), |x| x.get_state(self.1)) - } -} - -#[pymethods] -impl GeneralSAMState { - fn is_in_chars(&self) -> bool { - self.0.is_left() - } - - fn is_in_bytes(&self) -> bool { - self.0.is_right() - } - - fn get_node_id(&self) -> usize { - self.1 - } - - fn is_nil(&self) -> bool { - for_both!(self.get_state().as_ref(), x => x.is_nil()) - } - - fn is_root(&self) -> bool { - for_both!(self.get_state().as_ref(), x => x.is_root()) - } - - fn is_accepting(&self) -> bool { - for_both!(self.get_state().as_ref(), x => x.is_accepting()) - } - - fn get_trans(&self) -> PyObject { - Python::with_gil(|py| { - for_both!(self.get_state().as_ref(), state => { - if let Some(node) = state.get_node() { - node.get_trans().clone().into_py(py) - } else { - PyDict::new(py).into_py(py) - } - }) - }) - } - - fn get_suffix_parent_id(&self) -> usize { - for_both!(self.get_state().as_ref() , x => { - x.get_node() - .map(|node| node.get_suffix_parent_id()) - .unwrap_or(sam::SAM_NIL_NODE_ID) - }) - } - - fn copy(&self) -> Self { - self.clone() - } - - fn goto_suffix_parent(&mut self) { - for_both!(self.get_state(), mut state => { - state.goto_suffix_parent(); - self.1 = state.node_id; - }) - } - - fn goto_char(&mut self, t: char) { - let mut state = self.get_state().left().unwrap(); - state.goto(&t); - self.1 = state.node_id; - } - - fn goto_byte(&mut self, t: u8) { - let mut state = self.get_state().right().unwrap(); - state.goto(&t); - self.1 = state.node_id; - } - - fn feed_chars(&mut self, s: &str) { - match self.get_state() { - Either::Left(state_chars) => { - let state_chars = state_chars.feed_chars(s); - self.1 = state_chars.node_id; - } - Either::Right(state_bytes) => { - let state_bytes = state_bytes.feed_ref_iter(s.as_bytes().iter()); - self.1 = state_bytes.node_id; - } - } - } - - fn feed_bytes(&mut self, s: &[u8]) { - match self.get_state() { - Either::Left(state_chars) => { - let state_chars = state_chars.feed_iter(from_utf8(s).unwrap().chars()); - self.1 = state_chars.node_id; - } - Either::Right(state_bytes) => { - let state_bytes = state_bytes.feed_ref_iter(s.iter()); - self.1 = state_bytes.node_id; - } - } - } - - #[pyo3(signature = (trie, in_stack_callback, out_stack_callback, trie_node_id=None))] - fn dfs_along( - &self, - trie: &Trie, - in_stack_callback: PyObject, - out_stack_callback: PyObject, - trie_node_id: Option, - ) -> Result<(), PyErr> { - assert!(trie.is_in_chars() == self.is_in_chars()); - let sam_and_trie = self.0.as_ref().as_ref().map_either( - |sam_chars| (sam_chars, trie.0.as_ref().left().unwrap()), - |sam_bytes| (sam_bytes, trie.0.as_ref().right().unwrap()), - ); - for_both!(sam_and_trie, (sam, trie) => { - let tn = trie.get_state(trie_node_id.unwrap_or(trie::TRIE_ROOT_NODE_ID)); - sam.dfs_along(tn, self.1, |event| match event { - TravelEvent::Push((st, tn), key_opt) => Python::with_gil(|py| { - in_stack_callback - .call1( - py, - ( - GeneralSAMState(self.0.clone(), st.node_id), - tn.node_id, - key_opt.copied(), - ), - ) - .map(|_| ()) - }) - .map(|_| ()), - TravelEvent::Pop((st, tn)) => Python::with_gil(|py| { - out_stack_callback - .call1( - py, - (GeneralSAMState(self.0.clone(), st.node_id), tn.node_id), - ) - .map(|_| ()) - }), - }) - }) - } - - #[pyo3(signature = (trie, in_stack_callback, out_stack_callback, trie_node_id=None))] - fn bfs_along( - &self, - trie: &Trie, - in_stack_callback: PyObject, - out_stack_callback: PyObject, - trie_node_id: Option, - ) -> Result<(), PyErr> { - assert!(trie.is_in_chars() == self.is_in_chars()); - let sam_and_trie = self.0.as_ref().as_ref().map_either( - |sam_chars| (sam_chars, trie.0.as_ref().left().unwrap()), - |sam_bytes| (sam_bytes, trie.0.as_ref().right().unwrap()), - ); - for_both!(sam_and_trie, (sam, trie) => { - let tn = trie.get_state(trie_node_id.unwrap_or(trie::TRIE_ROOT_NODE_ID)); - sam.bfs_along(tn, self.1, |event| match event { - TravelEvent::Push((st, tn), key_opt) => Python::with_gil(|py| { - in_stack_callback - .call1( - py, - ( - GeneralSAMState(self.0.clone(), st.node_id), - tn.node_id, - key_opt.copied(), - ), - ) - .map(|_| ()) - }) - .map(|_| ()), - TravelEvent::Pop((st, tn)) => Python::with_gil(|py| { - out_stack_callback - .call1( - py, - (GeneralSAMState(self.0.clone(), st.node_id), tn.node_id), - ) - .map(|_| ()) - }), - }) - }) - } -} - -#[pymethods] -impl GeneralSAM { - #[staticmethod] - fn construct_from_chars(s: &str) -> Self { - GeneralSAM(Arc::new(Either::Left( - sam::GeneralSAM::construct_from_chars(s.chars()), - ))) - } - - #[staticmethod] - fn construct_from_bytes(s: &[u8]) -> Self { - GeneralSAM(Arc::new(Either::Right( - sam::GeneralSAM::construct_from_bytes(s), - ))) - } - - #[staticmethod] - fn construct_from_trie(trie: &Trie) -> Self { - match trie.0.as_ref() { - Either::Left(trie_chars) => GeneralSAM(Arc::new(Either::Left( - sam::GeneralSAM::construct_from_trie(trie_chars.get_root_state()), - ))), - Either::Right(trie_bytes) => GeneralSAM(Arc::new(Either::Right( - sam::GeneralSAM::construct_from_trie(trie_bytes.get_root_state()), - ))), - } - } - - fn is_in_chars(&self) -> bool { - self.0.is_left() - } - - fn is_in_bytes(&self) -> bool { - self.0.is_right() - } - - fn num_of_nodes(&self) -> usize { - for_both!(self.0.as_ref(), x => x.num_of_nodes()) - } - - fn get_root_state(&self) -> GeneralSAMState { - GeneralSAMState(self.0.clone(), sam::SAM_ROOT_NODE_ID) - } - - fn get_state(&self, node_id: usize) -> GeneralSAMState { - GeneralSAMState(self.0.clone(), node_id) - } - - fn get_topo_order(&self) -> Vec { - for_both!(self.0.as_ref(), x => { - x.get_topo_order() - .map(|s| self.get_state(s.node_id)) - .collect() - }) - } -} - -#[pymodule] -fn general_sam(_py: Python, m: &PyModule) -> PyResult<()> { - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - Ok(()) -} diff --git a/pybind/tests/test_general_sam.py b/pybind/tests/test_general_sam.py deleted file mode 100644 index f3db3af..0000000 --- a/pybind/tests/test_general_sam.py +++ /dev/null @@ -1,47 +0,0 @@ -from general_sam import GeneralSAM, GeneralSAMState, construct_trie_from_chars - - -def test_bytes_abcbc(): - sam = GeneralSAM.construct_from_bytes(b'abcbc') - - state = sam.get_root_state() - state.feed_bytes(b'cbc') - assert state.is_accepting() - - state = sam.get_root_state() - state.feed_bytes(b'bcb') - assert not state.is_accepting() - - -def test_chars_abcbc(): - sam = GeneralSAM.construct_from_chars('abcbc') - state = sam.get_root_state() - - state.feed_chars('b') - assert not state.is_accepting() - state.feed_chars('c') - assert state.is_accepting() - state.feed_chars('bc') - assert state.is_accepting() - state.feed_chars('bc') - assert not state.is_accepting() and state.is_nil() - - -def test_simple_sam_from_trie(): - trie, _ = construct_trie_from_chars(['hello', 'Chielo']) - sam = GeneralSAM.construct_from_trie(trie) - - def fetch_state(s: str) -> GeneralSAMState: - state = sam.get_root_state() - state.feed_chars(s) - return state - - assert fetch_state('lo').is_accepting() - assert fetch_state('ello').is_accepting() - assert fetch_state('elo').is_accepting() - - state = fetch_state('el') - assert not state.is_accepting() and not state.is_nil() - - state = fetch_state('bye') - assert not state.is_accepting() and state.is_nil() diff --git a/pybind/tests/test_token_healing.py b/pybind/tests/test_token_healing.py deleted file mode 100644 index 86c5ec4..0000000 --- a/pybind/tests/test_token_healing.py +++ /dev/null @@ -1,134 +0,0 @@ -from typing import Collection, Iterable, Optional, Sequence, Union - -from general_sam import ( - CountInfo, - GeneralSAMState, - VocabPrefixAutomaton, - VocabPrefixBytesOrChars, -) - - -def _test_token_healing_batch( - vocab: Collection[Union[str, bytes]], - token_sequences: Iterable[Union[Sequence[str], Sequence[bytes]]], - bytes_or_chars: VocabPrefixBytesOrChars, -): - automaton = VocabPrefixAutomaton(vocab, bytes_or_chars=bytes_or_chars) - - vocab_sorted = sorted(vocab) - - def validate( - query: Union[str, bytes], state: GeneralSAMState, cnt_info: Optional[CountInfo] - ): - import bisect - - expected_l = bisect.bisect_left( - vocab_sorted, query, key=lambda x: x[: len(query)] - ) - expected_r = bisect.bisect_right( - vocab_sorted, query, key=lambda x: x[: len(query)] - ) - - if expected_l < expected_r: - expected_cnt_info = CountInfo( - str_cnt=expected_r - expected_l, - tot_cnt_lower=expected_l, - tot_cnt_upper=expected_r, - ) - else: - expected_cnt_info = None - - assert cnt_info == expected_cnt_info, (query, cnt_info, expected_cnt_info) - - assert state.is_nil() ^ any(query in i for i in vocab) # pyright: ignore - - def check(tokens: Sequence[Union[str, bytes]]): - state = automaton.get_root_state() - query = '' if isinstance(tokens[0], str) else b'' - - # NOTE: tokens are prepended in the reverse order - for token in reversed(tokens): - query = token + query # pyright: ignore - cnt_info = automaton.prepend_feed(state, token) - validate(query, state, cnt_info) - - for tokens in token_sequences: - check(tokens) - - -def _test_batch( - vocab: Collection[str], - token_sequences: Iterable[Union[Sequence[str], Sequence[bytes]]], -): - _test_token_healing_batch( - vocab, - tuple(filter(lambda x: isinstance(x[0], str), token_sequences)), - VocabPrefixBytesOrChars.CHARS, - ) - _test_token_healing_batch( - tuple(i.encode() for i in vocab), - tuple( - tuple(i.encode() if isinstance(i, str) else i for i in s) - for s in token_sequences - ), - VocabPrefixBytesOrChars.BYTES, - ) - - -def test_simple_token_healing(): - _test_batch( - ['bb', 'ca', 'ab', 'c', 'aa', 'bbaa', 'a', 'cc', 'b'], - [ - ['bb', 'a'], - ['b', 'b', 'b'], - ['b', 'b', 'a'], - ['b', 'ba'], - ['ca', 'c', 'ab'], - ['c', 'c', 'c'], - ], - ) - - -def test_simple_chinese_token_healing(): - _test_batch( - ['歌曲', '聆听歌曲', '播放歌曲', '歌词', '查看歌词'], - [ - ['歌曲'], - ['聆听歌曲'], - ['聆听', '歌曲'], - ['聆', '听', '歌曲'], - ['播放歌曲'], - ['播', '放歌曲'], - ['播放', '歌曲'], - ['歌词'], - ['查看歌词'], - ['查看', '歌词'], - ['听歌曲'], - ['听', '歌曲'], - ['放歌曲'], - ['听歌'], - ['放歌'], - ['词'], - ['查看'], - ['bb', 'a'], - ['b', 'b', 'b'], - ['b', 'b', 'a'], - ['b', 'ba'], - ['ca', 'c', 'ab'], - ['c', 'c', 'c'], - ], - ) - - -def test_simple_utf8_token_healing(): - # '䨻'.encode('utf8') == b'\xe4\xa8\xbb' - _test_batch( - ['䨻'], - [ - ['䨻'], - [b'\xe4', b'\xa8', b'\xbb'], - [b'\xe4', b'\xa8\xbb'], - [b'\xe4\xa8', b'\xbb'], - [b'\xe4\xa8\xbb'], - ], - ) diff --git a/pybind/tests/test_vocab_prefix.py b/pybind/tests/test_vocab_prefix.py deleted file mode 100644 index f741008..0000000 --- a/pybind/tests/test_vocab_prefix.py +++ /dev/null @@ -1,74 +0,0 @@ -from general_sam import VocabPrefixAutomaton, CountInfo - - -def test_chinese_chars_vocab_prefix(): - vocab = ['歌曲', '聆听歌曲', '播放歌曲', '歌词', '查看歌词'] - automaton = VocabPrefixAutomaton(vocab, bytes_or_chars='chars') - - # NOTE: CountInfo is related to the sorted vocab: - _ = ['播放歌曲', '查看歌词', '歌曲', '歌词', '聆听歌曲'] - - # 一起 | 聆 | 听 | 歌 - state = automaton.get_root_state() - - # feed 歌 - cnt_info = automaton.prepend_feed(state, '歌') - assert cnt_info is not None and cnt_info == CountInfo( - str_cnt=2, tot_cnt_lower=2, tot_cnt_upper=4 - ) - - selected_idx = automaton.get_order_slice(cnt_info) - assert frozenset(selected_idx) == {0, 3} - selected_vocab = [vocab[i] for i in selected_idx] - assert frozenset(selected_vocab) == {'歌曲', '歌词'} - - # feed 听 - cnt_info = automaton.prepend_feed(state, '听') - assert cnt_info is None - assert not state.is_nil() - - # feed 聆 - cnt_info = automaton.prepend_feed(state, '聆') - assert cnt_info is not None and cnt_info == CountInfo( - str_cnt=1, tot_cnt_lower=4, tot_cnt_upper=5 - ) - - selected_idx = automaton.get_order_slice(cnt_info) - assert frozenset(selected_idx) == {1} - selected_vocab = [vocab[i] for i in selected_idx] - assert frozenset(selected_vocab) == {'聆听歌曲'} - - # feed 一起 - assert not state.is_nil() - cnt_info = automaton.prepend_feed(state, '一起') - assert state.is_nil() - - # 来 | 查看 | 歌词 - state = automaton.get_root_state() - - # feed 歌词 - cnt_info = automaton.prepend_feed(state, '歌词') - assert cnt_info is not None and cnt_info == CountInfo( - str_cnt=1, tot_cnt_lower=3, tot_cnt_upper=4 - ) - - selected_idx = automaton.get_order_slice(cnt_info) - assert frozenset(selected_idx) == {3} - selected_vocab = [vocab[i] for i in selected_idx] - assert frozenset(selected_vocab) == {'歌词'} - - # feed 查看 - cnt_info = automaton.prepend_feed(state, '查看') - assert cnt_info is not None and cnt_info == CountInfo( - str_cnt=1, tot_cnt_lower=1, tot_cnt_upper=2 - ) - - selected_idx = automaton.get_order_slice(cnt_info) - assert frozenset(selected_idx) == {4} - selected_vocab = [vocab[i] for i in selected_idx] - assert frozenset(selected_vocab) == {'查看歌词'} - - # feed 来 - assert not state.is_nil() - cnt_info = automaton.prepend_feed(state, '来') - assert state.is_nil()