diff --git a/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.cpp b/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.cpp index 4f6c19d8c445..72cf2022489b 100644 --- a/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.cpp +++ b/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.cpp @@ -67,16 +67,17 @@ void SelectiveDecimalColumnReader::seekToRowGroup(int64_t index) { template template -void SelectiveDecimalColumnReader::readHelper(RowSet rows) { - vector_size_t numRows = rows.back() + 1; +void SelectiveDecimalColumnReader::readHelper( + common::Filter* filter, + RowSet rows) { ExtractToReader extractValues(this); - common::AlwaysTrue filter; + common::AlwaysTrue alwaysTrue; DirectRleColumnVisitor< int64_t, common::AlwaysTrue, decltype(extractValues), kDense> - visitor(filter, this, rows, extractValues); + visitor(alwaysTrue, this, rows, extractValues); // decode scale stream if (version_ == velox::dwrf::RleVersion_1) { @@ -96,14 +97,142 @@ void SelectiveDecimalColumnReader::readHelper(RowSet rows) { // reset numValues_ before reading values numValues_ = 0; valueSize_ = sizeof(DataT); + vector_size_t numRows = rows.back() + 1; ensureValuesCapacity(numRows); // decode value stream facebook::velox::dwio::common:: ColumnVisitor - valueVisitor(filter, this, rows, extractValues); + valueVisitor(alwaysTrue, this, rows, extractValues); decodeWithVisitor>(valueDecoder_.get(), valueVisitor); readOffset_ += numRows; + + // Fill decimals before applying filter. + fillDecimals(); + + const auto rawNulls = nullsInReadRange_ + ? (kDense ? nullsInReadRange_->as() : rawResultNulls_) + : nullptr; + // Treat the filter as kAlwaysTrue if any of the following conditions are met: + // 1) No filter found; + // 2) Filter is kIsNotNull but rawNulls==NULL (no elements is null). + auto filterKind = + !filter || (filter->kind() == common::FilterKind::kIsNotNull && !rawNulls) + ? common::FilterKind::kAlwaysTrue + : filter->kind(); + switch (filterKind) { + case common::FilterKind::kAlwaysTrue: + // Simply add all rows to output. + for (vector_size_t i = 0; i < numValues_; i++) { + addOutputRow(rows[i]); + } + break; + case common::FilterKind::kIsNull: + processNulls(true, rows, rawNulls); + break; + case common::FilterKind::kIsNotNull: + processNulls(false, rows, rawNulls); + break; + case common::FilterKind::kBigintRange: + case common::FilterKind::kBigintValuesUsingHashTable: + case common::FilterKind::kBigintValuesUsingBitmask: + case common::FilterKind::kNegatedBigintRange: + case common::FilterKind::kNegatedBigintValuesUsingHashTable: + case common::FilterKind::kNegatedBigintValuesUsingBitmask: + case common::FilterKind::kBigintMultiRange: { + if constexpr (std::is_same_v) { + processFilter(filter, rows, rawNulls); + } else { + VELOX_UNSUPPORTED("Unsupported filter: {}.", (int)filterKind); + } + break; + } + case common::FilterKind::kHugeintValuesUsingHashTable: + case common::FilterKind::kHugeintRange: { + if constexpr (std::is_same_v) { + processFilter(filter, rows, rawNulls); + } else { + VELOX_UNSUPPORTED("Unsupported filter: {}.", (int)filterKind); + } + break; + } + default: + VELOX_UNSUPPORTED("Unsupported filter: {}.", (int)filterKind); + } +} + +template +void SelectiveDecimalColumnReader::processNulls( + const bool isNull, + const RowSet rows, + const uint64_t* rawNulls) { + if (!rawNulls) { + return; + } + auto rawDecimal = values_->asMutable(); + auto rawScale = scaleBuffer_->asMutable(); + + returnReaderNulls_ = false; + anyNulls_ = !isNull; + allNull_ = isNull; + vector_size_t idx = 0; + for (vector_size_t i = 0; i < numValues_; i++) { + if (isNull) { + if (bits::isBitNull(rawNulls, i)) { + bits::setNull(rawResultNulls_, idx); + addOutputRow(rows[i]); + idx++; + } + } else { + if (!bits::isBitNull(rawNulls, i)) { + bits::setNull(rawResultNulls_, idx, false); + rawDecimal[idx] = rawDecimal[i]; + rawScale[idx] = rawScale[i]; + addOutputRow(rows[i]); + idx++; + } + } + } +} + +template +void SelectiveDecimalColumnReader::processFilter( + const common::Filter* filter, + const RowSet rows, + const uint64_t* rawNulls) { + auto rawDecimal = values_->asMutable(); + + returnReaderNulls_ = false; + anyNulls_ = false; + allNull_ = true; + vector_size_t idx = 0; + for (vector_size_t i = 0; i < numValues_; i++) { + if (rawNulls && bits::isBitNull(rawNulls, i)) { + if (filter->testNull()) { + bits::setNull(rawResultNulls_, idx); + addOutputRow(rows[i]); + anyNulls_ = true; + idx++; + } + } else { + bool tested; + if constexpr (std::is_same_v) { + tested = filter->testInt64(rawDecimal[i]); + } else { + tested = filter->testInt128(rawDecimal[i]); + } + + if (tested) { + if (rawNulls) { + bits::setNull(rawResultNulls_, idx, false); + } + rawDecimal[idx] = rawDecimal[i]; + addOutputRow(rows[i]); + allNull_ = false; + idx++; + } + } + } } template @@ -111,14 +240,13 @@ void SelectiveDecimalColumnReader::read( int64_t offset, const RowSet& rows, const uint64_t* incomingNulls) { - VELOX_CHECK(!scanSpec_->filter()); VELOX_CHECK(!scanSpec_->valueHook()); prepareRead(offset, rows, incomingNulls); bool isDense = rows.back() == rows.size() - 1; if (isDense) { - readHelper(rows); + readHelper(scanSpec_->filter(), rows); } else { - readHelper(rows); + readHelper(scanSpec_->filter(), rows); } } @@ -126,16 +254,18 @@ template void SelectiveDecimalColumnReader::getValues( const RowSet& rows, VectorPtr* result) { + rawValues_ = values_->asMutable(); + getIntValues(rows, requestedType_, result); +} + +template +void SelectiveDecimalColumnReader::fillDecimals() { auto nullsPtr = resultNulls() ? resultNulls()->template as() : nullptr; auto scales = scaleBuffer_->as(); auto values = values_->asMutable(); - DecimalUtil::fillDecimals( values, nullsPtr, values, scales, numValues_, scale_); - - rawValues_ = values_->asMutable(); - getIntValues(rows, requestedType_, result); } template class SelectiveDecimalColumnReader; diff --git a/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.h b/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.h index 67a82b051e36..d41de405bdc8 100644 --- a/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.h +++ b/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.h @@ -49,7 +49,17 @@ class SelectiveDecimalColumnReader : public SelectiveColumnReader { private: template - void readHelper(RowSet rows); + void readHelper(common::Filter* filter, RowSet rows); + + void + processNulls(const bool isNull, const RowSet rows, const uint64_t* rawNulls); + + void processFilter( + const common::Filter* filter, + const RowSet rows, + const uint64_t* rawNulls); + + void fillDecimals(); std::unique_ptr> valueDecoder_; std::unique_ptr> scaleDecoder_; diff --git a/velox/exec/tests/AggregateSpillBenchmarkBase.h b/velox/exec/tests/AggregateSpillBenchmarkBase.h index 15b3bb853d66..4d175adaf5ac 100644 --- a/velox/exec/tests/AggregateSpillBenchmarkBase.h +++ b/velox/exec/tests/AggregateSpillBenchmarkBase.h @@ -20,7 +20,7 @@ namespace facebook::velox::exec::test { class AggregateSpillBenchmarkBase : public SpillerBenchmarkBase { public: explicit AggregateSpillBenchmarkBase(Spiller::Type spillerType) - : spillerType_(spillerType){}; + : spillerType_(spillerType) {}; /// Sets up the test. void setUp() override; diff --git a/velox/exec/tests/TableScanTest.cpp b/velox/exec/tests/TableScanTest.cpp index f228c258c96f..e3f2a7b0a11d 100644 --- a/velox/exec/tests/TableScanTest.cpp +++ b/velox/exec/tests/TableScanTest.cpp @@ -45,6 +45,7 @@ #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/expression/ExprToSubfieldFilter.h" +#include "velox/functions/lib/IsNull.h" #include "velox/type/Timestamp.h" #include "velox/type/Type.h" #include "velox/type/tests/SubfieldFiltersBuilder.h" @@ -1841,6 +1842,124 @@ TEST_F(TableScanTest, validFileNoData) { assertQuery(op, split, ""); } +TEST_F(TableScanTest, shortDecimalFilter) { + functions::registerIsNotNullFunction("isnotnull"); + + std::vector> values = { + 123456789123456789L, + 987654321123456L, + std::nullopt, + 2000000000000000L, + 5000000000000000L, + 987654321987654321L, + 100000000000000L, + 1230000000123456L, + 120000000123456L, + std::nullopt}; + auto rowVector = makeRowVector({ + makeNullableFlatVector(values, DECIMAL(18, 6)), + }); + createDuckDbTable({rowVector}); + + auto filePath = facebook::velox::test::getDataFilePath( + "velox/exec/tests", "data/short_decimal.orc"); + auto split = HiveConnectorSplitBuilder(filePath) + .start(0) + .length(fs::file_size(filePath)) + .fileFormat(dwio::common::FileFormat::ORC) + .build(); + + auto rowType = ROW({"d"}, {DECIMAL(18, 6)}); + + // Is not null. + auto op = + PlanBuilder().tableScan(rowType, {}, "isnotnull(d)", rowType).planNode(); + assertQuery(op, split, "SELECT c0 FROM tmp where c0 is not null"); + + // Is null. + op = PlanBuilder().tableScan(rowType, {}, "is_null(d)", rowType).planNode(); + assertQuery(op, split, "SELECT c0 FROM tmp where c0 is null"); + + // BigintRange. + op = + PlanBuilder() + .tableScan( + rowType, + {}, + "d > 2000000000.0::DECIMAL(18, 6) and d < 6000000000.0::DECIMAL(18, 6)", + rowType) + .planNode(); + assertQuery( + op, + split, + "SELECT c0 FROM tmp where c0 > 2000000000.0 and c0 < 6000000000.0"); + + // NegatedBigintRange. + op = + PlanBuilder() + .tableScan( + rowType, + {}, + "not(d between 2000000000.0::DECIMAL(18, 6) and 6000000000.0::DECIMAL(18, 6))", + rowType) + .planNode(); + assertQuery( + op, + split, + "SELECT c0 FROM tmp where c0 < 2000000000.0 or c0 > 6000000000.0"); +} + +TEST_F(TableScanTest, longDecimalFilter) { + functions::registerIsNotNullFunction("isnotnull"); + + std::vector> values = { + HugeInt::parse("123456789123456789123456789" + std::string(9, '0')), + HugeInt::parse("987654321123456789" + std::string(9, '0')), + std::nullopt, + HugeInt::parse("2" + std::string(37, '0')), + HugeInt::parse("5" + std::string(37, '0')), + HugeInt::parse("987654321987654321987654321" + std::string(9, '0')), + HugeInt::parse("1" + std::string(26, '0')), + HugeInt::parse("123000000012345678" + std::string(10, '0')), + HugeInt::parse("120000000123456789" + std::string(9, '0')), + HugeInt::parse("9" + std::string(37, '0'))}; + auto rowVector = makeRowVector({ + makeNullableFlatVector(values, DECIMAL(38, 18)), + }); + createDuckDbTable({rowVector}); + + auto filePath = facebook::velox::test::getDataFilePath( + "velox/exec/tests", "data/long_decimal.orc"); + auto split = HiveConnectorSplitBuilder(filePath) + .start(0) + .length(fs::file_size(filePath)) + .fileFormat(dwio::common::FileFormat::ORC) + .build(); + + auto rowType = ROW({"d"}, {DECIMAL(38, 18)}); + auto op = + PlanBuilder().tableScan(rowType, {}, "isnotnull(d)", rowType).planNode(); + assertQuery(op, split, "SELECT c0 FROM tmp where c0 is not null"); + + // Is null. + op = PlanBuilder().tableScan(rowType, {}, "is_null(d)", rowType).planNode(); + assertQuery(op, split, "SELECT c0 FROM tmp where c0 is null"); + + // HugeintRange. + op = + PlanBuilder() + .tableScan( + rowType, + {}, + "d > 2000000000.0::DECIMAL(38, 18) and d < 6000000000.0::DECIMAL(38, 18)", + rowType) + .planNode(); + assertQuery( + op, + split, + "SELECT c0 FROM tmp where c0 > 2000000000.0 and c0 < 6000000000.0"); +} + // An invalid (size = 0) file. TEST_F(TableScanTest, emptyFile) { auto filePath = TempFilePath::create(); diff --git a/velox/exec/tests/data/long_decimal.orc b/velox/exec/tests/data/long_decimal.orc new file mode 100644 index 000000000000..f732246b469a Binary files /dev/null and b/velox/exec/tests/data/long_decimal.orc differ diff --git a/velox/exec/tests/data/short_decimal.orc b/velox/exec/tests/data/short_decimal.orc new file mode 100644 index 000000000000..d442711fc750 Binary files /dev/null and b/velox/exec/tests/data/short_decimal.orc differ