diff --git a/velox/functions/lib/SIMDComparisonUtil.h b/velox/functions/lib/SIMDComparisonUtil.h new file mode 100644 index 000000000000..287fc8c56212 --- /dev/null +++ b/velox/functions/lib/SIMDComparisonUtil.h @@ -0,0 +1,302 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/expression/VectorFunction.h" + +namespace facebook::velox::functions { + +namespace detail { + +template +inline auto loadSimdData(const T* rawData, vector_size_t offset) { + using d_type = xsimd::batch; + if constexpr (kIsConstant) { + return xsimd::broadcast(rawData[0]); + } + return d_type::load_unaligned(rawData + offset); +} + +inline uint64_t to64Bits(const int8_t* resultData) { + using d_type = xsimd::batch; + constexpr auto numScalarElements = d_type::size; + static_assert( + numScalarElements == 16 || numScalarElements == 32 || + numScalarElements == 64, + "Unsupported number of scalar elements"); + uint64_t res = 0UL; + if constexpr (numScalarElements == 64) { + res = simd::toBitMask( + xsimd::batch_bool(d_type::load_unaligned(resultData))); + } else if constexpr (numScalarElements == 32) { + auto* addr = reinterpret_cast(&res); + *(addr) = simd::toBitMask( + xsimd::batch_bool(d_type::load_unaligned(resultData))); + *(addr + 1) = simd::toBitMask( + xsimd::batch_bool(d_type::load_unaligned(resultData + 32))); + } else if constexpr (numScalarElements == 16) { + auto* addr = reinterpret_cast(&res); + *(addr) = simd::toBitMask( + xsimd::batch_bool(d_type::load_unaligned(resultData))); + *(addr + 1) = simd::toBitMask( + xsimd::batch_bool(d_type::load_unaligned(resultData + 16))); + *(addr + 2) = simd::toBitMask( + xsimd::batch_bool(d_type::load_unaligned(resultData + 32))); + *(addr + 3) = simd::toBitMask( + xsimd::batch_bool(d_type::load_unaligned(resultData + 48))); + } + return res; +} + +template +void applyAutoSimdComparisonInternal( + const SelectivityVector& rows, + const A* __restrict rawA, + const B* __restrict rawB, + Compare cmp, + VectorPtr& result) { + int8_t tempBuffer[64]; + int8_t* __restrict resultData = tempBuffer; + const vector_size_t rowsBegin = rows.begin(); + const vector_size_t rowsEnd = rows.end(); + auto* rowsData = reinterpret_cast(rows.allBits()); + auto* resultVector = result->asUnchecked>(); + auto* rawResult = resultVector->mutableRawValues(); + if (rows.isAllSelected()) { + auto i = 0; + for (; i + 64 <= rowsEnd; i += 64) { + for (auto j = 0; j < 64; ++j) { + resultData[j] = cmp(rawA, rawB, i + j) ? -1 : 0; + } + rawResult[i / 64] = to64Bits(resultData); + } + for (; i < rowsEnd; ++i) { + bits::setBit(rawResult, i, cmp(rawA, rawB, i)); + } + } else { + static constexpr uint64_t kAllSet = -1ULL; + bits::forEachWord( + rowsBegin, + rowsEnd, + [&](int32_t idx, uint64_t mask) { + auto word = rowsData[idx] & mask; + if (!word) { + return; + } + const size_t start = idx * 64; + while (word) { + auto index = start + __builtin_ctzll(word); + bits::setBit(rawResult, index, cmp(rawA, rawB, index)); + word &= word - 1; + } + }, + [&](int32_t idx) { + auto word = rowsData[idx]; + const size_t start = idx * 64; + if (kAllSet == word) { + // Do 64 comparisons in a batch, set results by SIMD. + for (size_t row = 0; row < 64; ++row) { + resultData[row] = cmp(rawA, rawB, row + start) ? -1 : 0; + } + rawResult[idx] = to64Bits(resultData); + } else { + while (word) { + auto index = __builtin_ctzll(word); + resultData[index] = cmp(rawA, rawB, start + index) ? -1 : 0; + word &= word - 1; + } + // Set results only for selected rows. + uint64_t mask = rowsData[idx]; + rawResult[idx] = + (rawResult[idx] & ~mask) | (to64Bits(resultData) & mask); + } + }); + } +} +} // namespace detail + +template < + typename T, + bool kIsLeftConstant, + bool kIsRightConstant, + typename ComparisonOp> +void applySimdComparison( + const vector_size_t begin, + const vector_size_t end, + const T* rawLhs, + const T* rawRhs, + uint8_t* rawResult) { + using d_type = xsimd::batch; + constexpr auto numScalarElements = d_type::size; + const auto vectorEnd = (end - begin) - (end - begin) % numScalarElements; + static_assert( + numScalarElements == 2 || numScalarElements == 4 || + numScalarElements == 8 || numScalarElements == 16 || + numScalarElements == 32, + "Unsupported number of scalar elements"); + if constexpr (numScalarElements == 2 || numScalarElements == 4) { + for (auto i = begin; i < vectorEnd; i += 8) { + rawResult[i / 8] = 0; + for (auto j = 0; j < 8 && (i + j) < vectorEnd; j += numScalarElements) { + auto left = detail::loadSimdData(rawLhs, i + j); + auto right = detail::loadSimdData(rawRhs, i + j); + + uint8_t res = simd::toBitMask(ComparisonOp()(left, right)); + rawResult[i / 8] |= res << j; + } + } + } else { + for (auto i = begin; i < vectorEnd; i += numScalarElements) { + auto left = detail::loadSimdData(rawLhs, i); + auto right = detail::loadSimdData(rawRhs, i); + + auto res = simd::toBitMask(ComparisonOp()(left, right)); + if constexpr (numScalarElements == 8) { + rawResult[i / 8] = res; + } else if constexpr (numScalarElements == 16) { + uint16_t* addr = reinterpret_cast(rawResult + i / 8); + *addr = res; + } else if constexpr (numScalarElements == 32) { + uint32_t* addr = reinterpret_cast(rawResult + i / 8); + *addr = res; + } + } + } + + // Evaluate remaining values. + for (auto i = vectorEnd; i < end; i++) { + if constexpr (kIsRightConstant && kIsLeftConstant) { + bits::setBit(rawResult, i, ComparisonOp()(rawLhs[0], rawRhs[0])); + } else if constexpr (kIsRightConstant) { + bits::setBit(rawResult, i, ComparisonOp()(rawLhs[i], rawRhs[0])); + } else if constexpr (kIsLeftConstant) { + bits::setBit(rawResult, i, ComparisonOp()(rawLhs[0], rawRhs[i])); + } else { + bits::setBit(rawResult, i, ComparisonOp()(rawLhs[i], rawRhs[i])); + } + } +} + +template +void applySimdComparison( + const SelectivityVector& rows, + std::vector& args, + VectorPtr& result) { + auto resultVector = result->asUnchecked>(); + auto rawResult = resultVector->mutableRawValues(); + if (args[0]->isConstantEncoding() && args[1]->isConstantEncoding()) { + auto l = args[0]->asUnchecked>()->valueAt(0); + auto r = args[1]->asUnchecked>()->valueAt(0); + applySimdComparison( + rows.begin(), rows.end(), &l, &r, rawResult); + } else if (args[0]->isConstantEncoding()) { + auto l = args[0]->asUnchecked>()->valueAt(0); + auto rawRhs = args[1]->asUnchecked>()->rawValues(); + applySimdComparison( + rows.begin(), rows.end(), &l, rawRhs, rawResult); + } else if (args[1]->isConstantEncoding()) { + auto rawLhs = args[0]->asUnchecked>()->rawValues(); + auto r = args[1]->asUnchecked>()->valueAt(0); + applySimdComparison( + rows.begin(), rows.end(), rawLhs, &r, rawResult); + } else { + auto rawLhs = args[0]->asUnchecked>()->rawValues(); + auto rawRhs = args[1]->asUnchecked>()->rawValues(); + applySimdComparison( + rows.begin(), rows.end(), rawLhs, rawRhs, rawResult); + } +} + +template +void applyAutoSimdComparison( + const SelectivityVector& rows, + std::vector& args, + VectorPtr& result, + Args... cmpArgs) { + const Compare cmp; + if (args[0]->isFlatEncoding() && args[1]->isFlatEncoding()) { + const A* __restrict rawA = + args[0]->asUnchecked>()->template rawValues(); + const B* __restrict rawB = + args[1]->asUnchecked>()->template rawValues(); + detail::applyAutoSimdComparisonInternal( + rows, + rawA, + rawB, + [&](const A* __restrict rawA, const B* __restrict rawB, int i) { + if constexpr (sizeof...(cmpArgs) > 0) { + return Compare::apply(rawA[i], rawB[i], cmpArgs...); + } else { + return cmp(rawA[i], rawB[i]); + } + }, + result); + } else if (args[0]->isConstantEncoding() && args[1]->isFlatEncoding()) { + const A constA = args[0]->asUnchecked>()->valueAt(0); + const A* __restrict rawA = &constA; + const B* __restrict rawB = + args[1]->asUnchecked>()->template rawValues(); + detail::applyAutoSimdComparisonInternal( + rows, + rawA, + rawB, + [&](const A* __restrict rawA, const B* __restrict rawB, int i) { + if constexpr (sizeof...(cmpArgs) > 0) { + return Compare::apply(rawA[0], rawB[i], cmpArgs...); + } else { + return cmp(rawA[0], rawB[i]); + } + }, + result); + } else if (args[0]->isFlatEncoding() && args[1]->isConstantEncoding()) { + const A* __restrict rawA = + args[0]->asUnchecked>()->template rawValues(); + const B constB = args[1]->asUnchecked>()->valueAt(0); + const B* __restrict rawB = &constB; + detail::applyAutoSimdComparisonInternal( + rows, + rawA, + rawB, + [&](const A* __restrict rawA, const B* __restrict rawB, int i) { + if constexpr (sizeof...(cmpArgs) > 0) { + return Compare::apply(rawA[i], rawB[0], cmpArgs...); + } else { + return cmp(rawA[i], rawB[0]); + } + }, + result); + } else if (args[0]->isConstantEncoding() && args[1]->isConstantEncoding()) { + const A constA = args[0]->asUnchecked>()->valueAt(0); + const A* __restrict rawA = &constA; + const B constB = args[1]->asUnchecked>()->valueAt(0); + const B* __restrict rawB = &constB; + detail::applyAutoSimdComparisonInternal( + rows, + rawA, + rawB, + [&](const A* __restrict rawA, const B* __restrict rawB, int i) { + if constexpr (sizeof...(cmpArgs) > 0) { + return Compare::apply(rawA[0], rawB[0], cmpArgs...); + } else { + return cmp(rawA[0], rawB[0]); + } + }, + result); + } else { + VELOX_UNREACHABLE(); + } +} +} // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/Comparisons.cpp b/velox/functions/prestosql/Comparisons.cpp index 05eab2b80f56..ef5e76d773a5 100644 --- a/velox/functions/prestosql/Comparisons.cpp +++ b/velox/functions/prestosql/Comparisons.cpp @@ -17,6 +17,7 @@ #include "velox/functions/prestosql/Comparisons.h" #include #include "velox/functions/Udf.h" +#include "velox/functions/lib/SIMDComparisonUtil.h" #include "velox/vector/BaseVector.h" namespace facebook::velox::functions { @@ -36,71 +37,6 @@ namespace { /// If the vector encoding is not flat, we revert to non simd approach. template struct SimdComparator { - template - inline auto loadSimdData(const T* rawData, vector_size_t offset) { - using d_type = xsimd::batch; - if constexpr (isConstant) { - return xsimd::broadcast(rawData[0]); - } - return d_type::load_unaligned(rawData + offset); - } - - template - void applySimdComparison( - const vector_size_t begin, - const vector_size_t end, - const T* rawLhs, - const T* rawRhs, - uint8_t* rawResult) { - using d_type = xsimd::batch; - constexpr auto numScalarElements = d_type::size; - const auto vectorEnd = (end - begin) - (end - begin) % numScalarElements; - - if constexpr (numScalarElements == 2 || numScalarElements == 4) { - for (auto i = begin; i < vectorEnd; i += 8) { - rawResult[i / 8] = 0; - for (auto j = 0; j < 8 && (i + j) < vectorEnd; j += numScalarElements) { - auto left = loadSimdData(rawLhs, i + j); - auto right = loadSimdData(rawRhs, i + j); - - uint8_t res = simd::toBitMask(ComparisonOp()(left, right)); - rawResult[i / 8] |= res << j; - } - } - } else { - for (auto i = begin; i < vectorEnd; i += numScalarElements) { - auto left = loadSimdData(rawLhs, i); - auto right = loadSimdData(rawRhs, i); - - auto res = simd::toBitMask(ComparisonOp()(left, right)); - if constexpr (numScalarElements == 8) { - rawResult[i / 8] = res; - } else if constexpr (numScalarElements == 16) { - uint16_t* addr = reinterpret_cast(rawResult + i / 8); - *addr = res; - } else if constexpr (numScalarElements == 32) { - uint32_t* addr = reinterpret_cast(rawResult + i / 8); - *addr = res; - } else { - VELOX_FAIL("Unsupported number of scalar elements"); - } - } - } - - // Evaluate remaining values. - for (auto i = vectorEnd; i < end; i++) { - if constexpr (isRightConstant && isLeftConstant) { - bits::setBit(rawResult, i, ComparisonOp()(rawLhs[0], rawRhs[0])); - } else if constexpr (isRightConstant) { - bits::setBit(rawResult, i, ComparisonOp()(rawLhs[i], rawRhs[0])); - } else if constexpr (isLeftConstant) { - bits::setBit(rawResult, i, ComparisonOp()(rawLhs[0], rawRhs[i])); - } else { - bits::setBit(rawResult, i, ComparisonOp()(rawLhs[i], rawRhs[i])); - } - } - } - template inline bool compare(T& l, T& r) const { if constexpr (std::is_floating_point_v) { @@ -167,22 +103,22 @@ struct SimdComparator { if (lhs.isConstantEncoding() && rhs.isConstantEncoding()) { auto l = lhs.asUnchecked>()->valueAt(0); auto r = rhs.asUnchecked>()->valueAt(0); - applySimdComparison( + applySimdComparison( rows.begin(), rows.end(), &l, &r, rawResult); } else if (lhs.isConstantEncoding()) { auto l = lhs.asUnchecked>()->valueAt(0); auto rawRhs = rhs.asUnchecked>()->rawValues(); - applySimdComparison( + applySimdComparison( rows.begin(), rows.end(), &l, rawRhs, rawResult); } else if (rhs.isConstantEncoding()) { auto rawLhs = lhs.asUnchecked>()->rawValues(); auto r = rhs.asUnchecked>()->valueAt(0); - applySimdComparison( + applySimdComparison( rows.begin(), rows.end(), rawLhs, &r, rawResult); } else { auto rawLhs = lhs.asUnchecked>()->rawValues(); auto rawRhs = rhs.asUnchecked>()->rawValues(); - applySimdComparison( + applySimdComparison( rows.begin(), rows.end(), rawLhs, rawRhs, rawResult); } diff --git a/velox/functions/sparksql/Comparisons.cpp b/velox/functions/sparksql/Comparisons.cpp index 5676503fe9f5..f52e603da816 100644 --- a/velox/functions/sparksql/Comparisons.cpp +++ b/velox/functions/sparksql/Comparisons.cpp @@ -17,6 +17,7 @@ #include "velox/expression/EvalCtx.h" #include "velox/expression/Expr.h" +#include "velox/functions/lib/SIMDComparisonUtil.h" #include "velox/functions/sparksql/Comparisons.h" #include "velox/type/Type.h" @@ -37,28 +38,45 @@ class ComparisonFunction final : public exec::VectorFunction { const TypePtr& outputType, exec::EvalCtx& context, VectorPtr& result) const override { + const Cmp cmp; context.ensureWritable(rows, BOOLEAN(), result); + result->clearNulls(rows); + if ((args[0]->isFlatEncoding() || args[0]->isConstantEncoding()) && + (args[1]->isFlatEncoding() || args[1]->isConstantEncoding())) { + if constexpr ( + kind == TypeKind::TINYINT || kind == TypeKind::SMALLINT || + kind == TypeKind::INTEGER || kind == TypeKind::BIGINT) { + if (rows.isAllSelected()) { + applySimdComparison(rows, args, result); + return; + } + } + if (rows.end() - rows.begin() > 64) { + applyAutoSimdComparison(rows, args, result); + return; + } + } + auto* flatResult = result->asUnchecked>(); - const Cmp cmp; if (args[0]->isFlatEncoding() && args[1]->isFlatEncoding()) { // Fast path for (flat, flat). - auto rawA = args[0]->asUnchecked>()->mutableRawValues(); - auto rawB = args[1]->asUnchecked>()->mutableRawValues(); + const auto* rawA = args[0]->asUnchecked>()->rawValues(); + const auto* rawB = args[1]->asUnchecked>()->rawValues(); rows.applyToSelected( [&](vector_size_t i) { flatResult->set(i, cmp(rawA[i], rawB[i])); }); } else if (args[0]->isConstantEncoding() && args[1]->isFlatEncoding()) { // Fast path for (const, flat). auto constant = args[0]->asUnchecked>()->valueAt(0); - auto rawValues = - args[1]->asUnchecked>()->mutableRawValues(); + const auto* rawValues = + args[1]->asUnchecked>()->rawValues(); rows.applyToSelected([&](vector_size_t i) { flatResult->set(i, cmp(constant, rawValues[i])); }); } else if (args[0]->isFlatEncoding() && args[1]->isConstantEncoding()) { // Fast path for (flat, const). - auto rawValues = - args[0]->asUnchecked>()->mutableRawValues(); + const auto* rawValues = + args[0]->asUnchecked>()->rawValues(); auto constant = args[1]->asUnchecked>()->valueAt(0); rows.applyToSelected([&](vector_size_t i) { flatResult->set(i, cmp(rawValues[i], constant)); @@ -136,7 +154,7 @@ class BoolComparisonFunction final : public exec::VectorFunction { } }; -template