diff --git a/tests/cases/corner_cases.cddl b/tests/cases/corner_cases.cddl index fe520d7b..1d5ef5e0 100644 --- a/tests/cases/corner_cases.cddl +++ b/tests/cases/corner_cases.cddl @@ -254,3 +254,8 @@ Uint64List = [ ? uint64_lit: 0x0123456789abcdef, * nint64_lit: -0x0123456789abcdef, ] + +BstrSize = { + bstr12, +} +bstr12 = ("s" : bstr .size 12) diff --git a/tests/decode/test5_corner_cases/CMakeLists.txt b/tests/decode/test5_corner_cases/CMakeLists.txt index f353edd9..18a586aa 100644 --- a/tests/decode/test5_corner_cases/CMakeLists.txt +++ b/tests/decode/test5_corner_cases/CMakeLists.txt @@ -53,6 +53,7 @@ set(py_command Intmax2 InvalidIdentifiers Uint64List + BstrSize --decode --git-sha-header --short-names diff --git a/tests/decode/test5_corner_cases/src/main.c b/tests/decode/test5_corner_cases/src/main.c index a7090322..77dfa955 100644 --- a/tests/decode/test5_corner_cases/src/main.c +++ b/tests/decode/test5_corner_cases/src/main.c @@ -2103,4 +2103,24 @@ ZTEST(cbor_decode_test5, test_uint64_list) } +ZTEST(cbor_decode_test5, test_bstr_size) +{ + uint8_t bstr_size_payload1[] = {MAP(1), + 0x61, 's', + 0x4c, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, + 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, + END + }; + + struct BstrSize result; + size_t num_decode; + + zassert_equal(ZCBOR_SUCCESS, cbor_decode_BstrSize(bstr_size_payload1, + sizeof(bstr_size_payload1), &result, &num_decode), NULL); + zassert_equal(sizeof(bstr_size_payload1), num_decode, NULL); + + zassert_equal(12, result.bstr12_m.s.len, NULL); +} + + ZTEST_SUITE(cbor_decode_test5, NULL, NULL, NULL, NULL, NULL); diff --git a/zcbor/zcbor.py b/zcbor/zcbor.py index f227434c..9f810f8a 100755 --- a/zcbor/zcbor.py +++ b/zcbor/zcbor.py @@ -2389,7 +2389,8 @@ def range_checks(self, access): if self.max_size is not None: range_checks.append(f"({access}.len <= {self.max_size})") elif self.type == "OTHER": - range_checks.extend(self.my_types[self.value].range_checks(access)) + if not self.my_types[self.value].single_func_impl_condition(): + range_checks.extend(self.my_types[self.value].range_checks(access)) if range_checks: range_checks[0] = "((" + range_checks[0]