From f206ddc2b6b6c236ab36090c4e1b36ae6bd13968 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20R=C3=B8nningstad?= Date: Wed, 15 Jan 2025 13:11:34 +0100 Subject: [PATCH 1/3] Reformat all python files with black MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Modify the TestCodestyle test to use black instead of pycodestyle, to avoid disagreements in formatting between the two. Add black.sh for quickly formatting the codebase. Signed-off-by: Øyvind Rønningstad --- __init__.py | 6 +- scripts/add_helptext.py | 8 +- scripts/black.sh | 1 + scripts/regenerate_samples.py | 18 +- scripts/requirements-test.txt | 2 +- scripts/update_version.py | 26 +- tests/decode/test5_corner_cases/floats.py | 39 +- tests/scripts/test_repo_files.py | 110 +- tests/scripts/test_versions.py | 42 +- tests/scripts/test_zcbor.py | 814 +++++++++-- tests/unit/test3_float16/floats.py | 58 +- zcbor/zcbor.py | 1619 ++++++++++++++------- 12 files changed, 1897 insertions(+), 846 deletions(-) create mode 100755 scripts/black.sh diff --git a/__init__.py b/__init__.py index d5cc08cc..9dc6cbef 100644 --- a/__init__.py +++ b/__init__.py @@ -7,8 +7,4 @@ from pathlib import Path -from .zcbor.zcbor import ( - CddlValidationError, - DataTranslator, - main -) +from .zcbor.zcbor import CddlValidationError, DataTranslator, main diff --git a/scripts/add_helptext.py b/scripts/add_helptext.py index 81fa6d21..edb6f7df 100644 --- a/scripts/add_helptext.py +++ b/scripts/add_helptext.py @@ -11,7 +11,7 @@ from sys import argv p_root = Path(__file__).absolute().parents[1] -p_README = Path(p_root, 'README.md') +p_README = Path(p_root, "README.md") pattern = r""" Command line documentation @@ -42,13 +42,13 @@ ``` """ - with open(p_README, 'r', encoding="utf-8") as f: + with open(p_README, "r", encoding="utf-8") as f: readme_contents = f.read() - new_readme_contents = sub(pattern + r'.*', output, readme_contents, flags=S) + new_readme_contents = sub(pattern + r".*", output, readme_contents, flags=S) if len(argv) > 1 and argv[1] == "--check": if new_readme_contents != readme_contents: print("Check failed") exit(9) else: - with open(p_README, 'w', encoding="utf-8") as f: + with open(p_README, "w", encoding="utf-8") as f: f.write(new_readme_contents) diff --git a/scripts/black.sh b/scripts/black.sh new file mode 100755 index 00000000..2a94bd7b --- /dev/null +++ b/scripts/black.sh @@ -0,0 +1 @@ +black $(dirname "$0")/.. -l 100 diff --git a/scripts/regenerate_samples.py b/scripts/regenerate_samples.py index ed4eec22..5d4c4f3e 100644 --- a/scripts/regenerate_samples.py +++ b/scripts/regenerate_samples.py @@ -13,22 +13,24 @@ from tempfile import mkdtemp p_root = Path(__file__).absolute().parents[1] -p_build = p_root / 'build' -p_pet_sample = p_root / 'samples' / 'pet' -p_pet_cmake = p_pet_sample / 'pet.cmake' -p_pet_include = p_pet_sample / 'include' -p_pet_src = p_pet_sample / 'src' +p_build = p_root / "build" +p_pet_sample = p_root / "samples" / "pet" +p_pet_cmake = p_pet_sample / "pet.cmake" +p_pet_include = p_pet_sample / "include" +p_pet_src = p_pet_sample / "src" def regenerate(): tmpdir = Path(mkdtemp()) - run(['cmake', p_pet_sample, "-DREGENERATE_ZCBOR=Y", "-DCMAKE_MESSAGE_LOG_LEVEL=WARNING"], - cwd=tmpdir) + run( + ["cmake", p_pet_sample, "-DREGENERATE_ZCBOR=Y", "-DCMAKE_MESSAGE_LOG_LEVEL=WARNING"], + cwd=tmpdir, + ) rmtree(tmpdir) def check(): - files = (list(p_pet_include.iterdir()) + list(p_pet_src.iterdir()) + [p_pet_cmake]) + files = list(p_pet_include.iterdir()) + list(p_pet_src.iterdir()) + [p_pet_cmake] contents = "".join(p.read_text(encoding="utf-8") for p in files) tmpdir = Path(mkdtemp()) list(makedirs(tmpdir / f.relative_to(p_pet_sample).parent, exist_ok=True) for f in files) diff --git a/scripts/requirements-test.txt b/scripts/requirements-test.txt index 3fa33988..ae23e92a 100644 --- a/scripts/requirements-test.txt +++ b/scripts/requirements-test.txt @@ -1,4 +1,4 @@ pyelftools -pycodestyle +black west ecdsa diff --git a/scripts/update_version.py b/scripts/update_version.py index 48dd8b64..49c4dddf 100644 --- a/scripts/update_version.py +++ b/scripts/update_version.py @@ -10,10 +10,10 @@ from datetime import datetime p_root = Path(__file__).absolute().parents[1] -p_VERSION = Path(p_root, 'zcbor', 'VERSION') -p_RELEASE_NOTES = Path(p_root, 'RELEASE_NOTES.md') -p_MIGRATION_GUIDE = Path(p_root, 'MIGRATION_GUIDE.md') -p_common_h = Path(p_root, 'include', 'zcbor_common.h') +p_VERSION = Path(p_root, "zcbor", "VERSION") +p_RELEASE_NOTES = Path(p_root, "RELEASE_NOTES.md") +p_MIGRATION_GUIDE = Path(p_root, "MIGRATION_GUIDE.md") +p_common_h = Path(p_root, "include", "zcbor_common.h") RELEASE_NOTES_boilerplate = """ Any new bugs, requests, or missing features should be reported as [Github issues](https://github.com/NordicSemiconductor/zcbor/issues). @@ -31,23 +31,29 @@ def update_relnotes(p_relnotes, version, boilerplate="", include_date=True): new_date = f" ({datetime.today().strftime('%Y-%m-%d')})" if include_date else "" relnotes_new_header = f"# zcbor v. {version}{new_date}\n" if ".99" not in relnotes_lines[0]: - relnotes_contents = relnotes_new_header + boilerplate + '\n\n' + relnotes_contents + relnotes_contents = relnotes_new_header + boilerplate + "\n\n" + relnotes_contents relnotes_contents = sub(r".*?\n", relnotes_new_header, relnotes_contents, count=1) p_relnotes.write_text(relnotes_contents, encoding="utf-8") if __name__ == "__main__": - if len(argv) != 2 or match(r'\d+\.\d+\.\d+', argv[1]) is None: + if len(argv) != 2 or match(r"\d+\.\d+\.\d+", argv[1]) is None: print(f"Usage: {argv[0]} ") exit(1) version = argv[1] - (major, minor, bugfix) = version.split('.') + (major, minor, bugfix) = version.split(".") p_VERSION.write_text(version, encoding="utf-8") update_relnotes(p_RELEASE_NOTES, version, boilerplate=RELEASE_NOTES_boilerplate) update_relnotes(p_MIGRATION_GUIDE, version, include_date=False) p_common_h_contents = p_common_h.read_text(encoding="utf-8") - common_h_new_contents = sub(r"(#define ZCBOR_VERSION_MAJOR )\d+", f"\\g<1>{major}", p_common_h_contents) - common_h_new_contents = sub(r"(#define ZCBOR_VERSION_MINOR )\d+", f"\\g<1>{minor}", common_h_new_contents) - common_h_new_contents = sub(r"(#define ZCBOR_VERSION_BUGFIX )\d+", f"\\g<1>{bugfix}", common_h_new_contents) + common_h_new_contents = sub( + r"(#define ZCBOR_VERSION_MAJOR )\d+", f"\\g<1>{major}", p_common_h_contents + ) + common_h_new_contents = sub( + r"(#define ZCBOR_VERSION_MINOR )\d+", f"\\g<1>{minor}", common_h_new_contents + ) + common_h_new_contents = sub( + r"(#define ZCBOR_VERSION_BUGFIX )\d+", f"\\g<1>{bugfix}", common_h_new_contents + ) p_common_h.write_text(common_h_new_contents, encoding="utf-8") diff --git a/tests/decode/test5_corner_cases/floats.py b/tests/decode/test5_corner_cases/floats.py index a90ebe4e..76e57685 100644 --- a/tests/decode/test5_corner_cases/floats.py +++ b/tests/decode/test5_corner_cases/floats.py @@ -12,28 +12,35 @@ import math import cbor2 + def print_var(val1, val2, bytestr): - var_str = "" - for b in bytestr: - var_str += hex(b) + ", " - print(str(val1) + ":", val2, bytestr.hex(), var_str) - -def print_32_64(str_val, val = None): - val = val or float(str_val) - print_var(str_val, val, struct.pack("!e", numpy.float16(val))) - print_var(str_val, val, struct.pack("!f", struct.unpack("!e", struct.pack("!e", numpy.float16(val)))[0])) - print (numpy.float32(struct.unpack("!e", struct.pack("!e", numpy.float16(val)))[0])) - print_var(str_val, val, struct.pack("!f", val)) - print_var(str_val, val, struct.pack("!d", val)) - print_var(str_val, val, struct.pack("!d", struct.unpack("!f", struct.pack("!f", val))[0])) - print() + var_str = "" + for b in bytestr: + var_str += hex(b) + ", " + print(str(val1) + ":", val2, bytestr.hex(), var_str) + + +def print_32_64(str_val, val=None): + val = val or float(str_val) + print_var(str_val, val, struct.pack("!e", numpy.float16(val))) + print_var( + str_val, + val, + struct.pack("!f", struct.unpack("!e", struct.pack("!e", numpy.float16(val)))[0]), + ) + print(numpy.float32(struct.unpack("!e", struct.pack("!e", numpy.float16(val)))[0])) + print_var(str_val, val, struct.pack("!f", val)) + print_var(str_val, val, struct.pack("!d", val)) + print_var(str_val, val, struct.pack("!d", struct.unpack("!f", struct.pack("!f", val))[0])) + print() + print_32_64("3.1415") print_32_64("2.71828") print_32_64("1234567.89") print_32_64("-98765.4321") -print_32_64("123/456789", 123/456789) -print_32_64("-2^(-42)", -1/(2**(42))) +print_32_64("123/456789", 123 / 456789) +print_32_64("-2^(-42)", -1 / (2 ** (42))) print_32_64("1.0") print_32_64("-10000.0") print_32_64("0.0") diff --git a/tests/scripts/test_repo_files.py b/tests/scripts/test_repo_files.py index ae62a867..5a692d95 100644 --- a/tests/scripts/test_repo_files.py +++ b/tests/scripts/test_repo_files.py @@ -11,7 +11,6 @@ from urllib.error import HTTPError from argparse import ArgumentParser from subprocess import Popen, check_output, PIPE, run -from pycodestyle import StyleGuide from shutil import rmtree, copy2 from platform import python_version_tuple from sys import platform @@ -21,77 +20,63 @@ p_root = Path(__file__).absolute().parents[2] -p_tests = p_root / 'tests' +p_tests = p_root / "tests" p_readme = p_root / "README.md" p_pypi_readme = p_root / "pypi_README.md" p_architecture = p_root / "ARCHITECTURE.md" p_release_notes = p_root / "RELEASE_NOTES.md" -p_init_py = p_root / '__init__.py' -p_zcbor_py = p_root / 'zcbor' / 'zcbor.py' -p_add_helptext = p_root / 'scripts' / 'add_helptext.py' -p_regenerate_samples = p_root / 'scripts' / 'regenerate_samples.py' -p_test_zcbor_py = p_tests / 'scripts' / 'test_zcbor.py' -p_test_versions_py = p_tests / 'scripts' / 'test_versions.py' -p_test_repo_files_py = p_tests / 'scripts' / 'test_repo_files.py' -p_hello_world_sample = p_root / 'samples' / 'hello_world' -p_hello_world_build = p_hello_world_sample / 'build' -p_pet_sample = p_root / 'samples' / 'pet' -p_pet_cmake = p_pet_sample / 'pet.cmake' -p_pet_include = p_pet_sample / 'include' -p_pet_src = p_pet_sample / 'src' -p_pet_build = p_pet_sample / 'build' +p_add_helptext = p_root / "scripts" / "add_helptext.py" +p_regenerate_samples = p_root / "scripts" / "regenerate_samples.py" +p_hello_world_sample = p_root / "samples" / "hello_world" +p_hello_world_build = p_hello_world_sample / "build" +p_pet_sample = p_root / "samples" / "pet" +p_pet_cmake = p_pet_sample / "pet.cmake" +p_pet_include = p_pet_sample / "include" +p_pet_src = p_pet_sample / "src" +p_pet_build = p_pet_sample / "build" class TestCodestyle(TestCase): - def do_codestyle(self, files, **kwargs): - style = StyleGuide(max_line_length=100, **kwargs) - result = style.check_files([str(f) for f in files]) - result.print_statistics() - self.assertEqual(result.total_errors, 0, - f"Found {result.total_errors} style errors") - def test_codestyle(self): - """Run codestyle tests on all Python scripts in the repo.""" - self.do_codestyle([p_init_py, p_test_versions_py, p_test_repo_files_py, p_add_helptext, - p_regenerate_samples]) - self.do_codestyle([p_zcbor_py], ignore=['W191', 'E101', 'W503']) - self.do_codestyle([p_test_zcbor_py], ignore=['E402', 'E501', 'W503']) + black_res = Popen(["black", "--check", p_root, "-l", "100"], stdout=PIPE, stderr=PIPE) + _, stderr = black_res.communicate() + self.assertEqual(0, black_res.returncode, "black failed:\n" + stderr.decode("utf-8")) def version_int(in_str): - return int(search(r'\A\d+', in_str)[0]) # e.g. '0rc' -> '0' + return int(search(r"\A\d+", in_str)[0]) # e.g. '0rc' -> '0' class TestSamples(TestCase): - def popen_test(self, args, input='', exp_retcode=0, **kwargs): + def popen_test(self, args, input="", exp_retcode=0, **kwargs): call0 = Popen(args, stdin=PIPE, stdout=PIPE, stderr=PIPE, **kwargs) stdout0, stderr0 = call0.communicate(input) - self.assertEqual(exp_retcode, call0.returncode, stderr0.decode('utf-8')) + self.assertEqual(exp_retcode, call0.returncode, stderr0.decode("utf-8")) return stdout0, stderr0 def cmake_build_run(self, path, build_path): if build_path.exists(): rmtree(build_path) - with open(path / 'README.md', 'r', encoding="utf-8") as f: + with open(path / "README.md", "r", encoding="utf-8") as f: contents = f.read() - to_build_patt = r'### To build:.*?```(?P.*?)```' - to_run_patt = r'### To run:.*?```(?P.*?)```' - exp_out_patt = r'### Expected output:.*?(?P(\n>[^\n]*)+)' - to_build = search(to_build_patt, contents, flags=S)['to_build'].strip() - to_run = search(to_run_patt, contents, flags=S)['to_run'].strip() - exp_out = search(exp_out_patt, contents, flags=S)['exp_out'].replace("\n> ", "\n").strip() + to_build_patt = r"### To build:.*?```(?P.*?)```" + to_run_patt = r"### To run:.*?```(?P.*?)```" + exp_out_patt = r"### Expected output:.*?(?P(\n>[^\n]*)+)" + to_build = search(to_build_patt, contents, flags=S)["to_build"].strip() + to_run = search(to_run_patt, contents, flags=S)["to_run"].strip() + exp_out = search(exp_out_patt, contents, flags=S)["exp_out"].replace("\n> ", "\n").strip() os.chdir(path) - commands_build = [(line.split(' ')) for line in to_build.split('\n')] - assert '\n' not in to_run, "The 'to run' section should only have one command." - commands_run = to_run.split(' ') + commands_build = [(line.split(" ")) for line in to_build.split("\n")] + assert "\n" not in to_run, "The 'to run' section should only have one command." + commands_run = to_run.split(" ") for c in commands_build: self.popen_test(c) output_run = "" for c in commands_run: output, _ = self.popen_test(c) - output_run += output.decode('utf-8') + output_run += output.decode("utf-8") self.assertEqual(exp_out, output_run.strip()) @skipIf(platform.startswith("win"), "Skip on Windows because requires a Unix shell.") @@ -109,17 +94,16 @@ def test_pet_regenerate(self): self.assertEqual(0, regenerate.returncode) def test_pet_file_header(self): - files = (list(p_pet_include.iterdir()) + list(p_pet_src.iterdir()) + [p_pet_cmake]) + files = list(p_pet_include.iterdir()) + list(p_pet_src.iterdir()) + [p_pet_cmake] for p in [f for f in files if "pet" in f.name]: - with p.open('r', encoding="utf-8") as f: + with p.open("r", encoding="utf-8") as f: f.readline() # discard self.assertEqual( f.readline().strip(" *#\n"), - "Copyright (c) 2022 Nordic Semiconductor ASA") + "Copyright (c) 2022 Nordic Semiconductor ASA", + ) f.readline() # discard - self.assertEqual( - f.readline().strip(" *#\n"), - "SPDX-License-Identifier: Apache-2.0") + self.assertEqual(f.readline().strip(" *#\n"), "SPDX-License-Identifier: Apache-2.0") f.readline() # discard self.assertIn("Generated using zcbor version", f.readline()) self.assertIn("https://github.com/NordicSemiconductor/zcbor", f.readline()) @@ -130,15 +114,21 @@ class TestDocs(TestCase): def __init__(self, *args, **kwargs): """Overridden to get base URL for relative links from remote tracking branch.""" super(TestDocs, self).__init__(*args, **kwargs) - remote_tr_args = ['git', 'rev-parse', '--abbrev-ref', '--symbolic-full-name', '@{u}'] - remote_tracking = run(remote_tr_args, capture_output=True).stdout.decode('utf-8').strip() + remote_tr_args = [ + "git", + "rev-parse", + "--abbrev-ref", + "--symbolic-full-name", + "@{u}", + ] + remote_tracking = run(remote_tr_args, capture_output=True).stdout.decode("utf-8").strip() if remote_tracking: - remote, remote_branch = remote_tracking.split('/', 1) # '1' to only split one time. - repo_url_args = ['git', 'remote', 'get-url', remote] - repo_url = check_output(repo_url_args).decode('utf-8').strip().strip('.git') - if 'github.com' in repo_url: - self.base_url = (repo_url + '/tree/' + remote_branch + '/') + remote, remote_branch = remote_tracking.split("/", 1) # '1' to only split one time. + repo_url_args = ["git", "remote", "get-url", remote] + repo_url = check_output(repo_url_args).decode("utf-8").strip().strip(".git") + if "github.com" in repo_url: + self.base_url = repo_url + "/tree/" + remote_branch + "/" else: # The URL is not in github.com, so we are not sure it is constructed correctly. self.base_url = None @@ -150,7 +140,7 @@ def __init__(self, *args, **kwargs): # There is no remote tracking branch. self.base_url = None - self.link_regex = compile(r'\[.*?\]\((?P.*?)\)') + self.link_regex = compile(r"\[.*?\]\((?P.*?)\)") def check_code(self, link, codes): """Check the status code of a URL link. Assert if not 200 (OK).""" @@ -164,7 +154,7 @@ def check_code(self, link, codes): def do_test_links(self, path, allow_local=True): """Get all Markdown links in the file at and check that they work.""" if allow_local and self.base_url is None: - raise SkipTest('This test requires the current branch to be pushed to Github.') + raise SkipTest("This test requires the current branch to be pushed to Github.") text = path.read_text(encoding="utf-8") @@ -211,8 +201,10 @@ def test_pet_readme(self): def test_pypi_readme(self): self.do_test_links(p_pypi_readme, allow_local=False) - @skipIf(list(map(version_int, python_version_tuple())) < [3, 10, 0], - "Skip on Python < 3.10 because of different wording in argparse output.") + @skipIf( + list(map(version_int, python_version_tuple())) < [3, 10, 0], + "Skip on Python < 3.10 because of different wording in argparse output.", + ) @skipIf(platform.startswith("win"), "Skip on Windows because of path/newline issues.") def test_cli_doc(self): """Check the auto-generated CLI docs in the top level README.md file.""" diff --git a/tests/scripts/test_versions.py b/tests/scripts/test_versions.py index 3ecfd6a6..c6440abb 100644 --- a/tests/scripts/test_versions.py +++ b/tests/scripts/test_versions.py @@ -22,38 +22,52 @@ class VersionTest(TestCase): def test_version_num(self): """For release branches - Test that all version numbers have been updated.""" - current_branch = Popen(['git', 'branch', '--show-current'], - stdout=PIPE).communicate()[0].decode("utf-8").strip() + current_branch = ( + Popen(["git", "branch", "--show-current"], stdout=PIPE) + .communicate()[0] + .decode("utf-8") + .strip() + ) if not current_branch: current_branch = p_HEAD_REF.read_text(encoding="utf-8") self.assertRegex( - current_branch, r"release/\d+\.\d+\.\d+", - "This test is meant to be run on a release branch on the form 'release/x.y.z'.") + current_branch, + r"release/\d+\.\d+\.\d+", + "This test is meant to be run on a release branch on the form 'release/x.y.z'.", + ) version_number = current_branch.replace("release/", "") version_number_no_bugfix = ".".join(version_number.split(".")[:-1]) self.assertRegex( - version_number, r'\d+\.\d+\.(?!99)\d+', - "Releases cannot have the x.y.99 development bugfix release number.") + version_number, + r"\d+\.\d+\.(?!99)\d+", + "Releases cannot have the x.y.99 development bugfix release number.", + ) self.assertEqual( - version_number, p_VERSION.read_text(encoding="utf-8"), - f"{p_VERSION} has not been updated to the correct version number.") + version_number, + p_VERSION.read_text(encoding="utf-8"), + f"{p_VERSION} has not been updated to the correct version number.", + ) tomorrow = date.today() + timedelta(days=1) self.assertRegex( p_release_notes.read_text(encoding="utf-8").splitlines()[0], escape(r"# zcbor v. " + version_number) - + fr" \(({date.today():%Y-%m-%d}|{tomorrow:%Y-%m-%d})\)", - f"{p_release_notes} has not been updated with the correct version number or date.") + + rf" \(({date.today():%Y-%m-%d}|{tomorrow:%Y-%m-%d})\)", + f"{p_release_notes} has not been updated with the correct version number or date.", + ) self.assertRegex( p_migration_guide.read_text(encoding="utf-8").splitlines()[0], escape(r"# zcbor v. " + version_number_no_bugfix) + r"\.\d", - f"{p_migration_guide} has not been updated with the correct minor/major version num.") + f"{p_migration_guide} has not been updated with the correct minor/major version num.", + ) - tags_stdout, _ = Popen(['git', 'tag'], stdout=PIPE).communicate() + tags_stdout, _ = Popen(["git", "tag"], stdout=PIPE).communicate() tags = tags_stdout.decode("utf-8").strip().splitlines() self.assertNotIn( - version_number, tags, - "Version number already exists as a tag. Has the version number been updated?") + version_number, + tags, + "Version number already exists as a tag. Has the version number been updated?", + ) if __name__ == "__main__": diff --git a/tests/scripts/test_zcbor.py b/tests/scripts/test_zcbor.py index 4cc4433d..f1490933 100644 --- a/tests/scripts/test_zcbor.py +++ b/tests/scripts/test_zcbor.py @@ -23,42 +23,47 @@ try: import zcbor except ImportError: - print(""" + print( + """ The zcbor package must be installed to run these tests. During development, install with `pip3 install -e .` to install in a way that picks up changes in the files without having to reinstall. -""") +""" + ) exit(1) p_root = Path(__file__).absolute().parents[2] -p_tests = Path(p_root, 'tests') -p_cases = Path(p_tests, 'cases') -p_manifest12 = Path(p_cases, 'manifest12.cddl') -p_manifest14 = Path(p_cases, 'manifest14.cddl') -p_manifest16 = Path(p_cases, 'manifest16.cddl') -p_manifest20 = Path(p_cases, 'manifest20.cddl') -p_test_vectors12 = tuple(Path(p_cases, f'manifest12_example{i}.cborhex') for i in range(6)) -p_test_vectors14 = tuple(Path(p_cases, f'manifest14_example{i}.cborhex') for i in range(6)) -p_test_vectors16 = tuple(Path(p_cases, f'manifest14_example{i}.cborhex') for i in range(6)) # Identical to manifest14. -p_test_vectors20 = tuple(Path(p_cases, f'manifest20_example{i}.cborhex') for i in range(6)) -p_optional = Path(p_cases, 'optional.cddl') -p_corner_cases = Path(p_cases, 'corner_cases.cddl') -p_cose = Path(p_cases, 'cose.cddl') -p_manifest14_priv = Path(p_cases, 'manifest14.priv') -p_manifest14_pub = Path(p_cases, 'manifest14.pub') -p_map_bstr_cddl = Path(p_cases, 'map_bstr.cddl') -p_map_bstr_yaml = Path(p_cases, 'map_bstr.yaml') -p_yaml_compat_cddl = Path(p_cases, 'yaml_compatibility.cddl') -p_yaml_compat_yaml = Path(p_cases, 'yaml_compatibility.yaml') -p_pet_cddl = Path(p_cases, 'pet.cddl') -p_README = Path(p_root, 'README.md') -p_prelude = Path(p_root, 'zcbor', 'prelude.cddl') -p_VERSION = Path(p_root, 'zcbor', 'VERSION') +p_tests = Path(p_root, "tests") +p_cases = Path(p_tests, "cases") +p_manifest12 = Path(p_cases, "manifest12.cddl") +p_manifest14 = Path(p_cases, "manifest14.cddl") +p_manifest16 = Path(p_cases, "manifest16.cddl") +p_manifest20 = Path(p_cases, "manifest20.cddl") +p_test_vectors12 = tuple(Path(p_cases, f"manifest12_example{i}.cborhex") for i in range(6)) +p_test_vectors14 = tuple(Path(p_cases, f"manifest14_example{i}.cborhex") for i in range(6)) +p_test_vectors16 = tuple( + Path(p_cases, f"manifest14_example{i}.cborhex") for i in range(6) +) # Identical to manifest14. +p_test_vectors20 = tuple(Path(p_cases, f"manifest20_example{i}.cborhex") for i in range(6)) +p_optional = Path(p_cases, "optional.cddl") +p_corner_cases = Path(p_cases, "corner_cases.cddl") +p_cose = Path(p_cases, "cose.cddl") +p_manifest14_priv = Path(p_cases, "manifest14.priv") +p_manifest14_pub = Path(p_cases, "manifest14.pub") +p_map_bstr_cddl = Path(p_cases, "map_bstr.cddl") +p_map_bstr_yaml = Path(p_cases, "map_bstr.yaml") +p_yaml_compat_cddl = Path(p_cases, "yaml_compatibility.cddl") +p_yaml_compat_yaml = Path(p_cases, "yaml_compatibility.yaml") +p_pet_cddl = Path(p_cases, "pet.cddl") +p_README = Path(p_root, "README.md") +p_prelude = Path(p_root, "zcbor", "prelude.cddl") +p_VERSION = Path(p_root, "zcbor", "VERSION") class TestManifest(TestCase): """Class for testing examples against CDDL for various versions of the SUIT manifest spec.""" + def decode_file(self, data_path, *cddl_paths): data = bytes.fromhex(data_path.read_text(encoding="utf-8").replace("\n", "")) self.decode_string(data, *cddl_paths) @@ -78,29 +83,53 @@ def __init__(self, *args, **kwargs): def test_manifest_digest(self): self.assertEqual( bytes.fromhex("5c097ef64bf3bb9b494e71e1f2418eef8d466cc902f639a855ec9af3e9eddb99"), - self.decoded.suit_authentication_wrapper.SUIT_Digest_bstr.suit_digest_bytes) + self.decoded.suit_authentication_wrapper.SUIT_Digest_bstr.suit_digest_bytes, + ) def test_signature(self): self.assertEqual( 1, - self.decoded.suit_authentication_wrapper.SUIT_Authentication_Block_bstr[0].COSE_Sign1_Tagged_m.protected.uintint[0].uintint_key) + self.decoded.suit_authentication_wrapper.SUIT_Authentication_Block_bstr[0] + .COSE_Sign1_Tagged_m.protected.uintint[0] + .uintint_key, + ) self.assertEqual( -7, - self.decoded.suit_authentication_wrapper.SUIT_Authentication_Block_bstr[0].COSE_Sign1_Tagged_m.protected.uintint[0].uintint) + self.decoded.suit_authentication_wrapper.SUIT_Authentication_Block_bstr[0] + .COSE_Sign1_Tagged_m.protected.uintint[0] + .uintint, + ) self.assertEqual( - bytes.fromhex("a19fd1f23b17beed321cece7423dfb48c457b8f1f6ac83577a3c10c6773f6f3a7902376b59540920b6c5f57bac5fc8543d8f5d3d974faa2e6d03daa534b443a7"), - self.decoded.suit_authentication_wrapper.SUIT_Authentication_Block_bstr[0].COSE_Sign1_Tagged_m.signature) + bytes.fromhex( + "a19fd1f23b17beed321cece7423dfb48c457b8f1f6ac83577a3c10c6773f6f3a7902376b59540920b6c5f57bac5fc8543d8f5d3d974faa2e6d03daa534b443a7" + ), + self.decoded.suit_authentication_wrapper.SUIT_Authentication_Block_bstr[ + 0 + ].COSE_Sign1_Tagged_m.signature, + ) def test_validate_run(self): self.assertEqual( "suit_condition_image_match_m_l", - self.decoded.suit_manifest.SUIT_Unseverable_Members.suit_validate[0].suit_validate.union[0].SUIT_Condition_m.union_choice) + self.decoded.suit_manifest.SUIT_Unseverable_Members.suit_validate[0] + .suit_validate.union[0] + .SUIT_Condition_m.union_choice, + ) self.assertEqual( "suit_directive_run_m_l", - self.decoded.suit_manifest.SUIT_Unseverable_Members.suit_run[0].suit_run.union[0].SUIT_Directive_m.union_choice) + self.decoded.suit_manifest.SUIT_Unseverable_Members.suit_run[0] + .suit_run.union[0] + .SUIT_Directive_m.union_choice, + ) def test_image_size(self): - self.assertEqual(34768, self.decoded.suit_manifest.suit_common.suit_common_sequence[0].suit_common_sequence.union[0].SUIT_Common_Commands_m.suit_directive_override_parameters_m_l.map[3].suit_parameter_image_size) + self.assertEqual( + 34768, + self.decoded.suit_manifest.suit_common.suit_common_sequence[0] + .suit_common_sequence.union[0] + .SUIT_Common_Commands_m.suit_directive_override_parameters_m_l.map[3] + .suit_parameter_image_size, + ) class TestEx0InvManifest12(TestManifest): @@ -119,13 +148,17 @@ def __init__(self, *args, **kwargs): def test_components(self): self.assertEqual( - [b'\x00'], - self.decoded.suit_manifest.suit_common.suit_components[0][0].bstr) + [b"\x00"], self.decoded.suit_manifest.suit_common.suit_components[0][0].bstr + ) def test_uri(self): self.assertEqual( "http://example.com/file.bin", - self.decoded.suit_manifest.SUIT_Severable_Manifest_Members.suit_install[0].suit_install.union[0].SUIT_Directive_m.suit_directive_set_parameters_m_l.map[0].suit_parameter_uri) + self.decoded.suit_manifest.SUIT_Severable_Manifest_Members.suit_install[0] + .suit_install.union[0] + .SUIT_Directive_m.suit_directive_set_parameters_m_l.map[0] + .suit_parameter_uri, + ) class TestEx2Manifest12(TestManifest): @@ -136,21 +169,37 @@ def __init__(self, *args, **kwargs): def test_severed_uri(self): self.assertEqual( "http://example.com/very/long/path/to/file/file.bin", - self.decoded.SUIT_Severable_Manifest_Members.suit_install[0].suit_install.union[0].SUIT_Directive_m.suit_directive_set_parameters_m_l.map[0].suit_parameter_uri) + self.decoded.SUIT_Severable_Manifest_Members.suit_install[0] + .suit_install.union[0] + .SUIT_Directive_m.suit_directive_set_parameters_m_l.map[0] + .suit_parameter_uri, + ) def test_severed_text(self): self.assertIn( "Example 2", - self.decoded.SUIT_Severable_Manifest_Members.suit_text[0].suit_text.SUIT_Text_Keys.suit_text_manifest_description[0]) + self.decoded.SUIT_Severable_Manifest_Members.suit_text[ + 0 + ].suit_text.SUIT_Text_Keys.suit_text_manifest_description[0], + ) self.assertEqual( - [b'\x00'], - self.decoded.SUIT_Severable_Manifest_Members.suit_text[0].suit_text.SUIT_Component_Identifier[0].SUIT_Component_Identifier_key.bstr) + [b"\x00"], + self.decoded.SUIT_Severable_Manifest_Members.suit_text[0] + .suit_text.SUIT_Component_Identifier[0] + .SUIT_Component_Identifier_key.bstr, + ) self.assertEqual( "arm.com", - self.decoded.SUIT_Severable_Manifest_Members.suit_text[0].suit_text.SUIT_Component_Identifier[0].SUIT_Component_Identifier.SUIT_Text_Component_Keys.suit_text_vendor_domain[0]) + self.decoded.SUIT_Severable_Manifest_Members.suit_text[0] + .suit_text.SUIT_Component_Identifier[0] + .SUIT_Component_Identifier.SUIT_Text_Component_Keys.suit_text_vendor_domain[0], + ) self.assertEqual( "This component is a demonstration. The digest is a sample pattern, not a real one.", - self.decoded.SUIT_Severable_Manifest_Members.suit_text[0].suit_text.SUIT_Component_Identifier[0].SUIT_Component_Identifier.SUIT_Text_Component_Keys.suit_text_component_description[0]) + self.decoded.SUIT_Severable_Manifest_Members.suit_text[0] + .suit_text.SUIT_Component_Identifier[0] + .SUIT_Component_Identifier.SUIT_Text_Component_Keys.suit_text_component_description[0], + ) class TestEx3Manifest12(TestManifest): @@ -161,10 +210,26 @@ def __init__(self, *args, **kwargs): def test_A_B_offset(self): self.assertEqual( 33792, - self.decoded.suit_manifest.suit_common.suit_common_sequence[0].suit_common_sequence.union[1].SUIT_Common_Commands_m.suit_directive_try_each_m_l.SUIT_Directive_Try_Each_Argument_m.SUIT_Command_Sequence_bstr[0].union[0].SUIT_Directive_m.suit_directive_override_parameters_m_l.map[0].suit_parameter_component_offset) + self.decoded.suit_manifest.suit_common.suit_common_sequence[0] + .suit_common_sequence.union[1] + .SUIT_Common_Commands_m.suit_directive_try_each_m_l.SUIT_Directive_Try_Each_Argument_m.SUIT_Command_Sequence_bstr[ + 0 + ] + .union[0] + .SUIT_Directive_m.suit_directive_override_parameters_m_l.map[0] + .suit_parameter_component_offset, + ) self.assertEqual( 541696, - self.decoded.suit_manifest.suit_common.suit_common_sequence[0].suit_common_sequence.union[1].SUIT_Common_Commands_m.suit_directive_try_each_m_l.SUIT_Directive_Try_Each_Argument_m.SUIT_Command_Sequence_bstr[1].union[0].SUIT_Directive_m.suit_directive_override_parameters_m_l.map[0].suit_parameter_component_offset) + self.decoded.suit_manifest.suit_common.suit_common_sequence[0] + .suit_common_sequence.union[1] + .SUIT_Common_Commands_m.suit_directive_try_each_m_l.SUIT_Directive_Try_Each_Argument_m.SUIT_Command_Sequence_bstr[ + 1 + ] + .union[0] + .SUIT_Directive_m.suit_directive_override_parameters_m_l.map[0] + .suit_parameter_component_offset, + ) class TestEx4Manifest12(TestManifest): @@ -175,10 +240,18 @@ def __init__(self, *args, **kwargs): def test_load_decompress(self): self.assertEqual( 0, - self.decoded.suit_manifest.SUIT_Unseverable_Members.suit_load[0].suit_load.union[1].SUIT_Directive_m.suit_directive_set_parameters_m_l.map[3].suit_parameter_source_component) + self.decoded.suit_manifest.SUIT_Unseverable_Members.suit_load[0] + .suit_load.union[1] + .SUIT_Directive_m.suit_directive_set_parameters_m_l.map[3] + .suit_parameter_source_component, + ) self.assertEqual( "SUIT_Compression_Algorithm_zlib_m", - self.decoded.suit_manifest.SUIT_Unseverable_Members.suit_load[0].suit_load.union[1].SUIT_Directive_m.suit_directive_set_parameters_m_l.map[2].suit_parameter_compression_info.suit_compression_algorithm) + self.decoded.suit_manifest.SUIT_Unseverable_Members.suit_load[0] + .suit_load.union[1] + .SUIT_Directive_m.suit_directive_set_parameters_m_l.map[2] + .suit_parameter_compression_info.suit_compression_algorithm, + ) class TestEx5Manifest12(TestManifest): @@ -189,10 +262,16 @@ def __init__(self, *args, **kwargs): def test_two_image_match(self): self.assertEqual( "suit_condition_image_match_m_l", - self.decoded.suit_manifest.SUIT_Severable_Manifest_Members.suit_install[0].suit_install.union[3].SUIT_Condition_m.union_choice) + self.decoded.suit_manifest.SUIT_Severable_Manifest_Members.suit_install[0] + .suit_install.union[3] + .SUIT_Condition_m.union_choice, + ) self.assertEqual( "suit_condition_image_match_m_l", - self.decoded.suit_manifest.SUIT_Severable_Manifest_Members.suit_install[0].suit_install.union[7].SUIT_Condition_m.union_choice) + self.decoded.suit_manifest.SUIT_Severable_Manifest_Members.suit_install[0] + .suit_install.union[7] + .SUIT_Condition_m.union_choice, + ) def dumps(obj): @@ -209,14 +288,28 @@ def __init__(self, *args, **kwargs): self.key = VerifyingKey.from_pem(p_manifest14_pub.read_text(encoding="utf-8")) def do_test_authentication(self): - self.assertEqual("COSE_Sign1_Tagged_m", self.decoded.suit_authentication_wrapper.SUIT_Authentication_Block_bstr[0].union_choice) - self.assertEqual(-7, self.decoded.suit_authentication_wrapper.SUIT_Authentication_Block_bstr[0].COSE_Sign1_Tagged_m.Headers_m.protected.header_map_bstr.Generic_Headers.uint1union[0].int) - - manifest_signature = self.decoded.suit_authentication_wrapper.SUIT_Authentication_Block_bstr[0].COSE_Sign1_Tagged_m.signature - signature_header = self.decoded.suit_authentication_wrapper.SUIT_Authentication_Block_bstr[0].COSE_Sign1_Tagged_m.Headers_m.protected.header_map_bstr_bstr + self.assertEqual( + "COSE_Sign1_Tagged_m", + self.decoded.suit_authentication_wrapper.SUIT_Authentication_Block_bstr[0].union_choice, + ) + self.assertEqual( + -7, + self.decoded.suit_authentication_wrapper.SUIT_Authentication_Block_bstr[0] + .COSE_Sign1_Tagged_m.Headers_m.protected.header_map_bstr.Generic_Headers.uint1union[0] + .int, + ) + + manifest_signature = ( + self.decoded.suit_authentication_wrapper.SUIT_Authentication_Block_bstr[ + 0 + ].COSE_Sign1_Tagged_m.signature + ) + signature_header = self.decoded.suit_authentication_wrapper.SUIT_Authentication_Block_bstr[ + 0 + ].COSE_Sign1_Tagged_m.Headers_m.protected.header_map_bstr_bstr manifest_suit_digest = self.decoded.suit_authentication_wrapper.SUIT_Digest_bstr_bstr - sig_struct = dumps(["Signature1", signature_header, b'', manifest_suit_digest]) + sig_struct = dumps(["Signature1", signature_header, b"", manifest_suit_digest]) self.key.verify(manifest_signature, sig_struct, hashfunc=sha256) @@ -249,27 +342,129 @@ class TestEx1Manifest14(TestManifest): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.decode_file(p_test_vectors14[1], p_manifest14, p_cose) - self.manifest_digest = bytes.fromhex("60c61d6eb7a1aaeddc49ce8157a55cff0821537eeee77a4ded44155b03045132") + self.manifest_digest = bytes.fromhex( + "60c61d6eb7a1aaeddc49ce8157a55cff0821537eeee77a4ded44155b03045132" + ) def test_structure(self): - self.assertEqual("COSE_Sign1_Tagged_m", self.decoded.suit_authentication_wrapper.SUIT_Authentication_Block_bstr[0].union_choice) - self.assertEqual(-7, self.decoded.suit_authentication_wrapper.SUIT_Authentication_Block_bstr[0].COSE_Sign1_Tagged_m.Headers_m.protected.header_map_bstr.Generic_Headers.uint1union[0].int) - self.assertEqual(self.manifest_digest, self.decoded.suit_authentication_wrapper.SUIT_Digest_bstr.suit_digest_bytes) + self.assertEqual( + "COSE_Sign1_Tagged_m", + self.decoded.suit_authentication_wrapper.SUIT_Authentication_Block_bstr[0].union_choice, + ) + self.assertEqual( + -7, + self.decoded.suit_authentication_wrapper.SUIT_Authentication_Block_bstr[0] + .COSE_Sign1_Tagged_m.Headers_m.protected.header_map_bstr.Generic_Headers.uint1union[0] + .int, + ) + self.assertEqual( + self.manifest_digest, + self.decoded.suit_authentication_wrapper.SUIT_Digest_bstr.suit_digest_bytes, + ) self.assertEqual(1, self.decoded.suit_manifest.suit_manifest_sequence_number) - self.assertEqual(bytes.fromhex("fa6b4a53d5ad5fdfbe9de663e4d41ffe"), self.decoded.suit_manifest.suit_common.suit_common_sequence[0].suit_common_sequence.union[0].SUIT_Common_Commands_m.suit_directive_override_parameters_m_l.map[0].suit_parameter_vendor_identifier.RFC4122_UUID_m) - self.assertEqual(bytes.fromhex("1492af1425695e48bf429b2d51f2ab45"), self.decoded.suit_manifest.suit_common.suit_common_sequence[0].suit_common_sequence.union[0].SUIT_Common_Commands_m.suit_directive_override_parameters_m_l.map[1].suit_parameter_class_identifier) - self.assertEqual(bytes.fromhex("00112233445566778899aabbccddeeff0123456789abcdeffedcba9876543210"), self.decoded.suit_manifest.suit_common.suit_common_sequence[0].suit_common_sequence.union[0].SUIT_Common_Commands_m.suit_directive_override_parameters_m_l.map[2].suit_parameter_image_digest.suit_digest_bytes) - self.assertEqual('cose_alg_sha_256_m', self.decoded.suit_manifest.suit_common.suit_common_sequence[0].suit_common_sequence.union[0].SUIT_Common_Commands_m.suit_directive_override_parameters_m_l.map[2].suit_parameter_image_digest.suit_digest_algorithm_id.union_choice) - self.assertEqual(34768, self.decoded.suit_manifest.suit_common.suit_common_sequence[0].suit_common_sequence.union[0].SUIT_Common_Commands_m.suit_directive_override_parameters_m_l.map[3].suit_parameter_image_size) - self.assertEqual(4, len(self.decoded.suit_manifest.suit_common.suit_common_sequence[0].suit_common_sequence.union[0].SUIT_Common_Commands_m.suit_directive_override_parameters_m_l.map)) - self.assertEqual(15, self.decoded.suit_manifest.suit_common.suit_common_sequence[0].suit_common_sequence.union[1].SUIT_Condition_m.suit_condition_vendor_identifier_m_l.SUIT_Rep_Policy_m) - self.assertEqual(15, self.decoded.suit_manifest.suit_common.suit_common_sequence[0].suit_common_sequence.union[2].SUIT_Condition_m.suit_condition_class_identifier_m_l.SUIT_Rep_Policy_m) - self.assertEqual(3, len(self.decoded.suit_manifest.suit_common.suit_common_sequence[0].suit_common_sequence.union)) - self.assertEqual(2, len(self.decoded.suit_manifest.suit_common.suit_common_sequence[0].suit_common_sequence.union[0])) - self.assertEqual(2, len(self.decoded.suit_manifest.suit_common.suit_common_sequence[0].suit_common_sequence.union[0].SUIT_Common_Commands_m)) - self.assertEqual(1, len(self.decoded.suit_manifest.suit_common.suit_common_sequence[0].suit_common_sequence.union[0].SUIT_Common_Commands_m.suit_directive_override_parameters_m_l)) - self.assertEqual(4, len(self.decoded.suit_manifest.suit_common.suit_common_sequence[0].suit_common_sequence.union[0].SUIT_Common_Commands_m.suit_directive_override_parameters_m_l.map)) - self.assertEqual(2, len(self.decoded.suit_manifest.suit_common.suit_common_sequence[0].suit_common_sequence.union[0].SUIT_Common_Commands_m.suit_directive_override_parameters_m_l.map[0])) + self.assertEqual( + bytes.fromhex("fa6b4a53d5ad5fdfbe9de663e4d41ffe"), + self.decoded.suit_manifest.suit_common.suit_common_sequence[0] + .suit_common_sequence.union[0] + .SUIT_Common_Commands_m.suit_directive_override_parameters_m_l.map[0] + .suit_parameter_vendor_identifier.RFC4122_UUID_m, + ) + self.assertEqual( + bytes.fromhex("1492af1425695e48bf429b2d51f2ab45"), + self.decoded.suit_manifest.suit_common.suit_common_sequence[0] + .suit_common_sequence.union[0] + .SUIT_Common_Commands_m.suit_directive_override_parameters_m_l.map[1] + .suit_parameter_class_identifier, + ) + self.assertEqual( + bytes.fromhex("00112233445566778899aabbccddeeff0123456789abcdeffedcba9876543210"), + self.decoded.suit_manifest.suit_common.suit_common_sequence[0] + .suit_common_sequence.union[0] + .SUIT_Common_Commands_m.suit_directive_override_parameters_m_l.map[2] + .suit_parameter_image_digest.suit_digest_bytes, + ) + self.assertEqual( + "cose_alg_sha_256_m", + self.decoded.suit_manifest.suit_common.suit_common_sequence[0] + .suit_common_sequence.union[0] + .SUIT_Common_Commands_m.suit_directive_override_parameters_m_l.map[2] + .suit_parameter_image_digest.suit_digest_algorithm_id.union_choice, + ) + self.assertEqual( + 34768, + self.decoded.suit_manifest.suit_common.suit_common_sequence[0] + .suit_common_sequence.union[0] + .SUIT_Common_Commands_m.suit_directive_override_parameters_m_l.map[3] + .suit_parameter_image_size, + ) + self.assertEqual( + 4, + len( + self.decoded.suit_manifest.suit_common.suit_common_sequence[0] + .suit_common_sequence.union[0] + .SUIT_Common_Commands_m.suit_directive_override_parameters_m_l.map + ), + ) + self.assertEqual( + 15, + self.decoded.suit_manifest.suit_common.suit_common_sequence[0] + .suit_common_sequence.union[1] + .SUIT_Condition_m.suit_condition_vendor_identifier_m_l.SUIT_Rep_Policy_m, + ) + self.assertEqual( + 15, + self.decoded.suit_manifest.suit_common.suit_common_sequence[0] + .suit_common_sequence.union[2] + .SUIT_Condition_m.suit_condition_class_identifier_m_l.SUIT_Rep_Policy_m, + ) + self.assertEqual( + 3, + len( + self.decoded.suit_manifest.suit_common.suit_common_sequence[ + 0 + ].suit_common_sequence.union + ), + ) + self.assertEqual( + 2, + len( + self.decoded.suit_manifest.suit_common.suit_common_sequence[ + 0 + ].suit_common_sequence.union[0] + ), + ) + self.assertEqual( + 2, + len( + self.decoded.suit_manifest.suit_common.suit_common_sequence[0] + .suit_common_sequence.union[0] + .SUIT_Common_Commands_m + ), + ) + self.assertEqual( + 1, + len( + self.decoded.suit_manifest.suit_common.suit_common_sequence[0] + .suit_common_sequence.union[0] + .SUIT_Common_Commands_m.suit_directive_override_parameters_m_l + ), + ) + self.assertEqual( + 4, + len( + self.decoded.suit_manifest.suit_common.suit_common_sequence[0] + .suit_common_sequence.union[0] + .SUIT_Common_Commands_m.suit_directive_override_parameters_m_l.map + ), + ) + self.assertEqual( + 2, + len( + self.decoded.suit_manifest.suit_common.suit_common_sequence[0] + .suit_common_sequence.union[0] + .SUIT_Common_Commands_m.suit_directive_override_parameters_m_l.map[0] + ), + ) def test_cbor_pen(self): data = bytes.fromhex(p_test_vectors14[1].read_text(encoding="utf-8").replace("\n", "")) @@ -321,7 +516,7 @@ def test_inv1(self): def test_inv2(self): data = bytes.fromhex(p_test_vectors14[1].read_text(encoding="utf-8").replace("\n", "")) struct = loads(data) - struct.value[23] = b'' # Invalid integrated payload key + struct.value[23] = b"" # Invalid integrated payload key data = dumps(struct) try: self.decode_string(data, p_manifest14, p_cose) @@ -338,7 +533,7 @@ def test_inv3(self): struct4 = loads(struct3[4]) # override params self.assertEqual(struct4[0], 20) self.assertTrue(isinstance(struct4[1][1], bytes)) - struct4[1][1] += b'x' # vendor ID: wrong length + struct4[1][1] += b"x" # vendor ID: wrong length struct3[4] = dumps(struct4) struct2[3] = dumps(struct3) struct.value[3] = dumps(struct2) @@ -358,26 +553,50 @@ def __init__(self, *args, **kwargs): def test_text(self): self.assertEqual( - bytes.fromhex('2bfc4d0cc6680be7dd9f5ca30aa2bb5d1998145de33d54101b80e2ca49faf918'), - self.decoded.suit_manifest.SUIT_Severable_Members_Choice.suit_text[0].SUIT_Digest_m.suit_digest_bytes) + bytes.fromhex("2bfc4d0cc6680be7dd9f5ca30aa2bb5d1998145de33d54101b80e2ca49faf918"), + self.decoded.suit_manifest.SUIT_Severable_Members_Choice.suit_text[ + 0 + ].SUIT_Digest_m.suit_digest_bytes, + ) self.assertEqual( - bytes.fromhex('2bfc4d0cc6680be7dd9f5ca30aa2bb5d1998145de33d54101b80e2ca49faf918'), - sha256(dumps(self.decoded.SUIT_Severable_Manifest_Members.suit_text[0].suit_text_bstr)).digest()) - self.assertEqual('arm.com', self.decoded.SUIT_Severable_Manifest_Members.suit_text[0].suit_text.SUIT_Component_Identifier[0].SUIT_Component_Identifier.SUIT_Text_Component_Keys.suit_text_vendor_domain[0]) - self.assertEqual('This component is a demonstration. The digest is a sample pattern, not a real one.', self.decoded.SUIT_Severable_Manifest_Members.suit_text[0].suit_text.SUIT_Component_Identifier[0].SUIT_Component_Identifier.SUIT_Text_Component_Keys.suit_text_component_description[0]) + bytes.fromhex("2bfc4d0cc6680be7dd9f5ca30aa2bb5d1998145de33d54101b80e2ca49faf918"), + sha256( + dumps(self.decoded.SUIT_Severable_Manifest_Members.suit_text[0].suit_text_bstr) + ).digest(), + ) + self.assertEqual( + "arm.com", + self.decoded.SUIT_Severable_Manifest_Members.suit_text[0] + .suit_text.SUIT_Component_Identifier[0] + .SUIT_Component_Identifier.SUIT_Text_Component_Keys.suit_text_vendor_domain[0], + ) + self.assertEqual( + "This component is a demonstration. The digest is a sample pattern, not a real one.", + self.decoded.SUIT_Severable_Manifest_Members.suit_text[0] + .suit_text.SUIT_Component_Identifier[0] + .SUIT_Component_Identifier.SUIT_Text_Component_Keys.suit_text_component_description[0], + ) # Check manifest description. The concatenation and .replace() call are there to add # trailing whitespace to all blank lines except the first. # This is done in this way to avoid editors automatically removing the whitespace. - self.assertEqual('''## Example 2: Simultaneous Download, Installation, Secure Boot, Severed Fields -''' + ''' + self.assertEqual( + """## Example 2: Simultaneous Download, Installation, Secure Boot, Severed Fields +""" + + """ This example covers the following templates: * Compatibility Check ({{template-compatibility-check}}) * Secure Boot ({{template-secure-boot}}) * Firmware Download ({{firmware-download-template}}) - This example also demonstrates severable elements ({{ovr-severable}}), and text ({{manifest-digest-text}}).'''.replace("\n\n", "\n \n"), self.decoded.SUIT_Severable_Manifest_Members.suit_text[0].suit_text.SUIT_Text_Keys.suit_text_manifest_description[0]) + This example also demonstrates severable elements ({{ovr-severable}}), and text ({{manifest-digest-text}}).""".replace( + "\n\n", "\n \n" + ), + self.decoded.SUIT_Severable_Manifest_Members.suit_text[ + 0 + ].suit_text.SUIT_Text_Keys.suit_text_manifest_description[0], + ) class TestEx3Manifest14(TestManifest): @@ -387,9 +606,36 @@ def __init__(self, *args, **kwargs): self.slots = (33792, 541696) def test_try_each(self): - self.assertEqual(2, len(self.decoded.suit_manifest.SUIT_Severable_Members_Choice.suit_install[0].SUIT_Command_Sequence_bstr.union[0].SUIT_Directive_m.suit_directive_try_each_m_l.SUIT_Directive_Try_Each_Argument_m.SUIT_Command_Sequence_bstr)) - self.assertEqual(self.slots[0], self.decoded.suit_manifest.SUIT_Severable_Members_Choice.suit_install[0].SUIT_Command_Sequence_bstr.union[0].SUIT_Directive_m.suit_directive_try_each_m_l.SUIT_Directive_Try_Each_Argument_m.SUIT_Command_Sequence_bstr[0].union[0].SUIT_Directive_m.suit_directive_override_parameters_m_l.map[0].suit_parameter_component_slot) - self.assertEqual(self.slots[1], self.decoded.suit_manifest.SUIT_Severable_Members_Choice.suit_install[0].SUIT_Command_Sequence_bstr.union[0].SUIT_Directive_m.suit_directive_try_each_m_l.SUIT_Directive_Try_Each_Argument_m.SUIT_Command_Sequence_bstr[1].union[0].SUIT_Directive_m.suit_directive_override_parameters_m_l.map[0].suit_parameter_component_slot) + self.assertEqual( + 2, + len( + self.decoded.suit_manifest.SUIT_Severable_Members_Choice.suit_install[0] + .SUIT_Command_Sequence_bstr.union[0] + .SUIT_Directive_m.suit_directive_try_each_m_l.SUIT_Directive_Try_Each_Argument_m.SUIT_Command_Sequence_bstr + ), + ) + self.assertEqual( + self.slots[0], + self.decoded.suit_manifest.SUIT_Severable_Members_Choice.suit_install[0] + .SUIT_Command_Sequence_bstr.union[0] + .SUIT_Directive_m.suit_directive_try_each_m_l.SUIT_Directive_Try_Each_Argument_m.SUIT_Command_Sequence_bstr[ + 0 + ] + .union[0] + .SUIT_Directive_m.suit_directive_override_parameters_m_l.map[0] + .suit_parameter_component_slot, + ) + self.assertEqual( + self.slots[1], + self.decoded.suit_manifest.SUIT_Severable_Members_Choice.suit_install[0] + .SUIT_Command_Sequence_bstr.union[0] + .SUIT_Directive_m.suit_directive_try_each_m_l.SUIT_Directive_Try_Each_Argument_m.SUIT_Command_Sequence_bstr[ + 1 + ] + .union[0] + .SUIT_Directive_m.suit_directive_override_parameters_m_l.map[0] + .suit_parameter_component_slot, + ) class TestEx4Manifest14(TestManifest): @@ -399,9 +645,15 @@ def __init__(self, *args, **kwargs): def test_components(self): self.assertEqual(3, len(self.decoded.suit_manifest.suit_common.suit_components[0])) - self.assertEqual(b'\x00', self.decoded.suit_manifest.suit_common.suit_components[0][0].bstr[0]) - self.assertEqual(b'\x02', self.decoded.suit_manifest.suit_common.suit_components[0][1].bstr[0]) - self.assertEqual(b'\x01', self.decoded.suit_manifest.suit_common.suit_components[0][2].bstr[0]) + self.assertEqual( + b"\x00", self.decoded.suit_manifest.suit_common.suit_components[0][0].bstr[0] + ) + self.assertEqual( + b"\x02", self.decoded.suit_manifest.suit_common.suit_components[0][1].bstr[0] + ) + self.assertEqual( + b"\x01", self.decoded.suit_manifest.suit_common.suit_components[0][2].bstr[0] + ) class TestEx5Manifest14(TestManifest): @@ -410,8 +662,20 @@ def __init__(self, *args, **kwargs): self.decode_file(p_test_vectors14[5], p_manifest14, p_cose) def test_validate(self): - self.assertEqual(4, len(self.decoded.suit_manifest.SUIT_Unseverable_Members.suit_validate[0].suit_validate.union)) - self.assertEqual(15, self.decoded.suit_manifest.SUIT_Unseverable_Members.suit_validate[0].suit_validate.union[1].SUIT_Condition_m.suit_condition_image_match_m_l.SUIT_Rep_Policy_m) + self.assertEqual( + 4, + len( + self.decoded.suit_manifest.SUIT_Unseverable_Members.suit_validate[ + 0 + ].suit_validate.union + ), + ) + self.assertEqual( + 15, + self.decoded.suit_manifest.SUIT_Unseverable_Members.suit_validate[0] + .suit_validate.union[1] + .SUIT_Condition_m.suit_condition_image_match_m_l.SUIT_Rep_Policy_m, + ) class TestEx5InvManifest14(TestManifest): @@ -491,27 +755,129 @@ class TestEx1Manifest20(TestEx1Manifest16): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.decode_file(p_test_vectors20[1], p_manifest20, p_cose) - self.manifest_digest = bytes.fromhex("ef14b7091e8adae8aa3bb6fca1d64fb37e19dcf8b35714cfdddc5968c80ff50e") + self.manifest_digest = bytes.fromhex( + "ef14b7091e8adae8aa3bb6fca1d64fb37e19dcf8b35714cfdddc5968c80ff50e" + ) def test_structure(self): - self.assertEqual("COSE_Sign1_Tagged_m", self.decoded.suit_authentication_wrapper.SUIT_Authentication_Block_bstr[0].union_choice) - self.assertEqual(-7, self.decoded.suit_authentication_wrapper.SUIT_Authentication_Block_bstr[0].COSE_Sign1_Tagged_m.Headers_m.protected.header_map_bstr.Generic_Headers.uint1union[0].int) - self.assertEqual(self.manifest_digest, self.decoded.suit_authentication_wrapper.SUIT_Digest_bstr.suit_digest_bytes) + self.assertEqual( + "COSE_Sign1_Tagged_m", + self.decoded.suit_authentication_wrapper.SUIT_Authentication_Block_bstr[0].union_choice, + ) + self.assertEqual( + -7, + self.decoded.suit_authentication_wrapper.SUIT_Authentication_Block_bstr[0] + .COSE_Sign1_Tagged_m.Headers_m.protected.header_map_bstr.Generic_Headers.uint1union[0] + .int, + ) + self.assertEqual( + self.manifest_digest, + self.decoded.suit_authentication_wrapper.SUIT_Digest_bstr.suit_digest_bytes, + ) self.assertEqual(1, self.decoded.suit_manifest.suit_manifest_sequence_number) - self.assertEqual(bytes.fromhex("fa6b4a53d5ad5fdfbe9de663e4d41ffe"), self.decoded.suit_manifest.suit_common.suit_shared_sequence[0].suit_shared_sequence.union[0].SUIT_Shared_Commands_m.suit_directive_override_parameters_m_l.map[0].suit_parameter_vendor_identifier.RFC4122_UUID_m) - self.assertEqual(bytes.fromhex("1492af1425695e48bf429b2d51f2ab45"), self.decoded.suit_manifest.suit_common.suit_shared_sequence[0].suit_shared_sequence.union[0].SUIT_Shared_Commands_m.suit_directive_override_parameters_m_l.map[1].suit_parameter_class_identifier) - self.assertEqual(bytes.fromhex("00112233445566778899aabbccddeeff0123456789abcdeffedcba9876543210"), self.decoded.suit_manifest.suit_common.suit_shared_sequence[0].suit_shared_sequence.union[0].SUIT_Shared_Commands_m.suit_directive_override_parameters_m_l.map[2].suit_parameter_image_digest.suit_digest_bytes) - self.assertEqual('cose_alg_sha_256_m', self.decoded.suit_manifest.suit_common.suit_shared_sequence[0].suit_shared_sequence.union[0].SUIT_Shared_Commands_m.suit_directive_override_parameters_m_l.map[2].suit_parameter_image_digest.suit_digest_algorithm_id.union_choice) - self.assertEqual(34768, self.decoded.suit_manifest.suit_common.suit_shared_sequence[0].suit_shared_sequence.union[0].SUIT_Shared_Commands_m.suit_directive_override_parameters_m_l.map[3].suit_parameter_image_size) - self.assertEqual(4, len(self.decoded.suit_manifest.suit_common.suit_shared_sequence[0].suit_shared_sequence.union[0].SUIT_Shared_Commands_m.suit_directive_override_parameters_m_l.map)) - self.assertEqual(15, self.decoded.suit_manifest.suit_common.suit_shared_sequence[0].suit_shared_sequence.union[1].SUIT_Condition_m.suit_condition_vendor_identifier_m_l.SUIT_Rep_Policy_m) - self.assertEqual(15, self.decoded.suit_manifest.suit_common.suit_shared_sequence[0].suit_shared_sequence.union[2].SUIT_Condition_m.suit_condition_class_identifier_m_l.SUIT_Rep_Policy_m) - self.assertEqual(3, len(self.decoded.suit_manifest.suit_common.suit_shared_sequence[0].suit_shared_sequence.union)) - self.assertEqual(2, len(self.decoded.suit_manifest.suit_common.suit_shared_sequence[0].suit_shared_sequence.union[0])) - self.assertEqual(2, len(self.decoded.suit_manifest.suit_common.suit_shared_sequence[0].suit_shared_sequence.union[0].SUIT_Shared_Commands_m)) - self.assertEqual(1, len(self.decoded.suit_manifest.suit_common.suit_shared_sequence[0].suit_shared_sequence.union[0].SUIT_Shared_Commands_m.suit_directive_override_parameters_m_l)) - self.assertEqual(4, len(self.decoded.suit_manifest.suit_common.suit_shared_sequence[0].suit_shared_sequence.union[0].SUIT_Shared_Commands_m.suit_directive_override_parameters_m_l.map)) - self.assertEqual(2, len(self.decoded.suit_manifest.suit_common.suit_shared_sequence[0].suit_shared_sequence.union[0].SUIT_Shared_Commands_m.suit_directive_override_parameters_m_l.map[0])) + self.assertEqual( + bytes.fromhex("fa6b4a53d5ad5fdfbe9de663e4d41ffe"), + self.decoded.suit_manifest.suit_common.suit_shared_sequence[0] + .suit_shared_sequence.union[0] + .SUIT_Shared_Commands_m.suit_directive_override_parameters_m_l.map[0] + .suit_parameter_vendor_identifier.RFC4122_UUID_m, + ) + self.assertEqual( + bytes.fromhex("1492af1425695e48bf429b2d51f2ab45"), + self.decoded.suit_manifest.suit_common.suit_shared_sequence[0] + .suit_shared_sequence.union[0] + .SUIT_Shared_Commands_m.suit_directive_override_parameters_m_l.map[1] + .suit_parameter_class_identifier, + ) + self.assertEqual( + bytes.fromhex("00112233445566778899aabbccddeeff0123456789abcdeffedcba9876543210"), + self.decoded.suit_manifest.suit_common.suit_shared_sequence[0] + .suit_shared_sequence.union[0] + .SUIT_Shared_Commands_m.suit_directive_override_parameters_m_l.map[2] + .suit_parameter_image_digest.suit_digest_bytes, + ) + self.assertEqual( + "cose_alg_sha_256_m", + self.decoded.suit_manifest.suit_common.suit_shared_sequence[0] + .suit_shared_sequence.union[0] + .SUIT_Shared_Commands_m.suit_directive_override_parameters_m_l.map[2] + .suit_parameter_image_digest.suit_digest_algorithm_id.union_choice, + ) + self.assertEqual( + 34768, + self.decoded.suit_manifest.suit_common.suit_shared_sequence[0] + .suit_shared_sequence.union[0] + .SUIT_Shared_Commands_m.suit_directive_override_parameters_m_l.map[3] + .suit_parameter_image_size, + ) + self.assertEqual( + 4, + len( + self.decoded.suit_manifest.suit_common.suit_shared_sequence[0] + .suit_shared_sequence.union[0] + .SUIT_Shared_Commands_m.suit_directive_override_parameters_m_l.map + ), + ) + self.assertEqual( + 15, + self.decoded.suit_manifest.suit_common.suit_shared_sequence[0] + .suit_shared_sequence.union[1] + .SUIT_Condition_m.suit_condition_vendor_identifier_m_l.SUIT_Rep_Policy_m, + ) + self.assertEqual( + 15, + self.decoded.suit_manifest.suit_common.suit_shared_sequence[0] + .suit_shared_sequence.union[2] + .SUIT_Condition_m.suit_condition_class_identifier_m_l.SUIT_Rep_Policy_m, + ) + self.assertEqual( + 3, + len( + self.decoded.suit_manifest.suit_common.suit_shared_sequence[ + 0 + ].suit_shared_sequence.union + ), + ) + self.assertEqual( + 2, + len( + self.decoded.suit_manifest.suit_common.suit_shared_sequence[ + 0 + ].suit_shared_sequence.union[0] + ), + ) + self.assertEqual( + 2, + len( + self.decoded.suit_manifest.suit_common.suit_shared_sequence[0] + .suit_shared_sequence.union[0] + .SUIT_Shared_Commands_m + ), + ) + self.assertEqual( + 1, + len( + self.decoded.suit_manifest.suit_common.suit_shared_sequence[0] + .suit_shared_sequence.union[0] + .SUIT_Shared_Commands_m.suit_directive_override_parameters_m_l + ), + ) + self.assertEqual( + 4, + len( + self.decoded.suit_manifest.suit_common.suit_shared_sequence[0] + .suit_shared_sequence.union[0] + .SUIT_Shared_Commands_m.suit_directive_override_parameters_m_l.map + ), + ) + self.assertEqual( + 2, + len( + self.decoded.suit_manifest.suit_common.suit_shared_sequence[0] + .suit_shared_sequence.union[0] + .SUIT_Shared_Commands_m.suit_directive_override_parameters_m_l.map[0] + ), + ) class TestEx1InvManifest20(TestEx1InvManifest16): @@ -555,40 +921,69 @@ class PopenTest(TestCase): def popen_test(self, args, input="", exp_retcode=0): call0 = Popen(args, stdin=PIPE, stdout=PIPE, stderr=PIPE) stdout0, stderr0 = call0.communicate(input) - self.assertEqual(exp_retcode, call0.returncode, stderr0.decode('utf-8')) + self.assertEqual(exp_retcode, call0.returncode, stderr0.decode("utf-8")) return stdout0, stderr0 class TestCLI(PopenTest): def get_std_args(self, input, cmd="convert"): - return ["zcbor", cmd, "--cddl", str(p_manifest12), "--input", str(input), "-t", "SUIT_Envelope_Tagged", "--yaml-compatibility"] + return [ + "zcbor", + cmd, + "--cddl", + str(p_manifest12), + "--input", + str(input), + "-t", + "SUIT_Envelope_Tagged", + "--yaml-compatibility", + ] def do_testManifest(self, n): self.popen_test(self.get_std_args(p_test_vectors12[n], cmd="validate"), "") - stdout0, _ = self.popen_test(self.get_std_args(p_test_vectors12[n]) + ["--output", "-", "--output-as", "cbor"], "") + stdout0, _ = self.popen_test( + self.get_std_args(p_test_vectors12[n]) + ["--output", "-", "--output-as", "cbor"], "" + ) self.popen_test(self.get_std_args("-", cmd="validate") + ["--input-as", "cbor"], stdout0) - stdout1, _ = self.popen_test(self.get_std_args("-") + ["--input-as", "cbor", "--output", "-", "--output-as", "json"], stdout0) + stdout1, _ = self.popen_test( + self.get_std_args("-") + ["--input-as", "cbor", "--output", "-", "--output-as", "json"], + stdout0, + ) self.popen_test(self.get_std_args("-", cmd="validate") + ["--input-as", "json"], stdout1) - stdout2, _ = self.popen_test(self.get_std_args("-") + ["--input-as", "json", "--output", "-", "--output-as", "yaml"], stdout1) + stdout2, _ = self.popen_test( + self.get_std_args("-") + ["--input-as", "json", "--output", "-", "--output-as", "yaml"], + stdout1, + ) self.popen_test(self.get_std_args("-", cmd="validate") + ["--input-as", "yaml"], stdout2) - stdout3, _ = self.popen_test(self.get_std_args("-") + ["--input-as", "yaml", "--output", "-", "--output-as", "cbor"], stdout2) + stdout3, _ = self.popen_test( + self.get_std_args("-") + ["--input-as", "yaml", "--output", "-", "--output-as", "cbor"], + stdout2, + ) self.assertEqual(stdout0, stdout3) self.popen_test(self.get_std_args("-", cmd="validate") + ["--input-as", "cbor"], stdout3) - stdout4, _ = self.popen_test(self.get_std_args("-") + ["--input-as", "cbor", "--output", "-", "--output-as", "cborhex"], stdout3) + stdout4, _ = self.popen_test( + self.get_std_args("-") + + ["--input-as", "cbor", "--output", "-", "--output-as", "cborhex"], + stdout3, + ) self.popen_test(self.get_std_args("-", cmd="validate") + ["--input-as", "cborhex"], stdout4) - stdout5, _ = self.popen_test(self.get_std_args("-") + ["--input-as", "cborhex", "--output", "-", "--output-as", "json"], stdout4) + stdout5, _ = self.popen_test( + self.get_std_args("-") + + ["--input-as", "cborhex", "--output", "-", "--output-as", "json"], + stdout4, + ) self.assertEqual(stdout1, stdout5) self.maxDiff = None - with open(p_test_vectors12[n], 'r', encoding="utf-8") as f: + with open(p_test_vectors12[n], "r", encoding="utf-8") as f: self.assertEqual(sub(r"\W+", "", f.read()), sub(r"\W+", "", stdout4.decode("utf-8"))) def test_0(self): @@ -610,11 +1005,37 @@ def test_5(self): self.do_testManifest(5) def test_map_bstr(self): - stdout1, _ = self.popen_test(["zcbor", "convert", "--cddl", str(p_map_bstr_cddl), "--input", str(p_map_bstr_yaml), "-t", "map", "--yaml-compatibility", "--output", "-"], "") - self.assertEqual(dumps({"test": bytes.fromhex("1234abcd"), "test2": cbor2.CBORTag(1234, bytes.fromhex("1a2b3c4d")), ("test3",): dumps(1234)}), stdout1) + stdout1, _ = self.popen_test( + [ + "zcbor", + "convert", + "--cddl", + str(p_map_bstr_cddl), + "--input", + str(p_map_bstr_yaml), + "-t", + "map", + "--yaml-compatibility", + "--output", + "-", + ], + "", + ) + self.assertEqual( + dumps( + { + "test": bytes.fromhex("1234abcd"), + "test2": cbor2.CBORTag(1234, bytes.fromhex("1a2b3c4d")), + ("test3",): dumps(1234), + } + ), + stdout1, + ) def test_decode_encode(self): - _, stderr1 = self.popen_test(["zcbor", "code", "--cddl", str(p_map_bstr_cddl), "-t", "map"], "", exp_retcode=2) + _, stderr1 = self.popen_test( + ["zcbor", "code", "--cddl", str(p_map_bstr_cddl), "-t", "map"], "", exp_retcode=2 + ) self.assertIn(b"error: Please specify at least one of --decode or --encode", stderr1) def test_output_present(self): @@ -623,13 +1044,15 @@ def test_output_present(self): self.assertIn( b"error: Please specify both --output-c and --output-h " b"unless --output-cmake is specified.", - stderr1) + stderr1, + ) _, stderr2 = self.popen_test(args + ["--output-c", "/tmp/map.c"], "", exp_retcode=2) self.assertIn( b"error: Please specify both --output-c and --output-h " b"unless --output-cmake is specified.", - stderr2) + stderr2, + ) def do_test_file_header(self, from_file=False): tempd = Path(mkdtemp()) @@ -642,7 +1065,25 @@ def do_test_file_header(self, from_file=False): else: file_header_input = file_header - _, __ = self.popen_test(["zcbor", "code", "--cddl", str(p_pet_cddl), "-t", "Pet", "--output-cmake", str(tempd / "pet.cmake"), "-d", "-e", "--file-header", (file_header_input), "--dq", "5"], "") + _, __ = self.popen_test( + [ + "zcbor", + "code", + "--cddl", + str(p_pet_cddl), + "-t", + "Pet", + "--output-cmake", + str(tempd / "pet.cmake"), + "-d", + "-e", + "--file-header", + (file_header_input), + "--dq", + "5", + ], + "", + ) exp_cmake_header = f"""# # Sample # @@ -661,10 +1102,16 @@ def do_test_file_header(self, from_file=False): * https://github.com/NordicSemiconductor/zcbor * Generated with a --default-max-qty of 5 */""".splitlines() - self.assertEqual(exp_cmake_header, (tempd / "pet.cmake").read_text(encoding="utf-8").splitlines()[:9]) - for p in (tempd / "src" / "pet_decode.c", tempd / "src" / "pet_encode.c", - tempd / "include" / "pet_decode.h", tempd / "include" / "pet_encode.h", - tempd / "include" / "pet_types.h"): + self.assertEqual( + exp_cmake_header, (tempd / "pet.cmake").read_text(encoding="utf-8").splitlines()[:9] + ) + for p in ( + tempd / "src" / "pet_decode.c", + tempd / "src" / "pet_encode.c", + tempd / "include" / "pet_decode.h", + tempd / "include" / "pet_encode.h", + tempd / "include" / "pet_types.h", + ): self.assertEqual(exp_c_header, p.read_text(encoding="utf-8").splitlines()[:9]) rmtree(tempd) @@ -675,9 +1122,9 @@ def test_file_header(self): class TestOptional(TestCase): def test_optional_0(self): - with open(p_optional, 'r', encoding="utf-8") as f: + with open(p_optional, "r", encoding="utf-8") as f: cddl_res = zcbor.DataTranslator.from_cddl(f.read(), 16) - cddl = cddl_res.my_types['cfg'] + cddl = cddl_res.my_types["cfg"] test_yaml = """ mem_config: - 0 @@ -690,8 +1137,12 @@ def test_optional_0(self): class TestUndefined(TestCase): def test_undefined_0(self): cddl_res = zcbor.DataTranslator.from_cddl( - p_prelude.read_text(encoding="utf-8") + '\n' + p_corner_cases.read_text(encoding="utf-8"), 16) - cddl = cddl_res.my_types['Simples'] + p_prelude.read_text(encoding="utf-8") + + "\n" + + p_corner_cases.read_text(encoding="utf-8"), + 16, + ) + cddl = cddl_res.my_types["Simples"] test_yaml = "[true, false, true, null, [zcbor_undefined]]" decoded = cddl.decode_str_yaml(test_yaml, yaml_compat=True) @@ -704,8 +1155,12 @@ def test_undefined_0(self): class TestFloat(TestCase): def test_float_0(self): cddl_res = zcbor.DataTranslator.from_cddl( - p_prelude.read_text(encoding="utf-8") + '\n' + p_corner_cases.read_text(encoding="utf-8"), 16) - cddl = cddl_res.my_types['Floats'] + p_prelude.read_text(encoding="utf-8") + + "\n" + + p_corner_cases.read_text(encoding="utf-8"), + 16, + ) + cddl = cddl_res.my_types["Floats"] test_yaml = f"[3.1415, 1234567.89, 0.000123, 3.1415, 2.71828, 5.0, {1 / 3}]" decoded = cddl.decode_str_yaml(test_yaml) @@ -722,25 +1177,90 @@ def test_float_0(self): class TestYamlCompatibility(PopenTest): def test_yaml_compatibility(self): - self.popen_test(["zcbor", "validate", "-c", p_yaml_compat_cddl, "-i", p_yaml_compat_yaml, "-t", "Yaml_compatibility_example"], exp_retcode=1) - self.popen_test(["zcbor", "validate", "-c", p_yaml_compat_cddl, "-i", p_yaml_compat_yaml, "-t", "Yaml_compatibility_example", "--yaml-compatibility"]) - stdout1, _ = self.popen_test(["zcbor", "convert", "-c", p_yaml_compat_cddl, "-i", p_yaml_compat_yaml, "-o", "-", "-t", "Yaml_compatibility_example", "--yaml-compatibility"]) - stdout2, _ = self.popen_test(["zcbor", "convert", "-c", p_yaml_compat_cddl, "-i", "-", "-o", "-", "--output-as", "yaml", "-t", "Yaml_compatibility_example", "--yaml-compatibility"], stdout1) - self.assertEqual(safe_load(stdout2), safe_load(p_yaml_compat_yaml.read_text(encoding="utf-8"))) + self.popen_test( + [ + "zcbor", + "validate", + "-c", + p_yaml_compat_cddl, + "-i", + p_yaml_compat_yaml, + "-t", + "Yaml_compatibility_example", + ], + exp_retcode=1, + ) + self.popen_test( + [ + "zcbor", + "validate", + "-c", + p_yaml_compat_cddl, + "-i", + p_yaml_compat_yaml, + "-t", + "Yaml_compatibility_example", + "--yaml-compatibility", + ] + ) + stdout1, _ = self.popen_test( + [ + "zcbor", + "convert", + "-c", + p_yaml_compat_cddl, + "-i", + p_yaml_compat_yaml, + "-o", + "-", + "-t", + "Yaml_compatibility_example", + "--yaml-compatibility", + ] + ) + stdout2, _ = self.popen_test( + [ + "zcbor", + "convert", + "-c", + p_yaml_compat_cddl, + "-i", + "-", + "-o", + "-", + "--output-as", + "yaml", + "-t", + "Yaml_compatibility_example", + "--yaml-compatibility", + ], + stdout1, + ) + self.assertEqual( + safe_load(stdout2), safe_load(p_yaml_compat_yaml.read_text(encoding="utf-8")) + ) class TestIntmax(TestCase): def test_intmax1(self): cddl_res = zcbor.DataTranslator.from_cddl( - p_prelude.read_text(encoding="utf-8") + '\n' + p_corner_cases.read_text(encoding="utf-8"), 16) - cddl = cddl_res.my_types['Intmax1'] + p_prelude.read_text(encoding="utf-8") + + "\n" + + p_corner_cases.read_text(encoding="utf-8"), + 16, + ) + cddl = cddl_res.my_types["Intmax1"] test_yaml = f"[-128, 127, 255, -32768, 32767, 65535, -2147483648, 2147483647, 4294967295, -9223372036854775808, 9223372036854775807, 18446744073709551615]" decoded = cddl.decode_str_yaml(test_yaml) def test_intmax2(self): cddl_res = zcbor.DataTranslator.from_cddl( - p_prelude.read_text(encoding="utf-8") + '\n' + p_corner_cases.read_text(encoding="utf-8"), 16) - cddl = cddl_res.my_types['Intmax2'] + p_prelude.read_text(encoding="utf-8") + + "\n" + + p_corner_cases.read_text(encoding="utf-8"), + 16, + ) + cddl = cddl_res.my_types["Intmax2"] test_yaml1 = f"[-128, 0, -32768, 0, -2147483648, 0, -9223372036854775808, 0]" decoded = cddl.decode_str_yaml(test_yaml1) self.assertEqual(decoded.INT_8, -128) @@ -767,8 +1287,12 @@ def test_intmax2(self): class TestInvalidIdentifiers(TestCase): def test_invalid_identifiers0(self): cddl_res = zcbor.DataTranslator.from_cddl( - p_prelude.read_text(encoding="utf-8") + '\n' + p_corner_cases.read_text(encoding="utf-8"), 16) - cddl = cddl_res.my_types['InvalidIdentifiers'] + p_prelude.read_text(encoding="utf-8") + + "\n" + + p_corner_cases.read_text(encoding="utf-8"), + 16, + ) + cddl = cddl_res.my_types["InvalidIdentifiers"] test_yaml = "['1one', 2, '{[a-z]}']" decoded = cddl.decode_str_yaml(test_yaml) self.assertTrue(decoded.f_1one_tstr) diff --git a/tests/unit/test3_float16/floats.py b/tests/unit/test3_float16/floats.py index 72d2cdf1..e26cb80f 100644 --- a/tests/unit/test3_float16/floats.py +++ b/tests/unit/test3_float16/floats.py @@ -8,38 +8,46 @@ import sys import os + def decode_test(): - num_start = 0 - num_end = 0x10000 + num_start = 0 + num_end = 0x10000 + + a = numpy.frombuffer( + numpy.arange(num_start, num_end, dtype=numpy.ushort).tobytes(), dtype=numpy.float16 + ) + with open(os.path.join(sys.argv[1], "fp_bytes_decode.bin"), "wb") as f: + f.write(a.astype("f").tobytes()) - a = numpy.frombuffer(numpy.arange(num_start, num_end, dtype=numpy.ushort).tobytes(), dtype=numpy.float16) - with open(os.path.join(sys.argv[1], "fp_bytes_decode.bin"), 'wb') as f: - f.write(a.astype("f").tobytes()) def encode_test(): - num_start = 0x33000001 - num_end = 0x477ff000 + num_start = 0x33000001 + num_end = 0x477FF000 - a = numpy.arange(num_start, num_end, dtype=numpy.uintc) - b = numpy.frombuffer(a.tobytes(), dtype=numpy.float32).astype("I").tobytes()) - with open(os.path.join(sys.argv[1], "fp_bytes_encode.bin"), 'wb') as f: - f.write(c.astype("I").tobytes()) def print_help(): - print("Generate bin files with results from converting between float16 and float32 (both ways)") - print() - print(f"Usage: {sys.argv[0]} ") + print("Generate bin files with results from converting between float16 and float32 (both ways)") + print() + print(f"Usage: {sys.argv[0]} ") + if __name__ == "__main__": - if "--help" in sys.argv or "-h" in sys.argv or len(sys.argv) < 2: - print_help() - elif len(sys.argv) < 3: - decode_test() - encode_test() - elif sys.argv[2] == "decode": - decode_test() - elif sys.argv[2] == "encode": - encode_test() + if "--help" in sys.argv or "-h" in sys.argv or len(sys.argv) < 2: + print_help() + elif len(sys.argv) < 3: + decode_test() + encode_test() + elif sys.argv[2] == "decode": + decode_test() + elif sys.argv[2] == "encode": + encode_test() diff --git a/zcbor/zcbor.py b/zcbor/zcbor.py index a49bc2ec..4cd30806 100755 --- a/zcbor/zcbor.py +++ b/zcbor/zcbor.py @@ -15,8 +15,16 @@ from datetime import datetime from copy import copy from itertools import tee, chain -from cbor2 import (loads, dumps, CBORTag, load, CBORDecodeValueError, CBORDecodeEOF, undefined, - CBORSimpleValue) +from cbor2 import ( + loads, + dumps, + CBORTag, + load, + CBORDecodeValueError, + CBORDecodeEOF, + undefined, + CBORSimpleValue, +) from yaml import safe_load as yaml_load, dump as yaml_dump from json import loads as json_load, dumps as json_dump from io import BytesIO @@ -125,13 +133,13 @@ def val_or_null(value, var_name): Return a code snippet that assigns the value to a variable var_name and returns pointer to the variable, or returns NULL if the value is None. - """ + """ return "(%s = %d, &%s)" % (var_name, value, var_name) if value is not None else "NULL" def tmp_str_or_null(value): """Assign the min_value variable.""" - value_str = f'"{value}"' if value is not None else 'NULL' + value_str = f'"{value}"' if value is not None else "NULL" len_str = f"""sizeof({f'"{value}"'}) - 1, &tmp_str)""" return f"(tmp_str.value = (uint8_t *){value_str}, tmp_str.len = {len_str}" @@ -148,11 +156,9 @@ def deref_if_not_null(access): def xcode_args(res, *sargs): """Return an argument list for a function call to a encoder/decoder function.""" if len(sargs) > 0: - return "state, %s, %s, %s" % ( - "&(%s)" % res if res != "NULL" else res, sargs[0], sargs[1]) + return "state, %s, %s, %s" % ("&(%s)" % res if res != "NULL" else res, sargs[0], sargs[1]) else: - return "state, %s" % ( - "(%s)" % res if res != "NULL" else res) + return "state, %s" % ("(%s)" % res if res != "NULL" else res) def xcode_statement(func, *sargs, **kwargs): @@ -179,7 +185,8 @@ def ternary_if_chain(access, names, xcode_strings): names[0], xcode_strings[0], newl_ind, - ternary_if_chain(access, names[1:], xcode_strings[1:]) if len(names) > 1 else "false") + ternary_if_chain(access, names[1:], xcode_strings[1:]) if len(names) > 1 else "false", + ) def comma_operator(*expressions): @@ -228,8 +235,16 @@ class CddlParser: - For "OTHER" types, one instance points to another type definition. - For "GROUP" and "UNION" types, there is no separate data item for the instance. """ - def __init__(self, default_max_qty, my_types, my_control_groups, base_name=None, - short_names=False, base_stem=''): + + def __init__( + self, + default_max_qty, + my_types, + my_control_groups, + base_name=None, + short_names=False, + base_stem="", + ): self.id_prefix = "temp_" + str(counter()) self.id_num = None # Unique ID number. Only populated if needed. # The value of the data item. Has different meaning for different @@ -291,16 +306,19 @@ def from_cddl(cddl_class, cddl_string, default_max_qty, *args, **kwargs): type_strings = cddl_class.get_types(cddl_string) # Separate type_strings as keys in two dicts, one dict for strings that start with &( which # are special control operators for .bits, and one dict for all the regular types. - my_types = \ - {my_type: None for my_type, val in type_strings.items() if not val.startswith("&(")} - my_control_groups = \ - {my_cg: None for my_cg, val in type_strings.items() if val.startswith("&(")} + my_types = { + my_type: None for my_type, val in type_strings.items() if not val.startswith("&(") + } + my_control_groups = { + my_cg: None for my_cg, val in type_strings.items() if val.startswith("&(") + } # Parse the definitions, replacing the each string with a # CodeGenerator instance. for my_type, cddl_string in type_strings.items(): - parsed = cddl_class(*args, default_max_qty, my_types, my_control_groups, **kwargs, - base_stem=my_type) + parsed = cddl_class( + *args, default_max_qty, my_types, my_control_groups, **kwargs, base_stem=my_type + ) parsed.get_value(cddl_string.replace("\n", " ").lstrip("&")) parsed = parsed.flatten()[0] if my_type in my_types: @@ -324,7 +342,7 @@ def from_cddl(cddl_class, cddl_string, default_max_qty, *args, **kwargs): @staticmethod def strip_comments(instr): """Strip CDDL comments (';') from the string.""" - return getrp(r"\;.*?(\n|$)").sub('', instr) + return getrp(r"\;.*?(\n|$)").sub("", instr) @staticmethod def resolve_backslashes(instr): @@ -336,12 +354,14 @@ def get_types(cls, cddl_string): """Returns a dict containing multiple typename=>string""" instr = cls.strip_comments(cddl_string) instr = cls.resolve_backslashes(instr) - type_regex = \ + type_regex = ( r"(\s*?\$?\$?([\w-]+)\s*(\/{0,2})=\s*(.*?)(?=(\Z|\s*\$?\$?[\w-]+\s*\/{0,2}=(?!\>))))" + ) result = defaultdict(lambda: "") types = [ (key, value, slashes) - for (_1, key, slashes, value, _2) in getrp(type_regex, S | M).findall(instr)] + for (_1, key, slashes, value, _2) in getrp(type_regex, S | M).findall(instr) + ] for key, value, slashes in types: if slashes: result[key] += slashes @@ -353,24 +373,30 @@ def get_types(cls, cddl_string): result[key] = value return dict(result) - backslash_quotation_mark = r'\"' + backslash_quotation_mark = r"\"" def generate_base_name(self): """Generate a (hopefully) unique and descriptive name""" - byte_multi = (8 if self.type in ["INT", "UINT", "NINT", "FLOAT"] else 1) + byte_multi = 8 if self.type in ["INT", "UINT", "NINT", "FLOAT"] else 1 # The first non-None entry is used: - raw_name = (( + raw_name = ( # The label is the default if present: self.label # Name a key/value pair by its key type or string value: or (self.key.value if self.key and self.key.type in ["TSTR", "OTHER"] else None) # Name a string by its expected value: - or (f"{self.value.replace(self.backslash_quotation_mark, '')}_{self.type.lower()}" - if self.type == "TSTR" and self.value is not None else None) + or ( + f"{self.value.replace(self.backslash_quotation_mark, '')}_{self.type.lower()}" + if self.type == "TSTR" and self.value is not None + else None + ) # Name an integer by its expected value: - or (f"{self.type.lower()}{abs(self.value)}" - if self.type in ["UINT", "NINT"] and self.value is not None else None) + or ( + f"{self.type.lower()}{abs(self.value)}" + if self.type in ["UINT", "NINT"] and self.value is not None + else None + ) # Name a type by its type name or (next((key for key, value in self.my_types.items() if value == self), None)) # Name a control group by its name @@ -378,32 +404,46 @@ def generate_base_name(self): # Name an instance by its type: or (self.value + "_m" if self.type == "OTHER" else None) # Name a list by its first element: - or (self.value[0].get_base_name() + "_l" - if self.type in ["LIST", "GROUP"] and self.value else None) + or ( + self.value[0].get_base_name() + "_l" + if self.type in ["LIST", "GROUP"] and self.value + else None + ) # Name a cbor-encoded bstr by its expected cbor contents: - or ((self.cbor.value + "_bstr") - if self.cbor and self.cbor.type in ["TSTR", "OTHER"] else None) + or ( + (self.cbor.value + "_bstr") + if self.cbor and self.cbor.type in ["TSTR", "OTHER"] + else None + ) # Name a key value pair by its key (regardless of the key type) or ((self.key.generate_base_name() + self.type.lower()) if self.key else None) # Name an element by its minimum/maximum "size" (if the min == the max) - or (f"{self.type.lower()}{self.min_size * byte_multi}" - if (self.min_size is not None) and self.min_size == self.max_size else None) + or ( + f"{self.type.lower()}{self.min_size * byte_multi}" + if (self.min_size is not None) and self.min_size == self.max_size + else None + ) # Name an element by its minimum/maximum "size" (if the min != the max) - or (f"{self.type.lower()}{self.min_size * byte_multi}-{self.max_size * byte_multi}" - if (self.min_size is not None) and (self.max_size is not None) else None) + or ( + f"{self.type.lower()}{self.min_size * byte_multi}-{self.max_size * byte_multi}" + if (self.min_size is not None) and (self.max_size is not None) + else None + ) # Name an element by its type. - or self.type.lower()).replace("-", "_")) + or self.type.lower() + ).replace("-", "_") # Make the name compatible with C variable names # (don't start with a digit, don't use accented letters or symbols other than '_') - name_regex = getrp(r'[a-zA-Z_][a-zA-Z\d_]*') + name_regex = getrp(r"[a-zA-Z_][a-zA-Z\d_]*") if name_regex.fullmatch(raw_name) is None: - latinized_name = getrp(r'[^a-zA-Z\d_]').sub("", raw_name) + latinized_name = getrp(r"[^a-zA-Z\d_]").sub("", raw_name) if name_regex.fullmatch(latinized_name) is None: # Add '_' if name starts with a digit or is empty after removing accented chars. latinized_name = "_" + latinized_name - assert name_regex.fullmatch(latinized_name) is not None, \ - f"Couldn't make '{raw_name}' valid. '{latinized_name}' is invalid." + assert ( + name_regex.fullmatch(latinized_name) is not None + ), f"Couldn't make '{raw_name}' valid. '{latinized_name}' is invalid." return latinized_name return raw_name @@ -437,13 +477,17 @@ def id(self, with_prefix=True): raw_name = self.get_base_name() if not with_prefix and self.short_names: return raw_name - if (self.id_prefix - and (f"{self.id_prefix}_" not in raw_name) - and (self.id_prefix != raw_name.strip("_"))): + if ( + self.id_prefix + and (f"{self.id_prefix}_" not in raw_name) + and (self.id_prefix != raw_name.strip("_")) + ): return f"{self.id_prefix}_{raw_name}" - if (self.base_stem - and (f"{self.base_stem}_" not in raw_name) - and (self.base_stem != raw_name.strip("_"))): + if ( + self.base_stem + and (f"{self.base_stem}_" not in raw_name) + and (self.base_stem != raw_name.strip("_")) + ): return f"{self.base_stem}_{raw_name}" return raw_name @@ -454,10 +498,12 @@ def init_args(self): def init_kwargs(self): """Return the kwargs that should be used to initialize a new instance of this class.""" return { - "my_types": self.my_types, "my_control_groups": self.my_control_groups, - "short_names": self.short_names} + "my_types": self.my_types, + "my_control_groups": self.my_control_groups, + "short_names": self.short_names, + } - def set_id_prefix(self, id_prefix=''): + def set_id_prefix(self, id_prefix=""): self.id_prefix = id_prefix if self.type in ["LIST", "MAP", "GROUP", "UNION"]: for child in self.value: @@ -476,37 +522,36 @@ def child_base_id(self): def mrepr(self, newline): """Human readable representation.""" - reprstr = '' + reprstr = "" if self.quantifier: reprstr += self.quantifier if self.label: - reprstr += self.label + ':' + reprstr += self.label + ":" for tag in self.tags: reprstr += f"#6.{tag}" if self.key: reprstr += repr(self.key) + " => " if self.is_unambiguous(): - reprstr += '/' + reprstr += "/" if self.is_unambiguous_repeated(): - reprstr += '/' + reprstr += "/" reprstr += self.type if self.size: - reprstr += '(%d)' % self.size + reprstr += "(%d)" % self.size if newline: - reprstr += '\n' + reprstr += "\n" if self.value: reprstr += pformat(self.value, indent=4, width=1) if self.cbor: reprstr += " cbor: " + repr(self.cbor) - return reprstr.replace('\n', '\n ') + return reprstr.replace("\n", "\n ") def _flatten(self): """Recursively flatten children, key, and cbor elements.""" new_value = [] if self.type in ["LIST", "MAP", "GROUP", "UNION"]: for child in self.value: - new_value.extend( - child.flatten(allow_multi=self.type != "UNION")) + new_value.extend(child.flatten(allow_multi=self.type != "UNION")) self.value = new_value if self.key: self.key = self.key.flatten()[0] @@ -518,9 +563,11 @@ def flatten(self, allow_multi=False): self._flatten() if self.type == "OTHER" and self.is_socket and self.value not in self.my_types: return [] - if self.type in ["GROUP", "UNION"]\ - and (len(self.value) == 1)\ - and (not (self.key and self.value[0].key)): + if ( + self.type in ["GROUP", "UNION"] + and (len(self.value) == 1) + and (not (self.key and self.value[0].key)) + ): self.value[0].min_qty *= self.min_qty self.value[0].max_qty *= self.max_qty if not self.value[0].label: @@ -544,9 +591,7 @@ def set_max_value(self, max_value): def type_and_value(self, new_type, value_generator): """Set the self.type and self.value of this element.""" if self.type is not None: - raise TypeError( - "Cannot have two values: %s, %s" % - (self.type, new_type)) + raise TypeError("Cannot have two values: %s, %s" % (self.type, new_type)) if new_type is None: raise TypeError("Cannot set None as type") if new_type == "UNION" and self.value is not None: @@ -588,8 +633,10 @@ def set_default(self, value): if not self.type == value.type: if not (self.type == "INT" and value.type in ["UINT", "NINT"]): - raise TypeError(f"Type of default does not match type of element. " - "({self.type} != {value.type})") + raise TypeError( + f"Type of default does not match type of element. " + "({self.type} != {value.type})" + ) self.default = value.value @@ -600,13 +647,11 @@ def type_and_range(self, new_type, min_val, max_val, inc_end=True): if not inc_end: max_val -= 1 if new_type not in ["INT", "UINT", "NINT"]: - raise TypeError( - "Only integers (not %s) can have range" % - (new_type,)) + raise TypeError("Only integers (not %s) can have range" % (new_type,)) if min_val > max_val: raise TypeError( - "Range has larger minimum than maximum (min %d, max %d)" % - (min_val, max_val)) + "Range has larger minimum than maximum (min %d, max %d)" % (min_val, max_val) + ) if min_val == max_val: return self.type_and_value(new_type, min_val) self.type = new_type @@ -638,19 +683,20 @@ def set_label(self, label): def set_quantifier(self, quantifier): """Set the self.quantifier, self.min_qty, and self.max_qty of this element""" if self.type is not None: - raise TypeError( - "Cannot have quantifier after value: " + quantifier) + raise TypeError("Cannot have quantifier after value: " + quantifier) quantifier_mapping = [ (r"\?", lambda mo: (0, 1)), (r"\*", lambda mo: (0, None)), (r"\+", lambda mo: (1, None)), - (r"(.*?)\*\*?(.*)", - lambda mo: (int(mo.groups()[0] or "0", 0), int(mo.groups()[1] or "0", 0) or None)), + ( + r"(.*?)\*\*?(.*)", + lambda mo: (int(mo.groups()[0] or "0", 0), int(mo.groups()[1] or "0", 0) or None), + ), ] self.quantifier = quantifier - for (reg, handler) in quantifier_mapping: + for reg, handler in quantifier_mapping: match_obj = getrp(reg).match(quantifier) if match_obj: (self.min_qty, self.max_qty) = handler(match_obj) @@ -685,11 +731,10 @@ def set_size_range(self, min_size, max_size_in, inc_end=True): """ max_size = max_size_in if inc_end else max_size_in - 1 - if (min_size and min_size < 0 or max_size and max_size < 0) \ - or (None not in [min_size, max_size] and min_size > max_size): - raise TypeError( - "Invalid size range (min %d, max %d)" % - (min_size, max_size)) + if (min_size and min_size < 0 or max_size and max_size < 0) or ( + None not in [min_size, max_size] and min_size > max_size + ): + raise TypeError("Invalid size range (min %d, max %d)" % (min_size, max_size)) self.set_min_size(min_size) self.set_max_size(max_size) @@ -697,25 +742,21 @@ def set_size_range(self, min_size, max_size_in, inc_end=True): def set_min_size(self, min_size): """Set self.min_size, and self.minValue if type is UINT.""" if self.type == "UINT": - self.minValue = 256**min(0, abs(min_size - 1)) if min_size is not None else None + self.minValue = 256 ** min(0, abs(min_size - 1)) if min_size is not None else None self.min_size = min_size if min_size is not None else None def set_max_size(self, max_size): """Set self.max_size, and self.max_value if type is UINT.""" if self.type == "UINT" and max_size and self.max_value is None: if max_size > 8: - raise TypeError( - "Size too large for integer. size %d" % - max_size) + raise TypeError("Size too large for integer. size %d" % max_size) self.max_value = 256**max_size - 1 self.max_size = max_size def set_cbor(self, cbor, cborseq): """Set the self.cbor of this element. For use during CDDL parsing.""" if self.type != "BSTR": - raise TypeError( - "%s must be used with bstr." % - (".cborseq" if cborseq else ".cbor",)) + raise TypeError("%s must be used with bstr." % (".cborseq" if cborseq else ".cbor",)) self.cbor = cbor if cborseq: self.cbor.max_qty = self.default_max_qty @@ -817,21 +858,32 @@ def cddl_regexes_init(self): # The "range_types" match the contents of brackets i.e. (), [], and {}, # and strings, i.e. ' or " range_types = [ - (r'(?P\[(?P(?>[^[\]]+|(?&bracket))*)\])', - lambda m_self, list_str: m_self.type_and_value( - "LIST", lambda: m_self.parse(list_str))), - (r'(?P\((?P(?>[^\(\)]+|(?&paren))*)\))', - lambda m_self, group_str: m_self.type_and_value( - "GROUP", lambda: m_self.parse(group_str))), - (r'(?P{(?P(?>[^{}]+|(?&curly))*)})', - lambda m_self, map_str: m_self.type_and_value( - "MAP", lambda: m_self.parse(map_str))), - (r'\'(?P.*?)(?.*?)(?\[(?P(?>[^[\]]+|(?&bracket))*)\])", + lambda m_self, list_str: m_self.type_and_value( + "LIST", lambda: m_self.parse(list_str) + ), + ), + ( + r"(?P\((?P(?>[^\(\)]+|(?&paren))*)\))", + lambda m_self, group_str: m_self.type_and_value( + "GROUP", lambda: m_self.parse(group_str) + ), + ), + ( + r"(?P{(?P(?>[^{}]+|(?&curly))*)})", + lambda m_self, map_str: m_self.type_and_value("MAP", lambda: m_self.parse(map_str)), + ), + ( + r"\'(?P.*?)(?.*?)(?.+?)(?=\/\/|\Z)', - lambda m_self, union_str: m_self.union_add_value( - m_self.parse("(%s)" % union_str if ',' in union_str else union_str)[0], - doubleslash=True)), - (r'(?P[^\W\d][\w-]*)\s*:', - self_type.set_key_or_label), - (r'((\=\>)|:)', - lambda m_self, _: m_self.convert_to_key()), - (r'([+*?])', - self_type.set_quantifier), - (r'(' + match_uint + r'\*\*?' + match_uint + r'?)', - self_type.set_quantifier), - (r'\/\s*(?P((' + range_types_regex + r')|[^,\[\]{}()])+?)(?=\/|\Z|,)', - lambda m_self, union_str: m_self.union_add_value( - m_self.parse(union_str)[0])), - (r'(uint|nint|int|float|bstr|tstr|bool|nil|any)(?![\w-])', - lambda m_self, type_str: m_self.type_and_value(type_str.upper(), lambda: None)), - (r'undefined(?!\w)', - lambda m_self, _: m_self.type_and_value("UNDEF", lambda: None)), - (r'float16(?![\w-])', - lambda m_self, _: m_self.type_value_size("FLOAT", lambda: None, 2)), - (r'float16-32(?![\w-])', - lambda m_self, _: m_self.type_value_size_range("FLOAT", lambda: None, 2, 4)), - (r'float32(?![\w-])', - lambda m_self, _: m_self.type_value_size("FLOAT", lambda: None, 4)), - (r'float32-64(?![\w-])', - lambda m_self, _: m_self.type_value_size_range("FLOAT", lambda: None, 4, 8)), - (r'float64(?![\w-])', - lambda m_self, _: m_self.type_value_size("FLOAT", lambda: None, 8)), - (r'\-?\d*\.\d+', - lambda m_self, num: m_self.type_and_value("FLOAT", lambda: float(num))), - (match_uint + r'\.\.' + match_uint, - lambda m_self, _range: m_self.type_and_range( - "UINT", *map(lambda num: int(num, 0), _range.split("..")))), - (match_nint + r'\.\.' + match_uint, - lambda m_self, _range: m_self.type_and_range( - "INT", *map(lambda num: int(num, 0), _range.split("..")))), - (match_nint + r'\.\.' + match_nint, - lambda m_self, _range: m_self.type_and_range( - "NINT", *map(lambda num: int(num, 0), _range.split("..")))), - (match_uint + r'\.\.\.' + match_uint, - lambda m_self, _range: m_self.type_and_range( - "UINT", *map(lambda num: int(num, 0), _range.split("...")), inc_end=False)), - (match_nint + r'\.\.\.' + match_uint, - lambda m_self, _range: m_self.type_and_range( - "INT", *map(lambda num: int(num, 0), _range.split("...")), inc_end=False)), - (match_nint + r'\.\.\.' + match_nint, - lambda m_self, _range: m_self.type_and_range( - "NINT", *map(lambda num: int(num, 0), _range.split("...")), inc_end=False)), - (match_nint, - lambda m_self, num: m_self.type_and_value("NINT", lambda: int(num, 0))), - (match_uint, - lambda m_self, num: m_self.type_and_value("UINT", lambda: int(num, 0))), - (r'true(?!\w)', - lambda m_self, _: m_self.type_and_value("BOOL", lambda: True)), - (r'false(?!\w)', - lambda m_self, _: m_self.type_and_value("BOOL", lambda: False)), - (r'#6\.(?P\d+)', - self_type.add_tag), - (r'(\$?\$?[\w-]+)', - lambda m_self, other_str: m_self.type_and_value("OTHER", lambda: other_str)), - (r'\.size \(?(?P' + match_int + r'\.\.' + match_int + r')\)?', - lambda m_self, _range: m_self.set_size_range( - *map(lambda num: int(num, 0), _range.split("..")))), - (r'\.size \(?(?P' + match_int + r'\.\.\.' + match_int + r')\)?', - lambda m_self, _range: m_self.set_size_range( - *map(lambda num: int(num, 0), _range.split("...")), inc_end=False)), - (r'\.size \(?(?P' + match_uint + r')\)?', - lambda m_self, size: m_self.set_size(int(size, 0))), - (r'\.gt \(?(?P' + match_int + r')\)?', - lambda m_self, minvalue: m_self.set_min_value(int(minvalue, 0) + 1)), - (r'\.lt \(?(?P' + match_int + r')\)?', - lambda m_self, maxvalue: m_self.set_max_value(int(maxvalue, 0) - 1)), - (r'\.ge \(?(?P' + match_int + r')\)?', - lambda m_self, minvalue: m_self.set_min_value(int(minvalue, 0))), - (r'\.le \(?(?P' + match_int + r')\)?', - lambda m_self, maxvalue: m_self.set_max_value(int(maxvalue, 0))), - (r'\.eq \(?(?P' + match_int + r')\)?', - lambda m_self, value: m_self.set_value(lambda: int(value, 0))), - (r'\.eq \"(?P.*?)(?(?>[^\(\)]+|(?1))*)\))', - lambda m_self, type_str: m_self.set_default(m_self.parse(type_str)[0])), - (r'\.default (?P[^\s,]+)', - lambda m_self, type_str: m_self.set_default(m_self.parse(type_str)[0])), - (r'\.cbor (\((?P(?>[^\(\)]+|(?1))*)\))', - lambda m_self, type_str: m_self.set_cbor(m_self.parse(type_str)[0], False)), - (r'\.cbor (?P[^\s,]+)', - lambda m_self, type_str: m_self.set_cbor(m_self.parse(type_str)[0], False)), - (r'\.cborseq (\((?P(?>[^\(\)]+|(?1))*)\))', - lambda m_self, type_str: m_self.set_cbor(m_self.parse(type_str)[0], True)), - (r'\.cborseq (?P[^\s,]+)', - lambda m_self, type_str: m_self.set_cbor(m_self.parse(type_str)[0], True)), - (r'\.bits (?P[\w-]+)', - lambda m_self, bits_str: m_self.set_bits(bits_str)) + ( + r"\/\/\s*(?P.+?)(?=\/\/|\Z)", + lambda m_self, union_str: m_self.union_add_value( + m_self.parse("(%s)" % union_str if "," in union_str else union_str)[0], + doubleslash=True, + ), + ), + (r"(?P[^\W\d][\w-]*)\s*:", self_type.set_key_or_label), + (r"((\=\>)|:)", lambda m_self, _: m_self.convert_to_key()), + (r"([+*?])", self_type.set_quantifier), + (r"(" + match_uint + r"\*\*?" + match_uint + r"?)", self_type.set_quantifier), + ( + r"\/\s*(?P((" + range_types_regex + r")|[^,\[\]{}()])+?)(?=\/|\Z|,)", + lambda m_self, union_str: m_self.union_add_value(m_self.parse(union_str)[0]), + ), + ( + r"(uint|nint|int|float|bstr|tstr|bool|nil|any)(?![\w-])", + lambda m_self, type_str: m_self.type_and_value(type_str.upper(), lambda: None), + ), + (r"undefined(?!\w)", lambda m_self, _: m_self.type_and_value("UNDEF", lambda: None)), + ( + r"float16(?![\w-])", + lambda m_self, _: m_self.type_value_size("FLOAT", lambda: None, 2), + ), + ( + r"float16-32(?![\w-])", + lambda m_self, _: m_self.type_value_size_range("FLOAT", lambda: None, 2, 4), + ), + ( + r"float32(?![\w-])", + lambda m_self, _: m_self.type_value_size("FLOAT", lambda: None, 4), + ), + ( + r"float32-64(?![\w-])", + lambda m_self, _: m_self.type_value_size_range("FLOAT", lambda: None, 4, 8), + ), + ( + r"float64(?![\w-])", + lambda m_self, _: m_self.type_value_size("FLOAT", lambda: None, 8), + ), + ( + r"\-?\d*\.\d+", + lambda m_self, num: m_self.type_and_value("FLOAT", lambda: float(num)), + ), + ( + match_uint + r"\.\." + match_uint, + lambda m_self, _range: m_self.type_and_range( + "UINT", *map(lambda num: int(num, 0), _range.split("..")) + ), + ), + ( + match_nint + r"\.\." + match_uint, + lambda m_self, _range: m_self.type_and_range( + "INT", *map(lambda num: int(num, 0), _range.split("..")) + ), + ), + ( + match_nint + r"\.\." + match_nint, + lambda m_self, _range: m_self.type_and_range( + "NINT", *map(lambda num: int(num, 0), _range.split("..")) + ), + ), + ( + match_uint + r"\.\.\." + match_uint, + lambda m_self, _range: m_self.type_and_range( + "UINT", *map(lambda num: int(num, 0), _range.split("...")), inc_end=False + ), + ), + ( + match_nint + r"\.\.\." + match_uint, + lambda m_self, _range: m_self.type_and_range( + "INT", *map(lambda num: int(num, 0), _range.split("...")), inc_end=False + ), + ), + ( + match_nint + r"\.\.\." + match_nint, + lambda m_self, _range: m_self.type_and_range( + "NINT", *map(lambda num: int(num, 0), _range.split("...")), inc_end=False + ), + ), + (match_nint, lambda m_self, num: m_self.type_and_value("NINT", lambda: int(num, 0))), + (match_uint, lambda m_self, num: m_self.type_and_value("UINT", lambda: int(num, 0))), + (r"true(?!\w)", lambda m_self, _: m_self.type_and_value("BOOL", lambda: True)), + (r"false(?!\w)", lambda m_self, _: m_self.type_and_value("BOOL", lambda: False)), + (r"#6\.(?P\d+)", self_type.add_tag), + ( + r"(\$?\$?[\w-]+)", + lambda m_self, other_str: m_self.type_and_value("OTHER", lambda: other_str), + ), + ( + r"\.size \(?(?P" + match_int + r"\.\." + match_int + r")\)?", + lambda m_self, _range: m_self.set_size_range( + *map(lambda num: int(num, 0), _range.split("..")) + ), + ), + ( + r"\.size \(?(?P" + match_int + r"\.\.\." + match_int + r")\)?", + lambda m_self, _range: m_self.set_size_range( + *map(lambda num: int(num, 0), _range.split("...")), inc_end=False + ), + ), + ( + r"\.size \(?(?P" + match_uint + r")\)?", + lambda m_self, size: m_self.set_size(int(size, 0)), + ), + ( + r"\.gt \(?(?P" + match_int + r")\)?", + lambda m_self, minvalue: m_self.set_min_value(int(minvalue, 0) + 1), + ), + ( + r"\.lt \(?(?P" + match_int + r")\)?", + lambda m_self, maxvalue: m_self.set_max_value(int(maxvalue, 0) - 1), + ), + ( + r"\.ge \(?(?P" + match_int + r")\)?", + lambda m_self, minvalue: m_self.set_min_value(int(minvalue, 0)), + ), + ( + r"\.le \(?(?P" + match_int + r")\)?", + lambda m_self, maxvalue: m_self.set_max_value(int(maxvalue, 0)), + ), + ( + r"\.eq \(?(?P" + match_int + r")\)?", + lambda m_self, value: m_self.set_value(lambda: int(value, 0)), + ), + ( + r"\.eq \"(?P.*?)(?(?>[^\(\)]+|(?1))*)\))", + lambda m_self, type_str: m_self.set_default(m_self.parse(type_str)[0]), + ), + ( + r"\.default (?P[^\s,]+)", + lambda m_self, type_str: m_self.set_default(m_self.parse(type_str)[0]), + ), + ( + r"\.cbor (\((?P(?>[^\(\)]+|(?1))*)\))", + lambda m_self, type_str: m_self.set_cbor(m_self.parse(type_str)[0], False), + ), + ( + r"\.cbor (?P[^\s,]+)", + lambda m_self, type_str: m_self.set_cbor(m_self.parse(type_str)[0], False), + ), + ( + r"\.cborseq (\((?P(?>[^\(\)]+|(?1))*)\))", + lambda m_self, type_str: m_self.set_cbor(m_self.parse(type_str)[0], True), + ), + ( + r"\.cborseq (?P[^\s,]+)", + lambda m_self, type_str: m_self.set_cbor(m_self.parse(type_str)[0], True), + ), + (r"\.bits (?P[\w-]+)", lambda m_self, bits_str: m_self.set_bits(bits_str)), ] def get_value(self, instr): @@ -948,9 +1059,9 @@ def get_value(self, instr): types = type(self).cddl_regexes[type(self)] # Keep parsing until a comma, or to the end of the string. - while instr != '' and instr[0] != ',': + while instr != "" and instr[0] != ",": match_obj = None - for (reg, handler) in types: + for reg, handler in types: match_obj = getrp(reg).match(instr) if match_obj: try: @@ -963,7 +1074,7 @@ def get_value(self, instr): raise Exception("Failed while parsing this: '%s'" % match_str) from e self.match_str += match_str old_len = len(instr) - instr = getrp(reg).sub('', instr, count=1).lstrip() + instr = getrp(reg).sub("", instr, count=1).lstrip() if old_len == len(instr): raise Exception("empty match") break @@ -984,10 +1095,14 @@ def has_key(self): This must have some recursion since CDDL allows the key to be hidden behind layers of indirection. """ - ret = self.key is not None\ - or (self.type == "OTHER" and self.my_types[self.value].has_key())\ - or (self.type in ["GROUP", "UNION"] - and (self.value and all(child.has_key() for child in self.value))) + ret = ( + self.key is not None + or (self.type == "OTHER" and self.my_types[self.value].has_key()) + or ( + self.type in ["GROUP", "UNION"] + and (self.value and all(child.has_key() for child in self.value)) + ) + ) return ret def is_valid_map_elem(self) -> tuple[bool, str]: @@ -1010,31 +1125,42 @@ def post_validate(self): invalid_elems = [child for child in self.value if not child.is_valid_map_elem()[0]] if self.type == "MAP" and invalid_elems: raise TypeError( - "Map member(s) are invalid:\n" + '\n'.join( - [f"{str(c)}: {c.is_valid_map_elem()[1]}" for c in invalid_elems])) + "Map member(s) are invalid:\n" + + "\n".join([f"{str(c)}: {c.is_valid_map_elem()[1]}" for c in invalid_elems]) + ) child_keys = [child for child in self.value if child not in invalid_elems] if self.type == "LIST" and child_keys: raise TypeError( - str(self) + linesep - + "List member(s) cannot have key: " + str(child_keys) + " pointing to " + str(self) + + linesep + + "List member(s) cannot have key: " + + str(child_keys) + + " pointing to " + str( - [self.my_types[elem.value] for elem in child_keys - if elem.type == "OTHER"])) + [self.my_types[elem.value] for elem in child_keys if elem.type == "OTHER"] + ) + ) if self.type == "OTHER": if self.value not in self.my_types.keys() or not isinstance( - self.my_types[self.value], type(self)): + self.my_types[self.value], type(self) + ): raise TypeError("%s has not been parsed." % self.value) if self.type == "LIST": for child in self.value[:-1]: if child.type == "ANY": if child.min_qty != child.max_qty: - raise TypeError(f"ambiguous quantity of 'any' is not supported in list, " - + "except as last element:\n{str(child)}") + raise TypeError( + f"ambiguous quantity of 'any' is not supported in list, " + + "except as last element:\n{str(child)}" + ) if self.type == "UNION" and len(self.value) > 1: - if any(((not child.key and child.type == "ANY") or ( - child.key and child.key.type == "ANY")) for child in self.value): + if any( + ((not child.key and child.type == "ANY") or (child.key and child.key.type == "ANY")) + for child in self.value + ): raise TypeError( - "'any' inside union is not supported since it would always be triggered.") + "'any' inside union is not supported since it would always be triggered." + ) # Validation of child elements. if self.type in ["MAP", "LIST", "UNION", "GROUP"]: @@ -1056,7 +1182,7 @@ def parse(self, instr): """Parses entire instr and returns a list of instances.""" instr = instr.strip() values = [] - while instr != '': + while instr != "": value = type(self)(*self.init_args(), **self.init_kwargs(), base_stem=self.base_stem) instr = value.get_value(instr) values.append(value) @@ -1067,18 +1193,76 @@ def __repr__(self): c_keywords = [ - "alignas", "alignof", "atomic_bool", "atomic_int", "auto", "bool", "break", "case", "char", - "complex", "const", "constexpr", "continue", "default", "do", "double", "else", "enum", - "extern", "false", "float", "for", "goto", "if", "imaginary", "inline", "int", "long", - "noreturn", "nullptr", "register", "restrict", "return", "short", "signed", "sizeof", "static", - "static_assert", "struct", "switch", "thread_local", "true", "typedef", "typeof", - "typeof_unqual", "union", "unsigned", "void", "volatile", "while"] + "alignas", + "alignof", + "atomic_bool", + "atomic_int", + "auto", + "bool", + "break", + "case", + "char", + "complex", + "const", + "constexpr", + "continue", + "default", + "do", + "double", + "else", + "enum", + "extern", + "false", + "float", + "for", + "goto", + "if", + "imaginary", + "inline", + "int", + "long", + "noreturn", + "nullptr", + "register", + "restrict", + "return", + "short", + "signed", + "sizeof", + "static", + "static_assert", + "struct", + "switch", + "thread_local", + "true", + "typedef", + "typeof", + "typeof_unqual", + "union", + "unsigned", + "void", + "volatile", + "while", +] c_keywords_underscore = [ - "_Alignas", "_Alignof", "_Atomic", "_BitInt", "_Bool", "_Complex", "_Decimal128", "_Decimal32", - "_Decimal64", "_Generic", "_Imaginary", "_Noreturn", "_Pragma", "_Static_assert", - "_Thread_local"] + "_Alignas", + "_Alignof", + "_Atomic", + "_BitInt", + "_Bool", + "_Complex", + "_Decimal128", + "_Decimal32", + "_Decimal64", + "_Generic", + "_Imaginary", + "_Noreturn", + "_Pragma", + "_Static_assert", + "_Thread_local", +] class CddlXcoder(CddlParser): @@ -1096,8 +1280,12 @@ def __init__(self, *args, **kwargs): def var_name(self, with_prefix=False, observe_skipped=True): """Name of variables and enum members for this element.""" - if (observe_skipped and self.skip_condition() - and self.type in ["LIST", "MAP", "GROUP"] and self.value): + if ( + observe_skipped + and self.skip_condition() + and self.type in ["LIST", "MAP", "GROUP"] + and self.value + ): return self.value[0].var_name(with_prefix) name = self.id(with_prefix=with_prefix) if name in c_keywords: @@ -1115,9 +1303,11 @@ def skip_condition(self): return False def set_skipped(self, skipped): - if self.range_check_condition() \ - and self.repeated_single_func_impl_condition() \ - and not self.key: + if ( + self.range_check_condition() + and self.repeated_single_func_impl_condition() + and not self.key + ): self.skipped = True else: self.skipped = skipped @@ -1136,13 +1326,19 @@ def set_access_prefix(self, prefix, is_delegated=False): self.accessPrefix = prefix if self.type in ["LIST", "MAP", "GROUP", "UNION"]: self.set_skipped(self.skip_condition()) - list(map(lambda child: child.set_skipped(child.skip_condition()), - self.value)) - list(map(lambda child: child.set_access_prefix( - self.var_access(), - is_delegated=(self.delegate_type_condition() - or (is_delegated and self.skip_condition()))), - self.value)) + list(map(lambda child: child.set_skipped(child.skip_condition()), self.value)) + list( + map( + lambda child: child.set_access_prefix( + self.var_access(), + is_delegated=( + self.delegate_type_condition() + or (is_delegated and self.skip_condition()) + ), + ), + self.value, + ) + ) elif self in self.my_types.values(): self.set_skipped(not self.multi_member()) if self.key is not None: @@ -1158,22 +1354,30 @@ def multi_member(self): def is_unambiguous_value(self): """Whether this element is a non-compound value that can be known a priori.""" - return (self.type in ["NIL", "UNDEF", "ANY"] - or (self.type in ["INT", "NINT", "UINT", "FLOAT", "BSTR", "TSTR", "BOOL"] - and self.value is not None) - or (self.type == "OTHER" and self.my_types[self.value].is_unambiguous())) + return ( + self.type in ["NIL", "UNDEF", "ANY"] + or ( + self.type in ["INT", "NINT", "UINT", "FLOAT", "BSTR", "TSTR", "BOOL"] + and self.value is not None + ) + or (self.type == "OTHER" and self.my_types[self.value].is_unambiguous()) + ) def is_unambiguous_repeated(self): """Whether the repeated part of this element is known a priori.""" - return (self.is_unambiguous_value() - and (self.key is None or self.key.is_unambiguous_repeated()) - or (self.type in ["LIST", "GROUP", "MAP"] and len(self.value) == 0) - or (self.type in ["LIST", "GROUP", "MAP"] - and all((child.is_unambiguous() for child in self.value)))) + return ( + self.is_unambiguous_value() + and (self.key is None or self.key.is_unambiguous_repeated()) + or (self.type in ["LIST", "GROUP", "MAP"] and len(self.value) == 0) + or ( + self.type in ["LIST", "GROUP", "MAP"] + and all((child.is_unambiguous() for child in self.value)) + ) + ) def is_unambiguous(self): """Whether or not we can know the exact encoding of this element a priori.""" - return (self.is_unambiguous_repeated() and (self.min_qty == self.max_qty)) + return self.is_unambiguous_repeated() and (self.min_qty == self.max_qty) def access_append_delimiter(self, prefix, delimiter, *suffix): """Create an access prefix based on an existing prefix, delimiter and a @@ -1187,16 +1391,16 @@ def access_append(self, *suffix): provided suffix. """ suffix = list(suffix) - return self.access_append_delimiter(self.accessPrefix, '.', *suffix) + return self.access_append_delimiter(self.accessPrefix, ".", *suffix) def var_access(self): - """"Path" to this element's variable.""" + """ "Path" to this element's variable.""" if self.is_unambiguous(): return "NULL" return self.access_append() def val_access(self, top_level=False): - """"Path" to access this element's actual value variable.""" + """ "Path" to access this element's actual value variable.""" if self.is_unambiguous_repeated(): ret = "NULL" elif self.skip_condition() or self.is_delegated_type(): @@ -1214,7 +1418,7 @@ def repeated_val_access(self): def optional_quantifier(self): """Whether the element has the "optional" quantifier ('?').""" - return (self.min_qty == 0 and isinstance(self.max_qty, int) and self.max_qty <= 1) + return self.min_qty == 0 and isinstance(self.max_qty, int) and self.max_qty <= 1 def present_var_condition(self): """Whether to include a "present" variable for this element.""" @@ -1226,8 +1430,9 @@ def count_var_condition(self): def is_cbor(self): """Whether to include a "cbor" variable for this element.""" - return (self.type not in ["NIL", "UNDEF", "ANY"]) \ - and ((self.type != "OTHER") or (self.my_types[self.value].is_cbor())) + return (self.type not in ["NIL", "UNDEF", "ANY"]) and ( + (self.type != "OTHER") or (self.my_types[self.value].is_cbor()) + ) def cbor_var_condition(self): """Whether to include a "cbor" variable for this element.""" @@ -1249,9 +1454,11 @@ def key_var_condition(self): return True if self.type == "OTHER" and self.my_types[self.value].key_var_condition(): return True - if (self.type in ["GROUP", "UNION"] - and len(self.value) >= 1 - and self.value[0].reduced_key_var_condition()): + if ( + self.type in ["GROUP", "UNION"] + and len(self.value) >= 1 + and self.value[0].reduced_key_var_condition() + ): return True return False @@ -1259,15 +1466,13 @@ def self_repeated_multi_var_condition(self): """Whether this value adds any repeated elements by itself. I.e. excluding multiple elements from children. """ - return (self.key_var_condition() - or self.cbor_var_condition() - or self.choice_var_condition()) + return self.key_var_condition() or self.cbor_var_condition() or self.choice_var_condition() def multi_val_condition(self): """Whether this element's actual value has multiple members.""" - return ( - self.type in ["LIST", "MAP", "GROUP", "UNION"] - and (len(self.value) > 1 or (len(self.value) == 1 and self.value[0].multi_member()))) + return self.type in ["LIST", "MAP", "GROUP", "UNION"] and ( + len(self.value) > 1 or (len(self.value) == 1 and self.value[0].multi_member()) + ) def repeated_multi_var_condition(self): """Whether any extra variables are to be included for this element for each @@ -1291,13 +1496,15 @@ def range_check_condition(self): return False if self.value is not None: return False - if self.type in ["INT", "NINT", "UINT"] \ - and (self.min_value is not None or self.max_value is not None): + if self.type in ["INT", "NINT", "UINT"] and ( + self.min_value is not None or self.max_value is not None + ): return True if self.type == "UINT" and self.bits: return True - if self.type in ["BSTR", "TSTR"] \ - and (self.min_size is not None or self.max_size is not None): + if self.type in ["BSTR", "TSTR"] and ( + self.min_size is not None or self.max_size is not None + ): return True return False @@ -1312,7 +1519,8 @@ def repeated_type_def_condition(self): return ( self.repeated_multi_var_condition() and self.multi_var_condition() - and not self.is_unambiguous_repeated()) + and not self.is_unambiguous_repeated() + ) def single_func_impl_condition(self): """Whether this element needs its own encoder/decoder function.""" @@ -1323,15 +1531,19 @@ def single_func_impl_condition(self): or (self.tags and self in self.my_types.values()) or self.type_def_condition() or (self.type in ["LIST", "MAP"]) - or (self.type == "GROUP" and len(self.value) != 0)) + or (self.type == "GROUP" and len(self.value) != 0) + ) def repeated_single_func_impl_condition(self): """Whether this element needs its own encoder/decoder function.""" - return self.repeated_type_def_condition() \ - or (self.type in ["LIST", "MAP", "GROUP"] and self.multi_member()) \ + return ( + self.repeated_type_def_condition() + or (self.type in ["LIST", "MAP", "GROUP"] and self.multi_member()) or ( self.multi_var_condition() - and (self.self_repeated_multi_var_condition() or self.range_check_condition())) + and (self.self_repeated_multi_var_condition() or self.range_check_condition()) + ) + ) def int_val(self): """If this element is an integer, or starts with an integer, return the integer value.""" @@ -1341,10 +1553,12 @@ def int_val(self): return self.value elif self.type == "GROUP" and not self.count_var_condition(): return self.value[0].int_val() - elif self.type == "OTHER" \ - and not self.count_var_condition() \ - and not self.single_func_impl_condition() \ - and not self.my_types[self.value].single_func_impl_condition(): + elif ( + self.type == "OTHER" + and not self.count_var_condition() + and not self.single_func_impl_condition() + and not self.my_types[self.value].single_func_impl_condition() + ): return self.my_types[self.value].int_val() return None @@ -1362,8 +1576,12 @@ def all_children_disambiguated(self, min_val, max_val): The min_val and max_val are to check whether the integers are within a certain range. """ values = set(child.int_val() for child in self.value) - retval = (len(values) == len(self.value)) and None not in values \ - and max(values) <= max_val and min(values) >= min_val + retval = ( + (len(values) == len(self.value)) + and None not in values + and max(values) <= max_val + and min(values) >= min_val + ) return retval def all_children_int_disambiguated(self): @@ -1400,8 +1618,11 @@ def enum_var_name(self): def enum_var(self, int_val=False): """Enum entry for this element.""" - return f"{self.enum_var_name()} = {val_to_str(self.int_val())}" \ - if int_val else self.enum_var_name() + return ( + f"{self.enum_var_name()} = {val_to_str(self.int_val())}" + if int_val + else self.enum_var_name() + ) def choice_var_access(self): """Full "path" of the "choice" variable for this element.""" @@ -1416,6 +1637,7 @@ class KeyTuple(tuple): """Subclass of tuple for holding key,value pairs. This is to make it possible to use isinstance() to separate it from other tuples.""" + def __new__(cls, *in_tuple): return super(KeyTuple, cls).__new__(cls, *in_tuple) @@ -1466,7 +1688,8 @@ def _decode_assert(self, test, msg=""): """Check a condition and raise a CddlValidationError if not.""" if not test: raise CddlValidationError( - f"Data did not decode correctly {'(' + msg + ')' if msg else ''}") + f"Data did not decode correctly {'(' + msg + ')' if msg else ''}" + ) def _check_tag(self, obj): """Check that no unexpected tags are attached to this data. @@ -1511,18 +1734,22 @@ def _check_type(self, obj): exp_type = self._expected_type() self._decode_assert( type(obj) in exp_type, - f"{str(self)}: Wrong type ({type(obj)}) of {str(obj)}, expected {str(exp_type)}") + f"{str(self)}: Wrong type ({type(obj)}) of {str(obj)}, expected {str(exp_type)}", + ) def _check_value(self, obj): """Check that the decode value conforms to the restrictions in the CDDL.""" - if self.type in ["UINT", "INT", "NINT", "FLOAT", "TSTR", "BSTR", "BOOL"] \ - and self.value is not None: + if ( + self.type in ["UINT", "INT", "NINT", "FLOAT", "TSTR", "BSTR", "BOOL"] + and self.value is not None + ): value = self.value if self.type == "BSTR": value = self.value.encode("utf-8") self._decode_assert( self.value == obj, - f"{obj} should have value {self.value} according to {self.var_name()}") + f"{obj} should have value {self.value} according to {self.var_name()}", + ) if self.type in ["UINT", "INT", "NINT", "FLOAT"]: if self.min_value is not None: self._decode_assert(obj >= self.min_value, "Minimum value: " + str(self.min_value)) @@ -1535,15 +1762,18 @@ def _check_value(self, obj): if self.type in ["TSTR", "BSTR"]: if self.min_size is not None: self._decode_assert( - len(obj) >= self.min_size, "Minimum length: " + str(self.min_size)) + len(obj) >= self.min_size, "Minimum length: " + str(self.min_size) + ) if self.max_size is not None: self._decode_assert( - len(obj) <= self.max_size, "Maximum length: " + str(self.max_size)) + len(obj) <= self.max_size, "Maximum length: " + str(self.max_size) + ) def _check_key(self, obj): """Check that the object is not a KeyTuple, which would mean it's not properly processed.""" self._decode_assert( - not isinstance(obj, KeyTuple), "Unexpected key found: (key,value)=" + str(obj)) + not isinstance(obj, KeyTuple), "Unexpected key found: (key,value)=" + str(obj) + ) def _flatten_obj(self, obj): """Recursively remove intermediate objects that have single members. Keep lists as is.""" @@ -1553,11 +1783,13 @@ def _flatten_obj(self, obj): def _flatten_list(self, name, obj): """Return the contents of a list if it has a single member and the same name as us.""" - if (isinstance(obj, list) - and len(obj) == 1 - and (isinstance(obj[0], list) or isinstance(obj[0], tuple)) - and len(obj[0]) == 1 - and hasattr(obj[0], name)): + if ( + isinstance(obj, list) + and len(obj) == 1 + and (isinstance(obj[0], list) or isinstance(obj[0], tuple)) + and len(obj[0]) == 1 + and hasattr(obj[0], name) + ): return [obj[0][0]] return obj @@ -1570,10 +1802,11 @@ def _construct_obj(self, my_list): return None names, values = tuple(zip(*my_list)) if len(values) == 1: - values = (self._flatten_obj(values[0]), ) + values = (self._flatten_obj(values[0]),) values = tuple(self._flatten_list(names[i], values[i]) for i in range(len(values))) - assert (not any((isinstance(elem, KeyTuple) for elem in values))), \ - f"KeyTuple not processed: {values}" + assert not any( + (isinstance(elem, KeyTuple) for elem in values) + ), f"KeyTuple not processed: {values}" return namedtuple("_", names)(*values) def _add_if(self, my_list, obj, expect_key=False, name=None): @@ -1593,8 +1826,9 @@ def _add_if(self, my_list, obj, expect_key=False, name=None): self._add_if(retvals, obj[i]) obj[i] = self._construct_obj(retvals) if self.type == "BSTR" and self.cbor_var_condition() and isinstance(obj[i], bytes): - assert all((isinstance(o, bytes) for o in obj)), \ - """Unsupported configuration for cbor bstr. If a list contains a + assert all( + (isinstance(o, bytes) for o in obj) + ), """Unsupported configuration for cbor bstr. If a list contains a CBOR-formatted bstr, all elements must be bstrs. If not, it is a programmer error.""" if isinstance(obj, KeyTuple): key, obj = obj @@ -1624,7 +1858,8 @@ def _iter_is_empty(self, it): return True raise CddlValidationError( f"Iterator not consumed while parsing \n{self}\nRemaining elements:\n elem: " - + "\n elem: ".join(str(elem) for elem in ([val] + list(it)))) + + "\n elem: ".join(str(elem) for elem in ([val] + list(it))) + ) def _iter_next(self, it): """Get next element from iterator, throw CddlValidationError instead of StopIteration.""" @@ -1640,8 +1875,18 @@ def _decode_single_obj(self, obj): obj = self._check_tag(obj) self._check_type(obj) self._check_value(obj) - if self.type in ["UINT", "INT", "NINT", "FLOAT", "TSTR", - "BSTR", "BOOL", "NIL", "UNDEF", "ANY"]: + if self.type in [ + "UINT", + "INT", + "NINT", + "FLOAT", + "TSTR", + "BSTR", + "BOOL", + "NIL", + "UNDEF", + "ANY", + ]: return obj elif self.type == "OTHER": return self.my_types[self.value]._decode_single_obj(obj) @@ -1678,7 +1923,8 @@ def _decode_single_obj(self, obj): def _handle_key(self, next_obj): """Decode key and value in the form of a KeyTuple""" self._decode_assert( - isinstance(next_obj, KeyTuple), f"Expected key: {self.key} value=" + pformat(next_obj)) + isinstance(next_obj, KeyTuple), f"Expected key: {self.key} value=" + pformat(next_obj) + ) key, obj = next_obj key_res = self.key._decode_single_obj(key) obj_res = self._decode_single_obj(obj) @@ -1833,7 +2079,9 @@ def _to_yaml_obj(self, obj): for key, val in obj.items(): if not isinstance(key, str): retval[f"zcbor_keyval{i}"] = { - "key": self._to_yaml_obj(key), "val": self._to_yaml_obj(val)} + "key": self._to_yaml_obj(key), + "val": self._to_yaml_obj(val), + } i += 1 else: retval[key] = self._to_yaml_obj(val) @@ -1846,7 +2094,7 @@ def _to_yaml_obj(self, obj): # failed decoding bstr_obj = obj.hex() else: - if f.read(1) != b'': + if f.read(1) != b"": # not fully decoded bstr_obj = obj.hex() return {"zcbor_bstr": bstr_obj} @@ -1895,8 +2143,8 @@ def str_to_c_code(self, cbor_str, var_name, columns=0): """CBOR bytestring => C code (uint8_t array initialization)""" arr = ", ".join(f"0x{c:02x}" for c in cbor_str) if columns: - arr = '\n' + indent("\n".join(wrap(arr, 6 * columns)), '\t') + '\n' - return f'uint8_t {var_name}[] = {{{arr}}};\n' + arr = "\n" + indent("\n".join(wrap(arr, 6 * columns)), "\t") + "\n" + return f"uint8_t {var_name}[] = {{{arr}}};\n" class XcoderTuple(NamedTuple): @@ -1911,8 +2159,8 @@ class CddlTypes(NamedTuple): class CodeGenerator(CddlXcoder): - """Class for generating C code that encode/decodes CBOR and validates it according to the CDDL. - """ + """Class for generating C code that encode/decodes CBOR and validates it according to the CDDL.""" + def __init__(self, mode, entry_type_names, default_bit_size, *args, **kwargs): super(CodeGenerator, self).__init__(*args, **kwargs) self.mode = mode @@ -1935,8 +2183,11 @@ def is_entry_type(self): def is_cbor(self): """Whether to include a "cbor" variable for this element.""" - res = (self.type_name() is not None) and not self.is_entry_type() and ( - (self.type != "OTHER") or self.my_types[self.value].is_cbor()) + res = ( + (self.type_name() is not None) + and not self.is_entry_type() + and ((self.type != "OTHER") or self.my_types[self.value].is_cbor()) + ) return res def init_args(self): @@ -1944,10 +2195,12 @@ def init_args(self): def delegate_type_condition(self): """Whether to use the C type of the first child as this type's C type""" - ret = self.skip_condition() and (self.multi_var_condition() - or self.self_repeated_multi_var_condition() - or self.range_check_condition() - or (self in self.my_types.values())) + ret = self.skip_condition() and ( + self.multi_var_condition() + or self.self_repeated_multi_var_condition() + or self.range_check_condition() + or (self in self.my_types.values()) + ) return ret def is_delegated_type(self): @@ -2012,11 +2265,11 @@ def bit_size(self): bit_size = 32 for v in [self.value or 0, self.max_value or 0, self.min_value or 0]: - if (type(v) is str): + if type(v) is str: if "64" in v: bit_size = 64 elif self.type == "UINT": - if (v > UINT32_MAX): + if v > UINT32_MAX: bit_size = 64 else: if (v > INT32_MAX) or (v < INT32_MIN): @@ -2092,8 +2345,9 @@ def add_var_name(self, var_type, full=False, anonymous=False): Make it an array if the element is repeated. """ if var_type: - assert (var_type[-1][-1] == "}" or len(var_type) == 1), \ - f"Expected single var: {var_type!r}" + assert ( + var_type[-1][-1] == "}" or len(var_type) == 1 + ), f"Expected single var: {var_type!r}" if not anonymous or var_type[-1][-1] != "}": var_name = self.var_name() array_part = f"[{self.max_qty}]" if full and self.max_qty != 1 else "" @@ -2170,8 +2424,9 @@ def full_declaration(self): decl = [] else: decl = self.add_var_name( - [self.repeated_type_name()] - if self.repeated_type_name() is not None else [], full=True) + [self.repeated_type_name()] if self.repeated_type_name() is not None else [], + full=True, + ) else: decl = self.repeated_declaration() @@ -2208,8 +2463,8 @@ def type_def(self): ret_val = [] if self.type in ["LIST", "MAP", "GROUP", "UNION"]: ret_val.extend( - [elem for typedef in [ - child.type_def() for child in self.value] for elem in typedef]) + [elem for typedef in [child.type_def() for child in self.value] for elem in typedef] + ) if self.bits: ret_val.extend(self.my_control_groups[self.bits].type_def_bits()) if self.cbor_var_condition(): @@ -2257,7 +2512,7 @@ def float_prefix(self): def single_func_prim_prefix(self): if self.type == "OTHER": return self.my_types[self.value].single_func_prim_prefix() - return ({ + return { "INT": f"zcbor_int{self.bit_size()}", "UINT": f"zcbor_uint{self.bit_size()}", "NINT": f"zcbor_int{self.bit_size()}", @@ -2268,7 +2523,7 @@ def single_func_prim_prefix(self): "NIL": f"zcbor_nil", "UNDEF": f"zcbor_undefined", "ANY": f"zcbor_any", - }[self.type]) + }[self.type] def xcode_func_name(self): """Name of the encoder/decoder function for this element.""" @@ -2290,8 +2545,7 @@ def single_func_prim_name(self, union_int=None, ptr_result=False): elif not union_int: func = f"{func_prefix}_{'pexpect' if ptr_variant else 'expect'}" elif union_int == "EXPECT": - assert not ptr_variant, \ - "Programmer error: invalid use of expect_union." + assert not ptr_variant, "Programmer error: invalid use of expect_union." func = f"{func_prefix}_expect_union" elif union_int == "DROP": return None @@ -2332,7 +2586,7 @@ def single_func_prim(self, access, union_int=None, ptr_result=False): arg = tmp_str_or_null(self.value) elif self.type in ["UINT", "INT", "NINT", "FLOAT", "BOOL"]: value = val_to_str(self.value) - arg = (f"&({self.val_type_name()}){{{value}}}" if ptr_result else value) + arg = f"&({self.val_type_name()}){{{value}}}" if ptr_result else value else: assert False, "Should not come here." @@ -2343,8 +2597,9 @@ def single_func(self, access=None, union_int=None, ptr_result=False): if self.single_func_impl_condition(): return (self.xcode_func_name(), deref_if_not_null(access or self.var_access())) else: - return self.single_func_prim(access or self.val_access(), union_int, - ptr_result=ptr_result) + return self.single_func_prim( + access or self.val_access(), union_int, ptr_result=ptr_result + ) def repeated_single_func(self, ptr_result=False): """Return the function name and arguments to call to encode/decode the repeated @@ -2356,7 +2611,7 @@ def repeated_single_func(self, ptr_result=False): return self.single_func_prim(self.repeated_val_access(), ptr_result=ptr_result) def has_backup(self): - return (self.cbor_var_condition() or self.type in ["LIST", "MAP", "UNION"]) + return self.cbor_var_condition() or self.type in ["LIST", "MAP", "UNION"] def num_backups(self): """Calculate the number of state var backups needed for this element and all descendants.""" @@ -2400,7 +2655,7 @@ def xcode_single_func_prim(self, union_int=None, top_level=False): def list_counts(self): """Recursively sum the total minimum and maximum element count for this element.""" - retval = ({ + retval = { "INT": lambda: (self.min_qty, self.max_qty), "UINT": lambda: (self.min_qty, self.max_qty), "NINT": lambda: (self.min_qty, self.max_qty), @@ -2415,13 +2670,19 @@ def list_counts(self): "LIST": lambda: (self.min_qty, self.max_qty), # Maps are their own element "MAP": lambda: (self.min_qty, self.max_qty), - "GROUP": lambda: (self.min_qty * sum((child.list_counts()[0] for child in self.value)), - self.max_qty * sum((child.list_counts()[1] for child in self.value))), - "UNION": lambda: (self.min_qty * min((child.list_counts()[0] for child in self.value)), - self.max_qty * max((child.list_counts()[1] for child in self.value))), - "OTHER": lambda: (self.min_qty * self.my_types[self.value].list_counts()[0], - self.max_qty * self.my_types[self.value].list_counts()[1]), - }[self.type]()) + "GROUP": lambda: ( + self.min_qty * sum((child.list_counts()[0] for child in self.value)), + self.max_qty * sum((child.list_counts()[1] for child in self.value)), + ), + "UNION": lambda: ( + self.min_qty * min((child.list_counts()[0] for child in self.value)), + self.max_qty * max((child.list_counts()[1] for child in self.value)), + ), + "OTHER": lambda: ( + self.min_qty * self.my_types[self.value].list_counts()[0], + self.max_qty * self.my_types[self.value].list_counts()[1], + ), + }[self.type]() return retval def xcode_list(self): @@ -2430,32 +2691,40 @@ def xcode_list(self): end_func = f"zcbor_{self.type.lower()}_end_{self.mode}" end_func_force = f"zcbor_list_map_end_force_{self.mode}" assert start_func in [ - "zcbor_list_start_decode", "zcbor_list_start_encode", - "zcbor_map_start_decode", "zcbor_map_start_encode"] + "zcbor_list_start_decode", + "zcbor_list_start_encode", + "zcbor_map_start_decode", + "zcbor_map_start_encode", + ] assert end_func in [ - "zcbor_list_end_decode", "zcbor_list_end_encode", - "zcbor_map_end_decode", "zcbor_map_end_encode"] - assert self.type in ["LIST", "MAP"], \ - "Expected LIST or MAP type, was %s." % self.type - _, max_counts = zip( - *(child.list_counts() for child in self.value)) if self.value else ((0,), (0,)) - count_arg = f', {str(sum(max_counts))}' if self.mode == 'encode' else '' + "zcbor_list_end_decode", + "zcbor_list_end_encode", + "zcbor_map_end_decode", + "zcbor_map_end_encode", + ] + assert self.type in ["LIST", "MAP"], "Expected LIST or MAP type, was %s." % self.type + _, max_counts = ( + zip(*(child.list_counts() for child in self.value)) if self.value else ((0,), (0,)) + ) + count_arg = f", {str(sum(max_counts))}" if self.mode == "encode" else "" with_children = "(%s && ((%s) || (%s, false)) && %s)" % ( f"{start_func}(state{count_arg})", f"{newl_ind}&& ".join(child.full_xcode() for child in self.value), f"{end_func_force}(state)", - f"{end_func}(state{count_arg})") + f"{end_func}(state{count_arg})", + ) without_children = "(%s && %s)" % ( f"{start_func}(state{count_arg})", - f"{end_func}(state{count_arg})") + f"{end_func}(state{count_arg})", + ) return with_children if len(self.value) > 0 else without_children def xcode_group(self, union_int=None): """Return the full code needed to encode/decode a "GROUP" element's children.""" assert self.type in ["GROUP"], "Expected GROUP type." return "(%s)" % (newl_ind + "&& ").join( - [self.value[0].full_xcode(union_int)] - + [child.full_xcode() for child in self.value[1:]]) + [self.value[0].full_xcode(union_int)] + [child.full_xcode() for child in self.value[1:]] + ) def xcode_union(self): """Return the full code needed to encode/decode a "UNION" element's children.""" @@ -2464,61 +2733,89 @@ def xcode_union(self): if self.all_children_int_disambiguated(): lines = [] lines.extend( - ["((%s == %s) && (%s))" % - (self.choice_var_access(), child.enum_var_name(), - child.full_xcode(union_int="DROP")) - for child in self.value]) + [ + "((%s == %s) && (%s))" + % ( + self.choice_var_access(), + child.enum_var_name(), + child.full_xcode(union_int="DROP"), + ) + for child in self.value + ] + ) bit_size = self.value[0].bit_size() - func = f"zcbor_uint_{self.mode}" if self.all_children_uint_disambiguated() else \ - f"zcbor_int_{self.mode}" + func = ( + f"zcbor_uint_{self.mode}" + if self.all_children_uint_disambiguated() + else f"zcbor_int_{self.mode}" + ) return "((%s) && (%s))" % ( f"({func}(state, &{self.choice_var_access()}, " + f"sizeof({self.choice_var_access()})))", - "((" + f"{newl_ind}|| ".join(lines) - + ") || (zcbor_error(state, ZCBOR_ERR_WRONG_VALUE), false))",) - - child_values = ["(%s && ((%s = %s), true))" % - (child.full_xcode( - union_int="EXPECT" if child.is_int_disambiguated() else None), - self.choice_var_access(), child.enum_var_name()) - for child in self.value] + "((" + + f"{newl_ind}|| ".join(lines) + + ") || (zcbor_error(state, ZCBOR_ERR_WRONG_VALUE), false))", + ) + + child_values = [ + "(%s && ((%s = %s), true))" + % ( + child.full_xcode(union_int="EXPECT" if child.is_int_disambiguated() else None), + self.choice_var_access(), + child.enum_var_name(), + ) + for child in self.value + ] # Reset state for all but the first child. for i in range(1, len(child_values)): - if ((not self.value[i].is_int_disambiguated()) - and self.value[i - 1].simple_func_condition()): + if (not self.value[i].is_int_disambiguated()) and self.value[ + i - 1 + ].simple_func_condition(): child_values[i] = f"(zcbor_union_elem_code(state) && {child_values[i]})" child_code = f"{newl_ind}|| ".join(child_values) - return f"(zcbor_union_start_code(state) "\ + return ( + f"(zcbor_union_start_code(state) " + f"&& (int_res = ({child_code}), zcbor_union_end_code(state), int_res))" + ) else: return ternary_if_chain( self.choice_var_access(), [child.enum_var_name() for child in self.value], - [child.full_xcode() for child in self.value]) + [child.full_xcode() for child in self.value], + ) def xcode_bstr(self): if self.cbor and not self.cbor.is_entry_type(): - access_arg = f', {deref_if_not_null(self.val_access())}' if self.mode == 'decode' \ - else '' - res_arg = f', &tmp_str' if self.mode == 'encode' \ - else '' - xcode_cbor = "(%s)" % ((newl_ind + "&& ").join( - [f"zcbor_bstr_start_{self.mode}(state{access_arg})", - f"(int_res = ({self.cbor.full_xcode()}), " - f"zcbor_bstr_end_{self.mode}(state{res_arg}), int_res)"])) + access_arg = ( + f", {deref_if_not_null(self.val_access())}" if self.mode == "decode" else "" + ) + res_arg = f", &tmp_str" if self.mode == "encode" else "" + xcode_cbor = "(%s)" % ( + (newl_ind + "&& ").join( + [ + f"zcbor_bstr_start_{self.mode}(state{access_arg})", + f"(int_res = ({self.cbor.full_xcode()}), " + f"zcbor_bstr_end_{self.mode}(state{res_arg}), int_res)", + ] + ) + ) if self.mode == "decode" or self.is_unambiguous(): return xcode_cbor else: - return f"({self.val_access()}.value " \ - f"? (memcpy(&tmp_str, &{self.val_access()}, sizeof(tmp_str)), " \ + return ( + f"({self.val_access()}.value " + f"? (memcpy(&tmp_str, &{self.val_access()}, sizeof(tmp_str)), " f"{self.xcode_single_func_prim()}) : ({xcode_cbor}))" + ) return self.xcode_single_func_prim() def xcode_tags(self): - return [f"zcbor_tag_{'put' if (self.mode == 'encode') else 'expect'}(state, {tag})" - for tag in self.tags] + return [ + f"zcbor_tag_{'put' if (self.mode == 'encode') else 'expect'}(state, {tag})" + for tag in self.tags + ] def value_suffix(self, value_str): """Appends ULL or LL if a value exceeding 32-bits is used""" @@ -2548,21 +2845,32 @@ def range_checks(self, access): if self.type in ["INT", "UINT", "NINT", "FLOAT", "BOOL"]: if min_val is not None and min_val == max_val: - range_checks.append(f"({access} == {val_to_str(min_val)}" - f"{self.value_suffix(val_to_str(min_val))})") + range_checks.append( + f"({access} == {val_to_str(min_val)}" + f"{self.value_suffix(val_to_str(min_val))})" + ) else: if min_val is not None: - range_checks.append(f"({access} >= {val_to_str(min_val)}" - f"{self.value_suffix(val_to_str(min_val))})") + range_checks.append( + f"({access} >= {val_to_str(min_val)}" + f"{self.value_suffix(val_to_str(min_val))})" + ) if max_val is not None: - range_checks.append(f"({access} <= {val_to_str(max_val)}" - f"{self.value_suffix(val_to_str(max_val))})") + range_checks.append( + f"({access} <= {val_to_str(max_val)}" + f"{self.value_suffix(val_to_str(max_val))})" + ) if self.bits: range_checks.append( f"!({access} & ~(" - + ' | '.join([f'(1 << {c.enum_var_name()})' - for c in self.my_control_groups[self.bits].value]) - + "))") + + " | ".join( + [ + f"(1 << {c.enum_var_name()})" + for c in self.my_control_groups[self.bits].value + ] + ) + + "))" + ) elif self.type in ["BSTR", "TSTR"]: if self.min_size is not None and self.min_size == self.max_size: range_checks.append(f"({access}.len == {val_to_str(self.min_size)})") @@ -2577,8 +2885,9 @@ def range_checks(self, access): if range_checks: range_checks[0] = "((" + range_checks[0] - range_checks[-1] = range_checks[-1] \ - + ") || (zcbor_error(state, ZCBOR_ERR_WRONG_RANGE), false))" + range_checks[-1] = ( + range_checks[-1] + ") || (zcbor_error(state, ZCBOR_ERR_WRONG_RANGE), false))" + ) return range_checks @@ -2592,6 +2901,7 @@ def repeated_xcode(self, union_int=None, top_level=False): def do_xcode_single_func_prim(inner_union_int=None): return self.xcode_single_func_prim(union_int=inner_union_int, top_level=top_level) + xcoder = { "INT": do_xcode_single_func_prim, "UINT": lambda: do_xcode_single_func_prim(val_union_int), @@ -2643,39 +2953,49 @@ def full_xcode(self, union_int=None, top_level=False): func, *arguments = self.repeated_single_func(ptr_result=False) return f"(!{self.present_var_access()} || {func}({xcode_args(*arguments)}))" else: - assert self.mode == "decode", \ - f"This code needs self.mode to be 'decode', not {self.mode}." + assert ( + self.mode == "decode" + ), f"This code needs self.mode to be 'decode', not {self.mode}." assign = not self.repeated_single_func_impl_condition() default_assignment = None if self.default is not None: - default_value = (f"*({tmp_str_or_null(self.default)})" - if self.type in ["TSTR", "BSTR"] else val_to_str(self.default)) + default_value = ( + f"*({tmp_str_or_null(self.default)})" + if self.type in ["TSTR", "BSTR"] + else val_to_str(self.default) + ) access = self.val_access() if assign else self.repeated_val_access() default_assignment = f"({access} = {default_value})" if assign: decode_str = self.repeated_xcode(union_int) - return comma_operator(default_assignment, - f"{self.present_var_access()} = {decode_str}", "1") + return comma_operator( + default_assignment, f"{self.present_var_access()} = {decode_str}", "1" + ) func, *arguments = self.repeated_single_func(ptr_result=True) return comma_operator( default_assignment, - f"(zcbor_present_decode(&(%s), (zcbor_decoder_t *)%s, %s))" % - (self.present_var_access(), func, xcode_args(*arguments),)) + f"(zcbor_present_decode(&(%s), (zcbor_decoder_t *)%s, %s))" + % ( + self.present_var_access(), + func, + xcode_args(*arguments), + ), + ) elif self.count_var_condition(): func, arg = self.repeated_single_func(ptr_result=True) minmax = "_minmax" if self.mode == "encode" else "" mode = self.mode - return ( - f"zcbor_multi_{mode}{minmax}(%s, %s, &%s, (zcbor_{mode}r_t *)%s, %s, %s)" % - (self.min_qty, - self.max_qty, - self.count_var_access(), - func, - xcode_args("*" + arg if arg != "NULL" and self.result_len() != "0" else arg), - self.result_len())) + return f"zcbor_multi_{mode}{minmax}(%s, %s, &%s, (zcbor_{mode}r_t *)%s, %s, %s)" % ( + self.min_qty, + self.max_qty, + self.count_var_access(), + func, + xcode_args("*" + arg if arg != "NULL" and self.result_len() != "0" else arg), + self.result_len(), + ) else: return self.repeated_xcode(union_int=union_int, top_level=top_level) @@ -2704,8 +3024,9 @@ def xcoders(self): yield XcoderTuple( self.repeated_xcode(top_level=True), self.repeated_xcode_func_name(), - self.repeated_type_name()) - if (self.single_func_impl_condition()): + self.repeated_type_name(), + ) + if self.single_func_impl_condition(): xcode_body = self.xcode() yield XcoderTuple(xcode_body, self.xcode_func_name(), self.type_name()) @@ -2718,8 +3039,8 @@ def public_xcode_func_sig(self): {"size_t *payload_len_out"})""" -class CodeRenderer(): - def __init__(self, entry_types, modes, print_time, default_max_qty, git_sha='', file_header=''): +class CodeRenderer: + def __init__(self, entry_types, modes, print_time, default_max_qty, git_sha="", file_header=""): self.entry_types = entry_types self.print_time = print_time self.default_max_qty = default_max_qty @@ -2731,8 +3052,9 @@ def __init__(self, entry_types, modes, print_time, default_max_qty, git_sha='', # Sort type definitions so the typedefs will come in the correct order in the header file # and the function in the correct order in the c file. for mode in modes: - self.sorted_types[mode] = list(sorted( - self.entry_types[mode], key=lambda _type: _type.depends_on(), reverse=False)) + self.sorted_types[mode] = list( + sorted(self.entry_types[mode], key=lambda _type: _type.depends_on(), reverse=False) + ) self.functions[mode] = self.unique_funcs(mode) self.functions[mode] = self.used_funcs(mode) @@ -2741,7 +3063,7 @@ def __init__(self, entry_types, modes, print_time, default_max_qty, git_sha='', self.version = __version__ if git_sha: - self.version += f'-{git_sha}' + self.version += f"-{git_sha}" self.file_header = file_header.strip() + "\n\n" if file_header.strip() else "" self.file_header += f"""Generated using zcbor version {self.version} @@ -2765,7 +3087,9 @@ def unique_types(self, mode): type_names[type_name] = type_def[0] out_types.append(type_def) else: - assert (''.join(type_names[type_name]) == ''.join(type_def[0])), f""" + assert "".join(type_names[type_name]) == "".join( + type_def[0] + ), f""" Two elements share the type name {type_name}, but their implementations are not identical. Please change one or both names. They are {linesep.join(type_names[type_name])} @@ -2788,10 +3112,10 @@ def unique_funcs(self, mode): func_names[func_name] = funcType out_types.append(funcType) elif func_name in func_names.keys(): - assert func_names[func_name][0] == func_xcode, \ - ("Two elements share the function name %s, but their implementations are " - + "not identical. Please change one or both names.\n\n%s\n\n%s") % \ - (func_name, func_names[func_name][0], func_xcode) + assert func_names[func_name][0] == func_xcode, ( + "Two elements share the function name %s, but their implementations are " + + "not identical. Please change one or both names.\n\n%s\n\n%s" + ) % (func_name, func_names[func_name][0], func_xcode) return out_types @@ -2800,10 +3124,9 @@ def used_funcs(self, mode): functions removed. """ mod_entry_types = [ - XcoderTuple( - func_type.xcode(), - func_type.xcode_func_name(), - func_type.type_name()) for func_type in self.entry_types[mode]] + XcoderTuple(func_type.xcode(), func_type.xcode_func_name(), func_type.type_name()) + for func_type in self.entry_types[mode] + ] out_types = [func_type for func_type in mod_entry_types] full_code = "".join([func_type[0] for func_type in mod_entry_types]) for func_type in reversed(self.functions[mode]): @@ -2825,22 +3148,25 @@ def render_function(self, xcoder, mode): body = xcoder.body # Define the subroutine "paren" that matches parenthesised expressions. - paren_re = r'(?(DEFINE)(?P\(((?>[^\(\)]+|(?&paren))*)\)))' + paren_re = r"(?(DEFINE)(?P\(((?>[^\(\)]+|(?&paren))*)\)))" # This uses "paren" to match a single argument to a function. - arg_re = rf'([^,\(\)]|(?&paren))+' + arg_re = rf"([^,\(\)]|(?&paren))+" # Match a function pointer argument to a function. - func_re = rf'\(zcbor_(en|de)coder_t \*\)(?P{arg_re})' + func_re = rf"\(zcbor_(en|de)coder_t \*\)(?P{arg_re})" # Match a triplet of function pointer, state arg, and result arg. - call_re = rf'{func_re}, (?P{arg_re}), (?P{arg_re})' - multi_re = rf'{paren_re}zcbor_multi_(en|de)code\(({arg_re},){{3}} {call_re}' - present_re = rf'{paren_re}zcbor_present_(en|de)code\({arg_re}, {call_re}\)' - map_re = rf'{paren_re}zcbor_unordered_map_search\({call_re}\)' - all_funcs = chain(getrp(multi_re).finditer(body), - getrp(present_re).finditer(body), - getrp(map_re).finditer(body)) + call_re = rf"{func_re}, (?P{arg_re}), (?P{arg_re})" + multi_re = rf"{paren_re}zcbor_multi_(en|de)code\(({arg_re},){{3}} {call_re}" + present_re = rf"{paren_re}zcbor_present_(en|de)code\({arg_re}, {call_re}\)" + map_re = rf"{paren_re}zcbor_unordered_map_search\({call_re}\)" + all_funcs = chain( + getrp(multi_re).finditer(body), + getrp(present_re).finditer(body), + getrp(map_re).finditer(body), + ) arg_test = "" - calls = ("\n ".join( - (f"{m.group('func')}({m.group('state')}, {m.group('arg')});" for m in (all_funcs)))) + calls = "\n ".join( + (f"{m.group('func')}({m.group('state')}, {m.group('arg')});" for m in (all_funcs)) + ) if calls != "": arg_test = f""" if (false) {{ @@ -2864,7 +3190,9 @@ def render_function(self, xcoder, mode): {arg_test} log_result(state, res, __func__); return res; -}}""".replace(" \n", "") # call replace() to remove empty lines. +}}""".replace( + " \n", "" + ) # call replace() to remove empty lines. def render_entry_function(self, xcoder, mode): """Render a single entry function (API function) with signature and body.""" @@ -2920,8 +3248,7 @@ def render_c_file(self, header_file_name, mode): def render_h_file(self, type_def_file, header_guard, mode): """Render the entire generated header file contents.""" - return \ - f"""/*{self.render_file_header(" *")} + return f"""/*{self.render_file_header(" *")} */ #ifndef {header_guard} @@ -2952,12 +3279,13 @@ def render_h_file(self, type_def_file, header_guard, mode): """ def render_type_file(self, header_guard, mode): - body = ( - linesep + linesep).join( - [f"{typedef[1]} {{{linesep}{linesep.join(typedef[0][1:])};" - for typedef in self.type_defs[mode]]) - return \ - f"""/*{self.render_file_header(" *")} + body = (linesep + linesep).join( + [ + f"{typedef[1]} {{{linesep}{linesep.join(typedef[0][1:])};" + for typedef in self.type_defs[mode] + ] + ) + return f"""/*{self.render_file_header(" *")} */ #ifndef {header_guard} @@ -2990,21 +3318,29 @@ def render_type_file(self, header_guard, mode): #endif /* {header_guard} */ """ - def render_cmake_file(self, target_name, h_files, c_files, type_file, - output_c_dir, output_h_dir, cmake_dir): - include_dirs = sorted(set(((Path(output_h_dir)), - (Path(type_file.name).parent), - *((Path(h.name).parent) for h in h_files.values())))) + def render_cmake_file( + self, target_name, h_files, c_files, type_file, output_c_dir, output_h_dir, cmake_dir + ): + include_dirs = sorted( + set( + ( + (Path(output_h_dir)), + (Path(type_file.name).parent), + *((Path(h.name).parent) for h in h_files.values()), + ) + ) + ) def relativify(p): try: return PurePosixPath( - Path("${CMAKE_CURRENT_LIST_DIR}") / path.relpath(Path(p), cmake_dir)) + Path("${CMAKE_CURRENT_LIST_DIR}") / path.relpath(Path(p), cmake_dir) + ) except ValueError: # On Windows, the above will fail if the paths are on different drives. return Path(p).absolute().as_posix() - return \ - f"""\ + + return f"""\ #{self.render_file_header("#")} # @@ -3021,8 +3357,17 @@ def relativify(p): ) """ - def render(self, modes, h_files, c_files, type_file, include_prefix, cmake_file=None, - output_c_dir=None, output_h_dir=None): + def render( + self, + modes, + h_files, + c_files, + type_file, + include_prefix, + cmake_file=None, + output_c_dir=None, + output_h_dir=None, + ): for mode in modes: h_name = Path(include_prefix, Path(h_files[mode].name).name) @@ -3035,18 +3380,26 @@ def render(self, modes, h_files, c_files, type_file, include_prefix, cmake_file= c_files[mode].write(self.render_c_file(h_name, mode)) print("Writing to " + h_files[mode].name) - h_files[mode].write(self.render_h_file( - type_def_name, - self.header_guard(h_files[mode].name), mode)) + h_files[mode].write( + self.render_h_file(type_def_name, self.header_guard(h_files[mode].name), mode) + ) print("Writing to " + type_file.name) type_file.write(self.render_type_file(self.header_guard(type_file.name), mode)) if cmake_file: print("Writing to " + cmake_file.name) - cmake_file.write(self.render_cmake_file( - Path(cmake_file.name).stem, h_files, c_files, type_file, - output_c_dir, output_h_dir, Path(cmake_file.name).absolute().parent)) + cmake_file.write( + self.render_cmake_file( + Path(cmake_file.name).stem, + h_files, + c_files, + type_file, + output_c_dir, + output_h_dir, + Path(cmake_file.name).absolute().parent, + ) + ) def int_or_str(arg): @@ -3057,7 +3410,8 @@ def int_or_str(arg): if getrp(r"\A\w+\Z").match(arg) is not None: return arg raise ArgumentTypeError( - "Argument must be an integer or a string with only letters, numbers, or '_'.") + "Argument must be an integer or a string with only letters, numbers, or '_'." + ) def parse_args(): @@ -3065,27 +3419,42 @@ def parse_args(): parent_parser = ArgumentParser(add_help=False) parent_parser.add_argument( - "-c", "--cddl", required=True, type=FileType('r', encoding='utf-8'), action="append", + "-c", + "--cddl", + required=True, + type=FileType("r", encoding="utf-8"), + action="append", help="""Path to one or more input CDDL file(s). Passing multiple files is equivalent to -concatenating them.""") +concatenating them.""", + ) parent_parser.add_argument( - "--no-prelude", required=False, action="store_true", default=False, + "--no-prelude", + required=False, + action="store_true", + default=False, help=f"""Exclude the standard CDDL prelude from the build. The prelude can be viewed at -{PRELUDE_PATH.relative_to(PACKAGE_PATH)} in the repo, or together with the script.""") +{PRELUDE_PATH.relative_to(PACKAGE_PATH)} in the repo, or together with the script.""", + ) parent_parser.add_argument( - "-v", "--verbose", required=False, action="store_true", default=False, - help="Print more information while parsing CDDL and generating code.") + "-v", + "--verbose", + required=False, + action="store_true", + default=False, + help="Print more information while parsing CDDL and generating code.", + ) parser = ArgumentParser( - description='''Parse a CDDL file and validate/convert between YAML, JSON, and CBOR. -Can also generate C code for validation/encoding/decoding of CBOR.''') + description="""Parse a CDDL file and validate/convert between YAML, JSON, and CBOR. +Can also generate C code for validation/encoding/decoding of CBOR.""" + ) - parser.add_argument( - "--version", action="version", version=f"zcbor {__version__}") + parser.add_argument("--version", action="version", version=f"zcbor {__version__}") subparsers = parser.add_subparsers() code_parser = subparsers.add_parser( - "code", description='''Parse a CDDL file and produce C code that validates and xcodes CBOR. + "code", + description="""Parse a CDDL file and produce C code that validates and xcodes CBOR. The output from this script is a C file and a header file. The header file contains typedefs for all the types specified in the cddl input file, as well as declarations to xcode functions for the types designated as entry types when @@ -3098,44 +3467,68 @@ def parse_args(): decoding. Using this mechanism is necessary when the CDDL contains self- referencing types, since the C type cannot be self referencing. -This script requires 'regex' for lookaround functionality not present in 're'.''', +This script requires 'regex' for lookaround functionality not present in 're'.""", formatter_class=RawDescriptionHelpFormatter, - parents=[parent_parser]) + parents=[parent_parser], + ) code_parser.add_argument( - "--default-max-qty", "--dq", required=False, type=int_or_str, default=3, + "--default-max-qty", + "--dq", + required=False, + type=int_or_str, + default=3, help="""Default maximum number of repetitions when no maximum is specified. This is needed to construct complete C types. The default_max_qty can usually be set to a text symbol if desired, to allow it to be configurable when building the code. This is not always possible, as sometimes the value is needed for internal computations. -If so, the script will raise an exception.""") +If so, the script will raise an exception.""", + ) code_parser.add_argument( - "--output-c", "--oc", required=False, type=str, + "--output-c", + "--oc", + required=False, + type=str, help="""Path to output C file. If both --decode and --encode are specified, _decode and _encode will be appended to the filename when creating the two files. If not specified, the path and name will be based on the --output-cmake file. A 'src' directory will be created next to the cmake file, and the C file will be -placed there with the same name (except the file extension) as the cmake file.""") +placed there with the same name (except the file extension) as the cmake file.""", + ) code_parser.add_argument( - "--output-h", "--oh", required=False, type=str, + "--output-h", + "--oh", + required=False, + type=str, help="""Path to output header file. If both --decode and --encode are specified, _decode and _encode will be appended to the filename when creating the two files. If not specified, the path and name will be based on the --output-cmake file. An 'include' directory will be created next to the cmake file, and the C file will be -placed there with the same name (except the file extension) as the cmake file.""") +placed there with the same name (except the file extension) as the cmake file.""", + ) code_parser.add_argument( - "--output-h-types", "--oht", required=False, type=str, + "--output-h-types", + "--oht", + required=False, + type=str, help="""Path to output header file with typedefs (shared between decode and encode). If not specified, the path and name will be taken from the output header file -(--output-h), with '_types' added to the file name.""") +(--output-h), with '_types' added to the file name.""", + ) code_parser.add_argument( - "--copy-sources", required=False, action="store_true", default=False, + "--copy-sources", + required=False, + action="store_true", + default=False, help="""Copy the non-generated source files (zcbor_*.c/h) into the same directories as the -generated files.""") +generated files.""", + ) code_parser.add_argument( - "--output-cmake", required=False, type=str, + "--output-cmake", + required=False, + type=str, help="""Path to output CMake file. The filename of the CMake file without '.cmake' is used as the name of the CMake target in the file. The CMake file defines a CMake target with the zcbor source files and the @@ -3143,111 +3536,190 @@ def parse_args(): files' folders as include_directories. Add it to your project via include() in your CMakeLists.txt file, and link the target to your program. -This option works with or without the --copy-sources option.""") +This option works with or without the --copy-sources option.""", + ) code_parser.add_argument( - "-t", "--entry-types", required=True, type=str, nargs="+", - help="Names of the types which should have their xcode functions exposed.") + "-t", + "--entry-types", + required=True, + type=str, + nargs="+", + help="Names of the types which should have their xcode functions exposed.", + ) code_parser.add_argument( - "-d", "--decode", required=False, action="store_true", default=False, - help="Generate decoding code. Either --decode or --encode or both must be specified.") + "-d", + "--decode", + required=False, + action="store_true", + default=False, + help="Generate decoding code. Either --decode or --encode or both must be specified.", + ) code_parser.add_argument( - "-e", "--encode", required=False, action="store_true", default=False, - help="Generate encoding code. Either --decode or --encode or both must be specified.") + "-e", + "--encode", + required=False, + action="store_true", + default=False, + help="Generate encoding code. Either --decode or --encode or both must be specified.", + ) code_parser.add_argument( - "--time-header", required=False, action="store_true", default=False, - help="Put the current time in a comment in the generated files.") + "--time-header", + required=False, + action="store_true", + default=False, + help="Put the current time in a comment in the generated files.", + ) code_parser.add_argument( - "--git-sha-header", required=False, action="store_true", default=False, - help="Put the current git sha of zcbor in a comment in the generated files.") + "--git-sha-header", + required=False, + action="store_true", + default=False, + help="Put the current git sha of zcbor in a comment in the generated files.", + ) code_parser.add_argument( - "-b", "--default-bit-size", required=False, type=int, default=32, choices=[32, 64], + "-b", + "--default-bit-size", + required=False, + type=int, + default=32, + choices=[32, 64], help="""Default bit size of integers in code. When integers have no explicit bounds, assume they have this bit width. Should follow the bit width of the architecture -the code will be running on.""") +the code will be running on.""", + ) code_parser.add_argument( - "--include-prefix", default="", - help="""When #include'ing generated files, add this path prefix to the filename.""") + "--include-prefix", + default="", + help="""When #include'ing generated files, add this path prefix to the filename.""", + ) code_parser.add_argument( - "-s", "--short-names", required=False, action="store_true", default=False, + "-s", + "--short-names", + required=False, + action="store_true", + default=False, help="""Attempt to make most generated struct member names shorter. This might make some names identical which will cause a compile error. If so, tweak the CDDL labels or layout, or disable this option. This might also make enum names different -from the corresponding union members.""") +from the corresponding union members.""", + ) code_parser.add_argument( - "--file-header", required=False, type=str, default="", + "--file-header", + required=False, + type=str, + default="", help="""Header to be included in the comment at the top of generated files, e.g. copyright. Can be a string or a path to a file. If interpreted as a path to an existing file, -the file's contents will be used.""") +the file's contents will be used.""", + ) code_parser.set_defaults(process=process_code) validate_parent_parser = ArgumentParser(add_help=False) validate_parent_parser.add_argument( - "-i", "--input", required=True, type=str, - help='''Input data file. The option --input-as specifies how to interpret the contents. -Use "-" to indicate stdin.''') + "-i", + "--input", + required=True, + type=str, + help="""Input data file. The option --input-as specifies how to interpret the contents. +Use "-" to indicate stdin.""", + ) validate_parent_parser.add_argument( - "--input-as", required=False, choices=["yaml", "json", "cbor", "cborhex"], - help='''Which format to interpret the input file as. + "--input-as", + required=False, + choices=["yaml", "json", "cbor", "cborhex"], + help="""Which format to interpret the input file as. If omitted, the format is inferred from the file name. -.yaml, .yml => YAML, .json => JSON, .cborhex => CBOR as hex string, everything else => CBOR''') +.yaml, .yml => YAML, .json => JSON, .cborhex => CBOR as hex string, everything else => CBOR""", + ) validate_parent_parser.add_argument( - "-t", "--entry-type", required=True, type=str, - help='''Name of the type (from the CDDL) to interpret the data as.''') + "-t", + "--entry-type", + required=True, + type=str, + help="""Name of the type (from the CDDL) to interpret the data as.""", + ) validate_parent_parser.add_argument( - "--default-max-qty", "--dq", required=False, type=int, default=0xFFFFFFFF, + "--default-max-qty", + "--dq", + required=False, + type=int, + default=0xFFFFFFFF, help="""Default maximum number of repetitions when no maximum is specified. It is only relevant when handling data that will be decoded by generated code. -If omitted, a large number will be used.""") +If omitted, a large number will be used.""", + ) validate_parent_parser.add_argument( - "--yaml-compatibility", required=False, action="store_true", default=False, - help='''Whether to convert CBOR-only values to YAML-compatible ones + "--yaml-compatibility", + required=False, + action="store_true", + default=False, + help="""Whether to convert CBOR-only values to YAML-compatible ones (when converting from CBOR), or vice versa (when converting to CBOR). When this is enabled, all CBOR data is guaranteed to convert into YAML/JSON. JSON and YAML do not support all data types that CBOR/CDDL supports. bytestrings (BSTR), tags, undefined, and maps with non-text keys need -special handling. See the zcbor README for more information.''') +special handling. See the zcbor README for more information.""", + ) validate_parser = subparsers.add_parser( - "validate", description='''Read CBOR, YAML, or JSON data from file or stdin and validate + "validate", + description="""Read CBOR, YAML, or JSON data from file or stdin and validate it against a CDDL schema file. - ''', - parents=[parent_parser, validate_parent_parser]) + """, + parents=[parent_parser, validate_parent_parser], + ) validate_parser.set_defaults(process=process_validate) convert_parser = subparsers.add_parser( - "convert", description='''Parse a CDDL file and validate/convert between CBOR and YAML/JSON. + "convert", + description="""Parse a CDDL file and validate/convert between CBOR and YAML/JSON. The script decodes the CBOR/YAML/JSON data from a file or stdin and verifies that it conforms to the CDDL description. The script fails if the data does not conform. -'zcbor validate' can be used if only validate is needed.''', - parents=[parent_parser, validate_parent_parser]) +'zcbor validate' can be used if only validate is needed.""", + parents=[parent_parser, validate_parent_parser], + ) convert_parser.add_argument( - "-o", "--output", required=True, type=str, - help='''Output data file. The option --output-as specifies how to interpret the contents. - Use "-" to indicate stdout.''') + "-o", + "--output", + required=True, + type=str, + help="""Output data file. The option --output-as specifies how to interpret the contents. + Use "-" to indicate stdout.""", + ) convert_parser.add_argument( - "--output-as", required=False, choices=["yaml", "json", "cbor", "cborhex", "c_code"], - help='''Which format to interpret the output file as. + "--output-as", + required=False, + choices=["yaml", "json", "cbor", "cborhex", "c_code"], + help="""Which format to interpret the output file as. If omitted, the format is inferred from the file name. .yaml, .yml => YAML, .json => JSON, .c, .h => C code, -.cborhex => CBOR as hex string, everything else => CBOR''') +.cborhex => CBOR as hex string, everything else => CBOR""", + ) convert_parser.add_argument( - "--c-code-var-name", required=False, type=str, - help='''Only relevant together with '--output-as c_code' or .c files.''') + "--c-code-var-name", + required=False, + type=str, + help="""Only relevant together with '--output-as c_code' or .c files.""", + ) convert_parser.add_argument( - "--c-code-columns", required=False, type=int, default=0, - help='''Only relevant together with '--output-as c_code' or .c files. + "--c-code-columns", + required=False, + type=int, + default=0, + help="""Only relevant together with '--output-as c_code' or .c files. The number of bytes per line in the variable instantiation. If omitted, the -entire declaration is a single line.''') +entire declaration is a single line.""", + ) convert_parser.set_defaults(process=process_convert) args = parser.parse_args() if not args.no_prelude: - args.cddl.append(open(PRELUDE_PATH, 'r', encoding="utf-8")) + args.cddl.append(open(PRELUDE_PATH, "r", encoding="utf-8")) if hasattr(args, "decode") and not args.decode and not args.encode: parser.error("Please specify at least one of --decode or --encode.") @@ -3257,7 +3729,8 @@ def parse_args(): if not args.output_cmake: parser.error( "Please specify both --output-c and --output-h " - "unless --output-cmake is specified.") + "unless --output-cmake is specified." + ) return args @@ -3279,26 +3752,36 @@ def process_code(args): cddl_res = dict() for mode in modes: cddl_res[mode] = CodeGenerator.from_cddl( - mode, cddl_contents, args.default_max_qty, mode, args.entry_types, - args.default_bit_size, short_names=args.short_names) + mode, + cddl_contents, + args.default_max_qty, + mode, + args.entry_types, + args.default_bit_size, + short_names=args.short_names, + ) # Parsing is done, pretty print the result. verbose_print(args.verbose, "Parsed CDDL types:") for mode in modes: verbose_pprint(args.verbose, cddl_res[mode].my_types) - git_sha = '' + git_sha = "" if args.git_sha_header: if "zcbor.py" in sys.argv[0]: - git_args = ['git', 'rev-parse', '--verify', '--short', 'HEAD'] - git_sha = Popen( - git_args, cwd=PACKAGE_PATH, stdout=PIPE).communicate()[0].decode('utf-8').strip() + git_args = ["git", "rev-parse", "--verify", "--short", "HEAD"] + git_sha = ( + Popen(git_args, cwd=PACKAGE_PATH, stdout=PIPE) + .communicate()[0] + .decode("utf-8") + .strip() + ) else: git_sha = __version__ def create_and_open(path): Path(path).absolute().parent.mkdir(parents=True, exist_ok=True) - return Path(path).open('w', encoding='utf-8') + return Path(path).open("w", encoding="utf-8") if args.output_cmake: cmake_dir = Path(args.output_cmake).parent @@ -3317,11 +3800,15 @@ def add_mode_to_fname(filename, mode): out_h = args.output_h if (len(modes) == 1 and args.output_h) else None for mode in modes: output_c[mode] = create_and_open( - out_c or add_mode_to_fname( - args.output_c or Path(cmake_dir, 'src', f'{filenames}.c'), mode)) + out_c + or add_mode_to_fname(args.output_c or Path(cmake_dir, "src", f"{filenames}.c"), mode) + ) output_h[mode] = create_and_open( - out_h or add_mode_to_fname( - args.output_h or Path(cmake_dir, 'include', f'{filenames}.h'), mode)) + out_h + or add_mode_to_fname( + args.output_h or Path(cmake_dir, "include", f"{filenames}.h"), mode + ) + ) out_c_parent = Path(output_c[modes[0]].name).parent out_h_parent = Path(output_h[modes[0]].name).parent @@ -3329,14 +3816,19 @@ def add_mode_to_fname(filename, mode): output_h_types = create_and_open( args.output_h_types or (args.output_h and Path(args.output_h).with_name(Path(args.output_h).stem + "_types.h")) - or Path(cmake_dir, 'include', filenames + '_types.h')) + or Path(cmake_dir, "include", filenames + "_types.h") + ) - renderer = CodeRenderer(entry_types={mode: [cddl_res[mode].my_types[entry] - for entry in args.entry_types] for mode in modes}, - modes=modes, print_time=args.time_header, - default_max_qty=args.default_max_qty, git_sha=git_sha, - file_header=args.file_header - ) + renderer = CodeRenderer( + entry_types={ + mode: [cddl_res[mode].my_types[entry] for entry in args.entry_types] for mode in modes + }, + modes=modes, + print_time=args.time_header, + default_max_qty=args.default_max_qty, + git_sha=git_sha, + file_header=args.file_header, + ) c_code_dir = C_SRC_PATH h_code_dir = C_INCLUDE_PATH @@ -3356,8 +3848,16 @@ def add_mode_to_fname(filename, mode): c_code_dir = new_c_code_dir h_code_dir = new_h_code_dir - renderer.render(modes, output_h, output_c, output_h_types, args.include_prefix, - output_cmake, c_code_dir, h_code_dir) + renderer.render( + modes, + output_h, + output_c, + output_h_types, + args.include_prefix, + output_cmake, + c_code_dir, + h_code_dir, + ) def parse_cddl(args): @@ -3398,8 +3898,9 @@ def write_data(args, cddl, cbor_str): f.write(cddl.str_to_json(cbor_str, yaml_compat=args.yaml_compatibility)) elif out_file_format in ["c", "h", "c_code"]: f = sys.stdout if args.output == "-" else open(args.output, "w", encoding="utf-8") - assert args.c_code_var_name is not None, \ - "Must specify --c-code-var-name when outputting c code." + assert ( + args.c_code_var_name is not None + ), "Must specify --c-code-var-name when outputting c code." f.write(cddl.str_to_c_code(cbor_str, args.c_code_var_name, args.c_code_columns)) elif out_file_format == "cborhex": f = sys.stdout if args.output == "-" else open(args.output, "w", encoding="utf-8") From 4f38ecf4eb8ede4a981211a822e8104d99a1e4ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20R=C3=B8nningstad?= Date: Mon, 2 Dec 2024 18:36:01 +0100 Subject: [PATCH 2/3] zcbor.py: Performance improvements in DataTranslator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Øyvind Rønningstad --- tests/scripts/test_performance.py | 30 +++++++++++++ zcbor/zcbor.py | 71 +++++++++++++++++++------------ 2 files changed, 73 insertions(+), 28 deletions(-) create mode 100644 tests/scripts/test_performance.py diff --git a/tests/scripts/test_performance.py b/tests/scripts/test_performance.py new file mode 100644 index 00000000..53728f9b --- /dev/null +++ b/tests/scripts/test_performance.py @@ -0,0 +1,30 @@ +import zcbor +import cbor2 +import cProfile, pstats + + +try: + import zcbor +except ImportError: + print( + """ +The zcbor package must be installed to run these tests. +During development, install with `pip3 install -e .` to install in a way +that picks up changes in the files without having to reinstall. +""" + ) + exit(1) + +cddl_contents = """ +perf_int = [0*1000(int/bool)] +""" +raw_message = cbor2.dumps(list(range(1000))) +cmd_spec = zcbor.DataTranslator.from_cddl(cddl_contents, 3).my_types["perf_int"] +# cmd_spec = zcbor.DataDecoder.from_cddl(cddl_contents, 3).my_types["perf_int"] + +profiler = cProfile.Profile() +profiler.enable() +json_obj = cmd_spec.str_to_json(raw_message) +profiler.disable() + +profiler.print_stats() diff --git a/zcbor/zcbor.py b/zcbor/zcbor.py index 4cd30806..546a205d 100755 --- a/zcbor/zcbor.py +++ b/zcbor/zcbor.py @@ -1277,6 +1277,7 @@ def __init__(self, *args, **kwargs): # Used as a guard against endless recursion in self.dependsOn() self.dependsOnCall = False self.skipped = False + self.stored_id = None def var_name(self, with_prefix=False, observe_skipped=True): """Name of variables and enum members for this element.""" @@ -1678,7 +1679,9 @@ def id(self): If the name starts with an underscore, prepend an 'f', since namedtuple() doesn't support identifiers that start with an underscore. """ - return getrp(r"\A_").sub("f_", self.generate_base_name()) + if self.stored_id is None: + self.stored_id = getrp(r"\A_").sub("f_", self.get_base_name()) + return self.stored_id def var_name(self): """Override the var_name()""" @@ -1687,6 +1690,8 @@ def var_name(self): def _decode_assert(self, test, msg=""): """Check a condition and raise a CddlValidationError if not.""" if not test: + if callable(msg): + msg = msg() raise CddlValidationError( f"Data did not decode correctly {'(' + msg + ')' if msg else ''}" ) @@ -1696,6 +1701,9 @@ def _check_tag(self, obj): Return whether a tag was present. """ + if not self.tags and not isinstance(obj, CBORTag): + return obj + tags = copy(self.tags) # All expected tags # Process all tags present in obj while isinstance(obj, CBORTag): @@ -1706,27 +1714,29 @@ def _check_tag(self, obj): continue elif self.type in ["OTHER", "GROUP", "UNION"]: break - self._decode_assert(False, f"Tag ({obj.tag}) not expected for {self}") + self._decode_assert(False, lambda: f"Tag ({obj.tag}) not expected for {self}") # Check that all expected tags were found in obj. - self._decode_assert(not tags, f"Expected tags ({tags}), but none present.") + self._decode_assert(not tags, lambda: f"Expected tags ({tags}), but none present.") return obj + _exp_types = { + "UINT": (int,), + "INT": (int,), + "NINT": (int,), + "FLOAT": (float,), + "TSTR": (str,), + "BSTR": (bytes,), + "NIL": (type(None),), + "UNDEF": (type(undefined),), + "ANY": (int, float, str, bytes, type(None), type(undefined), bool, list, dict), + "BOOL": (bool,), + "LIST": (tuple, list), + "MAP": (dict,), + } + def _expected_type(self): """Return our expected python type as returned by cbor2.""" - return { - "UINT": lambda: (int,), - "INT": lambda: (int,), - "NINT": lambda: (int,), - "FLOAT": lambda: (float,), - "TSTR": lambda: (str,), - "BSTR": lambda: (bytes,), - "NIL": lambda: (type(None),), - "UNDEF": lambda: (type(undefined),), - "ANY": lambda: (int, float, str, bytes, type(None), type(undefined), bool, list, dict), - "BOOL": lambda: (bool,), - "LIST": lambda: (tuple, list), - "MAP": lambda: (dict,), - }[self.type]() + return self._exp_types[self.type] def _check_type(self, obj): """Check that the decoded object has the correct type.""" @@ -1734,7 +1744,7 @@ def _check_type(self, obj): exp_type = self._expected_type() self._decode_assert( type(obj) in exp_type, - f"{str(self)}: Wrong type ({type(obj)}) of {str(obj)}, expected {str(exp_type)}", + lambda: f"{str(self)}: Wrong type ({type(obj)}) of {str(obj)}, expected {str(exp_type)}", ) def _check_value(self, obj): @@ -1748,31 +1758,35 @@ def _check_value(self, obj): value = self.value.encode("utf-8") self._decode_assert( self.value == obj, - f"{obj} should have value {self.value} according to {self.var_name()}", + lambda: f"{obj} should have value {self.value} according to {self.var_name()}", ) if self.type in ["UINT", "INT", "NINT", "FLOAT"]: if self.min_value is not None: - self._decode_assert(obj >= self.min_value, "Minimum value: " + str(self.min_value)) + self._decode_assert( + obj >= self.min_value, lambda: "Minimum value: " + str(self.min_value) + ) if self.max_value is not None: - self._decode_assert(obj <= self.max_value, "Maximum value: " + str(self.max_value)) + self._decode_assert( + obj <= self.max_value, lambda: "Maximum value: " + str(self.max_value) + ) if self.type == "UINT": if self.bits: mask = sum(((1 << b.value) for b in self.my_control_groups[self.bits].value)) - self._decode_assert(not (obj & ~mask), "Allowed bitmask: " + bin(mask)) + self._decode_assert(not (obj & ~mask), lambda: "Allowed bitmask: " + bin(mask)) if self.type in ["TSTR", "BSTR"]: if self.min_size is not None: self._decode_assert( - len(obj) >= self.min_size, "Minimum length: " + str(self.min_size) + len(obj) >= self.min_size, lambda: "Minimum length: " + str(self.min_size) ) if self.max_size is not None: self._decode_assert( - len(obj) <= self.max_size, "Maximum length: " + str(self.max_size) + len(obj) <= self.max_size, lambda: "Maximum length: " + str(self.max_size) ) def _check_key(self, obj): """Check that the object is not a KeyTuple, which would mean it's not properly processed.""" self._decode_assert( - not isinstance(obj, KeyTuple), "Unexpected key found: (key,value)=" + str(obj) + not isinstance(obj, KeyTuple), lambda: "Unexpected key found: (key,value)=" + str(obj) ) def _flatten_obj(self, obj): @@ -1917,13 +1931,14 @@ def _decode_single_obj(self, obj): return self._construct_obj(retval) except CddlValidationError as c: self.errors.append(str(c)) - self._decode_assert(False, "No matches for union: " + str(self)) + self._decode_assert(False, lambda: "No matches for union: " + str(self)) assert False, "Unexpected type: " + self.type def _handle_key(self, next_obj): """Decode key and value in the form of a KeyTuple""" self._decode_assert( - isinstance(next_obj, KeyTuple), f"Expected key: {self.key} value=" + pformat(next_obj) + isinstance(next_obj, KeyTuple), + lambda: f"Expected key: {self.key} value=" + pformat(next_obj), ) key, obj = next_obj key_res = self.key._decode_single_obj(key) @@ -1975,7 +1990,7 @@ def _decode_obj(self, it): except CddlValidationError as c: self.errors.append(str(c)) child_it = it_copy - self._decode_assert(found, "No matches for union: " + str(self)) + self._decode_assert(found, lambda: "No matches for union: " + str(self)) else: ret = (it, self._decode_single_obj(self._iter_next(it))) return ret From 0ed5d594570a15938b4a1afbdba0756228c9e7ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20R=C3=B8nningstad?= Date: Tue, 21 Jan 2025 14:58:28 +0100 Subject: [PATCH 3/3] Split DataTranslator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Put the namedtuple-generation into a new class DataDecoder. This is done for performance reasons. Document the split in MIGRATION_GUIDE and ARCHITECTURE (and make some small improvements to ARCHITECTURE). Signed-off-by: Øyvind Rønningstad --- ARCHITECTURE.md | 16 ++++++-- MIGRATION_GUIDE.md | 9 +++++ __init__.py | 2 +- tests/scripts/test_zcbor.py | 14 +++---- zcbor/zcbor.py | 80 +++++++++++++++++++++++-------------- 5 files changed, 79 insertions(+), 42 deletions(-) diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 56fd40e7..ea3970d1 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -16,6 +16,7 @@ The functionality is spread across 5 classes: 1. CddlParser 2. CddlXcoder (inherits from CddlParser) 3. DataTranslator (inherits from CddlXcoder) +4. DataDecoder (inherits from DataTranslator) 4. CodeGenerator (inherits from CddlXcoder) 5. CodeRenderer @@ -100,7 +101,7 @@ Most of the functionality falls into one of two categories: - is_unambiguous(): Whether the type is completely specified, i.e. whether we know beforehand exactly how the encoding will look (e.g. `Foo = 5`). DataTranslator ------------ +-------------- DataTranslator is for handling and manipulating CBOR on the "host". For example, the user can compose data in YAML or JSON files and have them converted to CBOR and validated against the CDDL. @@ -127,15 +128,23 @@ One caveat is that CBOR supports more features than YAML/JSON, namely: zcbor allows creating bespoke representations via `--yaml-compatibility`, see the README or CLI docs for more info. -Finally, DataTranslator can also generate a separate internal representation using `namedtuple`s to allow browsing CBOR data by the names given in the CDDL. +DataTranslator functionality is tested in [tests/scripts/test_zcbor.py](tests/scripts/test_zcbor.py) + +DataDecoder +----------- + +DataDecoder contains functions for generating a separate internal representation using `namedtuple`s to allow browsing CBOR data by the names given in the CDDL. (This is more analogous to how the data is accessed in the C code.) -DataTranslator functionality is tested in [tests/scripts](tests/scripts) +This functionality was originally part of DataTranslator, but was moved because the internal representation was always created but seldom used, and the namedtuples caused a noticeable performance hit. + +DataDecoder functionality is tested in [tests/scripts/test_zcbor.py](tests/scripts/test_zcbor.py) CodeGenerator ------------- CodeGenerator, like DataTranslator, inherits from CddlXcoder. +It is used to generate C code. Its primary purpose is to construct the individual decoding/encoding functions for the types specified in the given CDDL document. It also constructs struct definitions used to hold the decoded data/data to be encoded. @@ -158,6 +167,7 @@ repeated_foo() concerns itself with the individual value, while foo() concerns i When invoking CodeGenerator, the user must decide which types it will need direct access to decode/encode. These types are called "entry types" and they are typically the "outermost" types, or the types it is expected that the data will have. +CodeGenerator will generate a public function for each entry type. The user can also use entry types when there are `"BSTR"`s that are CBOR encoded, specified as `Foo = bstr .cbor Bar`. Usually such strings are automatically decoded/encoded by the generated code, and the objects part of the encompassing struct. diff --git a/MIGRATION_GUIDE.md b/MIGRATION_GUIDE.md index 1eaab6fc..f0158deb 100644 --- a/MIGRATION_GUIDE.md +++ b/MIGRATION_GUIDE.md @@ -1,5 +1,14 @@ # zcbor v. 0.9.99 +* The following `DataTranslator` functions have been moved to a separate class `DataDecoder`: + + * `decode_obj()` + * `decode_str_yaml()` + * `decode_str()` + + The split was done for performance reasons (namedtuple objects are slow to create). + The `DataDecoder` class is a subclass of `DataTranslator` and can do all the the same things, just a bit slower. + This functionality is only relevant when zcbor is imported, so all CLI usage is unaffected. # zcbor v. 0.9.0 diff --git a/__init__.py b/__init__.py index 9dc6cbef..5e326d9c 100644 --- a/__init__.py +++ b/__init__.py @@ -7,4 +7,4 @@ from pathlib import Path -from .zcbor.zcbor import CddlValidationError, DataTranslator, main +from .zcbor.zcbor import CddlValidationError, DataTranslator, DataDecoder, main diff --git a/tests/scripts/test_zcbor.py b/tests/scripts/test_zcbor.py index f1490933..3adbd127 100644 --- a/tests/scripts/test_zcbor.py +++ b/tests/scripts/test_zcbor.py @@ -70,7 +70,7 @@ def decode_file(self, data_path, *cddl_paths): def decode_string(self, data_string, *cddl_paths): cddl_str = " ".join((Path(p).read_text(encoding="utf-8") for p in cddl_paths)) - self.my_types = zcbor.DataTranslator.from_cddl(cddl_str, 16).my_types + self.my_types = zcbor.DataDecoder.from_cddl(cddl_str, 16).my_types cddl = self.my_types["SUIT_Envelope_Tagged"] self.decoded = cddl.decode_str(data_string) @@ -1123,7 +1123,7 @@ def test_file_header(self): class TestOptional(TestCase): def test_optional_0(self): with open(p_optional, "r", encoding="utf-8") as f: - cddl_res = zcbor.DataTranslator.from_cddl(f.read(), 16) + cddl_res = zcbor.DataDecoder.from_cddl(f.read(), 16) cddl = cddl_res.my_types["cfg"] test_yaml = """ mem_config: @@ -1136,7 +1136,7 @@ def test_optional_0(self): class TestUndefined(TestCase): def test_undefined_0(self): - cddl_res = zcbor.DataTranslator.from_cddl( + cddl_res = zcbor.DataDecoder.from_cddl( p_prelude.read_text(encoding="utf-8") + "\n" + p_corner_cases.read_text(encoding="utf-8"), @@ -1154,7 +1154,7 @@ def test_undefined_0(self): class TestFloat(TestCase): def test_float_0(self): - cddl_res = zcbor.DataTranslator.from_cddl( + cddl_res = zcbor.DataDecoder.from_cddl( p_prelude.read_text(encoding="utf-8") + "\n" + p_corner_cases.read_text(encoding="utf-8"), @@ -1243,7 +1243,7 @@ def test_yaml_compatibility(self): class TestIntmax(TestCase): def test_intmax1(self): - cddl_res = zcbor.DataTranslator.from_cddl( + cddl_res = zcbor.DataDecoder.from_cddl( p_prelude.read_text(encoding="utf-8") + "\n" + p_corner_cases.read_text(encoding="utf-8"), @@ -1254,7 +1254,7 @@ def test_intmax1(self): decoded = cddl.decode_str_yaml(test_yaml) def test_intmax2(self): - cddl_res = zcbor.DataTranslator.from_cddl( + cddl_res = zcbor.DataDecoder.from_cddl( p_prelude.read_text(encoding="utf-8") + "\n" + p_corner_cases.read_text(encoding="utf-8"), @@ -1286,7 +1286,7 @@ def test_intmax2(self): class TestInvalidIdentifiers(TestCase): def test_invalid_identifiers0(self): - cddl_res = zcbor.DataTranslator.from_cddl( + cddl_res = zcbor.DataDecoder.from_cddl( p_prelude.read_text(encoding="utf-8") + "\n" + p_corner_cases.read_text(encoding="utf-8"), diff --git a/zcbor/zcbor.py b/zcbor/zcbor.py index 546a205d..7813cfb6 100755 --- a/zcbor/zcbor.py +++ b/zcbor/zcbor.py @@ -1808,20 +1808,8 @@ def _flatten_list(self, name, obj): return obj def _construct_obj(self, my_list): - """Construct a namedtuple object from my_list. my_list contains tuples of name/value. - - Also, attempt to flatten redundant levels of abstraction. - """ - if my_list == []: - return None - names, values = tuple(zip(*my_list)) - if len(values) == 1: - values = (self._flatten_obj(values[0]),) - values = tuple(self._flatten_list(names[i], values[i]) for i in range(len(values))) - assert not any( - (isinstance(elem, KeyTuple) for elem in values) - ), f"KeyTuple not processed: {values}" - return namedtuple("_", names)(*values) + """Can be overridden to construct a decoded object.""" + pass def _add_if(self, my_list, obj, expect_key=False, name=None): """Add construct obj and add it to my_list if relevant. @@ -1852,11 +1840,11 @@ def _add_if(self, my_list, obj, expect_key=False, name=None): # If a bstr is CBOR-formatted, add both the string and the decoding of the string here if isinstance(obj, list) and all((isinstance(o, bytes) for o in obj)): # One or more bstr in a list (i.e. it is optional or repeated) - my_list.append((name or self.var_name(), [self.cbor.decode_str(o) for o in obj])) + my_list.append((name or self.var_name(), [self.cbor._decode_str(o) for o in obj])) my_list.append(((name or self.var_name()) + "_bstr", obj)) return if isinstance(obj, bytes): - my_list.append((name or self.var_name(), self.cbor.decode_str(obj))) + my_list.append((name or self.var_name(), self.cbor._decode_str(obj))) my_list.append(((name or self.var_name()) + "_bstr", obj)) return my_list.append((name or self.var_name(), obj)) @@ -1946,7 +1934,7 @@ def _handle_key(self, next_obj): res = KeyTuple((key_res if not self.key.is_unambiguous() else None, obj_res)) return res - def _decode_obj(self, it): + def _decode_obj_it(self, it): """Decode single CDDL value, excluding repetitions. May consume 0 to n CBOR objects via the iterator. @@ -2003,22 +1991,22 @@ def _decode_full(self, it): if self.multi_var_condition(): retvals = [] for i in range(self.min_qty): - it, retval = self._decode_obj(it) + it, retval = self._decode_obj_it(it) retvals.append(retval if not self.is_unambiguous_repeated() else None) try: for i in range(self.max_qty - self.min_qty): it, it_copy = tee(it) - it, retval = self._decode_obj(it) + it, retval = self._decode_obj_it(it) retvals.append(retval if not self.is_unambiguous_repeated() else None) except CddlValidationError as c: self.errors.append(str(c)) it = it_copy return it, retvals else: - ret = self._decode_obj(it) + ret = self._decode_obj_it(it) return ret - def decode_obj(self, obj): + def _decode_obj(self, obj): """CBOR object => python object""" it = iter([obj]) try: @@ -2031,21 +2019,14 @@ def decode_obj(self, obj): raise e return decoded - def decode_str_yaml(self, yaml_str, yaml_compat=False): - """YAML => python object""" - yaml_obj = yaml_load(yaml_str) - obj = self._from_yaml_obj(yaml_obj) if yaml_compat else yaml_obj - self.validate_obj(obj) - return self.decode_obj(obj) - - def decode_str(self, cbor_str): + def _decode_str(self, cbor_str): """CBOR bytestring => python object""" cbor_obj = loads(cbor_str) - return self.decode_obj(cbor_obj) + return self._decode_obj(cbor_obj) def validate_obj(self, obj): """Validate CBOR object against CDDL. Exception if not valid.""" - self.decode_obj(obj) + self._decode_obj(obj) # Will raise exception if not valid return True def validate_str(self, cbor_str): @@ -2162,6 +2143,43 @@ def str_to_c_code(self, cbor_str, var_name, columns=0): return f"uint8_t {var_name}[] = {{{arr}}};\n" +class DataDecoder(DataTranslator): + """Create a decoded object with element names taken from the CDDL. + + This is kept separate from DataTranslator for performance reasons.""" + + def _construct_obj(self, my_list): + """Construct a namedtuple object from my_list. my_list contains tuples of name/value. + + Also, attempt to flatten redundant levels of abstraction. + """ + if my_list == []: + return None + names, values = tuple(zip(*my_list)) + if len(values) == 1: + values = (self._flatten_obj(values[0]),) + values = tuple(self._flatten_list(names[i], values[i]) for i in range(len(values))) + assert not any( + (isinstance(elem, KeyTuple) for elem in values) + ), f"KeyTuple not processed: {values}" + return namedtuple("_", names)(*values) + + def decode_obj(self, obj): + """CBOR object => python object""" + return self._decode_obj(obj) + + def decode_str_yaml(self, yaml_str, yaml_compat=False): + """YAML => python object""" + yaml_obj = yaml_load(yaml_str) + obj = self._from_yaml_obj(yaml_obj) if yaml_compat else yaml_obj + self.validate_obj(obj) + return self.decode_obj(obj) + + def decode_str(self, cbor_str): + """CBOR bytestring => python object""" + return self._decode_str(cbor_str) + + class XcoderTuple(NamedTuple): body: list func_name: str