Skip to content

Commit

Permalink
domain matcher
Browse files Browse the repository at this point in the history
Signed-off-by: Rohit Agrawal <[email protected]>
  • Loading branch information
agrawroh committed Nov 21, 2024
1 parent eb2033e commit 9dc9748
Show file tree
Hide file tree
Showing 9 changed files with 1,095 additions and 275 deletions.
3 changes: 3 additions & 0 deletions changelogs/current.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ new_features:
change: |
Added support in SNI dynamic forward proxy for saving the resolved upstream address in the filter state.
The state is saved with the key ``envoy.stream.upstream_address``.
- area: matcher
change: |
added support for ``xds.type.matcher.v3.ServerNameMatcher`` trie-based matching.
deprecated:
- area: rbac
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ are available in some contexts:

* :ref:`Trie-based IP matcher <envoy_v3_api_msg_.xds.type.matcher.v3.IPMatcher>` applies to network inputs.

.. _extension_envoy.matching.custom_matchers.domain_matcher:

* :ref:`Trie-based server name matcher <envoy_v3_api_msg_.xds.type.matcher.v3.ServerNameMatcher>` applies to network inputs.

* `Common Expression Language <https://github.com/google/cel-spec>`_ (CEL) based matching:

.. _extension_envoy.matching.inputs.cel_data_input:
Expand Down
18 changes: 18 additions & 0 deletions source/extensions/common/matcher/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,21 @@ envoy_cc_extension(
"@com_github_cncf_xds//xds/type/matcher/v3:pkg_cc_proto",
],
)

envoy_cc_extension(
name = "domain_matcher_lib",
srcs = ["domain_matcher.cc"],
hdrs = ["domain_matcher.h"],
extra_visibility = [
"//source/common/listener_manager:__subpackages__",
"//test:__subpackages__",
],
deps = [
"//envoy/matcher:matcher_interface",
"//envoy/network:filter_interface",
"//envoy/registry",
"//envoy/server:factory_context_interface",
"//source/common/matcher:matcher_lib",
"@com_github_cncf_xds//xds/type/matcher/v3:pkg_cc_proto",
],
)
18 changes: 18 additions & 0 deletions source/extensions/common/matcher/domain_matcher.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include "source/extensions/common/matcher/domain_matcher.h"

#include "envoy/registry/registry.h"

namespace Envoy {
namespace Extensions {
namespace Common {
namespace Matcher {

REGISTER_FACTORY(NetworkDomainMatcherFactory,
::Envoy::Matcher::CustomMatcherFactory<Network::MatchingData>);
REGISTER_FACTORY(HttpDomainMatcherFactory,
::Envoy::Matcher::CustomMatcherFactory<Http::HttpMatchingData>);

} // namespace Matcher
} // namespace Common
} // namespace Extensions
} // namespace Envoy
316 changes: 316 additions & 0 deletions source/extensions/common/matcher/domain_matcher.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
#pragma once

#include "envoy/matcher/matcher.h"
#include "envoy/network/filter.h"
#include "envoy/server/factory_context.h"

#include "source/common/matcher/matcher.h"

#include "xds/type/matcher/v3/domain.pb.h"
#include "xds/type/matcher/v3/domain.pb.validate.h"

