From 5f03ec729183f09d33bc85a4a7c86b8524ca3598 Mon Sep 17 00:00:00 2001 From: zhli1142015 Date: Mon, 4 Mar 2024 14:21:47 +0800 Subject: [PATCH] Add Spark configuration 'spark.partition_id' and refactor rand(seed) function --- velox/core/QueryConfig.h | 11 ++++ velox/docs/configs.rst | 3 + velox/functions/sparksql/Rand.h | 55 +++++-------------- .../functions/sparksql/RegisterArithmetic.cpp | 22 ++------ velox/functions/sparksql/tests/RandTest.cpp | 21 ++++--- 5 files changed, 45 insertions(+), 67 deletions(-) diff --git a/velox/core/QueryConfig.h b/velox/core/QueryConfig.h index 39bf52cfdab92..55cc9b5039099 100644 --- a/velox/core/QueryConfig.h +++ b/velox/core/QueryConfig.h @@ -306,6 +306,9 @@ class QueryConfig { static constexpr const char* kSparkBloomFilterMaxNumBits = "spark.bloom_filter.max_num_bits"; + /// The current spark partition id. + static constexpr const char* kSparkPartitionId = "spark.partition_id"; + /// The number of local parallel table writer operators per task. static constexpr const char* kTaskWriterCount = "task_writer_count"; @@ -674,6 +677,14 @@ class QueryConfig { return value; } + int32_t sparkPartitionId() const { + auto id = get(kSparkPartitionId); + VELOX_CHECK(id.has_value(), "Spark partition id is not set."); + auto value = id.value(); + VELOX_CHECK_GE(value, 0, "Invalid Spark partition id."); + return value; + } + bool exprTrackCpuUsage() const { return get(kExprTrackCpuUsage, false); } diff --git a/velox/docs/configs.rst b/velox/docs/configs.rst index 8c7b07fad4b53..400f8e150c159 100644 --- a/velox/docs/configs.rst +++ b/velox/docs/configs.rst @@ -620,3 +620,6 @@ Spark-specific Configuration - 4194304 - The maximum number of bits to use for the bloom filter in :spark:func:`bloom_filter_agg` function, the value of this config can not exceed the default value. + * - spark.partition_id + - integer + - The current task's Spark partition ID. It's set by the query engine (Spark) prior to task execution. diff --git a/velox/functions/sparksql/Rand.h b/velox/functions/sparksql/Rand.h index 50cd2abddd1d3..4fd5030a83973 100644 --- a/velox/functions/sparksql/Rand.h +++ b/velox/functions/sparksql/Rand.h @@ -23,54 +23,25 @@ template struct RandFunction { static constexpr bool is_deterministic = false; - FOLLY_ALWAYS_INLINE void call(double& result) { - result = folly::Random::randDouble01(); + template + void initialize(const core::QueryConfig& config, const TInput* seedInput) { + auto partitionId = config.sparkPartitionId(); + generator_ = std::mt19937{}; + int64_t seed = seedInput ? (int64_t)*seedInput : 0; + generator_.seed(seed + partitionId); } - FOLLY_ALWAYS_INLINE void callNullable( - double& result, - const int32_t* seed, - const int32_t* partitionIndex) { - initializeGenerator(seed, partitionIndex); - result = folly::Random::randDouble01(*generator_); - } - - // To differentiate generator for each thread, seed plus partitionIndex is - // the actual seed used for generator. - FOLLY_ALWAYS_INLINE void callNullable( - double& result, - const int64_t* seed, - const int32_t* partitionIndex) { - initializeGenerator(seed, partitionIndex); - result = folly::Random::randDouble01(*generator_); + FOLLY_ALWAYS_INLINE void call(double& result) { + result = folly::Random::randDouble01(); } - // For NULL constant input of unknown type. - FOLLY_ALWAYS_INLINE void callNullable( - double& result, - const UnknownValue* /*seed*/, - const int32_t* partitionIndex) { - initializeGenerator(nullptr, partitionIndex); - result = folly::Random::randDouble01(*generator_); + template + FOLLY_ALWAYS_INLINE void callNullable(double& result, TInput /*seedInput*/) { + result = folly::Random::randDouble01(generator_); } private: - template - FOLLY_ALWAYS_INLINE void initializeGenerator( - const TSeed* seed, - const int32_t* partitionIndex) { - VELOX_USER_CHECK_NOT_NULL(partitionIndex, "partitionIndex cannot be null."); - if (!generator_.has_value()) { - generator_ = std::mt19937{}; - if (seed != nullptr) { - generator_->seed((int64_t)*seed + *partitionIndex); - } else { - // For null seed, partitionIndex is the seed, consistent with Spark. - generator_->seed(*partitionIndex); - } - } - } - - std::optional generator_; + std::mt19937 generator_; }; + } // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/RegisterArithmetic.cpp b/velox/functions/sparksql/RegisterArithmetic.cpp index 5aa85c778878d..6656103cc14bd 100644 --- a/velox/functions/sparksql/RegisterArithmetic.cpp +++ b/velox/functions/sparksql/RegisterArithmetic.cpp @@ -24,24 +24,10 @@ namespace facebook::velox::functions::sparksql { void registerRandFunctions(const std::string& prefix) { registerFunction({prefix + "rand", prefix + "random"}); - // Has seed & partition index as input. - registerFunction< - RandFunction, - double, - int32_t /*seed*/, - int32_t /*partition index*/>({prefix + "rand", prefix + "random"}); - // Has seed & partition index as input. - registerFunction< - RandFunction, - double, - int64_t /*seed*/, - int32_t /*partition index*/>({prefix + "rand", prefix + "random"}); - // NULL constant as seed of unknown type. - registerFunction< - RandFunction, - double, - UnknownValue /*seed*/, - int32_t /*partition index*/>({prefix + "rand", prefix + "random"}); + registerFunction>( + {prefix + "rand", prefix + "random"}); + registerFunction>( + {prefix + "rand", prefix + "random"}); } void registerArithmeticFunctions(const std::string& prefix) { diff --git a/velox/functions/sparksql/tests/RandTest.cpp b/velox/functions/sparksql/tests/RandTest.cpp index cbb49476c8ff9..b8a8bc75551b8 100644 --- a/velox/functions/sparksql/tests/RandTest.cpp +++ b/velox/functions/sparksql/tests/RandTest.cpp @@ -26,25 +26,31 @@ class RandTest : public SparkFunctionBaseTest { } protected: + void setSparkPartitionId(int32_t partitionId) { + queryCtx_->testingOverrideConfigUnsafe( + {{core::QueryConfig::kSparkPartitionId, std::to_string(partitionId)}}); + } + std::optional rand(int32_t seed, int32_t partitionIndex = 0) { + setSparkPartitionId(partitionIndex); return evaluateOnce( - fmt::format("rand({}, {})", seed, partitionIndex), - makeRowVector(ROW({}), 1)); + fmt::format("rand({})", seed), makeRowVector(ROW({}), 1)); } std::optional randWithNullSeed(int32_t partitionIndex = 0) { - return evaluateOnce( - fmt::format("rand(NULL, {})", partitionIndex), - makeRowVector(ROW({}), 1)); + setSparkPartitionId(partitionIndex); + std::optional seed = std::nullopt; + return evaluateOnce("rand(c0)", seed); } std::optional randWithNoSeed() { + setSparkPartitionId(0); return evaluateOnce("rand()", makeRowVector(ROW({}), 1)); } VectorPtr randWithBatchInput(int32_t seed, int32_t partitionIndex = 0) { - auto exprSet = compileExpression( - fmt::format("rand({}, {})", seed, partitionIndex), ROW({})); + setSparkPartitionId(partitionIndex); + auto exprSet = compileExpression(fmt::format("rand({})", seed), ROW({})); return evaluate(*exprSet, makeRowVector(ROW({}), 20)); } @@ -92,6 +98,7 @@ TEST_F(RandTest, withSeed) { // Test with batch input. auto batchResult1 = randWithBatchInput(100); + ASSERT_FALSE(batchResult1->isConstantEncoding()); auto batchResult2 = randWithBatchInput(100); // Same seed & partition index produce same results. velox::test::assertEqualVectors(batchResult1, batchResult2);