From 356df99d5dafdf1b4bc738b1834ed029e124d0af Mon Sep 17 00:00:00 2001 From: Klaus Greff Date: Mon, 23 Sep 2024 09:59:35 -0700 Subject: [PATCH] standalone parser PiperOrigin-RevId: 677836507 --- LICENSE | 140 +- kauldron/typing/__init__.py | 3 +- kauldron/typing/shape_parser.py | 311 +++ kauldron/typing/shape_spec.lark | 81 + kauldron/typing/shape_spec.py | 399 +--- kauldron/typing/shape_spec_test.py | 7 +- kauldron/typing/type_check.py | 5 +- kauldron/typing/utils.py | 46 + kauldron/utils/standalone_parser.py | 3390 +++++++++++++++++++++++++++ 9 files changed, 3982 insertions(+), 400 deletions(-) create mode 100644 kauldron/typing/shape_parser.py create mode 100644 kauldron/typing/shape_spec.lark create mode 100644 kauldron/typing/utils.py create mode 100644 kauldron/utils/standalone_parser.py diff --git a/LICENSE b/LICENSE index 7a4a3ea2..b0b81461 100644 --- a/LICENSE +++ b/LICENSE @@ -1,3 +1,4 @@ +Files: * except those noted below Apache License Version 2.0, January 2004 @@ -199,4 +200,141 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file + limitations under the License. + +--- +Files: utils/standalone_parser.py + +Mozilla Public License 2.0 + +1.1. "Contributor" means each individual or legal entity that creates, contributes to the creation of, or owns Covered Software. +1.2. "Contributor Version" means the combination of the Contributions of others (if any) used by a Contributor and that particular Contributor's Contribution. +1.3. "Contribution" means Covered Software of a particular Contributor. +1.4. "Covered Software" means Source Code Form to which the initial Contributor has attached the notice in Exhibit A, the Executable Form of such Source Code Form, and Modifications of such Source Code Form, in each case including portions thereof. +1.5. "Incompatible With Secondary Licenses" means +(a) that the initial Contributor has attached the notice described in Exhibit B to the Covered Software; or +(b) that the Covered Software was made available under the terms of version 1.1 or earlier of the License, but not also under the terms of a Secondary License. +1.6. "Executable Form" means any form of the work other than Source Code Form. +1.7. "Larger Work" means a work that combines Covered Software with other material, in a separate file or files, that is not Covered Software. +1.8. "License" means this document. +1.9. "Licensable" means having the right to grant, to the maximum extent possible, whether at the time of the initial grant or subsequently, any and all of the rights conveyed by this License. +1.10. "Modifications" means any of the following: +(a) any file in Source Code Form that results from an addition to, deletion from, or modification of the contents of Covered Software; or +(b) any new file in Source Code Form that contains any Covered Software. +1.11. "Patent Claims" of a Contributor means any patent claim(s), including without limitation, method, process, and apparatus claims, in any patent Licensable by such Contributor that would be infringed, but for the grant of the License, by the making, using, selling, offering for sale, having made, import, or transfer of either its Contributions or its Contributor Version. +1.12. "Secondary License" means either the GNU General Public License, Version 2.0, the GNU Lesser General Public License, Version 2.1, the GNU Affero General Public License, Version 3.0, or any later versions of those licenses. +1.13. "Source Code Form" means the form of the work preferred for making modifications. +1.14. "You" (or "Your") means an individual or a legal entity exercising rights under this License. For legal entities, "You" includes any entity that controls, is controlled by, or is under common control with You. For purposes of this definition, "control" means (a) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (b) ownership of more than fifty percent (50%) of the outstanding shares or beneficial ownership of such entity. +2. License Grants and Conditions + +-------------------------------- + +2.1. Grants +Each Contributor hereby grants You a world-wide, royalty-free, non-exclusive license: + +(a) under intellectual property rights (other than patent or trademark) Licensable by such Contributor to use, reproduce, make available, modify, display, perform, distribute, and otherwise exploit its Contributions, either on an unmodified basis, with Modifications, or as part of a Larger Work; and +(b) under Patent Claims of such Contributor to make, use, sell, offer for sale, have made, import, and otherwise transfer either its Contributions or its Contributor Version. +2.2. Effective Date +The licenses granted in Section 2.1 with respect to any Contribution become effective for each Contribution on the date the Contributor first distributes such Contribution. + +2.3. Limitations on Grant Scope +The licenses granted in this Section 2 are the only rights granted under this License. No additional rights or licenses will be implied from the distribution or licensing of Covered Software under this License. Notwithstanding Section 2.1(b) above, no patent license is granted by a Contributor: + +(a) for any code that a Contributor has removed from Covered Software; or +(b) for infringements caused by: (i) Your and any other third party's modifications of Covered Software, or (ii) the combination of its Contributions with other software (except as part of its Contributor Version); or +(c) under Patent Claims infringed by Covered Software in the absence of its Contributions. +This License does not grant any rights in the trademarks, service marks, or logos of any Contributor (except as may be necessary to comply with the notice requirements in Section 3.4). + +2.4. Subsequent Licenses +No Contributor makes additional grants as a result of Your choice to distribute the Covered Software under a subsequent version of this License (see Section 10.2) or under the terms of a Secondary License (if permitted under the terms of Section 3.3). + +2.5. Representation +Each Contributor represents that the Contributor believes its Contributions are its original creation(s) or it has sufficient rights to grant the rights to its Contributions conveyed by this License. + +2.6. Fair Use +This License is not intended to limit any rights You have under applicable copyright doctrines of fair use, fair dealing, or other equivalents. + +2.7. Conditions +Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in Section 2.1. + +3. Responsibilities + +------------------- + +3.1. Distribution of Source Form +All distribution of Covered Software in Source Code Form, including any Modifications that You create or to which You contribute, must be under the terms of this License. You must inform recipients that the Source Code Form of the Covered Software is governed by the terms of this License, and how they can obtain a copy of this License. You may not attempt to alter or restrict the recipients' rights in the Source Code Form. + +3.2. Distribution of Executable Form +If You distribute Covered Software in Executable Form then: + +(a) such Covered Software must also be made available in Source Code Form, as described in Section 3.1, and You must inform recipients of the Executable Form how they can obtain a copy of such Source Code Form by reasonable means in a timely manner, at a charge no more than the cost of distribution to the recipient; and +(b) You may distribute such Executable Form under the terms of this License, or sublicense it under different terms, provided that the license for the Executable Form does not attempt to limit or alter the recipients' rights in the Source Code Form under this License. +3.3. Distribution of a Larger Work +You may create and distribute a Larger Work under terms of Your choice, provided that You also comply with the requirements of this License for the Covered Software. If the Larger Work is a combination of Covered Software with a work governed by one or more Secondary Licenses, and the Covered Software is not Incompatible With Secondary Licenses, this License permits You to additionally distribute such Covered Software under the terms of such Secondary License(s), so that the recipient of the Larger Work may, at their option, further distribute the Covered Software under the terms of either this License or such Secondary License(s). + +3.4. Notices +You may not remove or alter the substance of any license notices (including copyright notices, patent notices, disclaimers of warranty, or limitations of liability) contained within the Source Code Form of the Covered Software, except that You may alter any license notices to the extent required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms +You may choose to offer, and to charge a fee for, warranty, support, indemnity or liability obligations to one or more recipients of Covered Software. However, You may do so only on Your own behalf, and not on behalf of any Contributor. You must make it absolutely clear that any such warranty, support, indemnity, or liability obligation is offered by You alone, and You hereby agree to indemnify every Contributor for any liability incurred by such Contributor as a result of warranty, support, indemnity or liability terms You offer. You may include additional disclaimers of warranty and limitations of liability specific to any jurisdiction. + +4. Inability to Comply Due to Statute or Regulation + +--------------------------------------------------- + +If it is impossible for You to comply with any of the terms of this License with respect to some or all of the Covered Software due to statute, judicial order, or regulation then You must: (a) comply with the terms of this License to the maximum extent possible; and (b) describe the limitations and the code they affect. Such description must be placed in a text file included with all distributions of the Covered Software under this License. Except to the extent prohibited by statute or regulation, such description must be sufficiently detailed for a recipient of ordinary skill to be able to understand it. + +5. Termination + +-------------- + +5.1. The rights granted under this License will terminate automatically if You fail to comply with any of its terms. However, if You become compliant, then the rights granted under this License from a particular Contributor are reinstated (a) provisionally, unless and until such Contributor explicitly and finally terminates Your grants, and (b) on an ongoing basis, if such Contributor fails to notify You of the non-compliance by some reasonable means prior to 60 days after You have come back into compliance. Moreover, Your grants from a particular Contributor are reinstated on an ongoing basis if such Contributor notifies You of the non-compliance by some reasonable means, this is the first time You have received notice of non-compliance with this License from such Contributor, and You become compliant prior to 30 days after Your receipt of the notice. +5.2. If You initiate litigation against any entity by asserting a patent infringement claim (excluding declaratory judgment actions, counter-claims, and cross-claims) alleging that a Contributor Version directly or indirectly infringes any patent, then the rights granted to You by any and all Contributors for the Covered Software under Section 2.1 of this License shall terminate. +5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user license agreements (excluding distributors and resellers) which have been validly granted by You or Your distributors under this License prior to termination shall survive termination. +************************************************************************ + +6. Disclaimer of Warranty + +* ------------------------- * + +Covered Software is provided under this License on an "as is" basis, without warranty of any kind, either expressed, implied, or statutory, including, without limitation, warranties that the Covered Software is free of defects, merchantable, fit for a particular purpose or non-infringing. The entire risk as to the quality and performance of the Covered Software is with You. Should any Covered Software prove defective in any respect, You (not any Contributor) assume the cost of any necessary servicing, repair, or correction. This disclaimer of warranty constitutes an essential part of this License. No use of any Covered Software is authorized under this License except under this disclaimer. + +************************************************************************ + +************************************************************************ + +7. Limitation of Liability + +* -------------------------- * + +Under no circumstances and under no legal theory, whether tort (including negligence), contract, or otherwise, shall any Contributor, or anyone who distributes Covered Software as permitted above, be liable to You for any direct, indirect, special, incidental, or consequential damages of any character including, without limitation, damages for lost profits, loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses, even if such party shall have been informed of the possibility of such damages. This limitation of liability shall not apply to liability for death or personal injury resulting from such party's negligence to the extent applicable law prohibits such limitation. Some jurisdictions do not allow the exclusion or limitation of incidental or consequential damages, so this exclusion and limitation may not apply to You. + +************************************************************************ + +8. Litigation + +------------- + +Any litigation relating to this License may be brought only in the courts of a jurisdiction where the defendant maintains its principal place of business and such litigation shall be governed by laws of that jurisdiction, without reference to its conflict-of-law provisions. Nothing in this Section shall prevent a party's ability to bring cross-claims or counter-claims. + +9. Miscellaneous + +---------------- + +This License represents the complete agreement concerning the subject matter hereof. If any provision of this License is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable. Any law or regulation which provides that the language of a contract shall be construed against the drafter shall not be used to construe this License against a Contributor. + +10. Versions of the License + +--------------------------- + +10.1. New Versions +Mozilla Foundation is the license steward. Except as provided in Section 10.3, no one other than the license steward has the right to modify or publish new versions of this License. Each version will be given a distinguishing version number. + +10.2. Effect of New Versions +You may distribute the Covered Software under the terms of the version of the License under which You originally received the Covered Software, or under the terms of any subsequent version published by the license steward. + +10.3. Modified Versions +If you create software not governed by this License, and you want to create a new license for such software, you may create and use a modified version of this License if you rename the license and remove any references to the name of the license steward (except to note that such modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary Licenses +If You choose to distribute Source Code Form that is Incompatible With Secondary Licenses under the terms of this version of the License, the notice described in Exhibit B of this License must be attached. \ No newline at end of file diff --git a/kauldron/typing/__init__.py b/kauldron/typing/__init__.py index ec28e7cd..836135a7 100644 --- a/kauldron/typing/__init__.py +++ b/kauldron/typing/__init__.py @@ -42,8 +42,9 @@ UInt8, XArray, ) -from kauldron.typing.shape_spec import Dim, Memo, Shape # pylint: disable=g-multiple-import,g-importing-member +from kauldron.typing.shape_spec import Dim, Shape # pylint: disable=g-multiple-import,g-importing-member from kauldron.typing.type_check import TypeCheckError, typechecked # pylint: disable=g-multiple-import,g-importing-member +from kauldron.typing.utils import Memo # pylint: disable=g-importing-member import numpy as np import typeguard as _typeguard # make typeguard.check_type accessible in this namespace diff --git a/kauldron/typing/shape_parser.py b/kauldron/typing/shape_parser.py new file mode 100644 index 00000000..ce8f45c5 --- /dev/null +++ b/kauldron/typing/shape_parser.py @@ -0,0 +1,311 @@ +# Copyright 2024 The kauldron Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A parser for shape specs.""" + +from __future__ import annotations + +import abc +import base64 +import dataclasses +import enum +import itertools +import math +import operator +import pickle +from typing import Any, Callable, Optional +import zlib + +from kauldron.typing import utils +from kauldron.utils import standalone_parser + +# Serialized parser from grammar in shape_spec.lark (lark==1.2.1) +DATA = b'eJztmFlvG8kRx3UMb+q+T96XdR+2aV2O4PVijZEow5Lgp8VgRI1JwhRJzJFIDwLyJMRAP3a+YD5Jqrs50n9lJZCDRYIFogf9WFM1XVd3s9l/Dfz9H1M98u+WF1mwY9qOZXPxOdq0ri3bqLZbX6QccS37qtEymw7/lRdvOev9E9d7nFteD+u9Cn0K/QqaQkAhqBBSCCtEFKIKMYW4woDCoMKQwrDCiGOxYKPWatsWBcMGaoZt1axr40vTrDkUFYt4jmVc3LiWw7/5mbg3HYuzKCXkWteuZzY5CxvyqWFwFjkSRu9Eth6LqTo8JB+wvabVTZz8j6owxhTGFSYUJhWmFKYVZhRmFeYU5hUWFBYVEgpJhZRCWiGjkFXIKeQVCgpFhZLCC4UlhWWFFYVVhTWFdYUNhU2FLYVthZcKrxReK5QV3ijsKOwq7CnsKxwovKXWBRzXtF2qZv3MvK+1ao/WNJs2rydZ9KN8rBpS75Wz021/tVqOaAi1OGp8PjU+VI4+VN5zvZdp2feVn7jexzQxT7nez7TTs8NPXNeY9svh6S9cDzDt5/PKO64HmWZdd2yuh1j/h8oZ18Ms9GfTNi4b9F6EBTrtv9Aa0KMsel756f2n03cnn8hFTMyWw8pJxdjgepxpRx/F6AMs0nbrVvflwXubda4P3QtbXB9mmum2yWSEBbyWad9wffRev8n1MaZVDo/JzTgLHH+onJ9yfYKFaVDD6VhVrk8y7ZN0OMWCp+fHxslHrk+z4MeTz/LjDAsenx/Jj7Ms8O7k+PiQ63Ns0DBkqY1O03NETPNsxDBMu2Y0G073KTlfeDAU/0WCi/UzPcHCvimtiSCpxHqS1Zf/9N7u/L2jBoDgic9FvY8sNknVTzwgasQlYoA4TAwS14kh4gtimNhDjBAHiFFijBgjviLGiTniALFIHCQuE4eIO8Rh4iqRRtEHiSPEEeIoMUQcI5aI48QUcYK4BaFPy9D7MLU1TG1N6vuFXoRaINXjFP0Q/JR8lyLVCUjND8EvhZ9aN2RP136kgCKRqf9hIf0CzcoCBW5ptpJm4Y47elDk4etfiGpOoTCNwgwKPSjMojAJgqeHblU8K8JbWHibIynajf4lRDcl7SPY3l1s767UR/+bM1c0LvI7FH5Shh4ToQuTaajfCqa4Iu3iws4v6T6WdF/qB7BEGXw/gw3MYM8y2KYMjpmRYw7iPHjjyBqAEENhFIVxFPpR6EFhDIUACiEUoigMoRBHQQPB04eemhL/aqWLqRDvTpHFH1nxw7fq4baYyCPYghS2IIUtSGELUtiCFLYgJdMYxTE3cMwNqR8T+nmyKMtxevQhGOK1NBn319f+E+trTJpMCJMpUu1CNAcYzYG0m6R0p4TtAtmOk1oMu/bEsBPSfBqDX8bgl6V+BvVZ1GexYFksWBYLlsUQs3LMWTHm4/6Kfr1+Rl+788PT5zCyTYxsU3qZx4KJfTMoni7geiliCkXcN4uYTxFXRRGTK2JyRel28ak+7aHdnrRLCLsZspuFwZOYRhKjS6LbJA6XlMMl/Rmk3X3/7elvs/4UENvjG9hu/Rr/6La7SJx8Yvt9vO0+d7udk6mkcHqIV2fglWdPjzROjyWs6xKWcglLuST9Zx77z/4n/rPoP4/+89jXPE60PEaWx8jyMrIcbpiPN0rRzPkfiPR+g8z7+9OejMufuXKf8vQCfqm9xZjeypiKmOci5rko9SVccCVMvYQLroR1KOGCK2FRShhASTp4gQ4K6KCADgrooIAOCuiggA4K0sGSPx/mutUdFU+XMe9VzHtVvrWC+i3Ub0n9qtBPk0UGQsmhXQ5zyWH4OYw4hxHn5Nhrf8RfCeMy9HXcPv2yzGNZ5qXdBpZ3G/XbUr/51PaaQLsEljeBFU1gRRNyuC10t4PD7Ej99nOOMglioFv5/N1vv/pG/s2KFZ0Ze+iAp7/EKV/GQ18ZD31lPPSV8dBXxkNfGRdDGQ99ZTz0lfHQV8ZDXxkPfWU89JXx0FeWlXrlH6b7wOs6VnRd2r3GiqdRn8bGpXFdpLGLaexiWo5ZVntZj94PqldS9cY/MoXv1JGp9+77I9OMNN3p/jLaEAfKXb/xibvvz6zP/LJ42Ir3ngrwpfS6/0f8/TQqQz/oFixNBfNY/P4+RF161M/oCMOiVuvyN88mHa+eZEPisqrRqv1si8vF1iX36tn/3xX+LneFoXbHbbTVrR8LXFoXXo1/E5dRdqPq0qehr5bVMcxm0+jeDn5jUde2LKPaNB2HV1igalbrFj0OddqO27SueaXeW//MAvJWmNcTLObaZsv50ravSK7Uz36V95Lhjt1o2w33hrNgi3TivjhiXl00ap58qJme2+YsIG+eafixjt3umDWaGgY5aqigKb7uNTpFeGFWv4o82OiVeXNBZk2zatXbzUvLFpaD1mXDNR5u2St1Ou3UF76xeNsmE4smnuU6/G8s3rjqtMWlnunWxZ00izltz65a8gFlHBYXe7WGrJmYndqRaX/l3uo/AduonHg=' +DATA = pickle.loads(zlib.decompress(base64.b64decode(DATA))) +MEMO = b'eJytWMtvG0UY9/uR1ElTCgUKtI0p2E7t9AGUtkAVpUGNNruO7FiVSKPRNt5kNl3vWvtoG+QiHlLUVgsCuj33wqUSEn8AiDM3br1x4w8oEidOzOyuvbOPeOyCFUX2eL/f4/tmvpnx5+lHfx+J2a+7VonB/8yUzHcEy0wuc2uWme3yui6osoW/Sd/iJQN9NVG6fHH9dPXCRnnOMtNbEr+tWRsoQuVvW5yZAbfFtg7RSImJP8jH3FdcMHMA6LtdAQDLzK86uI0lyzBzXVVUVFHftZgYLJiTa4LaEWVeuiJsWQYTR9QwYWaaLRbUVy2Ywp8zZh6JuD7Xu14tWzC3YcEJDk7alExcgAU4ZcBpjAYPGkzCRWBbKyRCASHMz/euV3rzvZMhlEQQJemirNavkSiJSsUNNVOzlcqshRxMuO6aumphhISDkHIQUh+3uMVB/GGkoqsq7R6y3Ovwd3qa0QmISTKpgRgXKu1CcQvs0gBqHtdlofrJRm+dr366UUYfnT9ysGdXrgfK5UrAMlmrgPeMQzcBrjXBMreyzC35sohLcbyXL6P5MDJk1oFMs8tcqzlAi1f7yUzOVnEuYZGIybmuV1YXGl5IyQsphULybkjDF1L2QsqhkAk3pLlGhlS8kEooZNLNT4u7stRoLtYbXn7iwAsEocADTiBaGQtcnQOnB2HJWq3WD0zPog+h0II/9AwxIwExI8OcU66/qwvNq57MoiezGAqZ9nOd9biKxOwvhhNz0B94jhBZJESGGWfc6bFYZ9kFT+UpT+WpUMwh3KQyqJNsi3bDwhWRePVmTRLuCCpqVWvKTUG2HuIm12itLKERTedV3WKaaI3FdauBF+8kp8j9DoSaU1640+VlTVRk3NEc1CkA7EDQlQwNF60AvzRcp22xA7SusEkODp7H/3GlBl/BPfu9YKYVtY00MjEzzUsir6FGmlW6OqLV7N47fVMQuoCXJKBjE5p138zaytpnrPtwmjOndaHTlXhdAJpiqJsCAiigEX0XiHJb3BQ0q4zNNQxJqLu4BhpI4QHUZV/AcnYdTV/YmuBXfb+uSOf9PeL9A0c8NJk4/JqD3+DRb+/D77Ag+D0HHyJKiGngI4M5/DwMewOGBJ3hxfEYMGqSjvrSMNRwNoK6U3SGI+MxYNQ0HfXlUVE9rRk66iuRqCQoj4CydKBXhwDtDYBydKCjkUAoNk+PfY2IvVuCnxHLl/mFyTAJpyeQ4swUWnaqNVAY87Hci2J5nWD5PYCWvYU6AiL1AON0wDeGAOYVHQoByAQd8lgwE45N9nGcKUSnAR1wOmOl4ThB8WdgQpJZ9Z/3zIktUUJcQDF01OkKZs5rziElbisy8zfQA+ouULoWRdSJkG8bjZ1JMIcifacNjDyW8VmC45+QcU994Jj6DLmFfxlh2vFdFoMuXTj2cYI5Em2zq9wWxpvmb3okO1PxWCzYc87DZ/cojpKysE3zcjLkxZHKxpLM0eipyuvKeFP1LcLKsbAVEtEtWv9W8D8W7e3Q1LRZ2R+TzIlIn/aVra/AdWpmRVl3usFwthJh+VyUZefWQvqDH/ormuPVbSCJmk4ecBb7z/QTkNoy5E2amvJ+alwxzr2HsIpanJnDN9dRvFaGew36CvYmv6ckfTrNBSs56PrsrRQz1z96+kUsB0T4LAuheRy5v50KEhO7A/triqlFU9dHoEbVlhU5kO1IEVUi23yolh6Zu/ntbOFnhkPWhkO2SEg8Mw7YWgdb7XDw+f3A3eys07KTHJfx9HA766QddJ7c2aVn6AzFBE8zgU6YZuGGqvDtTV4LdpBIyrPjUfq6lXv6HJPxHIVxm2YSnVPNGY9yxHK9Q6Ht0Ghzz0X7bnAxe/2WfZphLgX2hP0OVjMA9OOc++tZb4MaYWt8jzD/Q8Q+4T/QOs896T8XvipGcpzfj8MZ+qk/NOLZ9n0icXuBncT3kxaxzEZorhciYcPXzZ2fsV7/LkjDvkhgPxiCHfy9YcR94dIw+PA91rHx2/5VjCT5gKjik4hlou27y47o4sPxCPwzc0QTHw3n8MaiKHeekikb4ReMy/+J7Q+Sjf7LhlH7F8KzwHI=' +MEMO = pickle.loads(zlib.decompress(base64.b64decode(MEMO))) +_parser = standalone_parser.get_parser((DATA, MEMO)) + + +class _Priority(enum.IntEnum): + ADD = enum.auto() + MUL = enum.auto() + POW = enum.auto() + UNARY = enum.auto() + ATOM = enum.auto() + + +class DimSpec(abc.ABC): + + def evaluate(self, memo: utils.Memo) -> tuple[int, ...]: + raise NotImplementedError() + + @property + def priority(self) -> int: + return _Priority.ATOM + + +@dataclasses.dataclass(init=False) +class ShapeSpec: + """Parsed shape specification.""" + + dim_specs: tuple[DimSpec, ...] + + def __init__(self, *dim_specs: DimSpec): + self.dim_specs = tuple(dim_specs) + + def evaluate(self, memo: utils.Memo) -> tuple[int, ...]: + return tuple( + itertools.chain.from_iterable(s.evaluate(memo) for s in self.dim_specs) + ) + + def __repr__(self): + return ' '.join(repr(ds) for ds in self.dim_specs) + + +@dataclasses.dataclass +class IntDim(DimSpec): + value: int + broadcastable: bool = False + + def evaluate(self, memo: utils.Memo) -> tuple[int, ...]: + if self.broadcastable: + raise utils.ShapeError(f'Cannot evaluate a broadcastable dim: {self!r}') + return (self.value,) + + def __repr__(self): + prefix = '_' if self.broadcastable else '' + return prefix + str(self.value) + + +@dataclasses.dataclass +class SingleDim(DimSpec): + """Simple individual dimensions like "height", "_a" or "#c".""" + + name: Optional[str] = None + broadcastable: bool = False + anonymous: bool = False + + def evaluate(self, memo: utils.Memo) -> tuple[int, ...]: + if self.anonymous: + raise utils.ShapeError(f'Cannot evaluate anonymous dimension: {self!r}') + elif self.broadcastable: + raise utils.ShapeError( + f'Cannot evaluate a broadcastable dimension: {self!r}' + ) + elif self.name not in memo.single: + raise utils.ShapeError( + f'No value known for {self!r}. ' + f'Known values are: {sorted(memo.single.keys())}' + ) + else: + return (memo.single[self.name],) + + def __repr__(self): + return ( + ('#' if self.broadcastable else '') + + ('_' if self.anonymous else '') + + (self.name if self.name else '') + ) + + +@dataclasses.dataclass +class VariadicDim(DimSpec): + """Variable size dimension specs like "*batch" or "...".""" + + name: Optional[str] = None + anonymous: bool = False + broadcastable: bool = False + + def evaluate(self, memo: utils.Memo) -> tuple[int, ...]: + if self.anonymous: + raise utils.ShapeError(f'Cannot evaluate anonymous dimension: {self!r}') + if self.broadcastable: + raise utils.ShapeError( + f'Cannot evaluate a broadcastable variadic dimension: {self!r}' + ) + if self.name not in memo.variadic: + raise utils.ShapeError( + f'No value known for {self!r}. Known values are:' + f' {sorted(memo.variadic.keys())}' + ) + return memo.variadic[self.name] + + def __repr__(self): + if self.anonymous: + return '...' + if self.broadcastable: + return '*#' + self.name + else: + return '*' + self.name + + +BinOp = Callable[[Any, Any], Any] + + +@dataclasses.dataclass +class Operator: + symbol: str + fn: BinOp + priority: _Priority + + +OPERATORS = [ + Operator('+', operator.add, _Priority.ADD), + Operator('-', operator.sub, _Priority.ADD), + Operator('*', operator.mul, _Priority.MUL), + Operator('/', operator.truediv, _Priority.MUL), + Operator('//', operator.floordiv, _Priority.MUL), + Operator('%', operator.mod, _Priority.MUL), + Operator('**', operator.pow, _Priority.POW), +] + +SYMBOL_2_OPERATOR = {o.symbol: o for o in OPERATORS} + + +@dataclasses.dataclass +class FunctionDim(DimSpec): + """Function based dimension specs like "min(a,b)" or "sum(*batch).""" + + name: str + fn: Callable[..., int] + arguments: list[DimSpec] + + def evaluate(self, memo: utils.Memo) -> tuple[int]: # pylint: disable=g-one-element-tuple + vals = itertools.chain.from_iterable( + arg.evaluate(memo) for arg in self.arguments + ) + return (self.fn(vals),) + + def __repr__(self): + arg_list = ','.join(repr(a) for a in self.arguments) + return f'{self.name}({arg_list})' + + +NAME_2_FUNC = {'sum': sum, 'min': min, 'max': max, 'prod': math.prod} + + +@dataclasses.dataclass +class BinaryOpDim(DimSpec): + """Binary ops for dim specs such as "H*W" or "C+1".""" + + op: Operator + left: DimSpec + right: DimSpec + + def evaluate(self, memo: utils.Memo) -> tuple[int]: # pylint: disable=g-one-element-tuple + (left,) = self.left.evaluate(memo) # unpack tuple (has to be 1-dim) + (right,) = self.right.evaluate(memo) # unpack tuple (has to be 1-dim) + return (self.op.fn(left, right),) + + @property + def priority(self) -> int: + return self.op.priority + + def __repr__(self): + left_repr = ( + repr(self.left) + if self.priority < self.left.priority + else f'({self.left!r})' + ) + right_repr = ( + repr(self.right) + if self.priority < self.right.priority + else f'({self.right!r})' + ) + return f'{left_repr}{self.op.symbol}{right_repr}' + + +@dataclasses.dataclass +class NegDim(DimSpec): + """Negation of a dim spec, e.g. "-h".""" + + child: DimSpec + + def evaluate(self, memo: utils.Memo) -> tuple[int]: # pylint: disable=g-one-element-tuple + return (-self.child.evaluate(memo)[0],) + + @property + def priority(self) -> int: + return _Priority.UNARY + + def __repr__(self): + if self.priority < self.child.priority: + return f'-{self.child!r}' + else: + return f'-({self.child!r})' + + +class ShapeSpecTransformer(standalone_parser.Transformer): + """Transform a lark standalone_parser.Tree into a ShapeSpec.""" + + @staticmethod + def start(args: list[DimSpec]) -> ShapeSpec: + return ShapeSpec(*args) + + @staticmethod + def int_dim(args: list[Any]) -> IntDim: + return IntDim(value=int(args[0])) + + @staticmethod + def name_dim(args: list[Any]) -> SingleDim: + return SingleDim(name=args[0]) + + @staticmethod + def anon_dim(args: list[Any]) -> SingleDim: + name = args[0] if args else None + return SingleDim(name=name, anonymous=True) + + @staticmethod + def anon_var_dim(args: list[Any]) -> VariadicDim: + name = args[0] if args else None + return VariadicDim(name=name, anonymous=True) + + @staticmethod + def var_dim(args: list[Any]) -> VariadicDim: + return VariadicDim(name=args[0]) + + @staticmethod + def broadcast_dim(args: list[Any]) -> DimSpec: + try: + return IntDim(value=int(args[0]), broadcastable=True) + except ValueError: + return SingleDim(name=args[0], broadcastable=True) + + @staticmethod + def broadcast_var_dim(args: list[Any]) -> VariadicDim: + return VariadicDim(name=args[0], broadcastable=True) + + @staticmethod + def binary_op(args: list[Any]) -> BinaryOpDim: + left, op, right = args + return BinaryOpDim(left=left, right=right, op=SYMBOL_2_OPERATOR[str(op)]) + + @staticmethod + def neg(args: list[Any]) -> NegDim: + return NegDim(child=args[0]) + + @staticmethod + def func(args: list[Any]) -> FunctionDim: + name, arguments = args + return FunctionDim(name=name, fn=NAME_2_FUNC[name], arguments=arguments) + + @staticmethod + def arg_list(args: list[Any]) -> list[Any]: + return args + + +def parse(spec: str) -> ShapeSpec: + tree = _parser.parse(spec) + return ShapeSpecTransformer().transform(tree) diff --git a/kauldron/typing/shape_spec.lark b/kauldron/typing/shape_spec.lark new file mode 100644 index 00000000..ce6863d7 --- /dev/null +++ b/kauldron/typing/shape_spec.lark @@ -0,0 +1,81 @@ +// To generate the serialized parser run: +// python -m lark.tools.standalone -c shape_spec.lark > tmp.py +// then copy the DATA and MEMO lines from the end of the file into shape_parser.py +// IMPORTANT: make sure to use lark==1.2.1 + +// shape_spec is a list of dim_specs separated by whitespace +// e.g. "*b h w//2 3" +start: _WS_INLINE* dim_spec (_WS_INLINE+ dim_spec)* _WS_INLINE* + | _WS_INLINE* // allow empty + +?dim_spec: expr + | var_dim + | other_dim + +// Dim expressions are sub-structured into term, factor, unary, power, and atom +// to account for operator precedence: +// expr (lowest precedence): sum operations (+, -) +?expr: term + | expr SUM_OP term -> binary_op +SUM_OP: "+" | "-" + +// multiplication operations (*, /, //, %) +?term: unary + | term MUL_OP unary -> binary_op +MUL_OP: "*" | "/" | "//" | "%" + +// unary operators (we only support "-", not "+" or "~") +?unary: power + | "-" unary -> neg + +// raising a value to the power of another (**) +?power: atom + | atom POW_OP unary -> binary_op +POW_OP.2: "**" + +// atoms (highest precedence): include ints, named dims, parenthesized +// expressions, and functions. +?atom: INT -> int_dim + | FUNC "(" arg_list ")" -> func + | NAME -> name_dim + | "(" expr ")" + + +FUNC.2: "min" | "max" | "sum" | "prod" + + +// named variadic dim spec (can be part of a function) +var_dim: "*" NAME + +// Other dim specs (cannot be part of an expression) +other_dim: "_" NAME? -> anon_dim + | "..." -> anon_var_dim + | "*_" NAME? -> anon_var_dim + | "#" NAME -> broadcast_dim + | "#" INT -> broadcast_dim + | "#*" NAME -> broadcast_var_dim + | "*#" NAME -> broadcast_var_dim + +// argument list for min, max, sum etc. can be either +// - a single variadic dim e.g. min(*channel) +// - a list of at least two normal dims e.g. min(a,b,c) +// (but not a single normal dim like min(a)) +// - a combination: e.g. sum(a,*b) +?arg_list: expr ("," (expr | var_dim))+ + | var_dim ("," (expr | var_dim))* + +// TODO: maybe add composition to atom? +// composition: "(" name_dim (_WS_INLINE (name_dim | var_dim))+ ")" +// | "(" var_dim (_WS_INLINE (name_dim | var_dim))* ")" + + + +// dimension names consist of letters, digits and underscores but have to start +// with a letter (underscores are used to indicate anonymous dims) +NAME: LETTER ("_"|LETTER|DIGIT)* + +_WS_INLINE: (" "|/\t/)+ + +%import common.INT +%import common.LETTER +%import common.DIGIT \ No newline at end of file diff --git a/kauldron/typing/shape_spec.py b/kauldron/typing/shape_spec.py index d06b5fa2..739a592f 100644 --- a/kauldron/typing/shape_spec.py +++ b/kauldron/typing/shape_spec.py @@ -16,19 +16,12 @@ from __future__ import annotations -import abc -import dataclasses -import enum import inspect -import itertools -import math -import operator import sys import typing -from typing import Any, Callable, List, Optional -import jaxtyping -import lark +from kauldron.typing import shape_parser +from kauldron.typing import utils if typing.TYPE_CHECKING: @@ -52,7 +45,7 @@ def foo(x: Float["*b h w c"], y: Float["h w c"]): def __new__(cls, spec_str: str) -> tuple[int, ...]: _assert_caller_is_typechecked_func() spec = parse_shape_spec(spec_str) - memo = Memo.from_current_context() + memo = utils.Memo.from_current_context() return spec.evaluate(memo) @@ -92,392 +85,14 @@ def Dim(spec_str: str) -> int: # pylint: disable=invalid-name """Helper to construct concrete Dim (for single-axis Shape).""" _assert_caller_is_typechecked_func() spec = parse_shape_spec(spec_str) - memo = Memo.from_current_context() + memo = utils.Memo.from_current_context() ret = spec.evaluate(memo) if len(ret) != 1: - raise ShapeError( + raise utils.ShapeError( f"Dim expects a single-axis string, but got : {ret!r}" ) return ret[0] # pytype: disable=bad-return-type -# try grammar online: https://www.lark-parser.org/ide/# -shape_parser = lark.Lark( - start="shape_spec", - regex=True, - grammar=r""" -// shape_spe is a list of dim_specs separated by whitespace -// e.g. "*b h w//2 3" -shape_spec: (_WS_INLINE* dim_spec)? (_WS_INLINE+ dim_spec)* - -?dim_spec: expr - | var_dim - | other_dim - -// Dim expressions are sub-structured into term, factor, unary, power, and atom -// to account for operator precedence: -// expr (lowest precedence): sum operations (+, -) -?expr: term - | expr SUM_OP term -> binary_op -SUM_OP: "+" | "-" - -// multiplication operations (*, /, //, %) -?term: unary - | term MUL_OP unary -> binary_op -MUL_OP: "*" | "/" | "//" | "%" - -// unary operators (we only support "-", not "+" or "~") -?unary: power - | "-" unary -> neg - -// raising a value to the power of another (**) -?power: atom - | atom POW_OP unary -> binary_op -POW_OP: "**" - -// atoms (highest precedence): include ints, named dims, parenthesized -// expressions, and functions. -?atom: INT -> int_dim - | NAME -> name_dim - | "(" expr ")" - | FUNC "(" arg_list ")" -> func - -FUNC: "min" | "max" | "sum" | "prod" - - -// named variadic dim spec (can be part of a function) -var_dim: "*" NAME - -// Other dim specs (cannot be part of an expression) -other_dim: "_" NAME? -> anon_dim - | "..." -> anon_var_dim - | "*_" NAME? -> anon_var_dim - | "#" NAME -> broadcast_dim - | "#" INT -> broadcast_dim - | "#*" NAME -> broadcast_var_dim - | "*#" NAME -> broadcast_var_dim - -// argument list for min, max, sum etc. can be either -// - a single variadic dim e.g. min(*channel) -// - a list of at least two normal dims e.g. min(a,b,c) -// (but not a single normal dim like min(a)) -// - a combination: e.g. sum(a,*b) -?arg_list: expr ("," (expr | var_dim))+ - | var_dim ("," (expr | var_dim))* - -// TODO: maybe add composition to atom? -// composition: "(" name_dim (_WS_INLINE (name_dim | var_dim))+ ")" -// | "(" var_dim (_WS_INLINE (name_dim | var_dim))* ")" - - - -// dimension names consist of letters, digits and underscores but have to start -// with a letter (underscores are used to indicate anonymous dims) -NAME: LETTER ("_"|LETTER|DIGIT)* - -_WS_INLINE: (" "|/\t/)+ - -%import common.INT -%import common.LETTER -%import common.DIGIT -""", -) - - -class ShapeError(ValueError): - pass - - -class _Priority(enum.IntEnum): - ADD = enum.auto() - MUL = enum.auto() - POW = enum.auto() - UNARY = enum.auto() - ATOM = enum.auto() - - -class DimSpec(abc.ABC): - - def evaluate(self, memo: Memo) -> tuple[int, ...]: - raise NotImplementedError() - - @property - def priority(self) -> int: - return _Priority.ATOM - - -@dataclasses.dataclass(init=False) -class ShapeSpec: - """Parsed shape specification.""" - - dim_specs: tuple[DimSpec, ...] - - def __init__(self, *dim_specs: DimSpec): - self.dim_specs = tuple(dim_specs) - - def evaluate(self, memo: Memo) -> tuple[int, ...]: - return tuple( - itertools.chain.from_iterable(s.evaluate(memo) for s in self.dim_specs) - ) - - def __repr__(self): - return " ".join(repr(ds) for ds in self.dim_specs) - - -@dataclasses.dataclass -class IntDim(DimSpec): - value: int - broadcastable: bool = False - - def evaluate(self, memo: Memo) -> tuple[int, ...]: - if self.broadcastable: - raise ShapeError(f"Cannot evaluate a broadcastable dim: {self!r}") - return (self.value,) - - def __repr__(self): - prefix = "_" if self.broadcastable else "" - return prefix + str(self.value) - - -@dataclasses.dataclass -class SingleDim(DimSpec): - """Simple individual dimensions like "height", "_a" or "#c".""" - - name: Optional[str] = None - broadcastable: bool = False - anonymous: bool = False - - def evaluate(self, memo: Memo) -> tuple[int, ...]: - if self.anonymous: - raise ShapeError(f"Cannot evaluate anonymous dimension: {self!r}") - elif self.broadcastable: - raise ShapeError(f"Cannot evaluate a broadcastable dimension: {self!r}") - elif self.name not in memo.single: - raise ShapeError( - f"No value known for {self!r}. " - f"Known values are: {sorted(memo.single.keys())}" - ) - else: - return (memo.single[self.name],) - - def __repr__(self): - return ( - ("#" if self.broadcastable else "") - + ("_" if self.anonymous else "") - + (self.name if self.name else "") - ) - - -@dataclasses.dataclass -class VariadicDim(DimSpec): - """Variable size dimension specs like "*batch" or "...".""" - - name: Optional[str] = None - anonymous: bool = False - broadcastable: bool = False - - def evaluate(self, memo: Memo) -> tuple[int, ...]: - if self.anonymous: - raise ShapeError(f"Cannot evaluate anonymous dimension: {self!r}") - if self.broadcastable: - raise ShapeError( - f"Cannot evaluate a broadcastable variadic dimension: {self!r}" - ) - if self.name not in memo.variadic: - raise ShapeError( - f"No value known for {self!r}. Known values are:" - f" {sorted(memo.variadic.keys())}" - ) - return memo.variadic[self.name] - - def __repr__(self): - if self.anonymous: - return "..." - if self.broadcastable: - return "*#" + self.name - else: - return "*" + self.name - - -BinOp = Callable[[Any, Any], Any] - - -@dataclasses.dataclass -class Operator: - symbol: str - fn: BinOp - priority: _Priority - - -OPERATORS = [ - Operator("+", operator.add, _Priority.ADD), - Operator("-", operator.sub, _Priority.ADD), - Operator("*", operator.mul, _Priority.MUL), - Operator("/", operator.truediv, _Priority.MUL), - Operator("//", operator.floordiv, _Priority.MUL), - Operator("%", operator.mod, _Priority.MUL), - Operator("**", operator.pow, _Priority.POW), -] - -SYMBOL_2_OPERATOR = {o.symbol: o for o in OPERATORS} - - -@dataclasses.dataclass -class Memo: - """Jaxtyping information about the shapes in the current scope.""" - - single: dict[str, int] - variadic: dict[str, tuple[int, ...]] - - @classmethod - def from_current_context(cls): - """Create a Memo from the current typechecking context.""" - single_memo, variadic_memo, *_ = jaxtyping._storage.get_shape_memo() # pylint: disable=protected-access - - variadic_memo = {k: tuple(dims) for k, (_, dims) in variadic_memo.items()} - return cls( - single=single_memo.copy(), - variadic=variadic_memo.copy(), - ) - - def __repr__(self) -> str: - out = {k: v for k, v in self.single.items()} - out.update({f"*{k}": v for k, v in self.variadic.items()}) - return repr(out) - - -@dataclasses.dataclass -class FunctionDim(DimSpec): - """Function based dimension specs like "min(a,b)" or "sum(*batch).""" - - name: str - fn: Callable[..., int] - arguments: list[DimSpec] - - def evaluate(self, memo: Memo) -> tuple[int]: # pylint: disable=g-one-element-tuple - vals = itertools.chain.from_iterable( - arg.evaluate(memo) for arg in self.arguments - ) - return (self.fn(vals),) - - def __repr__(self): - arg_list = ",".join(repr(a) for a in self.arguments) - return f"{self.name}({arg_list})" - - -NAME_2_FUNC = {"sum": sum, "min": min, "max": max, "prod": math.prod} - - -@dataclasses.dataclass -class BinaryOpDim(DimSpec): - """Binary ops for dim specs such as "H*W" or "C+1".""" - - op: Operator - left: DimSpec - right: DimSpec - - def evaluate(self, memo: Memo) -> tuple[int]: # pylint: disable=g-one-element-tuple - (left,) = self.left.evaluate(memo) # unpack tuple (has to be 1-dim) - (right,) = self.right.evaluate(memo) # unpack tuple (has to be 1-dim) - return (self.op.fn(left, right),) - - @property - def priority(self) -> int: - return self.op.priority - - def __repr__(self): - left_repr = ( - repr(self.left) - if self.priority < self.left.priority - else f"({self.left!r})" - ) - right_repr = ( - repr(self.right) - if self.priority < self.right.priority - else f"({self.right!r})" - ) - return f"{left_repr}{self.op.symbol}{right_repr}" - - -@dataclasses.dataclass -class NegDim(DimSpec): - """Negation of a dim spec, e.g. "-h".""" - - child: DimSpec - - def evaluate(self, memo: Memo) -> tuple[int]: # pylint: disable=g-one-element-tuple - return (-self.child.evaluate(memo)[0],) - - @property - def priority(self) -> int: - return _Priority.UNARY - - def __repr__(self): - if self.priority < self.child.priority: - return f"-{self.child!r}" - else: - return f"-({self.child!r})" - - -class ShapeSpecTransformer(lark.Transformer): - """Transform a lark.Tree into a ShapeSpec.""" - - @staticmethod - def shape_spec(args: List[DimSpec]) -> ShapeSpec: - return ShapeSpec(*args) - - @staticmethod - def int_dim(args: List[Any]) -> IntDim: - return IntDim(value=int(args[0])) - - @staticmethod - def name_dim(args: List[Any]) -> SingleDim: - return SingleDim(name=args[0]) - - @staticmethod - def anon_dim(args: List[Any]) -> SingleDim: - name = args[0] if args else None - return SingleDim(name=name, anonymous=True) - - @staticmethod - def anon_var_dim(args: List[Any]) -> VariadicDim: - name = args[0] if args else None - return VariadicDim(name=name, anonymous=True) - - @staticmethod - def var_dim(args: List[Any]) -> VariadicDim: - return VariadicDim(name=args[0]) - - @staticmethod - def broadcast_dim(args: List[Any]) -> DimSpec: - try: - return IntDim(value=int(args[0]), broadcastable=True) - except ValueError: - return SingleDim(name=args[0], broadcastable=True) - - @staticmethod - def broadcast_var_dim(args: List[Any]) -> VariadicDim: - return VariadicDim(name=args[0], broadcastable=True) - - @staticmethod - def binary_op(args: List[Any]) -> BinaryOpDim: - left, op, right = args - return BinaryOpDim(left=left, right=right, op=SYMBOL_2_OPERATOR[str(op)]) - - @staticmethod - def neg(args: List[Any]) -> NegDim: - return NegDim(child=args[0]) - - @staticmethod - def func(args: List[Any]) -> FunctionDim: - name, arguments = args - return FunctionDim(name=name, fn=NAME_2_FUNC[name], arguments=arguments) - - @staticmethod - def arg_list(args: List[Any]) -> List[Any]: - return args - - -def parse_shape_spec(spec: str) -> ShapeSpec: - tree = shape_parser.parse(spec) - return ShapeSpecTransformer().transform(tree) +def parse_shape_spec(spec: str) -> shape_parser.ShapeSpec: + return shape_parser.parse(spec) diff --git a/kauldron/typing/shape_spec_test.py b/kauldron/typing/shape_spec_test.py index 109d5631..c4ee7fcd 100644 --- a/kauldron/typing/shape_spec_test.py +++ b/kauldron/typing/shape_spec_test.py @@ -14,20 +14,19 @@ # pylint: disable=g-importing-member from kauldron.typing import Float, Shape, typechecked # pylint: disable=g-multiple-import -from kauldron.typing.shape_spec import ( # pylint: disable=g-multiple-import +from kauldron.typing.shape_parser import ( # pylint: disable=g-multiple-import BinaryOpDim, - Dim, FunctionDim, IntDim, - Memo, NAME_2_FUNC, NegDim, SYMBOL_2_OPERATOR, ShapeSpec, SingleDim, VariadicDim, - parse_shape_spec, ) +from kauldron.typing.shape_spec import Dim, parse_shape_spec # pylint: disable=g-multiple-import +from kauldron.typing.utils import Memo import numpy as np import pytest diff --git a/kauldron/typing/type_check.py b/kauldron/typing/type_check.py index d9ac4244..d3b01983 100644 --- a/kauldron/typing/type_check.py +++ b/kauldron/typing/type_check.py @@ -29,6 +29,7 @@ from etils import epy import jaxtyping from kauldron.typing import shape_spec +from kauldron.typing import utils import typeguard @@ -49,7 +50,7 @@ def __init__( return_value: Any, annotations: dict[str, Any], return_annotation: Any, - memo: shape_spec.Memo, + memo: utils.Memo, ): super().__init__(message) self.arguments = arguments @@ -136,7 +137,7 @@ def _reraise_with_shape_info(*args, _typecheck: bool = True, **kwargs): return_value=retval, annotations=annotations, return_annotation=sig.return_annotation, - memo=shape_spec.Memo.from_current_context(), + memo=utils.Memo.from_current_context(), ) from e return _reraise_with_shape_info diff --git a/kauldron/typing/utils.py b/kauldron/typing/utils.py new file mode 100644 index 00000000..8620d38b --- /dev/null +++ b/kauldron/typing/utils.py @@ -0,0 +1,46 @@ +# Copyright 2024 The kauldron Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shape-spec related utilities.""" + +import dataclasses +import jaxtyping + + +class ShapeError(ValueError): + pass + + +@dataclasses.dataclass +class Memo: + """Jaxtyping information about the shapes in the current scope.""" + + single: dict[str, int] + variadic: dict[str, tuple[int, ...]] + + @classmethod + def from_current_context(cls): + """Create a Memo from the current typechecking context.""" + single_memo, variadic_memo, *_ = jaxtyping._storage.get_shape_memo() # pylint: disable=protected-access + + variadic_memo = {k: tuple(dims) for k, (_, dims) in variadic_memo.items()} + return cls( + single=single_memo.copy(), + variadic=variadic_memo.copy(), + ) + + def __repr__(self) -> str: + out = {k: v for k, v in self.single.items()} + out.update({f'*{k}': v for k, v in self.variadic.items()}) + return repr(out) diff --git a/kauldron/utils/standalone_parser.py b/kauldron/utils/standalone_parser.py new file mode 100644 index 00000000..ec4cac19 --- /dev/null +++ b/kauldron/utils/standalone_parser.py @@ -0,0 +1,3390 @@ +# Copyright 2024 The kauldron Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Standalone lark LARL parser. + +The file was automatically generated by Lark v1.2.1 and +then adapted by klausg@google.com +""" + +# pytype: skip-file +__version__ = '1.2.1' + +# +# +# Lark Stand-alone Generator Tool +# ---------------------------------- +# Generates a stand-alone LALR(1) parser +# +# Git: https://github.com/erezsh/lark +# Author: Erez Shinan (erezshin@gmail.com) +# +# +# >>> LICENSE +# +# This tool and its generated code use a separate license from Lark, +# and are subject to the terms of the Mozilla Public License, v. 2.0. +# If a copy of the MPL was not distributed with this +# file, You can obtain one at https://mozilla.org/MPL/2.0/. +# +# If you wish to purchase a commercial license for this tool and its +# generated code, you may contact me via email or otherwise. +# +# If MPL2 is incompatible with your free or open-source project, +# contact me and we'll work it out. +# +# +# pylint: disable=missing-class-docstring,missing-function-docstring,g-multiple-import,g-importing-member,g-bare-generic,invalid-name,protected-access,g-bad-exception-name,raise-missing-from + +from abc import ABC, abstractmethod +import contextlib +import copy +from functools import partial, update_wrapper, wraps +from inspect import getmembers, getmro +from itertools import product +import re +from types import ModuleType +from typing import ( + Any, + Callable, + ClassVar, + Collection, + Dict, + FrozenSet, + Generic, + Iterable, + Iterator, + List, + Literal, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeAlias, + TypeVar, + Union, + cast, + overload, +) + +# pylint: disable=g-import-not-at-top,deprecated-module,g-bad-import-order +# pytype: disable=import-error +# if sys.version_info >= (3, 11): # TODO(klausg): version check should work +try: + import re._constants as sre_constants + import re._parser as sre_parse +except ImportError: + import sre_constants + import sre_parse +# pylint: enable=g-import-not-at-top,deprecated-module,g-bad-import-order +# pytype: enable=import-error + + +class LarkError(Exception): + pass + + +class ConfigurationError(LarkError, ValueError): + pass + + +def assert_config(value, options: Collection, msg='Got %r, expected one of %s'): + if value not in options: + raise ConfigurationError(msg % (value, options)) + + +class GrammarError(LarkError): + pass + + +class ParseError(LarkError): + pass + + +class LexError(LarkError): + pass + + +T = TypeVar('T') + + +class UnexpectedInput(LarkError): + # -- + line: int + column: int + pos_in_stream = None + state: Any + _terminals_by_name = None + interactive_parser: 'InteractiveParser' + + def get_context(self, text: str, span: int = 40) -> str: + # -- + assert self.pos_in_stream is not None, self + pos = self.pos_in_stream + start = max(pos - span, 0) + end = pos + span + if not isinstance(text, bytes): + before = text[start:pos].rsplit('\n', 1)[-1] + after = text[pos:end].split('\n', 1)[0] + return before + after + '\n' + ' ' * len(before.expandtabs()) + '^\n' + else: + before = text[start:pos].rsplit(b'\n', 1)[-1] + after = text[pos:end].split(b'\n', 1)[0] + return ( + before + after + b'\n' + b' ' * len(before.expandtabs()) + b'^\n' + ).decode('ascii', 'backslashreplace') + + def match_examples( + self, + parse_fn: 'Callable[[str], Tree]', + examples: Union[ + Mapping[T, Iterable[str]], Iterable[Tuple[T, Iterable[str]]] + ], + token_type_match_fallback: bool = False, + use_accepts: bool = True, + ) -> Optional[T]: + # -- + assert self.state is not None, 'Not supported for this exception' + + if isinstance(examples, Mapping): + examples = examples.items() + + candidate = (None, False) + for label, example in examples: + assert not isinstance(example, str), 'Expecting a list' + + for malformed in example: + try: + parse_fn(malformed) + except UnexpectedInput as ut: + if ut.state == self.state: + if ( + use_accepts + and isinstance(self, UnexpectedToken) + and isinstance(ut, UnexpectedToken) + and ut.accepts != self.accepts + ): + continue + if isinstance( + self, (UnexpectedToken, UnexpectedEOF) + ) and isinstance(ut, (UnexpectedToken, UnexpectedEOF)): + if ut.token == self.token: ## + return label + + if token_type_match_fallback: + ## + + if (ut.token.type == self.token.type) and not candidate[-1]: + candidate = label, True + + if candidate[0] is None: + candidate = label, False + + return candidate[0] + + def _format_expected(self, expected): + if self._terminals_by_name: + d = self._terminals_by_name + expected = [ + d[t_name].user_repr() if t_name in d else t_name + for t_name in expected + ] + return 'Expected one of: \n\t* %s\n' % '\n\t* '.join(expected) + + +class UnexpectedEOF(ParseError, UnexpectedInput): + # -- + expected: 'list[Token]' + + def __init__(self, expected, state=None, terminals_by_name=None): + super(UnexpectedEOF, self).__init__() + + self.expected = expected + self.state = state + + self.token = Token('', '') ## + + self.pos_in_stream = -1 + self.line = -1 + self.column = -1 + self._terminals_by_name = terminals_by_name + + def __str__(self): + message = 'Unexpected end-of-input. ' + message += self._format_expected(self.expected) + return message + + +class UnexpectedCharacters(LexError, UnexpectedInput): + # -- + + allowed: Set[str] + considered_tokens: Set[Any] + + def __init__( + self, + seq, + lex_pos, + line, + column, + allowed=None, + considered_tokens=None, + state=None, + token_history=None, + terminals_by_name=None, + considered_rules=None, + ): + super(UnexpectedCharacters, self).__init__() + + ## + + self.line = line + self.column = column + self.pos_in_stream = lex_pos + self.state = state + self._terminals_by_name = terminals_by_name + + self.allowed = allowed + self.considered_tokens = considered_tokens + self.considered_rules = considered_rules + self.token_history = token_history + + if isinstance(seq, bytes): + self.char = seq[lex_pos : lex_pos + 1].decode('ascii', 'backslashreplace') + else: + self.char = seq[lex_pos] + self._context = self.get_context(seq) + + def __str__(self): + message = ( + "No terminal matches '%s' in the current parser context, at line %d" + ' col %d' % (self.char, self.line, self.column) + ) + message += '\n\n' + self._context + if self.allowed: + message += self._format_expected(self.allowed) + if self.token_history: + message += '\nPrevious tokens: %s\n' % ', '.join( + repr(t) for t in self.token_history + ) + return message + + +class UnexpectedToken(ParseError, UnexpectedInput): + # -- + + expected: Set[str] + considered_rules: Set[str] + + def __init__( + self, + token, + expected, + considered_rules=None, + state=None, + interactive_parser=None, + terminals_by_name=None, + token_history=None, + ): + super(UnexpectedToken, self).__init__() + + ## + + self.line = getattr(token, 'line', '?') + self.column = getattr(token, 'column', '?') + self.pos_in_stream = getattr(token, 'start_pos', None) + self.state = state + + self.token = token + self.expected = expected ## + + self._accepts = NO_VALUE + self.considered_rules = considered_rules + self.interactive_parser = interactive_parser + self._terminals_by_name = terminals_by_name + self.token_history = token_history + + @property + def accepts(self): # -> Set[str]: + if self._accepts is NO_VALUE: + self._accepts = ( + self.interactive_parser and self.interactive_parser.accepts() + ) + return self._accepts + + def __str__(self): + message = 'Unexpected token %r at line %s, column %s.\n%s' % ( + self.token, + self.line, + self.column, + self._format_expected(self.accepts or self.expected), + ) + if self.token_history: + message += 'Previous tokens: %r\n' % self.token_history + + return message + + +class VisitError(LarkError): + # -- + + obj: 'Union[Tree, Token]' + orig_exc: Exception + + def __init__(self, rule, obj, orig_exc): + message = 'Error trying to process rule "%s":\n\n%s' % (rule, orig_exc) + super(VisitError, self).__init__(message) + + self.rule = rule + self.obj = obj + self.orig_exc = orig_exc + + +class MissingVariableError(LarkError): + pass + + +NO_VALUE = object() + +T = TypeVar('T') + + +def classify( + seq: Iterable, + key: Optional[Callable] = None, + value: Optional[Callable] = None, +) -> Dict: + d: Dict[Any, Any] = {} + for item in seq: + k = key(item) if (key is not None) else item + v = value(item) if (value is not None) else item + try: + d[k].append(v) + except KeyError: + d[k] = [v] + return d + + +def _deserialize(data: Any, namespace: Dict[str, Any], memo: Dict) -> Any: + if isinstance(data, dict): + if '__type__' in data: ## + + class_ = namespace[data['__type__']] + return class_.deserialize(data, memo) + elif '@' in data: + return memo[data['@']] + return { + key: _deserialize(value, namespace, memo) for key, value in data.items() + } + elif isinstance(data, list): + return [_deserialize(value, namespace, memo) for value in data] + return data + + +_T = TypeVar('_T', bound='Serialize') + + +class Serialize: + # -- + + @classmethod + def deserialize( + cls: Type[_T], data: Dict[str, Any], memo: Dict[int, Any] + ) -> _T: + namespace = getattr(cls, '__serialize_namespace__', []) + namespace = {c.__name__: c for c in namespace} + + fields = getattr(cls, '__serialize_fields__') + + if '@' in data: + return memo[data['@']] + + inst = cls.__new__(cls) + for f in fields: + try: + setattr(inst, f, _deserialize(data[f], namespace, memo)) + except KeyError as e: + raise KeyError('Cannot find key for class', cls, e) from e + + if hasattr(inst, '_deserialize'): + inst._deserialize() + + return inst + + +class Enumerator(Serialize): + + def __init__(self) -> None: + self.enums: Dict[Any, int] = {} + + def get(self, item) -> int: + if item not in self.enums: + self.enums[item] = len(self.enums) + return self.enums[item] + + def __len__(self): + return len(self.enums) + + def reversed(self) -> Dict[int, Any]: + r = {v: k for k, v in self.enums.items()} + assert len(r) == len(self.enums) + return r + + +class SerializeMemoizer(Serialize): + # -- + + __serialize_fields__ = ('memoized',) + + def __init__(self, types_to_memoize: List) -> None: + self.types_to_memoize = tuple(types_to_memoize) + self.memoized = Enumerator() + + def in_types(self, value: Serialize) -> bool: + return isinstance(value, self.types_to_memoize) + + @classmethod + def deserialize( + cls, data: Dict[int, Any], namespace: Dict[str, Any], memo: Dict[Any, Any] + ) -> Dict[int, Any]: ## + + return _deserialize(data, namespace, memo) + + +categ_pattern = re.compile(r'\\p{[A-Za-z_]+}') + + +def get_regexp_width(expr: str) -> Union[Tuple[int, int], List[int]]: + if re.search(categ_pattern, expr): + raise ImportError( + '`regex` module must be installed in order to use Unicode categories.', + expr, + ) + regexp_final = expr + try: + return [int(x) for x in sre_parse.parse(regexp_final).getwidth()] + except sre_constants.error: + raise ValueError(expr) + + +class Meta: + + empty: bool + line: int + column: int + start_pos: int + end_line: int + end_column: int + end_pos: int + orig_expansion: 'List[TerminalDef]' + match_tree: bool + + def __init__(self): + self.empty = True + + +_Leaf_T = TypeVar('_Leaf_T') +Branch = Union[_Leaf_T, 'Tree[_Leaf_T]'] + + +class Tree(Generic[_Leaf_T]): + # -- + + data: str + children: 'List[Branch[_Leaf_T]]' + + def __init__( + self, + data: str, + children: 'List[Branch[_Leaf_T]]', + meta: Optional[Meta] = None, + ) -> None: + self.data = data + self.children = children + self._meta = meta + + @property + def meta(self) -> Meta: + if self._meta is None: + self._meta = Meta() + return self._meta + + def __repr__(self): + return 'Tree(%r, %r)' % (self.data, self.children) + + def _pretty_label(self): + return self.data + + def _pretty(self, level, indent_str): + yield f'{indent_str*level}{self._pretty_label()}' + if len(self.children) == 1 and not isinstance(self.children[0], Tree): + yield f'\t{self.children[0]}\n' + else: + yield '\n' + for n in self.children: + if isinstance(n, Tree): + yield from n._pretty(level + 1, indent_str) + else: + yield f'{indent_str*(level+1)}{n}\n' + + def pretty(self, indent_str: str = ' ') -> str: + # -- + return ''.join(self._pretty(0, indent_str)) + + def __eq__(self, other): + try: + return self.data == other.data and self.children == other.children + except AttributeError: + return False + + def __ne__(self, other): + return not (self == other) + + def __hash__(self) -> int: + return hash((self.data, tuple(self.children))) + + def iter_subtrees(self) -> 'Iterator[Tree[_Leaf_T]]': + # -- + queue = [self] + subtrees = dict() + for subtree in queue: + subtrees[id(subtree)] = subtree + queue += [ + c + for c in reversed(subtree.children) + if isinstance(c, Tree) and id(c) not in subtrees + ] + + del queue + return reversed(list(subtrees.values())) + + def iter_subtrees_topdown(self): + # -- + stack = [self] + stack_append = stack.append + stack_pop = stack.pop + while stack: + node = stack_pop() + if not isinstance(node, Tree): + continue + yield node + for child in reversed(node.children): + stack_append(child) + + def find_pred( + self, pred: 'Callable[[Tree[_Leaf_T]], bool]' + ) -> 'Iterator[Tree[_Leaf_T]]': + # -- + return filter(pred, self.iter_subtrees()) + + def find_data(self, data: str) -> 'Iterator[Tree[_Leaf_T]]': + # -- + return self.find_pred(lambda t: t.data == data) + + +_Return_T = TypeVar('_Return_T') +_Return_V = TypeVar('_Return_V') +_Leaf_T = TypeVar('_Leaf_T') +_Leaf_U = TypeVar('_Leaf_U') +_R = TypeVar('_R') +_FUNC = Callable[..., _Return_T] +_DECORATED = Union[_FUNC, type] + + +class _DiscardType: + # -- + + def __repr__(self): + return 'lark.visitors.Discard' + + +Discard = _DiscardType() + +## + + +class _Decoratable: + # -- + + @classmethod + def _apply_v_args(cls, visit_wrapper): + mro = getmro(cls) + assert mro[0] is cls + libmembers = {name for _cls in mro[1:] for name, _ in getmembers(_cls)} + for name, value in getmembers(cls): + + ## + + if name.startswith('_') or ( + name in libmembers and name not in cls.__dict__ + ): + continue + if not callable(value): + continue + + ## + + if isinstance(cls.__dict__[name], _VArgsWrapper): + continue + + setattr(cls, name, _VArgsWrapper(cls.__dict__[name], visit_wrapper)) + return cls + + def __class_getitem__(cls, _): + return cls + + +class Transformer(_Decoratable, ABC, Generic[_Leaf_T, _Return_T]): + # -- + __visit_tokens__ = True ## + + def __init__(self, visit_tokens: bool = True) -> None: + self.__visit_tokens__ = visit_tokens + + def _call_userfunc(self, tree, new_children=None): + ## + + children = new_children if new_children is not None else tree.children + try: + f = getattr(self, tree.data) + except AttributeError: + return self.__default__(tree.data, children, tree.meta) + else: + try: + wrapper = getattr(f, 'visit_wrapper', None) + if wrapper is not None: + return f.visit_wrapper(f, tree.data, children, tree.meta) + else: + return f(children) + except GrammarError: + raise + except Exception as e: + raise VisitError(tree.data, tree, e) from e + + def _call_userfunc_token(self, token): + try: + f = getattr(self, token.type) + except AttributeError: + return self.__default_token__(token) + else: + try: + return f(token) + except GrammarError: + raise + except Exception as e: + raise VisitError(token.type, token, e) from e + + def _transform_children(self, children): + for c in children: + if isinstance(c, Tree): + res = self._transform_tree(c) + elif self.__visit_tokens__ and isinstance(c, Token): + res = self._call_userfunc_token(c) + else: + res = c + + if res is not Discard: + yield res + + def _transform_tree(self, tree): + children = list(self._transform_children(tree.children)) + return self._call_userfunc(tree, children) + + def transform(self, tree: Tree[_Leaf_T]) -> _Return_T: + # -- + res = list(self._transform_children([tree])) + if not res: + return None ## + + assert len(res) == 1 + return res[0] + + def __mul__( + self: 'Transformer[_Leaf_T, Tree[_Leaf_U]]', + other: 'Union[Transformer[_Leaf_U, _Return_V], TransformerChain[_Leaf_U, _Return_V,]]', + ) -> 'TransformerChain[_Leaf_T, _Return_V]': + # -- + return TransformerChain(self, other) + + def __default__(self, data, children, meta): + # -- + return Tree(data, children, meta) + + def __default_token__(self, token): + # -- + return token + + +def merge_transformers(base_transformer=None, **transformers_to_merge): + # -- + if base_transformer is None: + base_transformer = Transformer() + for prefix, transformer in transformers_to_merge.items(): + for method_name in dir(transformer): + method = getattr(transformer, method_name) + if not callable(method): + continue + if method_name.startswith('_') or method_name == 'transform': + continue + prefixed_method = prefix + '__' + method_name + if hasattr(base_transformer, prefixed_method): + raise AttributeError( + "Cannot merge: method '%s' appears more than once" % prefixed_method + ) + + setattr(base_transformer, prefixed_method, method) + + return base_transformer + + +class InlineTransformer(Transformer): ## + + def _call_userfunc(self, tree, new_children=None): + ## + + children = new_children if new_children is not None else tree.children + try: + f = getattr(self, tree.data) + except AttributeError: + return self.__default__(tree.data, children, tree.meta) + else: + return f(*children) + + +class TransformerChain(Generic[_Leaf_T, _Return_T]): + + transformers: 'Tuple[Union[Transformer, TransformerChain], ...]' + + def __init__( + self, *transformers: 'Union[Transformer, TransformerChain]' + ) -> None: + self.transformers = transformers + + def transform(self, tree: Tree[_Leaf_T]) -> _Return_T: + for t in self.transformers: + tree = t.transform(tree) + return cast(_Return_T, tree) + + def __mul__( + self: 'TransformerChain[_Leaf_T, Tree[_Leaf_U]]', + other: 'Union[Transformer[_Leaf_U, _Return_V], TransformerChain[_Leaf_U, _Return_V]]', + ) -> 'TransformerChain[_Leaf_T, _Return_V]': + return TransformerChain(*self.transformers + (other,)) + + +class Transformer_InPlace(Transformer[_Leaf_T, _Return_T]): + # -- + def _transform_tree(self, tree): ## + + return self._call_userfunc(tree) + + def transform(self, tree: Tree[_Leaf_T]) -> _Return_T: + for subtree in tree.iter_subtrees(): + subtree.children = list(self._transform_children(subtree.children)) + + return self._transform_tree(tree) + + +class Transformer_NonRecursive(Transformer[_Leaf_T, _Return_T]): + # -- + + def transform(self, tree: Tree[_Leaf_T]) -> _Return_T: + ## + + rev_postfix = [] + q: List[Branch[_Leaf_T]] = [tree] + while q: + t = q.pop() + rev_postfix.append(t) + if isinstance(t, Tree): + q += t.children + + ## + + stack: List = [] + for x in reversed(rev_postfix): + if isinstance(x, Tree): + size = len(x.children) + if size: + args = stack[-size:] + del stack[-size:] + else: + args = [] + + res = self._call_userfunc(x, args) + if res is not Discard: + stack.append(res) + + elif self.__visit_tokens__ and isinstance(x, Token): + res = self._call_userfunc_token(x) + if res is not Discard: + stack.append(res) + else: + stack.append(x) + + (result,) = stack ## + + ## + + ## + + ## + + return cast(_Return_T, result) + + +class Transformer_InPlaceRecursive(Transformer): + # -- + def _transform_tree(self, tree): + tree.children = list(self._transform_children(tree.children)) + return self._call_userfunc(tree) + + +## + + +class VisitorBase: + + def _call_userfunc(self, tree): + return getattr(self, tree.data, self.__default__)(tree) + + def __default__(self, tree): + # -- + return tree + + def __class_getitem__(cls, _): + return cls + + +class Visitor(VisitorBase, ABC, Generic[_Leaf_T]): + # -- + + def visit(self, tree: Tree[_Leaf_T]) -> Tree[_Leaf_T]: + # -- + for subtree in tree.iter_subtrees(): + self._call_userfunc(subtree) + return tree + + def visit_topdown(self, tree: Tree[_Leaf_T]) -> Tree[_Leaf_T]: + # -- + for subtree in tree.iter_subtrees_topdown(): + self._call_userfunc(subtree) + return tree + + +class Visitor_Recursive(VisitorBase, Generic[_Leaf_T]): + # -- + + def visit(self, tree: Tree[_Leaf_T]) -> Tree[_Leaf_T]: + # -- + for child in tree.children: + if isinstance(child, Tree): + self.visit(child) + + self._call_userfunc(tree) + return tree + + def visit_topdown(self, tree: Tree[_Leaf_T]) -> Tree[_Leaf_T]: + # -- + self._call_userfunc(tree) + + for child in tree.children: + if isinstance(child, Tree): + self.visit_topdown(child) + + return tree + + +class Interpreter(_Decoratable, ABC, Generic[_Leaf_T, _Return_T]): + # -- + + def visit(self, tree: Tree[_Leaf_T]) -> _Return_T: + ## + + ## + + ## + + return self._visit_tree(tree) + + def _visit_tree(self, tree: Tree[_Leaf_T]): + f = getattr(self, tree.data) + wrapper = getattr(f, 'visit_wrapper', None) + if wrapper is not None: + return f.visit_wrapper(f, tree.data, tree.children, tree.meta) + else: + return f(tree) + + def visit_children(self, tree: Tree[_Leaf_T]) -> List: + return [ + self._visit_tree(child) if isinstance(child, Tree) else child + for child in tree.children + ] + + def __getattr__(self, name): + return self.__default__ + + def __default__(self, tree): + return self.visit_children(tree) + + +_InterMethod = Callable[[Type[Interpreter], _Return_T], _R] + + +def visit_children_decor(func: _InterMethod) -> _InterMethod: + # -- + @wraps(func) + def inner(cls, tree): + values = cls.visit_children(tree) + return func(cls, values) + + return inner + + +## + + +def _apply_v_args(obj, visit_wrapper): + try: + _apply = obj._apply_v_args + except AttributeError: + return _VArgsWrapper(obj, visit_wrapper) + else: + return _apply(visit_wrapper) + + +class _VArgsWrapper: + # -- + base_func: Callable + + def __init__( + self, + func: Callable, + visit_wrapper: Callable[[Callable, str, list, Any], Any], + ): + if isinstance(func, _VArgsWrapper): + func = func.base_func + self.base_func = func + self.visit_wrapper = visit_wrapper + update_wrapper(self, func) + + def __call__(self, *args, **kwargs): + return self.base_func(*args, **kwargs) + + def __get__(self, instance, owner=None): + try: + ## + + ## + + g = type(self.base_func).__get__ # pytype: disable=attribute-error + except AttributeError: + return self + else: + return _VArgsWrapper( + g(self.base_func, instance, owner), self.visit_wrapper + ) + + def __set_name__(self, owner, name): + try: + f = type(self.base_func).__set_name__ # pytype: disable=attribute-error + except AttributeError: + return + else: + f(self.base_func, owner, name) + + +def _vargs_inline(f, _data, children, _meta): + return f(*children) + + +def _vargs_meta_inline(f, _data, children, meta): + return f(meta, *children) + + +def _vargs_meta(f, _data, children, meta): + return f(meta, children) + + +def _vargs_tree(f, data, children, meta): + return f(Tree(data, children, meta)) + + +def v_args( + inline: bool = False, + meta: bool = False, + tree: bool = False, + wrapper: Optional[Callable] = None, +) -> Callable[[_DECORATED], _DECORATED]: + # -- + if tree and (meta or inline): + raise ValueError( + "Visitor functions cannot combine 'tree' with 'meta' or 'inline'." + ) + + func = None + if meta: + if inline: + func = _vargs_meta_inline + else: + func = _vargs_meta + elif inline: + func = _vargs_inline + elif tree: + func = _vargs_tree + + if wrapper is not None: + if func is not None: + raise ValueError( + "Cannot use 'wrapper' along with 'tree', 'meta' or 'inline'." + ) + func = wrapper + + def _visitor_args_dec(obj): + return _apply_v_args(obj, func) + + return _visitor_args_dec + + +TOKEN_DEFAULT_PRIORITY = 0 + + +class Symbol(Serialize): + __slots__ = ('name',) + + name: str + is_term: ClassVar[bool] = NotImplemented + + def __init__(self, name: str) -> None: + self.name = name + + def __eq__(self, other): + assert isinstance(other, Symbol), other + return self.is_term == other.is_term and self.name == other.name + + def __ne__(self, other): + return not (self == other) + + def __hash__(self): + return hash(self.name) + + def __repr__(self): + return '%s(%r)' % (type(self).__name__, self.name) + + fullrepr = property(__repr__) + + def renamed(self, f): + return type(self)(f(self.name)) + + +class Terminal(Symbol): + __serialize_fields__ = 'name', 'filter_out' + + is_term: ClassVar[bool] = True + + def __init__(self, name, filter_out=False): + super().__init__(name) + self.name = name + self.filter_out = filter_out + + @property + def fullrepr(self): + return '%s(%r, %r)' % (type(self).__name__, self.name, self.filter_out) + + def renamed(self, f): + return type(self)(f(self.name), self.filter_out) + + +class NonTerminal(Symbol): + __serialize_fields__ = ('name',) + + is_term: ClassVar[bool] = False + + +class RuleOptions(Serialize): + __serialize_fields__ = ( + 'keep_all_tokens', + 'expand1', + 'priority', + 'template_source', + 'empty_indices', + ) + + keep_all_tokens: bool + expand1: bool + priority: Optional[int] + template_source: Optional[str] + empty_indices: Tuple[bool, ...] + + def __init__( + self, + keep_all_tokens: bool = False, + expand1: bool = False, + priority: Optional[int] = None, + template_source: Optional[str] = None, + empty_indices: Tuple[bool, ...] = (), + ) -> None: + self.keep_all_tokens = keep_all_tokens + self.expand1 = expand1 + self.priority = priority + self.template_source = template_source + self.empty_indices = empty_indices + + def __repr__(self): + return 'RuleOptions(%r, %r, %r, %r)' % ( + self.keep_all_tokens, + self.expand1, + self.priority, + self.template_source, + ) + + +class Rule(Serialize): + # -- + __slots__ = ('origin', 'expansion', 'alias', 'options', 'order', '_hash') + + __serialize_fields__ = 'origin', 'expansion', 'order', 'alias', 'options' + __serialize_namespace__ = Terminal, NonTerminal, RuleOptions + + origin: NonTerminal + expansion: Sequence[Symbol] + order: int + alias: Optional[str] + options: RuleOptions + _hash: int + + def __init__( + self, + origin: NonTerminal, + expansion: Sequence[Symbol], + order: int = 0, + alias: Optional[str] = None, + options: Optional[RuleOptions] = None, + ): + self.origin = origin + self.expansion = expansion + self.alias = alias + self.order = order + self.options = options or RuleOptions() + self._hash = hash((self.origin, tuple(self.expansion))) + + def _deserialize(self): + self._hash = hash((self.origin, tuple(self.expansion))) + + def __str__(self): + return '<%s : %s>' % ( + self.origin.name, + ' '.join(x.name for x in self.expansion), + ) + + def __repr__(self): + return 'Rule(%r, %r, %r, %r)' % ( + self.origin, + self.expansion, + self.alias, + self.options, + ) + + def __hash__(self): + return self._hash + + def __eq__(self, other): + if not isinstance(other, Rule): + return False + return self.origin == other.origin and self.expansion == other.expansion + + +class Pattern(Serialize, ABC): + # -- + + value: str + flags: Collection[str] + raw: Optional[str] + type: ClassVar[str] + + def __init__( + self, value: str, flags: Collection[str] = (), raw: Optional[str] = None + ) -> None: + self.value = value + self.flags = frozenset(flags) + self.raw = raw + + def __repr__(self): + return repr(self.to_regexp()) + + ## + + def __hash__(self): + return hash((type(self), self.value, self.flags)) + + def __eq__(self, other): + return ( + type(self) == type(other) # pylint: disable=unidiomatic-typecheck + and self.value == other.value + and self.flags == other.flags + ) + + @abstractmethod + def to_regexp(self) -> str: + raise NotImplementedError() + + @property + @abstractmethod + def min_width(self) -> int: + raise NotImplementedError() + + @property + @abstractmethod + def max_width(self) -> int: + raise NotImplementedError() + + def _get_flags(self, value): + for f in self.flags: + value = '(?%s:%s)' % (f, value) + return value + + +class PatternStr(Pattern): + __serialize_fields__ = 'value', 'flags', 'raw' + + type: ClassVar[str] = 'str' + + def to_regexp(self) -> str: + return self._get_flags(re.escape(self.value)) + + @property + def min_width(self) -> int: + return len(self.value) + + @property + def max_width(self) -> int: + return len(self.value) + + +class PatternRE(Pattern): + __serialize_fields__ = 'value', 'flags', 'raw', '_width' + + type: ClassVar[str] = 're' + + def to_regexp(self) -> str: + return self._get_flags(self.value) + + _width = None + + def _get_width(self): + if self._width is None: + self._width = get_regexp_width(self.to_regexp()) + return self._width + + @property + def min_width(self) -> int: + return self._get_width()[0] + + @property + def max_width(self) -> int: + return self._get_width()[1] + + +class TerminalDef(Serialize): + # -- + __serialize_fields__ = 'name', 'pattern', 'priority' + __serialize_namespace__ = PatternStr, PatternRE + + name: str + pattern: Pattern + priority: int + + def __init__( + self, name: str, pattern: Pattern, priority: int = TOKEN_DEFAULT_PRIORITY + ) -> None: + assert isinstance(pattern, Pattern), pattern + self.name = name + self.pattern = pattern + self.priority = priority + + def __repr__(self): + return '%s(%r, %r)' % (type(self).__name__, self.name, self.pattern) + + def user_repr(self) -> str: + if self.name.startswith('__'): ## + + return self.pattern.raw or self.name + else: + return self.name + + +_T = TypeVar('_T', bound='Token') + + +class Token(str): + # -- + __slots__ = ( + 'type', + 'start_pos', + 'value', + 'line', + 'column', + 'end_line', + 'end_column', + 'end_pos', + ) + + __match_args__ = ('type', 'value') + + type: str + start_pos: Optional[int] + value: Any + line: Optional[int] + column: Optional[int] + end_line: Optional[int] + end_column: Optional[int] + end_pos: Optional[int] + + @overload + def __new__( + cls, + type: str, + value: Any, + start_pos: Optional[int] = None, + line: Optional[int] = None, + column: Optional[int] = None, + end_line: Optional[int] = None, + end_column: Optional[int] = None, + end_pos: Optional[int] = None, + ) -> 'Token': + ... + + @overload + def __new__( + cls, + type_: str, + value: Any, + start_pos: Optional[int] = None, + line: Optional[int] = None, + column: Optional[int] = None, + end_line: Optional[int] = None, + end_column: Optional[int] = None, + end_pos: Optional[int] = None, + ) -> 'Token': + ... + + def __new__(cls, *args, **kwargs): + return cls._future_new(*args, **kwargs) + + @classmethod + def _future_new( + cls, + type, + value, + start_pos=None, + line=None, + column=None, + end_line=None, + end_column=None, + end_pos=None, + ): + inst = super(Token, cls).__new__(cls, value) + + inst.type = type + inst.start_pos = start_pos + inst.value = value + inst.line = line + inst.column = column + inst.end_line = end_line + inst.end_column = end_column + inst.end_pos = end_pos + return inst + + @overload + def update( + self, type: Optional[str] = None, value: Optional[Any] = None + ) -> 'Token': + ... + + @overload + def update( + self, type_: Optional[str] = None, value: Optional[Any] = None + ) -> 'Token': + ... + + def update(self, *args, **kwargs): + + return self._future_update(*args, **kwargs) + + def _future_update( + self, type: Optional[str] = None, value: Optional[Any] = None + ) -> 'Token': + return Token.new_borrow_pos( + type if type is not None else self.type, + value if value is not None else self.value, + self, + ) + + @classmethod + def new_borrow_pos( + cls: Type[_T], type_: str, value: Any, borrow_t: 'Token' + ) -> _T: + return cls( + type_, + value, + borrow_t.start_pos, + borrow_t.line, + borrow_t.column, + borrow_t.end_line, + borrow_t.end_column, + borrow_t.end_pos, + ) + + def __reduce__(self): + return ( + self.__class__, + (self.type, self.value, self.start_pos, self.line, self.column), + ) + + def __repr__(self): + return 'Token(%r, %r)' % (self.type, self.value) + + def __deepcopy__(self, memo): + return Token(self.type, self.value, self.start_pos, self.line, self.column) + + def __eq__(self, other): + if isinstance(other, Token) and self.type != other.type: + return False + + return str.__eq__(self, other) + + __hash__ = str.__hash__ + + +class LineCounter: + # -- + + __slots__ = 'char_pos', 'line', 'column', 'line_start_pos', 'newline_char' + + def __init__(self, newline_char): + self.newline_char = newline_char + self.char_pos = 0 + self.line = 1 + self.column = 1 + self.line_start_pos = 0 + + def __eq__(self, other): + if not isinstance(other, LineCounter): + return NotImplemented + + return ( + self.char_pos == other.char_pos + and self.newline_char == other.newline_char + ) + + def feed(self, token: Token, test_newline=True): + # -- + if test_newline: + newlines = token.count(self.newline_char) + if newlines: + self.line += newlines + self.line_start_pos = ( + self.char_pos + token.rindex(self.newline_char) + 1 + ) + + self.char_pos += len(token) + self.column = self.char_pos - self.line_start_pos + 1 + + +class UnlessCallback: + + def __init__(self, scanner): + self.scanner = scanner + + def __call__(self, t): + res = self.scanner.match(t.value, 0) + if res: + _value, t.type = res + return t + + +class CallChain: + + def __init__(self, callback1, callback2, cond): + self.callback1 = callback1 + self.callback2 = callback2 + self.cond = cond + + def __call__(self, t): + t2 = self.callback1(t) + return self.callback2(t) if self.cond(t2) else t2 + + +def _get_match(re_, regexp, s, flags): + m = re_.match(regexp, s, flags) + if m: + return m.group(0) + + +def _create_unless(terminals, g_regex_flags, re_, use_bytes): + tokens_by_type = classify(terminals, lambda t: type(t.pattern)) + assert len(tokens_by_type) <= 2, tokens_by_type.keys() + embedded_strs = set() + callback = {} + for retok in tokens_by_type.get(PatternRE, []): + unless = [] + for strtok in tokens_by_type.get(PatternStr, []): + if strtok.priority != retok.priority: + continue + s = strtok.pattern.value + if s == _get_match(re_, retok.pattern.to_regexp(), s, g_regex_flags): + unless.append(strtok) + if strtok.pattern.flags <= retok.pattern.flags: + embedded_strs.add(strtok) + if unless: + callback[retok.name] = UnlessCallback( + Scanner( + unless, g_regex_flags, re_, match_whole=True, use_bytes=use_bytes + ) + ) + + new_terminals = [t for t in terminals if t not in embedded_strs] + return new_terminals, callback + + +class Scanner: + + def __init__( + self, terminals, g_regex_flags, re_, use_bytes, match_whole=False + ): + self.terminals = terminals + self.g_regex_flags = g_regex_flags + self.re_ = re_ + self.use_bytes = use_bytes + self.match_whole = match_whole + + self.allowed_types = {t.name for t in self.terminals} + + self._mres = self._build_mres(terminals, len(terminals)) + + def _build_mres(self, terminals, max_size): + ## + + ## + + ## + + postfix = '$' if self.match_whole else '' + mres = [] + while terminals: + pattern = '|'.join( + '(?P<%s>%s)' % (t.name, t.pattern.to_regexp() + postfix) + for t in terminals[:max_size] + ) + if self.use_bytes: + pattern = pattern.encode('latin-1') + try: + mre = self.re_.compile(pattern, self.g_regex_flags) + except AssertionError: ## + + return self._build_mres(terminals, max_size // 2) + + mres.append(mre) + terminals = terminals[max_size:] + return mres + + def match(self, text, pos): + for mre in self._mres: + m = mre.match(text, pos) + if m: + return m.group(0), m.lastgroup + + +def _regexp_has_newline(r: str): + # -- + return ( + '\n' in r + or '\\n' in r + or '\\s' in r + or '[^' in r + or ('(?s' in r and '.' in r) + ) + + +class LexerState: + # -- + + __slots__ = 'text', 'line_ctr', 'last_token' + + text: str + line_ctr: LineCounter + last_token: Optional[Token] + + def __init__( + self, + text: str, + line_ctr: Optional[LineCounter] = None, + last_token: Optional[Token] = None, + ): + self.text = text + self.line_ctr = line_ctr or LineCounter( + b'\n' if isinstance(text, bytes) else '\n' + ) + self.last_token = last_token + + def __eq__(self, other): + if not isinstance(other, LexerState): + return NotImplemented + + return ( + self.text is other.text + and self.line_ctr == other.line_ctr + and self.last_token == other.last_token + ) + + def __copy__(self): + return type(self)(self.text, copy.copy(self.line_ctr), self.last_token) + + +class LexerThread: + # -- + + def __init__(self, lexer: 'Lexer', lexer_state: LexerState): + self.lexer = lexer + self.state = lexer_state + + @classmethod + def from_text(cls, lexer: 'Lexer', text: str) -> 'LexerThread': + return cls(lexer, LexerState(text)) + + def lex(self, parser_state): + return self.lexer.lex(self.state, parser_state) + + def __copy__(self): + return type(self)(self.lexer, copy.copy(self.state)) + + _Token = Token + + +_Callback = Callable[[Token], Token] + + +class Lexer(ABC): + # -- + @abstractmethod + def lex(self, lexer_state: LexerState, parser_state: Any) -> Iterator[Token]: + return NotImplemented + + def make_lexer_state(self, text): + # -- + return LexerState(text) + + +class AbstractBasicLexer(Lexer): + terminals_by_name: Dict[str, TerminalDef] + + @abstractmethod + def __init__(self, conf: 'LexerConf', comparator=None) -> None: + ... + + @abstractmethod + def next_token( + self, lex_state: LexerState, parser_state: Any = None + ) -> Token: + ... + + def lex(self, lexer_state: LexerState, parser_state: Any) -> Iterator[Token]: + with contextlib.suppress(EOFError): + while True: + yield self.next_token(lexer_state, parser_state) + + +class BasicLexer(AbstractBasicLexer): + terminals: Collection[TerminalDef] + ignore_types: FrozenSet[str] + newline_types: FrozenSet[str] + user_callbacks: Dict[str, _Callback] + callback: Dict[str, _Callback] + re: ModuleType + + def __init__(self, conf: 'LexerConf', comparator=None) -> None: + terminals = list(conf.terminals) + assert all(isinstance(t, TerminalDef) for t in terminals), terminals + + self.re = conf.re_module + + if not conf.skip_validation: + ## + + terminal_to_regexp = {} + for t in terminals: + regexp = t.pattern.to_regexp() + try: + self.re.compile(regexp, conf.g_regex_flags) + except self.re.error as e: + raise LexError( + 'Cannot compile token %s: %s' % (t.name, t.pattern) + ) from e + + if t.pattern.min_width == 0: + raise LexError( + 'Lexer does not allow zero-width terminals. (%s: %s)' + % (t.name, t.pattern) + ) + if t.pattern.type == 're': + terminal_to_regexp[t] = regexp + + if not (set(conf.ignore) <= {t.name for t in terminals}): + raise LexError( + 'Ignore terminals are not defined: %s' + % (set(conf.ignore) - {t.name for t in terminals}) + ) + + raise LexError( + 'interegular must be installed for strict mode. Use `pip install' + " 'lark[interegular]'`." + ) + + ## + + self.newline_types = frozenset( + t.name for t in terminals if _regexp_has_newline(t.pattern.to_regexp()) + ) + self.ignore_types = frozenset(conf.ignore) + + terminals.sort( + key=lambda x: ( + -x.priority, + -x.pattern.max_width, + -len(x.pattern.value), + x.name, + ) + ) + self.terminals = terminals + self.user_callbacks = conf.callbacks + self.g_regex_flags = conf.g_regex_flags + self.use_bytes = conf.use_bytes + self.terminals_by_name = conf.terminals_by_name + + self._scanner = None + + def _build_scanner(self): + terminals, self.callback = _create_unless( + self.terminals, self.g_regex_flags, self.re, self.use_bytes + ) + assert all(self.callback.values()) + + for type_, f in self.user_callbacks.items(): + if type_ in self.callback: + ## + def scanner_callback(t, target_type=type_): + return t.type == target_type + + self.callback[type_] = CallChain( + self.callback[type_], f, scanner_callback + ) + else: + self.callback[type_] = f + + self._scanner = Scanner( + terminals, self.g_regex_flags, self.re, self.use_bytes + ) + + @property + def scanner(self): + if self._scanner is None: + self._build_scanner() + return self._scanner + + def match(self, text, pos): + assert self.scanner is not None + return self.scanner.match(text, pos) + + def next_token( + self, lex_state: LexerState, parser_state: Any = None + ) -> Token: + line_ctr = lex_state.line_ctr + assert self.scanner is not None + while line_ctr.char_pos < len(lex_state.text): + res = self.match(lex_state.text, line_ctr.char_pos) + if not res: + allowed = self.scanner.allowed_types - self.ignore_types + if not allowed: + allowed = {''} + raise UnexpectedCharacters( + lex_state.text, + line_ctr.char_pos, + line_ctr.line, + line_ctr.column, + allowed=allowed, + token_history=lex_state.last_token and [lex_state.last_token], + state=parser_state, + terminals_by_name=self.terminals_by_name, + ) + + value, type_ = res + + ignored = type_ in self.ignore_types + t = None + if not ignored or type_ in self.callback: + t = Token( + type_, value, line_ctr.char_pos, line_ctr.line, line_ctr.column + ) + line_ctr.feed(value, type_ in self.newline_types) + if t is not None: + t.end_line = line_ctr.line + t.end_column = line_ctr.column + t.end_pos = line_ctr.char_pos + if t.type in self.callback: + t = self.callback[t.type](t) + if not ignored: + if not isinstance(t, Token): + raise LexError('Callbacks must return a token (returned %r)' % t) + lex_state.last_token = t + return t + + ## + + raise EOFError(self) + + +class ContextualLexer(Lexer): + lexers: Dict[int, AbstractBasicLexer] + root_lexer: AbstractBasicLexer + + BasicLexer: Type[AbstractBasicLexer] = BasicLexer + + def __init__( + self, + conf: 'LexerConf', + states: Dict[int, Collection[str]], + always_accept: Collection[str] = (), + ) -> None: + terminals = list(conf.terminals) + terminals_by_name = conf.terminals_by_name + + trad_conf = copy.copy(conf) + trad_conf.terminals = terminals + + comparator = None + lexer_by_tokens: Dict[FrozenSet[str], AbstractBasicLexer] = {} + self.lexers = {} + for state, accepts in states.items(): + key = frozenset(accepts) + try: + lexer = lexer_by_tokens[key] + except KeyError: + accepts = set(accepts) | set(conf.ignore) | set(always_accept) + lexer_conf = copy.copy(trad_conf) + lexer_conf.terminals = [ + terminals_by_name[n] for n in accepts if n in terminals_by_name + ] + lexer = self.BasicLexer(lexer_conf, comparator) + lexer_by_tokens[key] = lexer + + self.lexers[state] = lexer + + assert trad_conf.terminals is terminals + trad_conf.skip_validation = True ## + + self.root_lexer = self.BasicLexer(trad_conf, comparator) + + def lex( + self, lexer_state: LexerState, parser_state: 'ParserState' + ) -> Iterator[Token]: + try: + while True: + lexer = self.lexers[parser_state.position] + yield lexer.next_token(lexer_state, parser_state) + except EOFError: + pass + except UnexpectedCharacters as e: + last_token = lexer_state.last_token ## + + token = self.root_lexer.next_token(lexer_state, parser_state) + raise UnexpectedToken( + token, + e.allowed, + state=parser_state, + token_history=[last_token], + terminals_by_name=self.root_lexer.terminals_by_name, + ) from e + + +_ParserArgType: 'TypeAlias' = 'Literal["earley", "lalr", "cyk", "auto"]' +_LexerArgType: 'TypeAlias' = ( + 'Union[Literal["auto", "basic", "contextual", "dynamic",' + ' "dynamic_complete"], Type[Lexer]]' +) +_LexerCallback = Callable[[Token], Token] +ParserCallbacks = Dict[str, Callable] + + +class LexerConf(Serialize): + __serialize_fields__ = ( + 'terminals', + 'ignore', + 'g_regex_flags', + 'use_bytes', + 'lexer_type', + ) + __serialize_namespace__ = (TerminalDef,) + + terminals: Collection[TerminalDef] + re_module: ModuleType + ignore: Collection[str] + postlex: 'Optional[PostLex]' + callbacks: Dict[str, _LexerCallback] + g_regex_flags: int + skip_validation: bool + use_bytes: bool + lexer_type: Optional[_LexerArgType] + strict: bool + + def __init__( + self, + terminals: Collection[TerminalDef], + re_module: ModuleType, + ignore: Collection[str] = (), + postlex: 'Optional[PostLex]' = None, + callbacks: Optional[Dict[str, _LexerCallback]] = None, + g_regex_flags: int = 0, + skip_validation: bool = False, + use_bytes: bool = False, + strict: bool = False, + ): + self.terminals = terminals + self.terminals_by_name = {t.name: t for t in self.terminals} + assert len(self.terminals) == len(self.terminals_by_name) + self.ignore = ignore + self.postlex = postlex + self.callbacks = callbacks or {} + self.g_regex_flags = g_regex_flags + self.re_module = re_module + self.skip_validation = skip_validation + self.use_bytes = use_bytes + self.strict = strict + self.lexer_type = None + + def _deserialize(self): + self.terminals_by_name = {t.name: t for t in self.terminals} + + def __deepcopy__(self, memo=None): + return type(self)( + copy.deepcopy(self.terminals, memo), + self.re_module, + copy.deepcopy(self.ignore, memo), + copy.deepcopy(self.postlex, memo), + copy.deepcopy(self.callbacks, memo), + copy.deepcopy(self.g_regex_flags, memo), + copy.deepcopy(self.skip_validation, memo), + copy.deepcopy(self.use_bytes, memo), + ) + + +class ParserConf(Serialize): + __serialize_fields__ = 'rules', 'start', 'parser_type' + + rules: List['Rule'] + callbacks: ParserCallbacks + start: List[str] + parser_type: _ParserArgType + + def __init__( + self, rules: List['Rule'], callbacks: ParserCallbacks, start: List[str] + ): + assert isinstance(start, list) + self.rules = rules + self.callbacks = callbacks + self.start = start + + +class ExpandSingleChild: + + def __init__(self, node_builder): + self.node_builder = node_builder + + def __call__(self, children): + if len(children) == 1: + return children[0] + else: + return self.node_builder(children) + + +class PropagatePositions: + + def __init__(self, node_builder, node_filter=None): + self.node_builder = node_builder + self.node_filter = node_filter + + def __call__(self, children): + res = self.node_builder(children) + + if isinstance(res, Tree): + ## + + ## + + ## + + ## + + res_meta = res.meta + + first_meta = self._pp_get_meta(children) + if first_meta is not None: + if not hasattr(res_meta, 'line'): + ## + + res_meta.line = getattr(first_meta, 'container_line', first_meta.line) + res_meta.column = getattr( + first_meta, 'container_column', first_meta.column + ) + res_meta.start_pos = getattr( + first_meta, 'container_start_pos', first_meta.start_pos + ) + res_meta.empty = False + + res_meta.container_line = getattr( + first_meta, 'container_line', first_meta.line + ) + res_meta.container_column = getattr( + first_meta, 'container_column', first_meta.column + ) + res_meta.container_start_pos = getattr( + first_meta, 'container_start_pos', first_meta.start_pos + ) + + last_meta = self._pp_get_meta(reversed(children)) + if last_meta is not None: + if not hasattr(res_meta, 'end_line'): + res_meta.end_line = getattr( + last_meta, 'container_end_line', last_meta.end_line + ) + res_meta.end_column = getattr( + last_meta, 'container_end_column', last_meta.end_column + ) + res_meta.end_pos = getattr( + last_meta, 'container_end_pos', last_meta.end_pos + ) + res_meta.empty = False + + res_meta.container_end_line = getattr( + last_meta, 'container_end_line', last_meta.end_line + ) + res_meta.container_end_column = getattr( + last_meta, 'container_end_column', last_meta.end_column + ) + res_meta.container_end_pos = getattr( + last_meta, 'container_end_pos', last_meta.end_pos + ) + + return res + + def _pp_get_meta(self, children): + for c in children: + if self.node_filter is not None and not self.node_filter(c): + continue + if isinstance(c, Tree): + if not c.meta.empty: + return c.meta + elif isinstance(c, Token): + return c + elif hasattr(c, '__lark_meta__'): + return c.__lark_meta__() + + +def make_propagate_positions(option): + if callable(option): + return partial(PropagatePositions, node_filter=option) + elif option == True: + return PropagatePositions + elif option == False: + return None + + raise ConfigurationError( + 'Invalid option for propagate_positions: %r' % option + ) + + +class ChildFilter: + + def __init__(self, to_include, append_none, node_builder): + self.node_builder = node_builder + self.to_include = to_include + self.append_none = append_none + + def __call__(self, children): + filtered = [] + + for i, to_expand, add_none in self.to_include: + if add_none: + filtered += [None] * add_none + if to_expand: + filtered += children[i].children + else: + filtered.append(children[i]) + + if self.append_none: + filtered += [None] * self.append_none + + return self.node_builder(filtered) + + +class ChildFilterLALR(ChildFilter): + # -- + + def __call__(self, children): + filtered = [] + for i, to_expand, add_none in self.to_include: + if add_none: + filtered += [None] * add_none + if to_expand: + if filtered: + filtered += children[i].children + else: ## + + filtered = children[i].children + else: + filtered.append(children[i]) + + if self.append_none: + filtered += [None] * self.append_none + + return self.node_builder(filtered) + + +class ChildFilterLALR_NoPlaceholders(ChildFilter): + # -- + def __init__(self, to_include, node_builder): # pylint: disable=super-init-not-called + self.node_builder = node_builder + self.to_include = to_include + + def __call__(self, children): + filtered = [] + for i, to_expand in self.to_include: + if to_expand: + if filtered: + filtered += children[i].children + else: ## + + filtered = children[i].children + else: + filtered.append(children[i]) + return self.node_builder(filtered) + + +def _should_expand(sym): + return not sym.is_term and sym.name.startswith('_') + + +def maybe_create_child_filter( + expansion, keep_all_tokens, ambiguous, _empty_indices: List[bool] +): + ## + + if _empty_indices: + assert _empty_indices.count(False) == len(expansion) + s = ''.join(str(int(b)) for b in _empty_indices) + empty_indices = [len(ones) for ones in s.split('0')] + assert len(empty_indices) == len(expansion) + 1, ( + empty_indices, + len(expansion), + ) + else: + empty_indices = [0] * (len(expansion) + 1) + + to_include = [] + nones_to_add = 0 + for i, sym in enumerate(expansion): + nones_to_add += empty_indices[i] + if keep_all_tokens or not (sym.is_term and sym.filter_out): + to_include.append((i, _should_expand(sym), nones_to_add)) + nones_to_add = 0 + + nones_to_add += empty_indices[len(expansion)] + + if ( + _empty_indices + or len(to_include) < len(expansion) + or any(to_expand for _, to_expand, _ in to_include) + ): + if _empty_indices or ambiguous: + return partial( + ChildFilter if ambiguous else ChildFilterLALR, + to_include, + nones_to_add, + ) + else: + ## + + return partial( + ChildFilterLALR_NoPlaceholders, [(i, x) for i, x, _ in to_include] + ) + + +class AmbiguousExpander: + # -- + def __init__(self, to_expand, tree_class, node_builder): + self.node_builder = node_builder + self.tree_class = tree_class + self.to_expand = to_expand + + def __call__(self, children): + def _is_ambig_tree(t): + return hasattr(t, 'data') and t.data == '_ambig' + + ## + + ## + + ## + + ## + + ambiguous = [] + for i, child in enumerate(children): + if _is_ambig_tree(child): + if i in self.to_expand: + ambiguous.append(i) + + child.expand_kids_by_data('_ambig') + + if not ambiguous: + return self.node_builder(children) + + expand = [ + child.children if i in ambiguous else (child,) + for i, child in enumerate(children) + ] + return self.tree_class( + '_ambig', [self.node_builder(list(f)) for f in product(*expand)] + ) + + +def maybe_create_ambiguous_expander(tree_class, expansion, keep_all_tokens): + to_expand = [ + i + for i, sym in enumerate(expansion) + if keep_all_tokens + or ((not (sym.is_term and sym.filter_out)) and _should_expand(sym)) + ] + if to_expand: + return partial(AmbiguousExpander, to_expand, tree_class) + + +class AmbiguousIntermediateExpander: + # -- + + def __init__(self, tree_class, node_builder): + self.node_builder = node_builder + self.tree_class = tree_class + + def __call__(self, children): + def _is_iambig_tree(child): + return hasattr(child, 'data') and child.data == '_iambig' + + def _collapse_iambig(children): + # -- + + ## + + ## + + if children and _is_iambig_tree(children[0]): + iambig_node = children[0] + result = [] + for grandchild in iambig_node.children: + collapsed = _collapse_iambig(grandchild.children) + if collapsed: + for child in collapsed: + child.children += children[1:] + result += collapsed + else: + new_tree = self.tree_class( + '_inter', grandchild.children + children[1:] + ) + result.append(new_tree) + return result + + collapsed = _collapse_iambig(children) + if collapsed: + processed_nodes = [self.node_builder(c.children) for c in collapsed] + return self.tree_class('_ambig', processed_nodes) + + return self.node_builder(children) + + +def inplace_transformer(func): + @wraps(func) + def f(children): + ## + + tree = Tree(func.__name__, children) + return func(tree) + + return f + + +def apply_visit_wrapper(func, name, wrapper): + if wrapper is _vargs_meta or wrapper is _vargs_meta_inline: + raise NotImplementedError( + 'Meta args not supported for internal transformer' + ) + + @wraps(func) + def f(children): + return wrapper(func, name, children, None) + + return f + + +class ParseTreeBuilder: + + def __init__( + self, + rules, + tree_class, + propagate_positions=False, + ambiguous=False, + maybe_placeholders=False, + ): + self.tree_class = tree_class + self.propagate_positions = propagate_positions + self.ambiguous = ambiguous + self.maybe_placeholders = maybe_placeholders + + self.rule_builders = list(self._init_builders(rules)) + + def _init_builders(self, rules): + propagate_positions = make_propagate_positions(self.propagate_positions) + + for rule in rules: + options = rule.options + keep_all_tokens = options.keep_all_tokens + expand_single_child = options.expand1 + + wrapper_chain = list( + filter( + None, + [ + (expand_single_child and not rule.alias) + and ExpandSingleChild, + maybe_create_child_filter( + rule.expansion, + keep_all_tokens, + self.ambiguous, + options.empty_indices + if self.maybe_placeholders + else None, + ), + propagate_positions, + self.ambiguous + and maybe_create_ambiguous_expander( + self.tree_class, rule.expansion, keep_all_tokens + ), + self.ambiguous + and partial(AmbiguousIntermediateExpander, self.tree_class), + ], + ) + ) + + yield rule, wrapper_chain + + def create_callback(self, transformer=None): + callbacks = {} + + default_handler = getattr(transformer, '__default__', None) + if default_handler: + + def default_callback(data, children): + return default_handler(data, children, None) + + else: + default_callback = self.tree_class + + for rule, wrapper_chain in self.rule_builders: + + user_callback_name = ( + rule.alias or rule.options.template_source or rule.origin.name + ) + try: + f = getattr(transformer, user_callback_name) + wrapper = getattr(f, 'visit_wrapper', None) + if wrapper is not None: + f = apply_visit_wrapper(f, user_callback_name, wrapper) + elif isinstance(transformer, Transformer_InPlace): + f = inplace_transformer(f) + except AttributeError: + f = partial(default_callback, user_callback_name) + + for w in wrapper_chain: + f = w(f) + + if rule in callbacks: + raise GrammarError("Rule '%s' already exists" % (rule,)) + + callbacks[rule] = f + + return callbacks + + +class Action: + + def __init__(self, name): + self.name = name + + def __str__(self): + return self.name + + def __repr__(self): + return str(self) + + +Shift = Action('Shift') +Reduce = Action('Reduce') + +StateT = TypeVar('StateT') + + +class ParseTableBase(Generic[StateT]): + states: Dict[StateT, Dict[str, Tuple]] + start_states: Dict[str, StateT] + end_states: Dict[str, StateT] + + def __init__(self, states, start_states, end_states): + self.states = states + self.start_states = start_states + self.end_states = end_states + + @classmethod + def deserialize(cls, data, memo): + tokens = data['tokens'] + states = { + state: { + tokens[token]: ( + (Reduce, Rule.deserialize(arg, memo)) + if action == 1 + else (Shift, arg) + ) + for token, (action, arg) in actions.items() + } + for state, actions in data['states'].items() + } + return cls(states, data['start_states'], data['end_states']) + + +class ParseTable(ParseTableBase['State']): + # -- + pass + + +class RulePtr: + __slots__ = ('rule', 'index') + rule: Rule + index: int + + def __init__(self, rule: Rule, index: int): + assert isinstance(rule, Rule) + assert index <= len(rule.expansion) + self.rule = rule + self.index = index + + def __repr__(self): + before = [x.name for x in self.rule.expansion[: self.index]] + after = [x.name for x in self.rule.expansion[self.index :]] + return '<%s : %s * %s>' % ( + self.rule.origin.name, + ' '.join(before), + ' '.join(after), + ) + + @property + def next(self) -> Symbol: + return self.rule.expansion[self.index] + + def advance(self, sym: Symbol) -> 'RulePtr': + assert self.next == sym + return RulePtr(self.rule, self.index + 1) + + @property + def is_satisfied(self) -> bool: + return self.index == len(self.rule.expansion) + + def __eq__(self, other) -> bool: + if not isinstance(other, RulePtr): + return NotImplemented + return self.rule == other.rule and self.index == other.index + + def __hash__(self) -> int: + return hash((self.rule, self.index)) + + +State = FrozenSet[RulePtr] + + +class IntParseTable(ParseTableBase[int]): + # -- + + @classmethod + def from_ParseTable(cls, parse_table: ParseTable): + enum = list(parse_table.states) + state_to_idx: Dict['State', int] = {s: i for i, s in enumerate(enum)} + int_states = {} + + for s, la in parse_table.states.items(): + la = { + k: (v[0], state_to_idx[v[1]]) if v[0] == Shift else v + for k, v in la.items() + } + int_states[state_to_idx[s]] = la + + start_states = { + start: state_to_idx[s] for start, s in parse_table.start_states.items() + } + end_states = { + start: state_to_idx[s] for start, s in parse_table.end_states.items() + } + return cls(int_states, start_states, end_states) + + +class ParseConf(Generic[StateT]): + __slots__ = ( + 'parse_table', + 'callbacks', + 'start', + 'start_state', + 'end_state', + 'states', + ) + + parse_table: ParseTableBase[StateT] + callbacks: ParserCallbacks + start: str + + start_state: StateT + end_state: StateT + states: Dict[StateT, Dict[str, tuple]] + + def __init__( + self, + parse_table: ParseTableBase[StateT], + callbacks: ParserCallbacks, + start: str, + ): + self.parse_table = parse_table + + self.start_state = self.parse_table.start_states[start] + self.end_state = self.parse_table.end_states[start] + self.states = self.parse_table.states + + self.callbacks = callbacks + self.start = start + + +class ParserState(Generic[StateT]): + __slots__ = 'parse_conf', 'lexer', 'state_stack', 'value_stack' + + parse_conf: ParseConf[StateT] + lexer: LexerThread + state_stack: List[StateT] + value_stack: list + + def __init__( + self, + parse_conf: ParseConf[StateT], + lexer: LexerThread, + state_stack=None, + value_stack=None, + ): + self.parse_conf = parse_conf + self.lexer = lexer + self.state_stack = state_stack or [self.parse_conf.start_state] + self.value_stack = value_stack or [] + + @property + def position(self) -> StateT: + return self.state_stack[-1] + + ## + + def __eq__(self, other) -> bool: + if not isinstance(other, ParserState): + return NotImplemented + return ( + len(self.state_stack) == len(other.state_stack) + and self.position == other.position + ) + + def __copy__(self): + return self.copy() + + def copy(self, deepcopy_values=True) -> 'ParserState[StateT]': + return type(self)( + self.parse_conf, + self.lexer, ## + copy.copy(self.state_stack), + copy.deepcopy(self.value_stack) + if deepcopy_values + else copy.copy(self.value_stack), + ) + + def feed_token(self, token: Token, is_end=False) -> Any: + state_stack = self.state_stack + value_stack = self.value_stack + states = self.parse_conf.states + end_state = self.parse_conf.end_state + callbacks = self.parse_conf.callbacks + + while True: + state = state_stack[-1] + try: + action, arg = states[state][token.type] + except KeyError: + expected = {s for s in states[state].keys() if s.isupper()} + raise UnexpectedToken( + token, expected, state=self, interactive_parser=None + ) + + assert arg != end_state + + if action == Shift: + ## + + assert not is_end + state_stack.append(arg) + value_stack.append( + token + if token.type not in callbacks + else callbacks[token.type](token) + ) + return + else: + ## + + rule = arg + size = len(rule.expansion) + if size: + s = value_stack[-size:] + del state_stack[-size:] + del value_stack[-size:] + else: + s = [] + + value = callbacks[rule](s) if callbacks else s + + _action, new_state = states[state_stack[-1]][rule.origin.name] + assert _action == Shift + state_stack.append(new_state) + value_stack.append(value) + + if is_end and state_stack[-1] == end_state: + return value_stack[-1] + + +class LALR_Parser(Serialize): + + def __init__(self, parser_conf: ParserConf, debug: bool = False): + self.parser_conf = parser_conf + self.parser = _Parser(None, {}, debug) # pytype: disable=wrong-arg-types + + @classmethod + def deserialize(cls, data, memo, callbacks, debug=False): + inst = cls.__new__(cls) + inst._parse_table = IntParseTable.deserialize(data, memo) + inst.parser = _Parser(inst._parse_table, callbacks, debug) + return inst + + def parse(self, lexer, start, on_error=None): + del on_error + return self.parser.parse(lexer, start) + + +class _Parser: + parse_table: ParseTableBase + callbacks: ParserCallbacks + debug: bool + + def __init__( + self, + parse_table: ParseTableBase, + callbacks: ParserCallbacks, + debug: bool = False, + ): + self.parse_table = parse_table + self.callbacks = callbacks + self.debug = debug + + def parse( + self, + lexer: LexerThread, + start: str, + value_stack=None, + state_stack=None, + start_interactive=False, + ): + parse_conf = ParseConf(self.parse_table, self.callbacks, start) + parser_state = ParserState(parse_conf, lexer, state_stack, value_stack) + if start_interactive: + return InteractiveParser(self, parser_state, parser_state.lexer) + return self.parse_from_state(parser_state) + + def parse_from_state( + self, state: ParserState, last_token: Optional[Token] = None + ): + # -- + try: + token = last_token + for token in state.lexer.lex(state): + assert token is not None + state.feed_token(token) + + end_token = ( + Token.new_borrow_pos('$END', '', token) + if token + else Token('$END', '', 0, 1, 1) + ) + return state.feed_token(end_token, True) + except UnexpectedInput as e: + try: + e.interactive_parser = InteractiveParser(self, state, state.lexer) + except NameError: + pass + raise e + + +class InteractiveParser: + # -- + def __init__( + self, parser, parser_state: ParserState, lexer_thread: LexerThread + ): + self.parser = parser + self.parser_state = parser_state + self.lexer_thread = lexer_thread + self.result = None + + @property + def lexer_state(self) -> LexerThread: + return self.lexer_thread + + def feed_token(self, token: Token): + # -- + return self.parser_state.feed_token(token, token.type == '$END') + + def iter_parse(self) -> Iterator[Token]: + # -- + for token in self.lexer_thread.lex(self.parser_state): + yield token + self.result = self.feed_token(token) + + def exhaust_lexer(self) -> List[Token]: + # -- + return list(self.iter_parse()) + + def feed_eof(self, last_token=None): + # -- + eof = ( + Token.new_borrow_pos('$END', '', last_token) + if last_token is not None + else self.lexer_thread._Token('$END', '', 0, 1, 1) + ) + return self.feed_token(eof) + + def __copy__(self): + # -- + return self.copy() + + def copy(self, deepcopy_values=True): + return type(self)( + self.parser, + self.parser_state.copy(deepcopy_values=deepcopy_values), + copy.copy(self.lexer_thread), + ) + + def __eq__(self, other): + if not isinstance(other, InteractiveParser): + return False + + return ( + self.parser_state == other.parser_state + and self.lexer_thread == other.lexer_thread + ) + + def as_immutable(self): + # -- + p = copy.copy(self) + return ImmutableInteractiveParser(p.parser, p.parser_state, p.lexer_thread) + + def pretty(self): + # -- + out = ['Parser choices:'] + for k, v in self.choices().items(): + out.append('\t- %s -> %r' % (k, v)) + out.append('stack size: %s' % len(self.parser_state.state_stack)) + return '\n'.join(out) + + def choices(self): + # -- + return self.parser_state.parse_conf.parse_table.states[ + self.parser_state.position + ] + + def accepts(self): + # -- + accepts = set() + conf_no_callbacks = copy.copy(self.parser_state.parse_conf) + ## + + ## + + conf_no_callbacks.callbacks = {} + for t in self.choices(): + if t.isupper(): ## + + new_cursor = self.copy(deepcopy_values=False) + new_cursor.parser_state.parse_conf = conf_no_callbacks + try: + new_cursor.feed_token(self.lexer_thread._Token(t, '')) + except UnexpectedToken: + pass + else: + accepts.add(t) + return accepts + + def resume_parse(self): + # -- + return self.parser.parse_from_state( + self.parser_state, last_token=self.lexer_thread.state.last_token + ) + + +class ImmutableInteractiveParser(InteractiveParser): + # -- + + result = None + + def __hash__(self): + return hash((self.parser_state, self.lexer_thread)) + + def feed_token(self, token): + c = copy.copy(self) + c.result = InteractiveParser.feed_token(c, token) + return c + + def exhaust_lexer(self): + # -- + cursor = self.as_mutable() + cursor.exhaust_lexer() + return cursor.as_immutable() + + def as_mutable(self): + # -- + p = copy.copy(self) + return InteractiveParser(p.parser, p.parser_state, p.lexer_thread) + + +def _wrap_lexer(lexer_class): + future_interface = getattr(lexer_class, '__future_interface__', False) + if future_interface: + return lexer_class + else: + + class CustomLexerWrapper(Lexer): + + def __init__(self, lexer_conf): + self.lexer = lexer_class(lexer_conf) + + def lex(self, lexer_state, parser_state): + return self.lexer.lex(lexer_state.text) + + return CustomLexerWrapper + + +def _deserialize_parsing_frontend(data, memo, lexer_conf, callbacks, options): + parser_conf = ParserConf.deserialize(data['parser_conf'], memo) + cls = (options and options._plugins.get('LALR_Parser')) or LALR_Parser + parser = cls.deserialize(data['parser'], memo, callbacks, options.debug) + parser_conf.callbacks = callbacks + return ParsingFrontend(lexer_conf, parser_conf, options, parser=parser) + + +_parser_creators: 'Dict[str, Callable[[LexerConf, Any, Any], Any]]' = {} + + +class ParsingFrontend(Serialize): + __serialize_fields__ = 'lexer_conf', 'parser_conf', 'parser' + + lexer_conf: LexerConf + parser_conf: ParserConf + options: Any + + def __init__( + self, lexer_conf: LexerConf, parser_conf: ParserConf, options, parser=None + ): + self.parser_conf = parser_conf + self.lexer_conf = lexer_conf + self.options = options + + ## + + if parser: ## + + self.parser = parser + else: + create_parser = _parser_creators.get(parser_conf.parser_type) + assert ( + create_parser is not None + ), '{} is not supported in standalone mode'.format( + parser_conf.parser_type + ) + self.parser = create_parser(lexer_conf, parser_conf, options) + + ## + + lexer_type = lexer_conf.lexer_type + self.skip_lexer = False + if lexer_type in ('dynamic', 'dynamic_complete'): + assert lexer_conf.postlex is None + self.skip_lexer = True + return + + if isinstance(lexer_type, type): + assert issubclass(lexer_type, Lexer) + self.lexer = _wrap_lexer(lexer_type)(lexer_conf) + elif isinstance(lexer_type, str): + create_lexer = { + 'basic': create_basic_lexer, + 'contextual': create_contextual_lexer, + }[lexer_type] + self.lexer = create_lexer( + lexer_conf, self.parser, lexer_conf.postlex, options + ) + else: + raise TypeError('Bad value for lexer_type: {lexer_type}') + + if lexer_conf.postlex: + self.lexer = PostLexConnector(self.lexer, lexer_conf.postlex) + + def _verify_start(self, start=None): + if start is None: + start_decls = self.parser_conf.start + if len(start_decls) > 1: + raise ConfigurationError( + 'Lark initialized with more than 1 possible start rule. Must' + ' specify which start rule to parse', + start_decls, + ) + (start,) = start_decls + elif start not in self.parser_conf.start: + raise ConfigurationError( + 'Unknown start rule %s. Must be one of %r' + % (start, self.parser_conf.start) + ) + return start + + def _make_lexer_thread(self, text: str) -> Union[str, LexerThread]: + cls = ( + self.options and self.options._plugins.get('LexerThread') + ) or LexerThread + return text if self.skip_lexer else cls.from_text(self.lexer, text) + + def parse(self, text: str, start=None, on_error=None): + chosen_start = self._verify_start(start) + kw = {} if on_error is None else {'on_error': on_error} + stream = self._make_lexer_thread(text) + return self.parser.parse(stream, chosen_start, **kw) + + +def _validate_frontend_args(parser, lexer) -> None: + assert_config(parser, ('lalr', 'earley', 'cyk')) + if not isinstance(lexer, type): ## + + expected = { + 'lalr': ('basic', 'contextual'), + 'earley': ('basic', 'dynamic', 'dynamic_complete'), + 'cyk': ('basic',), + }[parser] + assert_config( + lexer, + expected, + 'Parser %r does not support lexer %%r, expected one of %%s' % parser, + ) + + +def _get_lexer_callbacks(transformer, terminals): + result = {} + for terminal in terminals: + callback = getattr(transformer, terminal.name, None) + if callback is not None: + result[terminal.name] = callback + return result + + +class PostLexConnector: + + def __init__(self, lexer, postlexer): + self.lexer = lexer + self.postlexer = postlexer + + def lex(self, lexer_state, parser_state): + i = self.lexer.lex(lexer_state, parser_state) + return self.postlexer.process(i) + + +def create_basic_lexer(lexer_conf, parser, postlex, options) -> BasicLexer: + del parser, postlex + cls = (options and options._plugins.get('BasicLexer')) or BasicLexer + return cls(lexer_conf) + + +def create_contextual_lexer( + lexer_conf: LexerConf, parser, postlex, options +) -> ContextualLexer: + cls = (options and options._plugins.get('ContextualLexer')) or ContextualLexer + parse_table: ParseTableBase[int] = parser._parse_table + states: Dict[int, Collection[str]] = { + idx: list(t.keys()) for idx, t in parse_table.states.items() + } + always_accept: Collection[str] = postlex.always_accept if postlex else () + return cls(lexer_conf, states, always_accept=always_accept) + + +def create_lalr_parser( + lexer_conf: LexerConf, parser_conf: ParserConf, options=None +) -> LALR_Parser: + del lexer_conf + debug = options.debug if options else False + strict = options.strict if options else False + cls = (options and options._plugins.get('LALR_Parser')) or LALR_Parser + return cls(parser_conf, debug=debug, strict=strict) + + +_parser_creators['lalr'] = create_lalr_parser + + +class PostLex(ABC): + + @abstractmethod + def process(self, stream: Iterator[Token]) -> Iterator[Token]: + return stream + + always_accept: Iterable[str] = () + + +class LarkOptions(Serialize): + # -- + + start: List[str] + debug: bool + strict: bool + transformer: 'Optional[Transformer]' + propagate_positions: Union[bool, str] + maybe_placeholders: bool + cache: Union[bool, str] + regex: bool + g_regex_flags: int + keep_all_tokens: bool + tree_class: Optional[Callable[[str, List], Any]] + parser: _ParserArgType + lexer: _LexerArgType + ambiguity: Literal['auto', 'resolve', 'explicit', 'forest'] + postlex: Optional[PostLex] + priority: Optional[Literal['auto', 'normal', 'invert']] + lexer_callbacks: Dict[str, Callable[[Token], Token]] + use_bytes: bool + ordered_sets: bool + edit_terminals: Optional[Callable[[TerminalDef], TerminalDef]] + import_paths: ( + 'List[Union[str, Callable[[Union[None, str], str], Tuple[str, str]]]]' + ) + source_path: Optional[str] + + _defaults: Dict[str, Any] = { + 'debug': False, + 'strict': False, + 'keep_all_tokens': False, + 'tree_class': None, + 'cache': False, + 'postlex': None, + 'parser': 'earley', + 'lexer': 'auto', + 'transformer': None, + 'start': 'start', + 'priority': 'auto', + 'ambiguity': 'auto', + 'regex': False, + 'propagate_positions': False, + 'lexer_callbacks': {}, + 'maybe_placeholders': True, + 'edit_terminals': None, + 'g_regex_flags': 0, + 'use_bytes': False, + 'ordered_sets': True, + 'import_paths': [], + 'source_path': None, + '_plugins': {}, + } + + def __init__(self, options_dict: Dict[str, Any]) -> None: + o = dict(options_dict) + + options = {} + for name, default in self._defaults.items(): + if name in o: + value = o.pop(name) + if isinstance(default, bool) and name not in ( + 'cache', + 'use_bytes', + 'propagate_positions', + ): + value = bool(value) + else: + value = default + + options[name] = value + + if isinstance(options['start'], str): + options['start'] = [options['start']] + + self.__dict__['options'] = options + + assert_config(self.parser, ('earley', 'lalr', 'cyk', None)) + + if self.parser == 'earley' and self.transformer: + raise ConfigurationError( + 'Cannot specify an embedded transformer when using the Earley' + ' algorithm. Please use your transformer on the resulting parse tree,' + ' or use a different algorithm (i.e. LALR)' + ) + + if o: + raise ConfigurationError('Unknown options: %s' % o.keys()) + + def __getattr__(self, name: str) -> Any: + try: + return self.__dict__['options'][name] + except KeyError as e: + raise AttributeError(e) from e + + def __setattr__(self, name: str, value: str) -> None: + assert_config( + name, + self.options.keys(), + "%r isn't a valid option. Expected one of: %s", + ) + self.options[name] = value + + @classmethod + def deserialize( + cls, data: Dict[str, Any], memo: Dict[int, Union[TerminalDef, Rule]] + ) -> 'LarkOptions': + return cls(data) + + +## + +## + +_LOAD_ALLOWED_OPTIONS = frozenset({ + 'postlex', + 'transformer', + 'lexer_callbacks', + 'use_bytes', + 'debug', + 'g_regex_flags', + 'regex', + 'propagate_positions', + 'tree_class', + '_plugins', +}) + +_VALID_PRIORITY_OPTIONS = ('auto', 'normal', 'invert', None) +_VALID_AMBIGUITY_OPTIONS = ('auto', 'resolve', 'explicit', 'forest') + + +_T = TypeVar('_T', bound='Lark') + + +class Grammar: + """Context-free grammar.""" + + def __init__(self, rules): + self.rules = frozenset(rules) + + def __eq__(self, other): + return self.rules == other.rules + + def __str__(self): + return '\n' + '\n'.join(sorted(repr(x) for x in self.rules)) + '\n' + + def __repr__(self): + return str(self) + + +class Lark(Serialize): + # -- + + source_path: str + source_grammar: str + grammar: 'Grammar' + options: LarkOptions + lexer: Lexer + parser: 'ParsingFrontend' + terminals: Collection[TerminalDef] + + def __init__(self, grammar: 'Grammar', **options) -> None: + pass + + __serialize_fields__ = 'parser', 'rules', 'options' + + def _build_lexer(self, dont_ignore: bool = False) -> BasicLexer: + lexer_conf = self.lexer_conf + if dont_ignore: + lexer_conf = copy.copy(lexer_conf) + lexer_conf.ignore = () + return BasicLexer(lexer_conf) + + def _prepare_callbacks(self) -> None: + self._callbacks = {} + ## + + if self.options.ambiguity != 'forest': + self._parse_tree_builder = ParseTreeBuilder( + self.rules, + self.options.tree_class or Tree, + self.options.propagate_positions, + self.options.parser != 'lalr' + and self.options.ambiguity == 'explicit', + self.options.maybe_placeholders, + ) + self._callbacks = self._parse_tree_builder.create_callback( + self.options.transformer + ) + self._callbacks.update( + _get_lexer_callbacks(self.options.transformer, self.terminals) + ) + + @classmethod + def load(cls: Type[_T], f) -> _T: + # -- + inst = cls.__new__(cls) + return inst._load(f) + + def _deserialize_lexer_conf( + self, + data: Dict[str, Any], + memo: Dict[int, Union[TerminalDef, Rule]], + options: LarkOptions, + ) -> LexerConf: + lexer_conf = LexerConf.deserialize(data['lexer_conf'], memo) + lexer_conf.callbacks = options.lexer_callbacks or {} + lexer_conf.re_module = re + lexer_conf.use_bytes = options.use_bytes + lexer_conf.g_regex_flags = options.g_regex_flags + lexer_conf.skip_validation = True + lexer_conf.postlex = options.postlex + return lexer_conf + + def _load(self: _T, d: Any, **kwargs) -> _T: + memo_json = d['memo'] + data = d['data'] + + assert memo_json + memo = SerializeMemoizer.deserialize( + memo_json, {'Rule': Rule, 'TerminalDef': TerminalDef}, {} + ) + options = dict(data['options']) + if (set(kwargs) - _LOAD_ALLOWED_OPTIONS) & set(LarkOptions._defaults): + raise ConfigurationError( + 'Some options are not allowed when loading a Parser: {}'.format( + set(kwargs) - _LOAD_ALLOWED_OPTIONS + ) + ) + options.update(kwargs) + self.options = LarkOptions.deserialize(options, memo) + self.rules = [Rule.deserialize(r, memo) for r in data['rules']] + self.source_path = '' + _validate_frontend_args(self.options.parser, self.options.lexer) + self.lexer_conf = self._deserialize_lexer_conf( + data['parser'], memo, self.options + ) + self.terminals = self.lexer_conf.terminals + self._prepare_callbacks() + self._terminals_dict = {t.name: t for t in self.terminals} + self.parser = _deserialize_parsing_frontend( + data['parser'], + memo, + self.lexer_conf, + self._callbacks, + self.options, ## + ) + return self + + @classmethod + def _load_from_dict(cls, data, memo, **kwargs): + inst = cls.__new__(cls) + return inst._load({'data': data, 'memo': memo}, **kwargs) + + def __repr__(self): + return 'Lark(open(%r), parser=%r, lexer=%r, ...)' % ( + self.source_path, + self.options.parser, + self.options.lexer, + ) + + def lex(self, text: str, dont_ignore: bool = False) -> Iterator[Token]: + # -- + lexer: Lexer + if not hasattr(self, 'lexer') or dont_ignore: + lexer = self._build_lexer(dont_ignore) + else: + lexer = self.lexer + lexer_thread = LexerThread.from_text(lexer, text) + stream = lexer_thread.lex(None) + if self.options.postlex: + return self.options.postlex.process(stream) + return stream + + def get_terminal(self, name: str) -> TerminalDef: + # -- + return self._terminals_dict[name] + + def parse( + self, + text: str, + start: Optional[str] = None, + on_error: 'Optional[Callable[[UnexpectedInput], bool]]' = None, + ): # -> 'ParseTree' + return self.parser.parse(text, start=start, on_error=on_error) + + +Shift = 0 +Reduce = 1 + + +def get_parser(data_and_memo: tuple[dict[str, Any], dict[int, Any]]) -> Lark: + """Construct a standalone LALR parser from a serialized Lark parser. + + Use `memo_serialize` to serialize a Lark parser: + ``` + import lark + p = lark.Lark(parser="larl", grammar=YOUR_LARK_GRAMMAR_AS_STRING) + data_and_memo = p.memo_serialize([lark.lexer.TerminalDef, + lark.grammar.Rule]) + ``` + + Args: + data_and_memo: The serialized Lark parser as returned by `memo_serialize`. + + Returns: + A standalone parser. + """ + return Lark._load_from_dict(*data_and_memo)