From f61968b576cd1efe72349d2e61cc9a37c3c07095 Mon Sep 17 00:00:00 2001 From: Fanda Vacek Date: Wed, 30 Oct 2024 17:30:51 +0100 Subject: [PATCH] Implement defined read of very long cpon ints and decimals --- libshvchainpack/c/ccpon.c | 198 +++++++++++++----- .../include/shv/chainpack/rpcvalue.h | 7 +- libshvchainpack/src/chainpackwriter.cpp | 2 +- libshvchainpack/src/cponwriter.cpp | 2 +- libshvchainpack/src/rpcvalue.cpp | 18 +- libshvchainpack/tests/test_cpon.cpp | 29 +++ .../include/shv/core/utils/shvtypeinfo.h | 1 - 7 files changed, 194 insertions(+), 63 deletions(-) diff --git a/libshvchainpack/c/ccpon.c b/libshvchainpack/c/ccpon.c index cb1533ae5..255f2e374 100644 --- a/libshvchainpack/c/ccpon.c +++ b/libshvchainpack/c/ccpon.c @@ -1,7 +1,7 @@ #include #include -#include +#include static inline uint8_t hexify(uint8_t b) { @@ -801,16 +801,92 @@ const char* ccpon_unpack_skip_insignificant(ccpcp_unpack_context* unpack_context } } -static int unpack_int(ccpcp_unpack_context* unpack_context, int64_t *p_val) +static bool add_with_overflow_check_subopt(int64_t a, int b, int64_t *result) +{ + int64_t res = a + b; + if (res < 0) { + *result = INT64_MAX; + return true; + } + *result = res; + return false; +} + +static bool multiply_with_overflow_check_subopt(int64_t a, int b, int64_t* result) { - int64_t val = 0; + // Check for zero multiplication + if (a == 0 || b == 0) { + *result = 0; + return false; // No overflow + } + + assert(a > 0 && b > 0); + + // Check for overflow + if (a > INT64_MAX / b) { + return true; // Overflow occurred + } + + *result = a * b; + return false; // No overflow +} + +#if defined(__has_builtin) +# if __has_builtin(__builtin_add_overflow) +# define HAS_BUILTIN_ADD_OVERFLOW +# endif +# if __has_builtin(__builtin_mul_overflow) +# define HAS_BUILTIN_MUL_OVERFLOW +# endif +#endif + +static bool add_with_overflow_check(int64_t a, int b, int64_t* result) { +#ifdef HAS_BUILTIN_ADD_OVERFLOW + if (sizeof(long long int) == sizeof(int64_t)) { + long long res; + bool ret = __builtin_saddll_overflow(a, b, &res); + *result = (int64_t)res; + return ret; + } +#endif + return add_with_overflow_check_subopt(a, b, result); +} + +static bool multiply_with_overflow_check(int64_t a, int b, int64_t* result) { +#ifdef HAS_BUILTIN_MUL_OVERFLOW + if (sizeof(long long int) == 64) { + long long res; + bool ret = __builtin_smulll_overflow(a, b, &res); + *result = (int64_t)res; + return ret; + } +#endif + return multiply_with_overflow_check_subopt(a, b, result); +} +#undef HAS_BUILTIN_ADD_OVERFLOW +#undef HAS_BUILTIN_MUL_OVERFLOW + +typedef struct { + int64_t value; + int digit_cnt; + int signum; + bool is_overflow; +} read_int_result; + +static int unpack_int_to_result(ccpcp_unpack_context* unpack_context, int64_t init_value, read_int_result *result) +{ + result->digit_cnt = 0; + result->is_overflow = false; + result->value = init_value; int neg = 0; int base = 10; int n = 0; + bool starts_with_signum = false; for (; ; n++) { const char *p = ccpcp_unpack_take_byte(unpack_context); if(!p) goto eonumb; + int digit = -1; uint8_t b = (uint8_t)(*p); switch (b) { case '+': @@ -819,20 +895,22 @@ static int unpack_int(ccpcp_unpack_context* unpack_context, int64_t *p_val) unpack_context->current--; goto eonumb; } + starts_with_signum = true; if(b == '-') neg = 1; break; - case 'x': - if(n == 1 && val != 0) { - unpack_context->current--; - goto eonumb; + case 'x': { + int expected_x_pos = 1 + (starts_with_signum? 1: 0); + if(n == expected_x_pos && result->value == 0) { + base = 16; + result->digit_cnt = 0; } - if(n != 1) { + else { unpack_context->current--; goto eonumb; } - base = 16; break; + } case '0': case '1': case '2': @@ -843,8 +921,7 @@ static int unpack_int(ccpcp_unpack_context* unpack_context, int64_t *p_val) case '7': case '8': case '9': - val *= base; - val += b - '0'; + digit = b - '0'; break; case 'a': case 'b': @@ -856,8 +933,7 @@ static int unpack_int(ccpcp_unpack_context* unpack_context, int64_t *p_val) unpack_context->current--; goto eonumb; } - val *= base; - val += b - 'a' + 10; + digit = b - 'a' + 10; break; case 'A': case 'B': @@ -869,22 +945,36 @@ static int unpack_int(ccpcp_unpack_context* unpack_context, int64_t *p_val) unpack_context->current--; goto eonumb; } - val *= base; - val += b - 'A' + 10; + digit = b - 'A' + 10; break; default: unpack_context->current--; goto eonumb; } + if (digit >= 0) { + if (multiply_with_overflow_check(result->value, base, &result->value)) { + result->is_overflow = true; + } + else if (add_with_overflow_check(result->value, digit, &result->value)) { + result->is_overflow = true; + } + else { + result->digit_cnt++; + } + } } eonumb: - if(neg) - val = -val; + result->signum = neg? -1: 1; + return n; +} +static int unpack_int(ccpcp_unpack_context* unpack_context, int64_t *p_val) +{ + read_int_result res; + int n = unpack_int_to_result(unpack_context, 0, &res); if(p_val) - *p_val = val; + *p_val = res.value * res.signum; return n; } - void ccpon_unpack_date_time(ccpcp_unpack_context *unpack_context, struct tm *tm, int *msec, int *utc_offset) { tm->tm_year = 0; @@ -902,7 +992,7 @@ void ccpon_unpack_date_time(ccpcp_unpack_context *unpack_context, struct tm *tm, int64_t val; int n = unpack_int(unpack_context, &val); - if(n < 0) { + if(n == 0) { unpack_context->err_no = CCPCP_RC_MALFORMED_INPUT; unpack_context->err_msg = "Malformed year in DateTime"; return; @@ -917,7 +1007,7 @@ void ccpon_unpack_date_time(ccpcp_unpack_context *unpack_context, struct tm *tm, } n = unpack_int(unpack_context, &val); - if(n <= 0) { + if(n == 0) { unpack_context->err_no = CCPCP_RC_MALFORMED_INPUT; unpack_context->err_msg = "Malformed month in DateTime"; return; @@ -932,7 +1022,7 @@ void ccpon_unpack_date_time(ccpcp_unpack_context *unpack_context, struct tm *tm, } n = unpack_int(unpack_context, &val); - if(n <= 0) { + if(n == 0) { unpack_context->err_no = CCPCP_RC_MALFORMED_INPUT; unpack_context->err_msg = "Malformed day in DateTime"; return; @@ -947,7 +1037,7 @@ void ccpon_unpack_date_time(ccpcp_unpack_context *unpack_context, struct tm *tm, } n = unpack_int(unpack_context, &val); - if(n <= 0) { + if(n == 0) { unpack_context->err_no = CCPCP_RC_MALFORMED_INPUT; unpack_context->err_msg = "Malformed hour in DateTime"; return; @@ -957,7 +1047,7 @@ void ccpon_unpack_date_time(ccpcp_unpack_context *unpack_context, struct tm *tm, UNPACK_TAKE_BYTE(p); n = unpack_int(unpack_context, &val); - if(n <= 0) { + if(n == 0) { unpack_context->err_no = CCPCP_RC_MALFORMED_INPUT; unpack_context->err_msg = "Malformed minutes in DateTime"; return; @@ -967,7 +1057,7 @@ void ccpon_unpack_date_time(ccpcp_unpack_context *unpack_context, struct tm *tm, UNPACK_TAKE_BYTE(p); n = unpack_int(unpack_context, &val); - if(n <= 0) { + if(n == 0) { unpack_context->err_no = CCPCP_RC_MALFORMED_INPUT; unpack_context->err_msg = "Malformed seconds in DateTime"; return; @@ -978,7 +1068,7 @@ void ccpon_unpack_date_time(ccpcp_unpack_context *unpack_context, struct tm *tm, if(p) { if(*p == '.') { n = unpack_int(unpack_context, &val); - if(n < 0) + if(n == 0) return; *msec = (int)val; p = ccpcp_unpack_take_byte(unpack_context); @@ -1322,25 +1412,29 @@ void ccpon_unpack_next (ccpcp_unpack_context* unpack_context) case '+': case '-': { // number - int64_t mantisa = 0; + int64_t mantissa = 0; int64_t exponent = 0; - int64_t decimals = 0; int dec_cnt = 0; struct { uint8_t is_decimal: 1; uint8_t is_uint: 1; uint8_t is_neg: 1; + uint8_t is_overflow: 1; } flags; flags.is_decimal = 0; flags.is_uint = 0; flags.is_neg = 0; - flags.is_neg = *p == '-'; - if(!flags.is_neg) - unpack_context->current--; - int n = unpack_int(unpack_context, &mantisa); - if(n < 0) + unpack_context->current--; + + read_int_result result; + unpack_int_to_result(unpack_context, 0, &result); + if(result.digit_cnt == 0) { UNPACK_ERROR(CCPCP_RC_MALFORMED_INPUT, "Malformed number.") + } + mantissa = result.value; + flags.is_overflow = result.is_overflow; + flags.is_neg = result.signum == -1; p = ccpcp_unpack_take_byte(unpack_context); while(p) { if(*p == CCPON_C_UNSIGNED_END) { @@ -1349,18 +1443,20 @@ void ccpon_unpack_next (ccpcp_unpack_context* unpack_context) } if(*p == '.') { flags.is_decimal = 1; - n = unpack_int(unpack_context, &decimals); - if(n < 0) - UNPACK_ERROR(CCPCP_RC_MALFORMED_INPUT, "Malformed number decimal part.") - dec_cnt = n; + unpack_int_to_result(unpack_context, mantissa, &result); + mantissa = result.value; + flags.is_overflow |= result.is_overflow; + dec_cnt = result.digit_cnt; p = ccpcp_unpack_take_byte(unpack_context); if(!p) break; } if(*p == 'e' || *p == 'E') { flags.is_decimal = 1; - n = unpack_int(unpack_context, &exponent); - if(n < 0) + unpack_int_to_result(unpack_context, 0, &result); + flags.is_overflow |= result.is_overflow; + exponent = result.value * result.signum; + if(result.digit_cnt == 0) UNPACK_ERROR(CCPCP_RC_MALFORMED_INPUT, "Malformed number exponetional part.") break; } @@ -1371,22 +1467,28 @@ void ccpon_unpack_next (ccpcp_unpack_context* unpack_context) break; } if(flags.is_decimal) { - int i; - for (i = 0; i < dec_cnt; ++i) - mantisa *= 10; - mantisa += decimals; unpack_context->item.type = CCPCP_ITEM_DECIMAL; - unpack_context->item.as.Decimal.mantisa = flags.is_neg? -mantisa: mantisa; - unpack_context->item.as.Decimal.exponent = (int)(exponent - dec_cnt); + if (flags.is_overflow && dec_cnt == 0) { + unpack_context->item.as.Decimal.mantisa = flags.is_neg? INT64_MIN: INT64_MAX; + unpack_context->item.as.Decimal.exponent = 0; + } + else { + unpack_context->item.as.Decimal.mantisa = flags.is_neg? -mantissa: mantissa; + unpack_context->item.as.Decimal.exponent = (int)(exponent - dec_cnt); + } } else if(flags.is_uint) { unpack_context->item.type = CCPCP_ITEM_UINT; - unpack_context->item.as.UInt = (uint64_t)mantisa; - + unpack_context->item.as.UInt = (uint64_t)(flags.is_overflow? INT64_MAX: mantissa); } else { unpack_context->item.type = CCPCP_ITEM_INT; - unpack_context->item.as.Int = flags.is_neg? -mantisa: mantisa; + if (flags.is_overflow) { + unpack_context->item.as.Int = flags.is_neg? INT64_MIN: INT64_MAX; + } + else { + unpack_context->item.as.Int = flags.is_neg? -mantissa: mantissa; + } } unpack_context->err_no = CCPCP_RC_OK; break; diff --git a/libshvchainpack/include/shv/chainpack/rpcvalue.h b/libshvchainpack/include/shv/chainpack/rpcvalue.h index 0b3e5ef77..0648d9bc4 100644 --- a/libshvchainpack/include/shv/chainpack/rpcvalue.h +++ b/libshvchainpack/include/shv/chainpack/rpcvalue.h @@ -142,7 +142,7 @@ class SHVCHAINPACK_DECL_EXPORT RpcDecimal { static constexpr int Base = 10; struct Num { - int64_t mantisa = 0; + int64_t mantissa = 0; int exponent = 0; Num(); @@ -152,10 +152,11 @@ class SHVCHAINPACK_DECL_EXPORT RpcDecimal Num m_num; public: RpcDecimal(); - RpcDecimal(int64_t mantisa, int exponent); + RpcDecimal(int64_t mantissa, int exponent); RpcDecimal(int dec_places); - int64_t mantisa() const; + [[deprecated]] int64_t mantisa() const { return mantissa(); } + int64_t mantissa() const; int exponent() const; static RpcDecimal fromDouble(double d, int round_to_dec_places); diff --git a/libshvchainpack/src/chainpackwriter.cpp b/libshvchainpack/src/chainpackwriter.cpp index 97d742cc3..a623f7066 100644 --- a/libshvchainpack/src/chainpackwriter.cpp +++ b/libshvchainpack/src/chainpackwriter.cpp @@ -172,7 +172,7 @@ ChainPackWriter &ChainPackWriter::write_p(double value) ChainPackWriter &ChainPackWriter::write_p(RpcValue::Decimal value) { - cchainpack_pack_decimal(&m_outCtx, value.mantisa(), value.exponent()); + cchainpack_pack_decimal(&m_outCtx, value.mantissa(), value.exponent()); return *this; } diff --git a/libshvchainpack/src/cponwriter.cpp b/libshvchainpack/src/cponwriter.cpp index ddca5ca1f..036035072 100644 --- a/libshvchainpack/src/cponwriter.cpp +++ b/libshvchainpack/src/cponwriter.cpp @@ -375,7 +375,7 @@ CponWriter &CponWriter::write_p(double value) CponWriter &CponWriter::write_p(RpcValue::Decimal value) { - ccpon_pack_decimal(&m_outCtx, value.mantisa(), value.exponent()); + ccpon_pack_decimal(&m_outCtx, value.mantissa(), value.exponent()); return *this; } diff --git a/libshvchainpack/src/rpcvalue.cpp b/libshvchainpack/src/rpcvalue.cpp index 7b65b2052..626313106 100644 --- a/libshvchainpack/src/rpcvalue.cpp +++ b/libshvchainpack/src/rpcvalue.cpp @@ -220,7 +220,7 @@ bool RpcValue::isDefaultValue() const case RpcValue::Type::Int: return (toInt() == 0); case RpcValue::Type::UInt: return (toUInt() == 0); case RpcValue::Type::DateTime: return (toDateTime().msecsSinceEpoch() == 0); - case RpcValue::Type::Decimal: return (toDecimal().mantisa() == 0); + case RpcValue::Type::Decimal: return (toDecimal().mantissa() == 0); case RpcValue::Type::Double: return (toDouble() == 0); case RpcValue::Type::String: return (asString().empty()); case RpcValue::Type::Blob: return (asBlob().empty()); @@ -416,7 +416,7 @@ bool RpcValue::toBool() const if constexpr (std::is_same_v) { return x.msecsSinceEpoch() != 0; } else if constexpr (std::is_same_v) { - return x.mantisa() != 0; + return x.mantissa() != 0; } else if constexpr (std::is_arithmetic_v) { return x != 0; } else if constexpr (std::is_same_v) { @@ -836,7 +836,7 @@ RpcDecimal::Num::Num() } RpcDecimal::Num::Num(int64_t m, int e) - : mantisa(m) + : mantissa(m) , exponent(e) { } @@ -1413,8 +1413,8 @@ void RpcMetaData::swap(RpcMetaData &o) noexcept RpcDecimal::RpcDecimal() = default; -RpcDecimal::RpcDecimal(int64_t mantisa, int exponent) - : m_num{mantisa, exponent} +RpcDecimal::RpcDecimal(int64_t mantissa, int exponent) + : m_num{mantissa, exponent} { } @@ -1423,9 +1423,9 @@ RpcDecimal::RpcDecimal(int dec_places) { } -int64_t RpcDecimal::mantisa() const +int64_t RpcDecimal::mantissa() const { - return m_num.mantisa; + return m_num.mantissa; } int RpcDecimal::exponent() const @@ -1448,12 +1448,12 @@ RpcDecimal RpcDecimal::fromDouble(double d, int round_to_dec_places) void RpcDecimal::setDouble(double d) { RpcDecimal dc = fromDouble(d, -m_num.exponent); - m_num.mantisa = dc.mantisa(); + m_num.mantissa = dc.mantissa(); } double RpcDecimal::toDouble() const { - auto ret = static_cast(mantisa()); + auto ret = static_cast(mantissa()); int exp = exponent(); if(exp > 0) for(; exp > 0; exp--) ret *= Base; diff --git a/libshvchainpack/tests/test_cpon.cpp b/libshvchainpack/tests/test_cpon.cpp index f4ff2b137..469946cdb 100644 --- a/libshvchainpack/tests/test_cpon.cpp +++ b/libshvchainpack/tests/test_cpon.cpp @@ -7,6 +7,7 @@ #include #include +#include using namespace shv::chainpack; using std::string; @@ -39,6 +40,12 @@ namespace shv::chainpack { doctest::String toString(const RpcValue& value) { return value.toCpon().c_str(); } +doctest::String toString(const RpcDecimal& value) { + std::ostringstream sb; + sb << "RpcDecimal(" << value.mantissa() << ',' << value.exponent() << ')'; + return sb.str().c_str(); +} + } DOCTEST_TEST_CASE("Cpon") @@ -391,4 +398,26 @@ DOCTEST_TEST_CASE("Cpon") const string input = "invalid input"; REQUIRE_THROWS_AS(shv::chainpack::RpcValue::fromCpon(input), shv::chainpack::ParseException); } + + DOCTEST_SUBCASE("Very long decimals") + { + // read very long decimal without overflow error, value is capped + REQUIRE(RpcValue::fromCpon("123456789012345678901234567890123456789012345678901234567890").toInt64() == std::numeric_limits::max()); + REQUIRE(RpcValue::fromCpon("9223372036854775806").toInt64() == 9223372036854775806LL); + REQUIRE(RpcValue::fromCpon("9223372036854775807").toInt64() == std::numeric_limits::max()); + REQUIRE(RpcValue::fromCpon("9223372036854775808").toInt64() == std::numeric_limits::max()); + REQUIRE(RpcValue::fromCpon("0x7FFFFFFFFFFFFFFE").toInt64() == 0x7FFFFFFFFFFFFFFELL); + REQUIRE(RpcValue::fromCpon("0x7FFFFFFFFFFFFFFF").toInt64() == std::numeric_limits::max()); + REQUIRE(RpcValue::fromCpon("0x8000000000000000").toInt64() == std::numeric_limits::max()); + REQUIRE(RpcValue::fromCpon("-123456789012345678901234567890123456789012345678901234567890").toInt64() == std::numeric_limits::min()); + REQUIRE(RpcValue::fromCpon("-9223372036854775807").toInt64() == -9223372036854775807LL); + REQUIRE(RpcValue::fromCpon("-9223372036854775808").toInt64() == std::numeric_limits::min()); + REQUIRE(RpcValue::fromCpon("-9223372036854775809").toInt64() == std::numeric_limits::min()); + REQUIRE(RpcValue::fromCpon("-0x7FFFFFFFFFFFFFFF").toInt64() == -0x7FFFFFFFFFFFFFFFLL); + REQUIRE(RpcValue::fromCpon("-0x8000000000000000").toInt64() == std::numeric_limits::min()); + REQUIRE(RpcValue::fromCpon("-0x8000000000000001").toInt64() == std::numeric_limits::min()); + REQUIRE(RpcValue::fromCpon("1.23456789012345678901234567890123456789012345678901234567890").toDecimal() == RpcDecimal(1234567890123456789LL, -18)); + REQUIRE(RpcValue::fromCpon("12345678901234567890123456789012345678901234567890123456.7890").toDecimal() == RpcDecimal(std::numeric_limits::max(), 0)); + REQUIRE(RpcValue::fromCpon("123456789012345678901234567890123456789012345678901234567890.").toDecimal() == RpcDecimal(std::numeric_limits::max(), 0)); + } } diff --git a/libshvcore/include/shv/core/utils/shvtypeinfo.h b/libshvcore/include/shv/core/utils/shvtypeinfo.h index 0500fbae2..ab314c284 100644 --- a/libshvcore/include/shv/core/utils/shvtypeinfo.h +++ b/libshvcore/include/shv/core/utils/shvtypeinfo.h @@ -5,7 +5,6 @@ #include #include -#include #include namespace shv::chainpack { class MetaMethod; }