namespace Envoy {
namespace Extensions {
namespace Common {
namespace Matcher {

using ::Envoy::Matcher::DataInputFactoryCb;
using ::Envoy::Matcher::DataInputGetResult;
using ::Envoy::Matcher::DataInputPtr;
using ::Envoy::Matcher::evaluateMatch;
using ::Envoy::Matcher::MatchState;
using ::Envoy::Matcher::MatchTree;
using ::Envoy::Matcher::OnMatch;
using ::Envoy::Matcher::OnMatchFactory;
using ::Envoy::Matcher::OnMatchFactoryCb;

template <class DataType> struct DomainNode {
// Domain-specific node structure while maintaining compatibility with existing code
size_t index_; // Preserve original index for ordering
std::string domain_part_; // Individual part of domain (e.g., "com", "example")
bool is_wildcard_; // Whether this node represents a wildcard
bool is_terminal_; // Whether this node is end of a domain pattern
std::shared_ptr<OnMatch<DataType>> on_match_; // Action to take on match
absl::flat_hash_map<std::string, std::unique_ptr<DomainNode>> children_; // Child nodes

DomainNode() : index_(0), is_wildcard_(false), is_terminal_(false) {}

// Deep copy method that returns a unique_ptr
std::unique_ptr<DomainNode> clone() const {
auto node = std::make_unique<DomainNode>();
node->index_ = index_;
node->domain_part_ = domain_part_;
node->is_wildcard_ = is_wildcard_;
node->is_terminal_ = is_terminal_;
node->on_match_ = on_match_;

for (const auto& [key, child] : children_) {
node->children_.emplace(key, child ? child->clone() : nullptr);
}
return node;
}
};

/**
* Implementation of a domain-specific trie matcher.
*/
template <class DataType> class DomainTrieMatcher : public MatchTree<DataType> {
public:
DomainTrieMatcher(DataInputPtr<DataType>&& data_input,
absl::optional<OnMatch<DataType>> on_no_match,
std::unique_ptr<DomainNode<DataType>> root)
: data_input_(std::move(data_input)), on_no_match_(std::move(on_no_match)),
root_(std::move(root)) {
auto input_type = data_input_->dataInputType();
if (input_type != Envoy::Matcher::DefaultMatchingDataType) {
throw EnvoyException(
absl::StrCat("Unsupported data input type: ", input_type,
", currently only string type is supported in domain matcher"));
}
}

static std::vector<std::string> splitAndReverseDomain(absl::string_view domain) {
std::vector<std::string> parts;
size_t pos = 0;
while (pos < domain.length()) {
size_t dot_pos = domain.find('.', pos);
if (dot_pos == absl::string_view::npos) {
parts.push_back(std::string(domain.substr(pos)));
break;
}
parts.push_back(std::string(domain.substr(pos, dot_pos - pos)));
pos = dot_pos + 1;
}
std::reverse(parts.begin(), parts.end());
return parts;
}

typename MatchTree<DataType>::MatchResult match(const DataType& data) override {
const auto input = data_input_->get(data);
if (input.data_availability_ != DataInputGetResult::DataAvailability::AllDataAvailable) {
return {MatchState::UnableToMatch, absl::nullopt};
}

if (absl::holds_alternative<absl::monostate>(input.data_)) {
return {MatchState::MatchComplete, on_no_match_};
}

const auto& domain = absl::get<std::string>(input.data_);
if (domain.empty()) {
return {MatchState::MatchComplete, on_no_match_};
}

// Check global wildcard
auto wildcard_it = root_->children_.find("*");
if (wildcard_it != root_->children_.end() && wildcard_it->second->is_terminal_ &&
wildcard_it->second->on_match_) {
return {MatchState::MatchComplete,
OnMatch<DataType>{wildcard_it->second->on_match_->action_cb_, nullptr}};
}

const auto parts = splitAndReverseDomain(domain);
std::shared_ptr<OnMatch<DataType>> best_match;
size_t best_match_length = 0;

// Check exact matches first
{
const DomainNode<DataType>* current = root_.get();
size_t matched_length = 0;
for (const auto& part : parts) {
auto it = current->children_.find(part);
if (it == current->children_.end()) {
break;
}
matched_length++;
if (it->second->is_terminal_ && it->second->on_match_) {
best_match = it->second->on_match_;
best_match_length = matched_length;
}
current = it->second.get();
}
}

// Check wildcard matches
if (!best_match) {
const DomainNode<DataType>* current = root_.get();
size_t matched_length = 0;
for (const auto& part : parts) {
auto wildcard = current->children_.find("*");
if (wildcard != current->children_.end() && wildcard->second->is_terminal_ &&
wildcard->second->on_match_ && matched_length > best_match_length) {
best_match = wildcard->second->on_match_;
best_match_length = matched_length;
}

auto it = current->children_.find(part);
if (it == current->children_.end()) {
break;
}
matched_length++;
current = it->second.get();
}
}

if (best_match) {
return {MatchState::MatchComplete, OnMatch<DataType>{best_match->action_cb_, nullptr}};
}

return {MatchState::MatchComplete, on_no_match_};
}

private:
const DataInputPtr<DataType> data_input_;
const absl::optional<OnMatch<DataType>> on_no_match_;
std::unique_ptr<DomainNode<DataType>> root_;
};

template <class DataType>
class DomainTrieMatcherFactoryBase : public ::Envoy::Matcher::CustomMatcherFactory<DataType> {
public:
::Envoy::Matcher::MatchTreeFactoryCb<DataType>
createCustomMatcherFactoryCb(const Protobuf::Message& config,
Server::Configuration::ServerFactoryContext& factory_context,
DataInputFactoryCb<DataType> data_input,
absl::optional<OnMatchFactoryCb<DataType>> on_no_match,
OnMatchFactory<DataType>& on_match_factory) override {
const auto& typed_config =
MessageUtil::downcastAndValidate<const xds::type::matcher::v3::ServerNameMatcher&>(
config, factory_context.messageValidationVisitor());

validateDomains(typed_config);

std::vector<OnMatchFactoryCb<DataType>> match_children;
match_children.reserve(typed_config.domain_matchers().size());

auto root = std::make_shared<DomainNode<DataType>>();
buildDomainTrie(typed_config, on_match_factory, match_children, root.get());
auto children =
std::make_shared<std::vector<OnMatchFactoryCb<DataType>>>(std::move(match_children));

return [data_input, root, children, on_no_match]() {
return std::make_unique<DomainTrieMatcher<DataType>>(
data_input(), on_no_match ? absl::make_optional(on_no_match.value()()) : absl::nullopt,
root->clone());
};
}

ProtobufTypes::MessagePtr createEmptyConfigProto() override {
return std::make_unique<xds::type::matcher::v3::ServerNameMatcher>();
}

std::string name() const override { return "envoy.matching.custom_matchers.domain_matcher"; }

private:
void validateDomains(const xds::type::matcher::v3::ServerNameMatcher& config) const {
absl::flat_hash_set<std::string> unique_domains;

for (const auto& domain_matcher : config.domain_matchers()) {
for (const auto& domain : domain_matcher.domains()) {
if (!unique_domains.insert(domain).second) {
throw EnvoyException(absl::StrCat("Duplicate domain in ServerNameMatcher: ", domain));
}

if (domain != "*") {
bool is_wildcard = domain[0] == '*';
if (is_wildcard &&
(domain.size() < 2 || domain[1] != '.' || domain.find('*', 1) != std::string::npos)) {
throw EnvoyException(absl::StrCat("Invalid wildcard domain format: ", domain));
}
}
}
}
}

void buildDomainTrie(const xds::type::matcher::v3::ServerNameMatcher& config,
OnMatchFactory<DataType>& on_match_factory,
std::vector<OnMatchFactoryCb<DataType>>& match_children,
DomainNode<DataType>* root) const {
size_t matcher_index = 0;

for (const auto& domain_matcher : config.domain_matchers()) {
match_children.push_back(*on_match_factory.createOnMatch(domain_matcher.on_match()));
const auto on_match_cb = match_children.back();
auto on_match = std::make_shared<OnMatch<DataType>>(on_match_cb());

for (const auto& domain : domain_matcher.domains()) {
if (domain == "*") {
root->is_wildcard_ = true;
root->is_terminal_ = true;
root->on_match_ = on_match;
continue;
}

bool is_wildcard = domain[0] == '*';
std::vector<std::string> parts;

if (is_wildcard) {
// For wildcard domains like "*.api.example.com", we want:
// root -> "com" -> "example" -> "api" -> "*" (terminal)
parts = DomainTrieMatcher<DataType>::splitAndReverseDomain(domain.substr(2));
if (parts.empty()) {
continue;
}
} else {
parts = DomainTrieMatcher<DataType>::splitAndReverseDomain(domain);
}

DomainNode<DataType>* current = root;

// Add all parts except the last one
for (size_t i = 0; i < parts.size() - 1; i++) {
const auto& part = parts[i];
auto& next = current->children_[part];
if (!next) {
next = std::make_unique<DomainNode<DataType>>();
next->index_ = ++matcher_index;
next->domain_part_ = part;
}
current = next.get();
}

// Handle the last part differently for wildcards
if (is_wildcard) {
// Get the last concrete part
const auto& last_part = parts.back();
auto& last_node = current->children_[last_part];
if (!last_node) {
last_node = std::make_unique<DomainNode<DataType>>();
last_node->index_ = ++matcher_index;
last_node->domain_part_ = last_part;
}
current = last_node.get();

// Add wildcard node
auto& wildcard_node = current->children_["*"];
if (!wildcard_node) {
wildcard_node = std::make_unique<DomainNode<DataType>>();
wildcard_node->index_ = ++matcher_index;
wildcard_node->domain_part_ = "*";
}
wildcard_node->is_wildcard_ = true;
wildcard_node->is_terminal_ = true;
wildcard_node->on_match_ = on_match;
} else {
// Regular domain - set match on last node
const auto& last_part = parts.back();
auto& last_node = current->children_[last_part];
if (!last_node) {
last_node = std::make_unique<DomainNode<DataType>>();
last_node->index_ = ++matcher_index;
last_node->domain_part_ = last_part;
}
last_node->is_terminal_ = true;
last_node->on_match_ = on_match;
}
}
}
}
};

class NetworkDomainMatcherFactory : public DomainTrieMatcherFactoryBase<Network::MatchingData> {};
class HttpDomainMatcherFactory : public DomainTrieMatcherFactoryBase<Http::HttpMatchingData> {};

} // namespace Matcher
} // namespace Common
} // namespace Extensions
} // namespace Envoy
Loading

0 comments on commit 9dc9748

Please sign in to comment.