Skip to content

Commit

Permalink
Implement defined read of very long cpon ints and decimals
Browse files Browse the repository at this point in the history
  • Loading branch information
Fanda Vacek committed Oct 30, 2024
1 parent b9e1e9b commit f61968b
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 63 deletions.
198 changes: 150 additions & 48 deletions libshvchainpack/c/ccpon.c
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <shv/chainpack/ccpon.h>

#include <string.h>
#include <math.h>
#include <assert.h>

static inline uint8_t hexify(uint8_t b)
{
Expand Down Expand Up @@ -801,16 +801,92 @@ const char* ccpon_unpack_skip_insignificant(ccpcp_unpack_context* unpack_context
}
}

static int unpack_int(ccpcp_unpack_context* unpack_context, int64_t *p_val)
static bool add_with_overflow_check_subopt(int64_t a, int b, int64_t *result)
{
int64_t res = a + b;
if (res < 0) {
*result = INT64_MAX;
return true;
}
*result = res;
return false;
}

static bool multiply_with_overflow_check_subopt(int64_t a, int b, int64_t* result)
{
int64_t val = 0;
// Check for zero multiplication
if (a == 0 || b == 0) {
*result = 0;
return false; // No overflow
}

assert(a > 0 && b > 0);

// Check for overflow
if (a > INT64_MAX / b) {
return true; // Overflow occurred
}

*result = a * b;
return false; // No overflow
}

#if defined(__has_builtin)
# if __has_builtin(__builtin_add_overflow)
# define HAS_BUILTIN_ADD_OVERFLOW
# endif
# if __has_builtin(__builtin_mul_overflow)
# define HAS_BUILTIN_MUL_OVERFLOW
# endif
#endif

static bool add_with_overflow_check(int64_t a, int b, int64_t* result) {
#ifdef HAS_BUILTIN_ADD_OVERFLOW
if (sizeof(long long int) == sizeof(int64_t)) {
long long res;
bool ret = __builtin_saddll_overflow(a, b, &res);
*result = (int64_t)res;
return ret;
}
#endif
return add_with_overflow_check_subopt(a, b, result);
}

static bool multiply_with_overflow_check(int64_t a, int b, int64_t* result) {
#ifdef HAS_BUILTIN_MUL_OVERFLOW
if (sizeof(long long int) == 64) {
long long res;
bool ret = __builtin_smulll_overflow(a, b, &res);
*result = (int64_t)res;
return ret;
}
#endif
return multiply_with_overflow_check_subopt(a, b, result);
}
#undef HAS_BUILTIN_ADD_OVERFLOW
#undef HAS_BUILTIN_MUL_OVERFLOW

typedef struct {
int64_t value;
int digit_cnt;
int signum;
bool is_overflow;
} read_int_result;

static int unpack_int_to_result(ccpcp_unpack_context* unpack_context, int64_t init_value, read_int_result *result)
{
result->digit_cnt = 0;
result->is_overflow = false;
result->value = init_value;
int neg = 0;
int base = 10;
int n = 0;
bool starts_with_signum = false;
for (; ; n++) {
const char *p = ccpcp_unpack_take_byte(unpack_context);
if(!p)
goto eonumb;
int digit = -1;
uint8_t b = (uint8_t)(*p);
switch (b) {
case '+':
Expand All @@ -819,20 +895,22 @@ static int unpack_int(ccpcp_unpack_context* unpack_context, int64_t *p_val)
unpack_context->current--;
goto eonumb;
}
starts_with_signum = true;
if(b == '-')
neg = 1;
break;
case 'x':
if(n == 1 && val != 0) {
unpack_context->current--;
goto eonumb;
case 'x': {
int expected_x_pos = 1 + (starts_with_signum? 1: 0);
if(n == expected_x_pos && result->value == 0) {
base = 16;
result->digit_cnt = 0;
}
if(n != 1) {
else {
unpack_context->current--;
goto eonumb;
}
base = 16;
break;
}
case '0':
case '1':
case '2':
Expand All @@ -843,8 +921,7 @@ static int unpack_int(ccpcp_unpack_context* unpack_context, int64_t *p_val)
case '7':
case '8':
case '9':
val *= base;
val += b - '0';
digit = b - '0';
break;
case 'a':
case 'b':
Expand All @@ -856,8 +933,7 @@ static int unpack_int(ccpcp_unpack_context* unpack_context, int64_t *p_val)
unpack_context->current--;
goto eonumb;
}
val *= base;
val += b - 'a' + 10;
digit = b - 'a' + 10;
break;
case 'A':
case 'B':
Expand All @@ -869,22 +945,36 @@ static int unpack_int(ccpcp_unpack_context* unpack_context, int64_t *p_val)
unpack_context->current--;
goto eonumb;
}
val *= base;
val += b - 'A' + 10;
digit = b - 'A' + 10;
break;
default:
unpack_context->current--;
goto eonumb;
}
if (digit >= 0) {
if (multiply_with_overflow_check(result->value, base, &result->value)) {
result->is_overflow = true;
}
else if (add_with_overflow_check(result->value, digit, &result->value)) {
result->is_overflow = true;
}
else {
result->digit_cnt++;
}
}
}
eonumb:
if(neg)
val = -val;
result->signum = neg? -1: 1;
return n;
}
static int unpack_int(ccpcp_unpack_context* unpack_context, int64_t *p_val)
{
read_int_result res;
int n = unpack_int_to_result(unpack_context, 0, &res);
if(p_val)
*p_val = val;
*p_val = res.value * res.signum;
return n;
}

