diff --git a/CMakeLists.txt b/CMakeLists.txt index 18c3f51eede42..8f73faaa987f1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -108,7 +108,7 @@ option(VELOX_ENABLE_ABFS "Build Abfs Connector" OFF) option(VELOX_ENABLE_HDFS "Build Hdfs Connector" OFF) option(VELOX_ENABLE_PARQUET "Enable Parquet support" OFF) option(VELOX_ENABLE_ARROW "Enable Arrow support" OFF) -option(VELOX_ENABLE_REMOTE_FUNCTIONS "Enable remote function support" OFF) +option(VELOX_ENABLE_REMOTE_FUNCTIONS "Enable remote function support" ON) option(VELOX_ENABLE_CCACHE "Use ccache if installed." ON) option(VELOX_BUILD_TEST_UTILS "Builds Velox test utilities" OFF) diff --git a/velox/common/config/CMakeLists.txt b/velox/common/config/CMakeLists.txt index 7780665a29251..9639a2c8b6f76 100644 --- a/velox/common/config/CMakeLists.txt +++ b/velox/common/config/CMakeLists.txt @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -if (${VELOX_BUILD_TESTING}) +if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) -endif () +endif() velox_add_library(velox_common_config Config.cpp) velox_link_libraries( velox_common_config - PUBLIC velox_common_base - velox_exception + PUBLIC velox_common_base velox_exception PRIVATE re2::re2) diff --git a/velox/functions/remote/CMakeLists.txt b/velox/functions/remote/CMakeLists.txt index ccc8a2c5ec483..a38c65894a2ff 100644 --- a/velox/functions/remote/CMakeLists.txt +++ b/velox/functions/remote/CMakeLists.txt @@ -12,6 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +if(NOT DEFINED PROXYGEN_LIBRARIES) + find_package(Sodium REQUIRED) + find_library(PROXYGEN proxygen) + find_library(PROXYGEN_HTTP_SERVER proxygenhttpserver) + find_library(FIZZ fizz) + find_library(WANGLE wangle) + set(PROXYGEN_LIBRARIES ${PROXYGEN_HTTP_SERVER} ${PROXYGEN} ${WANGLE} ${FIZZ}) +endif() + add_subdirectory(if) add_subdirectory(client) add_subdirectory(server) diff --git a/velox/functions/remote/client/CMakeLists.txt b/velox/functions/remote/client/CMakeLists.txt index 56663a29d04b8..4fe8172d81d04 100644 --- a/velox/functions/remote/client/CMakeLists.txt +++ b/velox/functions/remote/client/CMakeLists.txt @@ -16,11 +16,16 @@ velox_add_library(velox_functions_remote_thrift_client ThriftClient.cpp) velox_link_libraries(velox_functions_remote_thrift_client PUBLIC remote_function_thrift FBThrift::thriftcpp2) +velox_add_library(velox_functions_remote_rest_client RestClient.cpp) +velox_link_libraries(velox_functions_remote_rest_client ${PROXYGEN_LIBRARIES} + Folly::folly) + velox_add_library(velox_functions_remote Remote.cpp) velox_link_libraries( velox_functions_remote PUBLIC velox_expression velox_functions_remote_thrift_client + velox_functions_remote_rest_client velox_functions_remote_get_serde velox_type_fbhive Folly::folly) diff --git a/velox/functions/remote/client/Remote.cpp b/velox/functions/remote/client/Remote.cpp index 1f88745aa73ea..7614a9dec662b 100644 --- a/velox/functions/remote/client/Remote.cpp +++ b/velox/functions/remote/client/Remote.cpp @@ -19,6 +19,7 @@ #include #include "velox/expression/Expr.h" #include "velox/expression/VectorFunction.h" +#include "velox/functions/remote/client/RestClient.h" #include "velox/functions/remote/client/ThriftClient.h" #include "velox/functions/remote/if/GetSerde.h" #include "velox/functions/remote/if/gen-cpp2/RemoteFunctionServiceAsyncClient.h" @@ -33,6 +34,17 @@ std::string serializeType(const TypePtr& type) { return type::fbhive::HiveTypeSerializer::serialize(type); } +std::string iobufToString(const folly::IOBuf& buf) { + std::string result; + result.reserve(buf.computeChainDataLength()); + + for (auto range : buf) { + result.append(reinterpret_cast(range.data()), range.size()); + } + + return result; +} + class RemoteFunction : public exec::VectorFunction { public: RemoteFunction( @@ -40,10 +52,16 @@ class RemoteFunction : public exec::VectorFunction { const std::vector& inputArgs, const RemoteVectorFunctionMetadata& metadata) : functionName_(functionName), - location_(metadata.location), - thriftClient_(getThriftClient(location_, &eventBase_)), serdeFormat_(metadata.serdeFormat), serde_(getSerde(serdeFormat_)) { + if (metadata.location.type() == typeid(SocketAddress)) { + location_ = boost::get(metadata.location); + thriftClient_ = getThriftClient(location_, &eventBase_); + } else if (metadata.location.type() == typeid(URL)) { + url_ = boost::get(metadata.location); + restClient_ = std::make_unique(url_.getUrl()); + } + std::vector types; types.reserve(inputArgs.size()); serializedInputTypes_.reserve(inputArgs.size()); @@ -62,7 +80,11 @@ class RemoteFunction : public exec::VectorFunction { exec::EvalCtx& context, VectorPtr& result) const override { try { - applyRemote(rows, args, outputType, context, result); + if (thriftClient_) { + applyRemote(rows, args, outputType, context, result); + } else if (restClient_) { + applyRestRemote(rows, args, outputType, context, result); + } } catch (const VeloxRuntimeError&) { throw; } catch (const std::exception&) { @@ -71,6 +93,69 @@ class RemoteFunction : public exec::VectorFunction { } private: + void applyRestRemote( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& outputType, + exec::EvalCtx& context, + VectorPtr& result) const { + try { + std::string responseBody; + auto remoteRowVector = std::make_shared( + context.pool(), + remoteInputType_, + BufferPtr{}, + rows.end(), + std::move(args)); + + /// construct json request + folly::dynamic remoteFunctionHandle = folly::dynamic::object; + remoteFunctionHandle["functionName"] = functionName_; + remoteFunctionHandle["returnType"] = serializeType(outputType); + remoteFunctionHandle["argumentTypes"] = folly::dynamic::array; + for (const auto& value : serializedInputTypes_) { + remoteFunctionHandle["argumentTypes"].push_back(value); + } + + folly::dynamic inputs = folly::dynamic::object; + inputs["pageFormat"] = static_cast(serdeFormat_); + // use existing serializer(Prestopage or Sparkunsaferow) + inputs["payload"] = iobufToString(rowVectorToIOBuf( + remoteRowVector, rows.end(), *context.pool(), serde_.get())); + inputs["rowCount"] = remoteRowVector->size(); + + folly::dynamic jsonObject = folly::dynamic::object; + jsonObject["remoteFunctionHandle"] = remoteFunctionHandle; + jsonObject["inputs"] = inputs; + jsonObject["throwOnError"] = context.throwOnError(); + + // call Rest client to send request + restClient_->invoke_function(folly::toJson(jsonObject), responseBody); + LOG(INFO) << responseBody; + + // parse json response + auto responseJsonObj = parseJson(responseBody); + if (responseJsonObj.count("err") > 0) { + VELOX_NYI(responseJsonObj["err"].asString()); + } + + auto payloadIObuf = folly::IOBuf::copyBuffer( + responseJsonObj["result"]["payload"].asString()); + + // use existing deserializer(Prestopage or Sparkunsaferow) + auto outputRowVector = IOBufToRowVector( + *payloadIObuf, ROW({outputType}), *context.pool(), serde_.get()); + result = outputRowVector->childAt(0); + + } catch (const std::exception& e) { + VELOX_FAIL( + "Error while executing remote function '{}' at '{}': {}", + functionName_, + url_.getUrl(), + e.what()); + } + } + void applyRemote( const SelectivityVector& rows, std::vector& args, @@ -122,10 +207,14 @@ class RemoteFunction : public exec::VectorFunction { } const std::string functionName_; - folly::SocketAddress location_; folly::EventBase eventBase_; std::unique_ptr thriftClient_; + folly::SocketAddress location_; + + std::unique_ptr restClient_; + proxygen::URL url_; + remote::PageFormat serdeFormat_; std::unique_ptr serde_; diff --git a/velox/functions/remote/client/Remote.h b/velox/functions/remote/client/Remote.h index a6a1e773dc812..fd90009da457f 100644 --- a/velox/functions/remote/client/Remote.h +++ b/velox/functions/remote/client/Remote.h @@ -16,17 +16,20 @@ #pragma once +#include #include +#include #include "velox/expression/VectorFunction.h" #include "velox/functions/remote/if/gen-cpp2/RemoteFunction_types.h" namespace facebook::velox::functions { struct RemoteVectorFunctionMetadata : public exec::VectorFunctionMetadata { - /// Network address of the servr to communicate with. Note that this can hold - /// a network location (ip/port pair) or a unix domain socket path (see + /// URL of the HTTP/REST server for remote function. + /// Or Network address of the servr to communicate with. Note that this can + /// hold a network location (ip/port pair) or a unix domain socket path (see /// SocketAddress::makeFromPath()). - folly::SocketAddress location; + boost::variant location; /// The serialization format to be used remote::PageFormat serdeFormat{remote::PageFormat::PRESTO_PAGE}; diff --git a/velox/functions/remote/client/RestClient.cpp b/velox/functions/remote/client/RestClient.cpp new file mode 100644 index 0000000000000..f21bb53485498 --- /dev/null +++ b/velox/functions/remote/client/RestClient.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/remote/client/RestClient.h" +#include + +using namespace facebook::velox::functions; + +namespace facebook::velox::functions { + +RestClient::RestClient(const std::string& url) : url_(url) { + httpClient_ = std::make_shared(url_); +}; + +void RestClient::invoke_function( + const std::string& requestBody, + std::string& responseBody) { + httpClient_->send(requestBody); + responseBody = httpClient_->getResponseBody(); + LOG(INFO) << responseBody; +}; + +} // namespace facebook::velox::functions diff --git a/velox/functions/remote/client/RestClient.h b/velox/functions/remote/client/RestClient.h new file mode 100644 index 0000000000000..ee5ea46ad3237 --- /dev/null +++ b/velox/functions/remote/client/RestClient.h @@ -0,0 +1,129 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "velox/functions/remote/client/RestClient.h" + +using namespace proxygen; +using namespace folly; + +namespace facebook::velox::functions { + +class HttpClient : public HTTPConnector::Callback, + public HTTPTransactionHandler { + public: + HttpClient(const URL& url) : url_(url) {} + + void send(std::string requestBody) { + requestBody_ = requestBody; + connector_ = std::make_unique( + this, WheelTimerInstance(std::chrono::milliseconds(1000))); + connector_->connect( + &evb_, + SocketAddress(url_.getHost(), url_.getPort(), true), + std::chrono::milliseconds(10000)); + evb_.loop(); + } + + std::string getResponseBody() { + return std::move(responseBody_); + } + + private: + URL url_; + EventBase evb_; + std::unique_ptr connector_; + std::shared_ptr session_; + std::string requestBody_; + std::string responseBody_; + + void connectSuccess(HTTPUpstreamSession* session) noexcept override { + session_ = std::shared_ptr( + session, [](HTTPUpstreamSession* s) { + // No-op deleter, managed by Proxygen + }); + sendRequest(); + } + + void connectError(const folly::AsyncSocketException& ex) noexcept override { + LOG(ERROR) << "Failed to connect: " << ex.what(); + evb_.terminateLoopSoon(); + } + + void sendRequest() { + auto txn = session_->newTransaction(this); + HTTPMessage req; + req.setMethod(HTTPMethod::POST); + req.setURL(url_.getUrl()); + req.getHeaders().add(HTTP_HEADER_CONTENT_TYPE, "application/json"); + req.getHeaders().add( + HTTP_HEADER_CONTENT_LENGTH, std::to_string(requestBody_.size())); + req.getHeaders().add(HTTP_HEADER_USER_AGENT, "Velox HTTPClient"); + + txn->sendHeaders(req); + txn->sendBody(folly::IOBuf::copyBuffer(requestBody_)); + txn->sendEOM(); + } + + void setTransaction(HTTPTransaction*) noexcept override {} + void detachTransaction() noexcept override { + session_.reset(); + evb_.terminateLoopSoon(); + } + + void onHeadersComplete(std::unique_ptr msg) noexcept override {} + + void onBody(std::unique_ptr chain) noexcept override { + if (chain) { + responseBody_.append( + reinterpret_cast(chain->data()), chain->length()); + } + } + + void onEOM() noexcept override { + session_->drain(); + } + + void onError(const HTTPException& error) noexcept override { + LOG(ERROR) << "Error: " << error.what(); + } + void onUpgrade(UpgradeProtocol) noexcept override {} + void onTrailers(std::unique_ptr) noexcept override {} + void onEgressPaused() noexcept override {} + void onEgressResumed() noexcept override {} +}; + +class RestClient { + public: + RestClient(const std::string& url); + void invoke_function(const std::string& request, std::string& response); + + private: + URL url_; + std::shared_ptr httpClient_; +}; + +} // namespace facebook::velox::functions diff --git a/velox/functions/remote/client/tests/CMakeLists.txt b/velox/functions/remote/client/tests/CMakeLists.txt index 1659ad9d7e5a3..38d0b25dbbd73 100644 --- a/velox/functions/remote/client/tests/CMakeLists.txt +++ b/velox/functions/remote/client/tests/CMakeLists.txt @@ -27,3 +27,20 @@ target_link_libraries( GTest::gmock GTest::gtest GTest::gtest_main) + +add_executable(velox_functions_remote_client_rest_test + RemoteFunctionRestTest.cpp) + +add_test(velox_functions_remote_client_rest_test + velox_functions_remote_client_rest_test) + +target_link_libraries( + velox_functions_remote_client_rest_test + velox_functions_remote_server_rest + velox_functions_remote + velox_function_registry + velox_functions_test_lib + velox_exec_test_lib + GTest::gmock + GTest::gtest + GTest::gtest_main) diff --git a/velox/functions/remote/client/tests/RemoteFunctionRestTest.cpp b/velox/functions/remote/client/tests/RemoteFunctionRestTest.cpp new file mode 100644 index 0000000000000..3ca2ecca7fce3 --- /dev/null +++ b/velox/functions/remote/client/tests/RemoteFunctionRestTest.cpp @@ -0,0 +1,204 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include "velox/common/base/Exceptions.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/Registerer.h" +#include "velox/functions/lib/CheckedArithmetic.h" +#include "velox/functions/prestosql/Arithmetic.h" +#include "velox/functions/prestosql/StringFunctions.h" +#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" +#include "velox/functions/remote/client/Remote.h" +// #include "velox/functions/remote/if/gen-cpp2/RemoteFunctionService.h" +#include "velox/functions/remote/server/RemoteFunctionRestService.h" + +using ::facebook::velox::test::assertEqualVectors; + +namespace facebook::velox::functions { +namespace { + +// Parametrize in the serialization format so we can test both presto page and +// unsafe row. +class RemoteFunctionRestTest + : public functions::test::FunctionBaseTest, + public ::testing::WithParamInterface { + public: + void SetUp() override { + initializeServer(); + registerRemoteFunctions(); + } + + // Registers a few remote functions to be used in this test. + void registerRemoteFunctions() { + RemoteVectorFunctionMetadata metadata; + metadata.serdeFormat = GetParam(); + metadata.location = location_; + + // Register the remote adapter. + auto plusSignatures = {exec::FunctionSignatureBuilder() + .returnType("bigint") + .argumentType("bigint") + .argumentType("bigint") + .build()}; + registerRemoteFunction("remote_plus", plusSignatures, metadata); + + RemoteVectorFunctionMetadata wrongMetadata = metadata; + wrongMetadata.location = folly::SocketAddress(); // empty address. + registerRemoteFunction("remote_wrong_port", plusSignatures, wrongMetadata); + + auto divSignatures = {exec::FunctionSignatureBuilder() + .returnType("double") + .argumentType("double") + .argumentType("double") + .build()}; + registerRemoteFunction("remote_divide", divSignatures, metadata); + + auto substrSignatures = {exec::FunctionSignatureBuilder() + .returnType("varchar") + .argumentType("varchar") + .argumentType("integer") + .build()}; + registerRemoteFunction("remote_substr", substrSignatures, metadata); + + // Registers the actual function under a different prefix. This is only + // needed for tests since the http service runs in the same process. + registerFunction( + {remotePrefix_ + ".remote_plus"}); + registerFunction( + {remotePrefix_ + ".remote_divide"}); + registerFunction( + {remotePrefix_ + ".remote_substr"}); + } + + void initializeServer() { + HTTPServerOptions options; + // options.threads = static_cast(sysconf(_SC_NPROCESSORS_ONLN)); + options.idleTimeout = std::chrono::milliseconds(6000); + options.handlerFactories = + RequestHandlerChain() + .addThen(remotePrefix_) + .build(); + options.h2cEnabled = true; + + std::vector IPs = { + {folly::SocketAddress(location_.getHost(), location_.getPort(), true), + HTTPServer::Protocol::HTTP}}; + + server_ = std::make_shared(std::move(options)); + server_->bind(IPs); + + thread_ = std::make_unique([&] { server_->start(); }); + + VELOX_CHECK(waitForRunning(), "Unable to initialize HTTP server."); + LOG(INFO) << "HTTP server is up and running in local port " + << location_.getUrl(); + } + + ~RemoteFunctionRestTest() { + server_->stop(); + thread_->join(); + LOG(INFO) << "HTTP server stopped."; + } + + private: + // Loop until the server is up and running. + bool waitForRunning() { + for (size_t i = 0; i < 100; ++i) { + using boost::asio::ip::tcp; + boost::asio::io_context io_context; + + tcp::socket socket(io_context); + tcp::resolver resolver(io_context); + + try { + boost::asio::connect( + socket, + resolver.resolve( + location_.getHost(), std::to_string(location_.getPort()))); + return true; + } catch (std::exception& e) { + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + } + } + return false; + } + + std::shared_ptr server_; + std::unique_ptr thread_; + + URL location_{URL("http://127.0.0.1:83211/")}; + const std::string remotePrefix_{"remote"}; +}; + +TEST_P(RemoteFunctionRestTest, simple) { + auto inputVector = makeFlatVector({1, 2, 3, 4, 5}); + auto results = evaluate>( + "remote_plus(c0, c0)", makeRowVector({inputVector})); + + auto expected = makeFlatVector({2, 4, 6, 8, 10}); + assertEqualVectors(expected, results); +} + +TEST_P(RemoteFunctionRestTest, string) { + auto inputVector = + makeFlatVector({"hello", "my", "remote", "world"}); + auto inputVector1 = makeFlatVector({2, 1, 3, 5}); + auto results = evaluate>( + "remote_substr(c0, c1)", makeRowVector({inputVector, inputVector1})); + + auto expected = makeFlatVector({"ello", "my", "mote", "d"}); + assertEqualVectors(expected, results); +} + +TEST_P(RemoteFunctionRestTest, connectionError) { + auto inputVector = makeFlatVector({1, 2, 3, 4, 5}); + auto func = [&]() { + evaluate>( + "remote_wrong_port(c0, c0)", makeRowVector({inputVector})); + }; + + // Check it throw and that the exception has the "connection refused" + // substring. + EXPECT_THROW(func(), VeloxRuntimeError); + try { + func(); + } catch (const VeloxRuntimeError& e) { + EXPECT_THAT(e.message(), testing::HasSubstr("Channel is !good()")); + } +} + +VELOX_INSTANTIATE_TEST_SUITE_P( + RemoteFunctionRestTestFixture, + RemoteFunctionRestTest, + ::testing::Values( + remote::PageFormat::PRESTO_PAGE, + remote::PageFormat::SPARK_UNSAFE_ROW)); + +} // namespace +} // namespace facebook::velox::functions + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + folly::Init init{&argc, &argv, false}; + return RUN_ALL_TESTS(); +} diff --git a/velox/functions/remote/server/CMakeLists.txt b/velox/functions/remote/server/CMakeLists.txt index ff2afa0fed6a8..1772a1ddfcc2a 100644 --- a/velox/functions/remote/server/CMakeLists.txt +++ b/velox/functions/remote/server/CMakeLists.txt @@ -24,3 +24,19 @@ add_executable(velox_functions_remote_server_main RemoteFunctionServiceMain.cpp) target_link_libraries( velox_functions_remote_server_main velox_functions_remote_server velox_functions_prestosql) + +add_library(velox_functions_remote_server_rest RemoteFunctionRestService.cpp) +target_link_libraries( + velox_functions_remote_server_rest + ${PROXYGEN_LIBRARIES} + velox_functions_remote_get_serde + velox_type_fbhive + velox_memory + velox_functions_prestosql) + +add_executable(velox_functions_remote_server_rest_main + RemoteFunctionServiceRestMain.cpp) + +target_link_libraries( + velox_functions_remote_server_rest_main velox_functions_remote_server_rest + velox_functions_prestosql) diff --git a/velox/functions/remote/server/RemoteFunctionRestService.cpp b/velox/functions/remote/server/RemoteFunctionRestService.cpp new file mode 100644 index 0000000000000..8abe97c1945ec --- /dev/null +++ b/velox/functions/remote/server/RemoteFunctionRestService.cpp @@ -0,0 +1,224 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/remote/server/RemoteFunctionRestService.h" +#include +#include +#include "velox/expression/Expr.h" +#include "velox/functions/remote/if/GetSerde.h" +#include "velox/type/fbhive/HiveTypeParser.h" +#include "velox/vector/VectorStream.h" + +namespace facebook::velox::functions { +namespace { +std::string iobufToString(const folly::IOBuf& buf) { + std::string result; + result.reserve(buf.computeChainDataLength()); + + for (auto range : buf) { + result.append(reinterpret_cast(range.data()), range.size()); + } + + return result; +} + +TypePtr deserializeType(const std::string& input) { + // Use hive type parser/serializer. + return type::fbhive::HiveTypeParser().parse(input); +} + +RowTypePtr deserializeArgTypes(const std::vector& argTypes) { + const size_t argCount = argTypes.size(); + + std::vector argumentTypes; + std::vector typeNames; + argumentTypes.reserve(argCount); + typeNames.reserve(argCount); + + for (size_t i = 0; i < argCount; ++i) { + argumentTypes.emplace_back(deserializeType(argTypes[i])); + typeNames.emplace_back(fmt::format("c{}", i)); + } + return ROW(std::move(typeNames), std::move(argumentTypes)); +} + +std::string getFunctionName( + const std::string& prefix, + const std::string& functionName) { + return prefix.empty() ? functionName + : fmt::format("{}.{}", prefix, functionName); +} +} // namespace + +std::vector getExpressions( + const RowTypePtr& inputType, + const TypePtr& returnType, + const std::string& functionName) { + std::vector inputs; + for (size_t i = 0; i < inputType->size(); ++i) { + inputs.push_back(std::make_shared( + inputType->childAt(i), inputType->nameOf(i))); + } + + return {std::make_shared( + returnType, std::move(inputs), functionName)}; +} + +// RestRequestHandler +void RestRequestHandler::onRequest( + std::unique_ptr headers) noexcept {} + +void RestRequestHandler::onEOM() noexcept { + try { + auto jsonObj = folly::parseJson(body_); + + auto payload = jsonObj["inputs"]["payload"]; + auto rowCount = jsonObj["inputs"]["rowCount"]; + auto remoteFunctionHandle = jsonObj["remoteFunctionHandle"]; + + LOG(INFO) << "Got a request for '" << remoteFunctionHandle["functionName"] + << "': " << rowCount << " input rows."; + + if (!jsonObj["throwOnError"].asBool()) { + VELOX_NYI("throwOnError not implemented yet on remote server."); + } + + // A remote function service should handle the function execution by its + // own. We use Velox eval framework here for quick prototype. + // Start of Function execution + std::vector argumentTypes; + for (const auto& element : remoteFunctionHandle["argumentTypes"]) { + argumentTypes.push_back(element.asString()); + } + auto inputType = deserializeArgTypes(argumentTypes); + auto outputType = + deserializeType(remoteFunctionHandle["returnType"].asString()); + + auto serdeFormat = static_cast( + jsonObj["inputs"]["pageFormat"].asInt()); + auto serde = getSerde(serdeFormat); + + // jsonObj to RowVector + auto inputVector = IOBufToRowVector( + *folly::IOBuf::copyBuffer(payload.asString()), + inputType, + *pool_, + serde.get()); + + const vector_size_t numRows = inputVector->size(); + SelectivityVector rows{numRows}; + + // Expression boilerplate. + auto queryCtx = core::QueryCtx::create(); + core::ExecCtx execCtx{pool_.get(), queryCtx.get()}; + exec::ExprSet exprSet{ + getExpressions( + inputType, + outputType, + getFunctionName( + functionPrefix_, + remoteFunctionHandle["functionName"].asString())), + &execCtx}; + exec::EvalCtx evalCtx(&execCtx, &exprSet, inputVector.get()); + + std::vector expressionResult; + exprSet.eval(rows, evalCtx, expressionResult); + + // Create output vector. + auto outputRowVector = std::make_shared( + pool_.get(), ROW({outputType}), BufferPtr(), numRows, expressionResult); + + // Construct a json object for REST response + // End of Function execution. + folly::dynamic retObj = folly::dynamic::object; + retObj["payload"] = iobufToString( + rowVectorToIOBuf(outputRowVector, rows.end(), *pool_, serde.get())); + retObj["rowCount"] = outputRowVector->size(); + + // LOG(INFO) << "result:" << retObj; + ResponseBuilder(downstream_) + .status(200, "OK") + .body(folly::toJson(folly::dynamic::object("result", retObj))) + .sendWithEOM(); + + } catch (const std::exception& ex) { + LOG(ERROR) << ex.what(); + ResponseBuilder(downstream_) + .status(500, "Internal Server Error") + .body(folly::toJson(folly::dynamic::object("err", ex.what()))) + .sendWithEOM(); + } +} + +void RestRequestHandler::onBody(std::unique_ptr chain) noexcept { + if (chain) { + body_.append(reinterpret_cast(chain->data()), chain->length()); + } +} + +void RestRequestHandler::onUpgrade(UpgradeProtocol /*protocol*/) noexcept { + // handler doesn't support upgrades +} + +void RestRequestHandler::requestComplete() noexcept { + delete this; +} + +void RestRequestHandler::onError(ProxygenError /*err*/) noexcept { + delete this; +} + +// ErrorHandler +ErrorHandler::ErrorHandler(int statusCode, std::string message) + : statusCode_(statusCode), message_(std::move(message)) {} + +void ErrorHandler::onRequest(std::unique_ptr) noexcept { + ResponseBuilder(downstream_) + .status(statusCode_, "Error") + .body(std::move(message_)) + .sendWithEOM(); +} + +void ErrorHandler::onEOM() noexcept {} + +void ErrorHandler::onBody(std::unique_ptr body) noexcept {} + +void ErrorHandler::onUpgrade(UpgradeProtocol protocol) noexcept { + // handler doesn't support upgrades +} + +void ErrorHandler::requestComplete() noexcept { + delete this; +} + +void ErrorHandler::onError(ProxygenError err) noexcept { + delete this; +} + +// RestRequestHandlerFactory +void RestRequestHandlerFactory::onServerStart(folly::EventBase* evb) noexcept {} + +void RestRequestHandlerFactory::onServerStop() noexcept {} + +RequestHandler* RestRequestHandlerFactory::onRequest( + proxygen::RequestHandler*, + proxygen::HTTPMessage* msg) noexcept { + if (msg->getMethod() != HTTPMethod::POST) { + return new ErrorHandler(405, "Only POST method is allowed"); + } + return new RestRequestHandler(functionPrefix_); +} +} // namespace facebook::velox::functions diff --git a/velox/functions/remote/server/RemoteFunctionRestService.h b/velox/functions/remote/server/RemoteFunctionRestService.h new file mode 100644 index 0000000000000..254f51a4bbaaa --- /dev/null +++ b/velox/functions/remote/server/RemoteFunctionRestService.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include "velox/common/memory/Memory.h" + +using namespace proxygen; + +namespace facebook::velox::functions { +class ErrorHandler : public RequestHandler { + public: + explicit ErrorHandler(int statusCode, std::string message); + void onRequest(std::unique_ptr headers) noexcept override; + void onBody(std::unique_ptr) noexcept override; + void onEOM() noexcept override; + void onUpgrade(UpgradeProtocol protocol) noexcept override; + void requestComplete() noexcept override; + void onError(ProxygenError err) noexcept override; + + private: + int statusCode_; + std::string message_; +}; + +class RestRequestHandler : public RequestHandler { + public: + explicit RestRequestHandler(const std::string& functionPrefix = "") + : functionPrefix_(functionPrefix) {} + void onRequest(std::unique_ptr headers) noexcept override; + void onBody(std::unique_ptr body) noexcept override; + void onEOM() noexcept override; + void onUpgrade(UpgradeProtocol protocol) noexcept override; + void requestComplete() noexcept override; + void onError(ProxygenError err) noexcept override; + + private: + std::string body_; + std::shared_ptr pool_{ + memory::memoryManager()->addLeafPool()}; + const std::string functionPrefix_; +}; + +class RestRequestHandlerFactory : public RequestHandlerFactory { + public: + explicit RestRequestHandlerFactory(const std::string& functionPrefix = "") + : functionPrefix_(functionPrefix) {} + void onServerStart(folly::EventBase* evb) noexcept override; + void onServerStop() noexcept override; + RequestHandler* onRequest(RequestHandler*, HTTPMessage* msg) noexcept + override; + + private: + const std::string functionPrefix_; +}; +} // namespace facebook::velox::functions diff --git a/velox/functions/remote/server/RemoteFunctionServiceMain.cpp b/velox/functions/remote/server/RemoteFunctionServiceMain.cpp index c92ab9231d114..92ff2791bb1f8 100644 --- a/velox/functions/remote/server/RemoteFunctionServiceMain.cpp +++ b/velox/functions/remote/server/RemoteFunctionServiceMain.cpp @@ -18,6 +18,7 @@ #include #include #include +#include "velox/functions/prestosql/StringFunctions.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" #include "velox/functions/remote/server/RemoteFunctionService.h" @@ -36,7 +37,7 @@ DEFINE_string( DEFINE_string( function_prefix, - "json.test_schema.", + "remote.schema.", "Prefix to be added to the functions being registered"); using namespace ::facebook::velox; @@ -46,11 +47,14 @@ int main(int argc, char* argv[]) { folly::Init init{&argc, &argv, false}; FLAGS_logtostderr = true; + memory::initializeMemoryManager({}); + // Always registers all Presto functions and make them available under a // certain prefix/namespace. LOG(INFO) << "Registering Presto functions"; functions::prestosql::registerAllScalarFunctions(FLAGS_function_prefix); + std::remove(FLAGS_uds_path.c_str()); folly::SocketAddress location{ folly::SocketAddress::makeFromPath(FLAGS_uds_path)}; diff --git a/velox/functions/remote/server/RemoteFunctionServiceRestMain.cpp b/velox/functions/remote/server/RemoteFunctionServiceRestMain.cpp new file mode 100644 index 0000000000000..5c9c0259814e4 --- /dev/null +++ b/velox/functions/remote/server/RemoteFunctionServiceRestMain.cpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "velox/common/memory/Memory.h" + +#include "velox/functions/Registerer.h" +#include "velox/functions/prestosql/Arithmetic.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/functions/remote/server/RemoteFunctionRestService.h" + +DEFINE_string( + service_host, + "127.0.0.1", + "Prefix to be added to the functions being registered"); + +DEFINE_int32( + service_port, + 8321, + "Prefix to be added to the functions being registered"); + +DEFINE_string( + function_prefix, + "remote.schema.", + "Prefix to be added to the functions being registered"); + +using namespace ::facebook::velox; + +int main(int argc, char* argv[]) { + folly::Init init(&argc, &argv); + FLAGS_logtostderr = true; + memory::initializeMemoryManager({}); + + // A remote function service should handle the function execution by its own. + // But we use Velox framework for quick prototype here + functions::prestosql::registerAllScalarFunctions(FLAGS_function_prefix); + // registerFunction( + // {"remote_plus"}); + // End of function registration + + LOG(INFO) << "Start HTTP Server at " << "http://" << FLAGS_service_host << ":" + << FLAGS_service_port; + + HTTPServerOptions options; + // options.threads = static_cast(sysconf(_SC_NPROCESSORS_ONLN)); + options.idleTimeout = std::chrono::milliseconds(60000); + options.handlerFactories = + RequestHandlerChain() + .addThen() + .build(); + options.h2cEnabled = true; + + std::vector IPs = { + {folly::SocketAddress(FLAGS_service_host, FLAGS_service_port, true), + HTTPServer::Protocol::HTTP}}; + + proxygen::HTTPServer server(std::move(options)); + server.bind(IPs); + + std::thread t([&]() { server.start(); }); + + t.join(); + return 0; +}