Skip to content

Commit

Permalink
filter and session store abstraction and filter chain refactor (#165)
Browse files Browse the repository at this point in the history
* filter and session store abstraction and filter chain refactor

Signed-off-by: Shikugawa <[email protected]>

* fix reviews

Signed-off-by: Shikugawa <[email protected]>
  • Loading branch information
Shikugawa authored Oct 6, 2021
1 parent 0ee8e7d commit e91ca69
Show file tree
Hide file tree
Showing 17 changed files with 289 additions and 91 deletions.
9 changes: 9 additions & 0 deletions src/filters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ xx_library(
],
)

xx_library(
name = "filter_factory",
hdrs = ["filter_factory.h"],
deps = [
":pipe",
]
)

xx_library(
name = "filter_chain",
srcs = [
Expand All @@ -33,6 +41,7 @@ xx_library(
"filter_chain.h",
],
deps = [
":filter_factory",
"//config:config_cc",
"//src/config",
"//src/filters:filter",
Expand Down
146 changes: 62 additions & 84 deletions src/filters/filter_chain.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "filter_chain.h"

#include <algorithm>
#include <memory>
#include <stdexcept>

#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
Expand All @@ -25,33 +27,63 @@ FilterChainImpl::FilterChainImpl(boost::asio::io_context& ioc,
: threads_(threads),
config_(std::move(config)),
oidc_session_store_(nullptr) {
for (auto& filter : config_.filters()) {
// TODO(shikugawa): We need an abstraction to handle many types of filters.
if (filter.type_case() == config::Filter::TypeCase::kOidc) {
switch (filter.oidc().jwks_config_case()) {
case config::oidc::OIDCConfig::kJwks:
jwks_resolver_map_.emplace_back(
std::make_shared<oidc::StaticJwksResolverImpl>(
filter.oidc().jwks()));
break;
case config::oidc::OIDCConfig::kJwksFetcher: {
uint32_t periodic_fetch_interval_sec =
filter.oidc().jwks_fetcher().periodic_fetch_interval_sec();
if (periodic_fetch_interval_sec == 0) {
periodic_fetch_interval_sec = 1200;
}

auto http_ptr = common::http::ptr_t(new common::http::HttpImpl);

jwks_resolver_map_.emplace_back(
std::make_shared<oidc::DynamicJwksResolverImpl>(
filter.oidc().jwks_fetcher().jwks_uri(),
std::chrono::seconds(periodic_fetch_interval_sec), http_ptr,
ioc));
} break;
default:
throw std::runtime_error("invalid JWKs config type");
}
// Validate that filter chain has only one OIDC filter.
const int oidc_filter_count =
std::count_if(config_.filters().begin(), config_.filters().end(),
[](const auto& filter) { return filter.has_oidc(); });
if (oidc_filter_count > 1) {
throw std::runtime_error(
"only one filter of type \"oidc\" is allowed in a chain");
}

bool skip_oidc_preparation = false;

if (oidc_filter_count == 0) {
skip_oidc_preparation = true;
}

if (!skip_oidc_preparation) {
// Setup OIDC related modules.
const auto& oidc_filter =
std::find_if(config_.filters().begin(), config_.filters().end(),
[](const auto& filter) { return filter.has_oidc(); });
assert(oidc_filter != config_.filters().end());

// Note that each incoming request gets a new instance of Filter to handle
// it, so here we ensure that each instance returned by New() shares the
// same session store.
auto absolute_session_timeout =
oidc_filter->oidc().absolute_session_timeout();
auto idle_session_timeout = oidc_filter->oidc().idle_session_timeout();

if (oidc_filter->oidc().has_redis_session_store_config()) {
spdlog::trace("{}: using RedisSession Store", __func__);
oidc_session_store_ =
oidc::RedisSessionStoreFactory(
oidc_filter->oidc().redis_session_store_config(),
absolute_session_timeout, idle_session_timeout, threads_)
.create();
} else {
spdlog::trace("{}: using InMemorySession Store", __func__);
oidc_session_store_ = oidc::InMemorySessionStoreFactory(
absolute_session_timeout, idle_session_timeout)
.create();
}

jwks_resolver_cache_ =
std::make_unique<oidc::JwksResolverCache>(oidc_filter->oidc(), ioc);
}

// Create filter chain factory
for (const auto& filter : config_.filters()) {
if (filter.has_mock()) {
filter_factories_.emplace_back(
std::make_unique<mock::FilterFactory>(filter.mock()));
} else if (filter.has_oidc()) {
filter_factories_.emplace_back(std::make_unique<oidc::FilterFactory>(
filter.oidc(), oidc_session_store_, jwks_resolver_cache_));
} else {
throw std::runtime_error("invalid filter type");
}
}
}
Expand Down Expand Up @@ -83,65 +115,11 @@ bool FilterChainImpl::Matches(
std::unique_ptr<Filter> FilterChainImpl::New() {
spdlog::trace("{}", __func__);
std::unique_ptr<Pipe> result(new Pipe);
int oidc_filter_count = 0;
for (auto i = 0; i < config_.filters_size(); ++i) {
auto& filter = *config_.mutable_filters(i);
if (filter.has_oidc()) {
++oidc_filter_count;
} else if (filter.has_mock()) {
result->AddFilter(std::make_unique<mock::MockFilter>(filter.mock()));
continue;
} else {
throw std::runtime_error("unsupported filter type");
}

if (oidc_filter_count > 1) {
throw std::runtime_error(
"only one filter of type \"oidc\" is allowed in a chain");
}
auto token_response_parser =
std::make_shared<oidc::TokenResponseParserImpl>(
jwks_resolver_map_[i]->jwks());
auto session_string_generator =
std::make_shared<common::session::SessionStringGenerator>();

auto http = common::http::ptr_t(new common::http::HttpImpl);

if (oidc_session_store_ == nullptr) {
// Note that each incoming request gets a new instance of Filter to handle
// it, so here we ensure that each instance returned by New() shares the
// same session store.
auto absolute_session_timeout = filter.oidc().absolute_session_timeout();
auto idle_session_timeout = filter.oidc().idle_session_timeout();

if (filter.oidc().has_redis_session_store_config()) {
auto redis_sever_uri =
filter.oidc().redis_session_store_config().server_uri();
spdlog::trace(
"{}: redis configuration found. attempting to connect to: {}",
__func__, redis_sever_uri);
auto redis_wrapper =
std::make_shared<oidc::RedisWrapper>(redis_sever_uri, threads_);
auto redis_retry_wrapper =
std::make_shared<oidc::RedisRetryWrapper>(redis_wrapper);
oidc_session_store_ = std::static_pointer_cast<oidc::RedisSessionStore>(
std::make_shared<oidc::RedisSessionStore>(
std::make_shared<common::utilities::TimeService>(),
absolute_session_timeout, idle_session_timeout,
redis_retry_wrapper));
} else {
spdlog::trace("{}: using InMemorySession Store", __func__);
oidc_session_store_ = std::static_pointer_cast<oidc::SessionStore>(
std::make_shared<oidc::InMemorySessionStore>(
std::make_shared<common::utilities::TimeService>(),
absolute_session_timeout, idle_session_timeout));
}
}

result->AddFilter(FilterPtr(
new oidc::OidcFilter(http, filter.oidc(), token_response_parser,
session_string_generator, oidc_session_store_)));
for (auto&& filter_factory : filter_factories_) {
result->AddFilter(filter_factory->create());
}

return result;
}

Expand Down
7 changes: 5 additions & 2 deletions src/filters/filter_chain.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
#define AUTHSERVICE_FILTER_CHAIN_H

#include <memory>
#include <vector>

#include "boost/asio/io_context.hpp"
#include "config/config.pb.h"
#include "config/oidc/config.pb.h"
#include "envoy/service/auth/v3/external_auth.grpc.pb.h"
#include "src/filters/filter.h"
#include "src/filters/filter_factory.h"
#include "src/filters/oidc/jwks_resolver.h"
#include "src/filters/oidc/session_store.h"

Expand Down Expand Up @@ -54,8 +56,9 @@ class FilterChainImpl : public FilterChain {
private:
unsigned int threads_;
config::FilterChain config_;
std::shared_ptr<oidc::SessionStore> oidc_session_store_;
std::vector<std::shared_ptr<oidc::JwksResolver>> jwks_resolver_map_;
oidc::SessionStorePtr oidc_session_store_;
oidc::JwksResolverCachePtr jwks_resolver_cache_;
std::vector<FilterFactoryPtr> filter_factories_;

public:
explicit FilterChainImpl(boost::asio::io_context &ioc,
Expand Down
26 changes: 26 additions & 0 deletions src/filters/filter_factory.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef AUTHSERVICE_FILTER_FACTORY_H
#define AUTHSERVICE_FILTER_FACTORY_H

#include <memory>

#include "src/filters/pipe.h"

namespace authservice {
namespace filters {

class FilterFactory {
public:
virtual ~FilterFactory() = default;

/**
* Creates an authentication filter.
*/
virtual FilterPtr create() = 0;
};

using FilterFactoryPtr = std::unique_ptr<FilterFactory>;

} // namespace filters
} // namespace authservice

#endif
1 change: 1 addition & 0 deletions src/filters/mock/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ xx_library(
deps = [
"//config/mock:config_cc",
"//src/filters:filter",
"//src/filters:filter_factory",
"@com_github_gabime_spdlog//:spdlog",
"@com_google_googleapis//google/rpc:code_cc_proto",
],
Expand Down
6 changes: 6 additions & 0 deletions src/filters/mock/mock_filter.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "mock_filter.h"

#include <memory>

#include "spdlog/spdlog.h"

namespace authservice {
Expand All @@ -21,6 +23,10 @@ enum google::rpc::Code MockFilter::Process(

absl::string_view MockFilter::Name() const { return "mock"; }

filters::FilterPtr FilterFactory::create() {
return std::make_unique<MockFilter>(config_);
}

} // namespace mock
} // namespace filters
} // namespace authservice
12 changes: 12 additions & 0 deletions src/filters/mock/mock_filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "config/mock/config.pb.h"
#include "google/rpc/code.pb.h"
#include "src/filters/filter.h"
#include "src/filters/filter_factory.h"

namespace authservice {
namespace filters {
Expand All @@ -23,6 +24,17 @@ class MockFilter final : public filters::Filter {

absl::string_view Name() const override;
};

class FilterFactory : public filters::FilterFactory {
public:
FilterFactory(const config::mock::MockConfig &config) : config_(config) {}

filters::FilterPtr create() override;

private:
const config::mock::MockConfig config_;
};

} // namespace mock
} // namespace filters
} // namespace authservice
Expand Down
12 changes: 12 additions & 0 deletions src/filters/oidc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,20 @@ xx_library(
],
)

xx_library(
name = "session_store_factory",
hdrs = ["session_store_factory.h"],
deps = [
"//config:config_cc",
]
)

xx_library(
name = "in_memory_session_store",
srcs = ["in_memory_session_store.cc"],
hdrs = ["in_memory_session_store.h", "session_store.h"],
deps = [
":session_store_factory",
"//src/filters/oidc:token_response",
"//src/filters/oidc:authorization_state",
"//src/common/utilities:time_service",
Expand Down Expand Up @@ -58,6 +67,7 @@ xx_library(
srcs = ["redis_session_store.cc"],
hdrs = ["redis_session_store.h", "session_store.h"],
deps = [
":session_store_factory",
"//src/filters/oidc:session_store",
"//src/filters/oidc:token_response",
"//src/filters/oidc:authorization_state",
Expand Down Expand Up @@ -103,6 +113,7 @@ xx_library(
srcs = ["jwks_resolver.cc"],
hdrs = ["jwks_resolver.h"],
deps = [
"//config/oidc:config_cc",
"//src/common/http",
"@boost//:all",
"@com_github_abseil-cpp//absl/synchronization:synchronization",
Expand All @@ -120,6 +131,7 @@ xx_library(
"//src/common/session:session_string_generator",
"//src/common/utilities:random",
"//src/filters:filter",
"//src/filters:filter_factory",
"//src/filters/oidc:in_memory_session_store",
"//src/common/utilities:time_service",
"//src/filters/oidc:token_response",
Expand Down
6 changes: 6 additions & 0 deletions src/filters/oidc/in_memory_session_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,12 @@ void InMemorySessionStore::Set(absl::string_view session_id,
}
}

SessionStorePtr InMemorySessionStoreFactory::create() {
return std::make_shared<oidc::InMemorySessionStore>(
std::make_shared<common::utilities::TimeService>(),
absolute_session_timeout_, idle_session_timeout_);
}

} // namespace oidc
} // namespace filters
} // namespace authservice
15 changes: 15 additions & 0 deletions src/filters/oidc/in_memory_session_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "src/common/utilities/synchronized.h"
#include "src/common/utilities/time_service.h"
#include "src/filters/oidc/session_store.h"
#include "src/filters/oidc/session_store_factory.h"

namespace authservice {
namespace filters {
Expand Down Expand Up @@ -52,6 +53,20 @@ class InMemorySessionStore : public SessionStore {
virtual void RemoveAllExpired() override;
};

class InMemorySessionStoreFactory : public SessionStoreFactory {
public:
InMemorySessionStoreFactory(uint32_t absolute_session_timeout,
uint32_t idle_session_timeout)
: absolute_session_timeout_(absolute_session_timeout),
idle_session_timeout_(idle_session_timeout) {}

SessionStorePtr create() override;

private:
const uint32_t absolute_session_timeout_ = 0;
const uint32_t idle_session_timeout_ = 0;
};

} // namespace oidc
} // namespace filters
} // namespace authservice
Expand Down
Loading

0 comments on commit e91ca69

Please sign in to comment.