void ccpon_unpack_date_time(ccpcp_unpack_context *unpack_context, struct tm *tm, int *msec, int *utc_offset)
{
tm->tm_year = 0;
Expand All @@ -902,7 +992,7 @@ void ccpon_unpack_date_time(ccpcp_unpack_context *unpack_context, struct tm *tm,

int64_t val;
int n = unpack_int(unpack_context, &val);
if(n < 0) {
if(n == 0) {
unpack_context->err_no = CCPCP_RC_MALFORMED_INPUT;
unpack_context->err_msg = "Malformed year in DateTime";
return;
Expand All @@ -917,7 +1007,7 @@ void ccpon_unpack_date_time(ccpcp_unpack_context *unpack_context, struct tm *tm,
}

n = unpack_int(unpack_context, &val);
if(n <= 0) {
if(n == 0) {
unpack_context->err_no = CCPCP_RC_MALFORMED_INPUT;
unpack_context->err_msg = "Malformed month in DateTime";
return;
Expand All @@ -932,7 +1022,7 @@ void ccpon_unpack_date_time(ccpcp_unpack_context *unpack_context, struct tm *tm,
}

n = unpack_int(unpack_context, &val);
if(n <= 0) {
if(n == 0) {
unpack_context->err_no = CCPCP_RC_MALFORMED_INPUT;
unpack_context->err_msg = "Malformed day in DateTime";
return;
Expand All @@ -947,7 +1037,7 @@ void ccpon_unpack_date_time(ccpcp_unpack_context *unpack_context, struct tm *tm,
}

n = unpack_int(unpack_context, &val);
if(n <= 0) {
if(n == 0) {
unpack_context->err_no = CCPCP_RC_MALFORMED_INPUT;
unpack_context->err_msg = "Malformed hour in DateTime";
return;
Expand All @@ -957,7 +1047,7 @@ void ccpon_unpack_date_time(ccpcp_unpack_context *unpack_context, struct tm *tm,
UNPACK_TAKE_BYTE(p);

n = unpack_int(unpack_context, &val);
if(n <= 0) {
if(n == 0) {
unpack_context->err_no = CCPCP_RC_MALFORMED_INPUT;
unpack_context->err_msg = "Malformed minutes in DateTime";
return;
Expand All @@ -967,7 +1057,7 @@ void ccpon_unpack_date_time(ccpcp_unpack_context *unpack_context, struct tm *tm,
UNPACK_TAKE_BYTE(p);

n = unpack_int(unpack_context, &val);
if(n <= 0) {
if(n == 0) {
unpack_context->err_no = CCPCP_RC_MALFORMED_INPUT;
unpack_context->err_msg = "Malformed seconds in DateTime";
return;
Expand All @@ -978,7 +1068,7 @@ void ccpon_unpack_date_time(ccpcp_unpack_context *unpack_context, struct tm *tm,
if(p) {
if(*p == '.') {
n = unpack_int(unpack_context, &val);
if(n < 0)
if(n == 0)
return;
*msec = (int)val;
p = ccpcp_unpack_take_byte(unpack_context);
Expand Down Expand Up @@ -1322,25 +1412,29 @@ void ccpon_unpack_next (ccpcp_unpack_context* unpack_context)
case '+':
case '-': {
// number
int64_t mantisa = 0;
int64_t mantissa = 0;
int64_t exponent = 0;
int64_t decimals = 0;
int dec_cnt = 0;
struct {
uint8_t is_decimal: 1;
uint8_t is_uint: 1;
uint8_t is_neg: 1;
uint8_t is_overflow: 1;
} flags;
flags.is_decimal = 0;
flags.is_uint = 0;
flags.is_neg = 0;

flags.is_neg = *p == '-';
if(!flags.is_neg)
unpack_context->current--;
int n = unpack_int(unpack_context, &mantisa);
if(n < 0)
unpack_context->current--;

read_int_result result;
unpack_int_to_result(unpack_context, 0, &result);
if(result.digit_cnt == 0) {
UNPACK_ERROR(CCPCP_RC_MALFORMED_INPUT, "Malformed number.")
}
mantissa = result.value;
flags.is_overflow = result.is_overflow;
flags.is_neg = result.signum == -1;
p = ccpcp_unpack_take_byte(unpack_context);
while(p) {
if(*p == CCPON_C_UNSIGNED_END) {
Expand All @@ -1349,18 +1443,20 @@ void ccpon_unpack_next (ccpcp_unpack_context* unpack_context)
}
if(*p == '.') {
flags.is_decimal = 1;
n = unpack_int(unpack_context, &decimals);
if(n < 0)
UNPACK_ERROR(CCPCP_RC_MALFORMED_INPUT, "Malformed number decimal part.")
dec_cnt = n;
unpack_int_to_result(unpack_context, mantissa, &result);
mantissa = result.value;
flags.is_overflow |= result.is_overflow;
dec_cnt = result.digit_cnt;
p = ccpcp_unpack_take_byte(unpack_context);
if(!p)
break;
}
if(*p == 'e' || *p == 'E') {
flags.is_decimal = 1;
n = unpack_int(unpack_context, &exponent);
if(n < 0)
unpack_int_to_result(unpack_context, 0, &result);
flags.is_overflow |= result.is_overflow;
exponent = result.value * result.signum;
if(result.digit_cnt == 0)
UNPACK_ERROR(CCPCP_RC_MALFORMED_INPUT, "Malformed number exponetional part.")
break;
}
Expand All @@ -1371,22 +1467,28 @@ void ccpon_unpack_next (ccpcp_unpack_context* unpack_context)
break;
}
if(flags.is_decimal) {
int i;
for (i = 0; i < dec_cnt; ++i)
mantisa *= 10;
mantisa += decimals;
unpack_context->item.type = CCPCP_ITEM_DECIMAL;
unpack_context->item.as.Decimal.mantisa = flags.is_neg? -mantisa: mantisa;
unpack_context->item.as.Decimal.exponent = (int)(exponent - dec_cnt);
if (flags.is_overflow && dec_cnt == 0) {
unpack_context->item.as.Decimal.mantisa = flags.is_neg? INT64_MIN: INT64_MAX;
unpack_context->item.as.Decimal.exponent = 0;
}
else {
unpack_context->item.as.Decimal.mantisa = flags.is_neg? -mantissa: mantissa;
unpack_context->item.as.Decimal.exponent = (int)(exponent - dec_cnt);
}
}
else if(flags.is_uint) {
unpack_context->item.type = CCPCP_ITEM_UINT;
unpack_context->item.as.UInt = (uint64_t)mantisa;

unpack_context->item.as.UInt = (uint64_t)(flags.is_overflow? INT64_MAX: mantissa);
}
else {
unpack_context->item.type = CCPCP_ITEM_INT;
unpack_context->item.as.Int = flags.is_neg? -mantisa: mantisa;
if (flags.is_overflow) {
unpack_context->item.as.Int = flags.is_neg? INT64_MIN: INT64_MAX;
}
else {
unpack_context->item.as.Int = flags.is_neg? -mantissa: mantissa;
}
}
unpack_context->err_no = CCPCP_RC_OK;
break;
Expand Down
7 changes: 4 additions & 3 deletions libshvchainpack/include/shv/chainpack/rpcvalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class SHVCHAINPACK_DECL_EXPORT RpcDecimal
{
static constexpr int Base = 10;
struct Num {
int64_t mantisa = 0;
int64_t mantissa = 0;
int exponent = 0;

Num();
Expand All @@ -152,10 +152,11 @@ class SHVCHAINPACK_DECL_EXPORT RpcDecimal
Num m_num;
public:
RpcDecimal();
RpcDecimal(int64_t mantisa, int exponent);
RpcDecimal(int64_t mantissa, int exponent);
RpcDecimal(int dec_places);

int64_t mantisa() const;
[[deprecated]] int64_t mantisa() const { return mantissa(); }
int64_t mantissa() const;
int exponent() const;

static RpcDecimal fromDouble(double d, int round_to_dec_places);
Expand Down
2 changes: 1 addition & 1 deletion libshvchainpack/src/chainpackwriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ ChainPackWriter &ChainPackWriter::write_p(double value)

ChainPackWriter &ChainPackWriter::write_p(RpcValue::Decimal value)
{
cchainpack_pack_decimal(&m_outCtx, value.mantisa(), value.exponent());
cchainpack_pack_decimal(&m_outCtx, value.mantissa(), value.exponent());
return *this;
}

Expand Down
Loading

0 comments on commit f61968b

Please sign in to comment.