Skip to content

Commit

Permalink
feat(functions): Add support for REST based remote functions
Browse files Browse the repository at this point in the history
Co-authored-by: Wills Feng <[email protected]>
  • Loading branch information
Joe-Abraham and wills-feng committed Dec 13, 2024
1 parent e86ff05 commit 97483aa
Show file tree
Hide file tree
Showing 14 changed files with 1,077 additions and 68 deletions.
12 changes: 12 additions & 0 deletions velox/functions/remote/client/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,23 @@ velox_add_library(velox_functions_remote_thrift_client ThriftClient.cpp)
velox_link_libraries(velox_functions_remote_thrift_client
PUBLIC remote_function_thrift FBThrift::thriftcpp2)

set(curl_SOURCE BUNDLED)
velox_resolve_dependency(curl)

velox_add_library(velox_functions_remote_rest_client RestClient.cpp)
velox_link_libraries(velox_functions_remote_rest_client Folly::folly
${CURL_LIBRARIES})

velox_add_library(velox_functions_remote Remote.cpp)
velox_link_libraries(
velox_functions_remote
PUBLIC velox_expression
velox_memory
velox_exec
velox_vector
velox_presto_serializer
velox_functions_remote_thrift_client
velox_functions_remote_rest_client
velox_functions_remote_get_serde
velox_type_fbhive
Folly::folly)
Expand Down
124 changes: 105 additions & 19 deletions velox/functions/remote/client/Remote.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,70 @@

#include "velox/functions/remote/client/Remote.h"

#include <fmt/format.h>
#include <folly/io/async/EventBase.h>
#include <sstream>
#include <string>

#include "velox/common/memory/ByteStream.h"
#include "velox/expression/Expr.h"
#include "velox/expression/VectorFunction.h"
#include "velox/functions/remote/client/RestClient.h"
#include "velox/functions/remote/client/ThriftClient.h"
#include "velox/functions/remote/if/GetSerde.h"
#include "velox/functions/remote/if/gen-cpp2/RemoteFunctionServiceAsyncClient.h"
#include "velox/serializers/PrestoSerializer.h"
#include "velox/type/fbhive/HiveTypeSerializer.h"
#include "velox/vector/VectorStream.h"

