Skip to content

Commit

Permalink
Support to truncate the fraction part of a decimal number when castin…
Browse files Browse the repository at this point in the history
…g decimal to an integer (6208)
  • Loading branch information
rui-mo committed Aug 24, 2023
1 parent c2e6d61 commit 0034e47
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 5 deletions.
15 changes: 10 additions & 5 deletions velox/expression/CastExpr-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ VectorPtr CastExpr::applyDecimalToFloatCast(
const TypePtr& toType) {
using To = typename TypeTraits<ToKind>::NativeType;

const auto& queryConfig = context.execCtx()->queryCtx()->queryConfig();
VectorPtr result;
context.ensureWritable(rows, toType, result);
(*result).clearNulls(rows);
Expand All @@ -292,6 +293,7 @@ VectorPtr CastExpr::applyDecimalToIntegralCast(
const TypePtr& fromType,
const TypePtr& toType) {
using To = typename TypeTraits<ToKind>::NativeType;
const auto& queryConfig = context.execCtx()->queryCtx()->queryConfig();

VectorPtr result;
context.ensureWritable(rows, toType, result);
Expand All @@ -303,11 +305,14 @@ VectorPtr CastExpr::applyDecimalToIntegralCast(
applyToSelectedNoThrowLocal(context, rows, result, [&](int row) {
auto value = simpleInput->valueAt(row);
auto integralPart = value / scaleFactor;
auto fractionPart = value % scaleFactor;
auto sign = value >= 0 ? 1 : -1;
bool needsRoundUp =
(scaleFactor != 1) && (sign * fractionPart >= (scaleFactor >> 1));
integralPart += needsRoundUp ? sign : 0;
if (!queryConfig.isCastToIntByTruncate()) {
auto fractionPart = value % scaleFactor;
auto sign = value >= 0 ? 1 : -1;
bool needsRoundUp =
(scaleFactor != 1) && (sign * fractionPart >= (scaleFactor >> 1));
integralPart += needsRoundUp ? sign : 0;
}

if (integralPart > std::numeric_limits<To>::max() ||
integralPart < std::numeric_limits<To>::min()) {
if (setNullInResultAtError()) {
Expand Down
36 changes: 36 additions & 0 deletions velox/expression/tests/CastExprTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ class CastExprTest : public functions::test::CastBaseTest {

template <typename T>
void testDecimalToIntegralCasts() {
setCastIntByTruncate(false);
auto shortFlat = makeNullableFlatVector<int64_t>(
{-300,
-260,
Expand Down Expand Up @@ -355,6 +356,41 @@ class CastExprTest : public functions::test::CastBaseTest {
69,
72,
std::nullopt}));

setCastIntByTruncate(true);
testComplexCast(
"c0",
shortFlat,
makeNullableFlatVector<T>(
{-3,
-2 /*-2.6 truncated to -2*/,
-2 /*-2.3 truncated to -2*/,
-2,
-1,
0,
55,
57 /*57.49 truncated to 57*/,
57 /*57.55 truncated to 57*/,
69,
72,
std::nullopt}));

testComplexCast(
"c0",
longFlat,
makeNullableFlatVector<T>(
{-3,
-2 /*-2.55 truncated to -2*/,
-2 /*-2.45 truncated to -2*/,
-2,
-1,
0,
55,
55 /* 55.49 truncated to 55*/,
55 /* 55.99 truncated to 55*/,
69,
72,
std::nullopt}));
}

template <typename T>
Expand Down

0 comments on commit 0034e47

Please sign in to comment.