diff --git a/velox/functions/prestosql/tests/IPPrefixCastTest.cpp b/velox/functions/prestosql/tests/IPPrefixCastTest.cpp new file mode 100644 index 000000000000..c8c77f40ec52 --- /dev/null +++ b/velox/functions/prestosql/tests/IPPrefixCastTest.cpp @@ -0,0 +1,110 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" + +namespace facebook::velox::functions::prestosql { + +class IPPrefixTypeTest : public functions::test::FunctionBaseTest { + protected: + std::optional castToVarchar( + const std::optional& input) { + auto result = evaluateOnce( + "cast(cast(c0 as ipprefix) as varchar)", input); + return result; + } +}; + +TEST_F(IPPrefixTypeTest, castToVarchar) { + EXPECT_EQ(castToVarchar("::ffff:1.2.3.4/24"), "1.2.3.0/24"); + EXPECT_EQ(castToVarchar("192.168.0.0/24"), "192.168.0.0/24"); + EXPECT_EQ(castToVarchar("255.2.3.4/0"), "0.0.0.0/0"); + EXPECT_EQ(castToVarchar("255.2.3.4/1"), "128.0.0.0/1"); + EXPECT_EQ(castToVarchar("255.2.3.4/2"), "192.0.0.0/2"); + EXPECT_EQ(castToVarchar("255.2.3.4/4"), "240.0.0.0/4"); + EXPECT_EQ(castToVarchar("1.2.3.4/8"), "1.0.0.0/8"); + EXPECT_EQ(castToVarchar("1.2.3.4/16"), "1.2.0.0/16"); + EXPECT_EQ(castToVarchar("1.2.3.4/24"), "1.2.3.0/24"); + EXPECT_EQ(castToVarchar("1.2.3.255/25"), "1.2.3.128/25"); + EXPECT_EQ(castToVarchar("1.2.3.255/26"), "1.2.3.192/26"); + EXPECT_EQ(castToVarchar("1.2.3.255/28"), "1.2.3.240/28"); + EXPECT_EQ(castToVarchar("1.2.3.255/30"), "1.2.3.252/30"); + EXPECT_EQ(castToVarchar("1.2.3.255/32"), "1.2.3.255/32"); + EXPECT_EQ( + castToVarchar("2001:0db8:0000:0000:0000:ff00:0042:8329/128"), + "2001:db8::ff00:42:8329/128"); + EXPECT_EQ( + castToVarchar("2001:db8::ff00:42:8329/128"), + "2001:db8::ff00:42:8329/128"); + EXPECT_EQ(castToVarchar("2001:db8:0:0:1:0:0:1/128"), "2001:db8::1:0:0:1/128"); + EXPECT_EQ(castToVarchar("2001:db8:0:0:1::1/128"), "2001:db8::1:0:0:1/128"); + EXPECT_EQ(castToVarchar("2001:db8::1:0:0:1/128"), "2001:db8::1:0:0:1/128"); + EXPECT_EQ( + castToVarchar("2001:DB8::FF00:ABCD:12EF/128"), + "2001:db8::ff00:abcd:12ef/128"); + EXPECT_EQ(castToVarchar("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/0"), "::/0"); + EXPECT_EQ( + castToVarchar("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/1"), "8000::/1"); + EXPECT_EQ( + castToVarchar("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/2"), "c000::/2"); + EXPECT_EQ( + castToVarchar("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/4"), "f000::/4"); + EXPECT_EQ( + castToVarchar("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/8"), "ff00::/8"); + EXPECT_EQ( + castToVarchar("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/16"), "ffff::/16"); + EXPECT_EQ( + castToVarchar("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/32"), + "ffff:ffff::/32"); + EXPECT_EQ( + castToVarchar("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/48"), + "ffff:ffff:ffff::/48"); + EXPECT_EQ( + castToVarchar("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/64"), + "ffff:ffff:ffff:ffff::/64"); + EXPECT_EQ( + castToVarchar("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/80"), + "ffff:ffff:ffff:ffff:ffff::/80"); + EXPECT_EQ( + castToVarchar("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/96"), + "ffff:ffff:ffff:ffff:ffff:ffff::/96"); + EXPECT_EQ( + castToVarchar("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/112"), + "ffff:ffff:ffff:ffff:ffff:ffff:ffff:0/112"); + EXPECT_EQ( + castToVarchar("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/120"), + "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ff00/120"); + EXPECT_EQ( + castToVarchar("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/124"), + "ffff:ffff:ffff:ffff:ffff:ffff:ffff:fff0/124"); + EXPECT_EQ( + castToVarchar("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/126"), + "ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffc/126"); + EXPECT_EQ( + castToVarchar("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/127"), + "ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffe/127"); + EXPECT_EQ( + castToVarchar("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/128"), + "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/128"); + EXPECT_THROW(castToVarchar("facebook.com/32"), VeloxUserError); + EXPECT_THROW(castToVarchar("localhost/32"), VeloxUserError); + EXPECT_THROW(castToVarchar("2001:db8::1::1/128"), VeloxUserError); + EXPECT_THROW(castToVarchar("2001:zxy::1::1/128"), VeloxUserError); + EXPECT_THROW(castToVarchar("789.1.1.1/32"), VeloxUserError); + EXPECT_THROW(castToVarchar("192.1.1.1"), VeloxUserError); + EXPECT_THROW(castToVarchar("192.1.1.1/128"), VeloxUserError); +} +} // namespace facebook::velox::functions::prestosql diff --git a/velox/functions/prestosql/types/IPPrefixType.cpp b/velox/functions/prestosql/types/IPPrefixType.cpp index aad808d7cfbc..b835876df232 100644 --- a/velox/functions/prestosql/types/IPPrefixType.cpp +++ b/velox/functions/prestosql/types/IPPrefixType.cpp @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include #include "velox/expression/CastExpr.h" +#include "velox/expression/VectorWriters.h" #include "velox/functions/prestosql/types/IPPrefixType.h" namespace facebook::velox { @@ -26,11 +26,21 @@ namespace { class IPPrefixCastOperator : public exec::CastOperator { public: bool isSupportedFromType(const TypePtr& other) const override { - return false; + switch (other->kind()) { + case TypeKind::VARCHAR: + return true; + default: + return false; + } } bool isSupportedToType(const TypePtr& other) const override { - return false; + switch (other->kind()) { + case TypeKind::VARCHAR: + return true; + default: + return false; + } } void castTo( @@ -40,8 +50,14 @@ class IPPrefixCastOperator : public exec::CastOperator { const TypePtr& resultType, VectorPtr& result) const override { context.ensureWritable(rows, resultType, result); - VELOX_NYI( - "Cast from {} to IPPrefix not yet supported", input.type()->toString()); + switch (input.typeKind()) { + case TypeKind::VARCHAR: + return castFromString(input, context, rows, *result); + default: + VELOX_NYI( + "Cast from {} to IPPrefix not yet supported", + input.type()->toString()); + } } void castFrom( @@ -51,8 +67,83 @@ class IPPrefixCastOperator : public exec::CastOperator { const TypePtr& resultType, VectorPtr& result) const override { context.ensureWritable(rows, resultType, result); - VELOX_NYI( - "Cast from IPPrefix to {} not yet supported", resultType->toString()); + switch (resultType->kind()) { + case TypeKind::VARCHAR: + return castToString(input, context, rows, *result); + default: + VELOX_NYI( + "Cast from IPPrefix to {} not yet supported", + resultType->toString()); + } + } + + private: + static void castToString( + const BaseVector& input, + exec::EvalCtx& context, + const SelectivityVector& rows, + BaseVector& result) { + auto* flatResult = result.as>(); + auto rowVector = input.as(); + auto rowType = rowVector->type(); + const auto* ipaddr = rowVector->childAt(ipaddress::kIpRowIndex) + ->as>(); + const auto* prefix = rowVector->childAt(ipaddress::kIpPrefixRowIndex) + ->as>(); + context.applyToSelectedNoThrow(rows, [&](auto row) { + const auto ipAddrVal = ipaddr->valueAt(row); + // The string representation of the last byte needs + // to be unsigned + const uint8_t prefixVal = prefix->valueAt(row); + + // Copy the first 16 bytes into a ByteArray16. + folly::ByteArray16 addrBytes; + memcpy(&addrBytes, &ipAddrVal, ipaddress::kIPAddressBytes); + // Reverse the bytes to get the correct order. Similar to + // IPAddressType. We assume we're ALWAYS on a little endian machine. + // Note: for big endian, we should not reverse the bytes. + std::reverse(addrBytes.begin(), addrBytes.end()); + // // Construct a V6 address from the ByteArray16. + folly::IPAddressV6 v6Addr(addrBytes); + + // Inline func to get string for ipv4 or ipv6 string + const auto ipString = + (v6Addr.isIPv4Mapped()) ? v6Addr.createIPv4().str() : v6Addr.str(); + + // Format of string is {ipString}/{mask} + auto stringRet = fmt::format("{}/{}", ipString, prefixVal); + + // Write the string to the result vector + exec::StringWriter result(flatResult, row); + result.append(stringRet); + result.finalize(); + }); + } + + static void castFromString( + const BaseVector& input, + exec::EvalCtx& context, + const SelectivityVector& rows, + BaseVector& result) { + auto* rowVectorResult = result.as(); + const auto* ipPrefixStrings = input.as>(); + + context.applyToSelectedNoThrow(rows, [&](auto row) { + auto ipAddressStringView = ipPrefixStrings->valueAt(row); + auto tryIpPrefix = ipaddress::tryParseIpPrefixString(ipAddressStringView); + if (tryIpPrefix.hasError()) { + context.setStatus(row, std::move(tryIpPrefix.error())); + } + + const auto& ipPrefix = tryIpPrefix.value(); + auto writer = exec::VectorWriter>(); + writer.init(*rowVectorResult); + writer.setOffset(row); + auto& rowWriter = writer.current(); + rowWriter.get_writer_at<0>() = ipPrefix.first; + rowWriter.get_writer_at<1>() = ipPrefix.second; + writer.commit(); + }); } }; diff --git a/velox/functions/prestosql/types/IPPrefixType.h b/velox/functions/prestosql/types/IPPrefixType.h index 22107bbd483c..bb387ef6f496 100644 --- a/velox/functions/prestosql/types/IPPrefixType.h +++ b/velox/functions/prestosql/types/IPPrefixType.h @@ -15,13 +15,145 @@ */ #pragma once +#include + +#include "velox/common/base/Status.h" +#include "velox/functions/prestosql/types/IPAddressType.h" #include "velox/type/SimpleFunctionApi.h" #include "velox/type/Type.h" namespace facebook::velox { +namespace ipaddress { +constexpr uint8_t kIPV4Bits = 32; +constexpr uint8_t kIPV6Bits = 128; +constexpr int kIPPrefixLengthIndex = 16; +constexpr int kIPPrefixBytes = 17; +constexpr auto kIpRowIndex = "ip"; +constexpr auto kIpPrefixRowIndex = "prefix"; + +namespace { +auto splitIpSlashCidr(folly::StringPiece ipSlashCidr) { + folly::small_vector vec; + folly::split('/', ipSlashCidr, vec); + return vec; +} + +Status handleFailedToCreateNetworkError( + folly::StringPiece ipaddress, + folly::CIDRNetworkError error) { + if (threadSkipErrorDetails()) { + return Status::UserError(); + } + + switch (error) { + case folly::CIDRNetworkError::INVALID_DEFAULT_CIDR: { + return Status::UserError( + "defaultCidr must be <= std::numeric_limits::max()"); + } + case folly::CIDRNetworkError::INVALID_IP_SLASH_CIDR: { + return Status::UserError( + "Invalid IP address string received. Received string:{} of length:{}", + ipaddress, + ipaddress.size()); + } + case folly::CIDRNetworkError::INVALID_IP: { + const auto vec = splitIpSlashCidr(ipaddress); + return Status::UserError( + "Invalid IP address '{}'", vec.size() > 0 ? vec.at(0) : ""); + } + case folly::CIDRNetworkError::INVALID_CIDR: { + auto const vec = splitIpSlashCidr(ipaddress); + return Status::UserError( + "Mask value '{}' not a valid mask", vec.size() > 1 ? vec.at(1) : ""); + } + case folly::CIDRNetworkError::CIDR_MISMATCH: { + const auto vec = splitIpSlashCidr(ipaddress); + if (!vec.empty()) { + const auto subnet = folly::IPAddress::tryFromString(vec.at(0)).value(); + return Status::UserError( + "CIDR value '{}' is > network bit count '{}'", + vec.size() == 2 ? vec.at(1) + : folly::to( + subnet.isV4() ? ipaddress::kIPV4Bits + : ipaddress::kIPV6Bits), + subnet.bitCount()); + } + return Status::UserError( + "Invalid IP address of size:{} received", ipaddress.size()); + } + default: + return Status::UserError( + "Unknown parsing error when parsing IP address: {} ", ipaddress); + } +} +} // namespace + +inline folly::Expected, Status> +tryParseIpPrefixString(folly::StringPiece ipprefixString) { + // Ensure '/' is present + if (ipprefixString.find('/') == std::string::npos) { + return folly::makeUnexpected( + threadSkipErrorDetails() + ? Status::UserError() + : Status::UserError( + "Invalid CIDR IP address specified. Expected IP/PREFIX format, got: {}", + ipprefixString)); + } + + auto tryCdirNetwork = folly::IPAddress::tryCreateNetwork( + ipprefixString, /*defaultCidr*/ -1, /*applyMask*/ false); + if (tryCdirNetwork.hasError()) { + return folly::makeUnexpected(handleFailedToCreateNetworkError( + ipprefixString, std::move(tryCdirNetwork.error()))); + } + + folly::ByteArray16 addrBytes; + const auto& cdirNetwork = tryCdirNetwork.value(); + if (cdirNetwork.first.isIPv4Mapped() || cdirNetwork.first.isV4()) { + // Validate that the prefix value is <= 32 for ipv4 + if (cdirNetwork.second > ipaddress::kIPV4Bits) { + return folly::makeUnexpected( + threadSkipErrorDetails() + ? Status::UserError() + : Status::UserError( + "CIDR value '{}' is > network bit count '{}'", + cdirNetwork.second, + ipaddress::kIPV4Bits)); + } + auto ipv4Addr = folly::IPAddress::createIPv4(cdirNetwork.first); + auto ipv4AddrWithMask = ipv4Addr.mask(cdirNetwork.second); + auto ipv6Addr = ipv4AddrWithMask.createIPv6(); + addrBytes = ipv6Addr.toByteArray(); + } else { + // Validate that the prefix value is <= 128 for ipv6 + if (cdirNetwork.second > ipaddress::kIPV6Bits) { + return folly::makeUnexpected( + threadSkipErrorDetails() + ? Status::UserError() + : Status::UserError( + "CIDR value '{}' is > network bit count '{}'", + cdirNetwork.second, + ipaddress::kIPV6Bits)); + } + auto ipv6Addr = folly::IPAddress::createIPv6(cdirNetwork.first); + auto ipv6AddrWithMask = ipv6Addr.mask(cdirNetwork.second); + addrBytes = ipv6AddrWithMask.toByteArray(); + } + + int128_t intAddr; + // Similar to IPAdressType, assume Velox is always on little endian systems + std::reverse(addrBytes.begin(), addrBytes.end()); + memcpy(&intAddr, &addrBytes, ipaddress::kIPAddressBytes); + return std::make_pair(intAddr, cdirNetwork.second); +} +}; // namespace ipaddress + class IPPrefixType : public RowType { - IPPrefixType() : RowType({"ip", "prefix"}, {HUGEINT(), TINYINT()}) {} + IPPrefixType() + : RowType( + {ipaddress::kIpRowIndex, ipaddress::kIpPrefixRowIndex}, + {HUGEINT(), TINYINT()}) {} public: static const std::shared_ptr& get() {