using namespace folly;
namespace facebook::velox::functions {
namespace {

std::string serializeType(const TypePtr& type) {
// Use hive type serializer.
return type::fbhive::HiveTypeSerializer::serialize(type);
}

std::string extractFunctionName(const std::string& input) {
size_t lastDot = input.find_last_of('.');
if (lastDot != std::string::npos) {
return input.substr(lastDot + 1);
}
return input;
}

std::string urlEncode(const std::string& value) {
std::ostringstream escaped;
escaped.fill('0');
escaped << std::hex;
for (char c : value) {
if (isalnum(static_cast<unsigned char>(c)) || c == '-' || c == '_' ||
c == '.' || c == '~') {
escaped << c;
} else {
escaped << '%' << std::setw(2) << int(static_cast<unsigned char>(c));
}
}
return escaped.str();
}

class RemoteFunction : public exec::VectorFunction {
public:
RemoteFunction(
const std::string& functionName,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const RemoteVectorFunctionMetadata& metadata)
const RemoteVectorFunctionMetadata& metadata,
std::unique_ptr<HttpClient> httpClient = nullptr)
: functionName_(functionName),
location_(metadata.location),
thriftClient_(getThriftClient(location_, &eventBase_)),
serdeFormat_(metadata.serdeFormat),
serde_(getSerde(serdeFormat_)) {
restClient_(httpClient ? std::move(httpClient) : getRestClient()),
metadata_(metadata) {
if (metadata.location.type() == typeid(SocketAddress)) {
location_ = boost::get<SocketAddress>(metadata.location);
thriftClient_ = getThriftClient(location_, &eventBase_);
} else if (metadata.location.type() == typeid(std::string)) {
url_ = boost::get<std::string>(metadata.location);
}

std::vector<TypePtr> types;
types.reserve(inputArgs.size());
serializedInputTypes_.reserve(inputArgs.size());
Expand All @@ -62,7 +98,11 @@ class RemoteFunction : public exec::VectorFunction {
exec::EvalCtx& context,
VectorPtr& result) const override {
try {
applyRemote(rows, args, outputType, context, result);
if ((metadata_.location.type() == typeid(SocketAddress))) {
applyRemote(rows, args, outputType, context, result);
} else if (metadata_.location.type() == typeid(std::string)) {
applyRestRemote(rows, args, outputType, context, result);
}
} catch (const VeloxRuntimeError&) {
throw;
} catch (const std::exception&) {
Expand All @@ -71,6 +111,48 @@ class RemoteFunction : public exec::VectorFunction {
}

private:
void applyRestRemote(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& result) const {
try {
serializer::presto::PrestoVectorSerde serde;
auto remoteRowVector = std::make_shared<RowVector>(
context.pool(),
remoteInputType_,
BufferPtr{},
rows.end(),
std::move(args));

std::unique_ptr<IOBuf> requestBody =
std::make_unique<IOBuf>(rowVectorToIOBuf(
remoteRowVector, rows.end(), *context.pool(), &serde));

const std::string fullUrl = fmt::format(
"{}/v1/functions/{}/{}/{}/{}",
url_,
metadata_.schema.value_or("default"),
extractFunctionName(functionName_),
urlEncode(metadata_.functionId.value_or("default_function_id")),
metadata_.version.value_or("1"));

std::unique_ptr<IOBuf> responseBody =
restClient_->invokeFunction(fullUrl, std::move(requestBody));

auto outputRowVector = IOBufToRowVector(
*responseBody, ROW({outputType}), *context.pool(), &serde);

result = outputRowVector->childAt(0);
} catch (const std::exception& e) {
VELOX_FAIL(
"Error while executing remote function '{}': {}",
functionName_,
e.what());
}
}

void applyRemote(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
Expand All @@ -97,11 +179,14 @@ class RemoteFunction : public exec::VectorFunction {

auto requestInputs = request.inputs_ref();
requestInputs->rowCount_ref() = remoteRowVector->size();
requestInputs->pageFormat_ref() = serdeFormat_;
requestInputs->pageFormat_ref() = metadata_.serdeFormat;

// TODO: serialize only active rows.
requestInputs->payload_ref() = rowVectorToIOBuf(
remoteRowVector, rows.end(), *context.pool(), serde_.get());
remoteRowVector,
rows.end(),
*context.pool(),
getSerde(metadata_.serdeFormat).get());

try {
thriftClient_->sync_invokeFunction(remoteResponse, request);
Expand All @@ -117,12 +202,15 @@ class RemoteFunction : public exec::VectorFunction {
remoteResponse.get_result().get_payload(),
ROW({outputType}),
*context.pool(),
serde_.get());
getSerde(metadata_.serdeFormat).get());
result = outputRowVector->childAt(0);

if (auto errorPayload = remoteResponse.get_result().errorPayload()) {
auto errorsRowVector = IOBufToRowVector(
*errorPayload, ROW({VARCHAR()}), *context.pool(), serde_.get());
*errorPayload,
ROW({VARCHAR()}),
*context.pool(),
getSerde(metadata_.serdeFormat).get());
auto errorsVector =
errorsRowVector->childAt(0)->asFlatVector<StringView>();
VELOX_CHECK(errorsVector, "Should be convertible to flat vector");
Expand All @@ -142,16 +230,14 @@ class RemoteFunction : public exec::VectorFunction {
}

const std::string functionName_;
folly::SocketAddress location_;

folly::EventBase eventBase_;
EventBase eventBase_;
std::unique_ptr<RemoteFunctionClient> thriftClient_;
remote::PageFormat serdeFormat_;
std::unique_ptr<VectorSerde> serde_;

// Structures we construct once to cache:
std::unique_ptr<HttpClient> restClient_;
SocketAddress location_;
std::string url_;
RowTypePtr remoteInputType_;
std::vector<std::string> serializedInputTypes_;
const RemoteVectorFunctionMetadata metadata_;
};

std::shared_ptr<exec::VectorFunction> createRemoteFunction(
Expand All @@ -169,7 +255,7 @@ void registerRemoteFunction(
std::vector<exec::FunctionSignaturePtr> signatures,
const RemoteVectorFunctionMetadata& metadata,
bool overwrite) {
exec::registerStatefulVectorFunction(
registerStatefulVectorFunction(
name,
signatures,
std::bind(
Expand Down
25 changes: 21 additions & 4 deletions velox/functions/remote/client/Remote.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,37 @@

#pragma once

#include <boost/variant.hpp>
#include <folly/SocketAddress.h>
#include "velox/expression/VectorFunction.h"
#include "velox/functions/remote/if/gen-cpp2/RemoteFunction_types.h"

namespace facebook::velox::functions {

struct RemoteVectorFunctionMetadata : public exec::VectorFunctionMetadata {
/// Network address of the servr to communicate with. Note that this can hold
/// a network location (ip/port pair) or a unix domain socket path (see
/// URL of the HTTP/REST server for remote function.
/// Or Network address of the server to communicate with. Note that this can
/// hold a network location (ip/port pair) or a unix domain socket path (see
/// SocketAddress::makeFromPath()).
folly::SocketAddress location;
boost::variant<folly::SocketAddress, std::string> location;

/// The serialization format to be used
/// The serialization format to be used when sending data to the remote.
remote::PageFormat serdeFormat{remote::PageFormat::PRESTO_PAGE};

/// Optional schema defining the structure of the data or input/output types
/// involved in the remote function. This may include details such as column
/// names and data types.
std::optional<std::string> schema;

/// Optional identifier for the specific remote function to be invoked.
/// This can be useful when the same server hosts multiple functions,
/// and the client needs to specify which function to call.
std::optional<std::string> functionId;

/// Optional version information to be used when calling the remote function.
/// This can help in ensuring compatibility with a particular version of the
/// function if multiple versions are available on the server.
std::optional<std::string> version;
};

/// Registers a new remote function. It will use the meatadata defined in
Expand Down
128 changes: 128 additions & 0 deletions velox/functions/remote/client/RestClient.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/functions/remote/client/RestClient.h"

#include <curl/curl.h>
#include <folly/io/IOBufQueue.h>

#include "velox/common/base/Exceptions.h"

using namespace folly;
namespace facebook::velox::functions {
namespace {

// Callback function for CURL to read data from the request payload.
// @param dest Destination buffer to copy data into.
// @param size Size of each data element.
// @param nmemb Number of elements to read.
// @param userp Pointer to user data (IOBufQueue containing the request
// payload).
// @return Number of bytes actually copied.
size_t readCallback(char* dest, size_t size, size_t nmemb, void* userp) {
auto* inputBufQueue = static_cast<IOBufQueue*>(userp);
size_t bufferSize = size * nmemb;
size_t totalCopied = 0;

while (totalCopied < bufferSize && !inputBufQueue->empty()) {
auto buf = inputBufQueue->front();
size_t remainingSize = bufferSize - totalCopied;
size_t copySize = std::min(remainingSize, buf->length());
std::memcpy(dest + totalCopied, buf->data(), copySize);
totalCopied += copySize;
inputBufQueue->pop_front();
}

return totalCopied;
}

// Callback function for CURL to write data to the response payload.
// @param ptr Pointer to the received data.
// @param size Size of each data element.
// @param nmemb Number of elements received.
// @param userData Pointer to user data (IOBufQueue to store the response
// payload).
// @return Number of bytes actually written.
size_t writeCallback(char* ptr, size_t size, size_t nmemb, void* userData) {
auto* outputBuf = static_cast<IOBufQueue*>(userData);
size_t totalSize = size * nmemb;
auto buf = IOBuf::copyBuffer(ptr, totalSize);
outputBuf->append(std::move(buf));
return totalSize;
}
} // namespace

std::unique_ptr<IOBuf> RestClient::invokeFunction(
const std::string& fullUrl,
std::unique_ptr<IOBuf> requestPayload) {
try {
IOBufQueue inputBufQueue(IOBufQueue::cacheChainLength());
inputBufQueue.append(std::move(requestPayload));

CURL* curl = curl_easy_init();
if (!curl) {
VELOX_FAIL(fmt::format(
"Error initializing CURL: {}",
curl_easy_strerror(CURLE_FAILED_INIT)));
}

curl_easy_setopt(curl, CURLOPT_URL, fullUrl.c_str());
curl_easy_setopt(curl, CURLOPT_POST, 1L);
curl_easy_setopt(curl, CURLOPT_READFUNCTION, readCallback);
curl_easy_setopt(curl, CURLOPT_READDATA, &inputBufQueue);
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, writeCallback);

IOBufQueue outputBuf(IOBufQueue::cacheChainLength());
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &outputBuf);
curl_easy_setopt(curl, CURLOPT_VERBOSE, 1L);

struct curl_slist* headers = nullptr;
headers =
curl_slist_append(headers, "Content-Type: application/X-presto-pages");
headers = curl_slist_append(headers, "Accept: application/X-presto-pages");
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);

curl_easy_setopt(
curl,
CURLOPT_POSTFIELDSIZE,
static_cast<long>(inputBufQueue.chainLength()));

CURLcode res = curl_easy_perform(curl);
if (res != CURLE_OK) {
curl_slist_free_all(headers);
curl_easy_cleanup(curl);
VELOX_FAIL(fmt::format(
"Error communicating with server: {}\nURL: {}\nCURL Error: {}",
curl_easy_strerror(res),
fullUrl.c_str(),
curl_easy_strerror(res)));
}

curl_slist_free_all(headers);
curl_easy_cleanup(curl);

return outputBuf.move();

} catch (const std::exception& e) {
VELOX_FAIL(fmt::format("Exception during CURL request: {}", e.what()));
}
}

std::unique_ptr<HttpClient> getRestClient() {
return std::make_unique<RestClient>();
}

} // namespace facebook::velox::functions
Loading

0 comments on commit 97483aa

Please sign in to comment.