Skip to content

Commit

Permalink
Merge pull request #100 from kuefmz:fix_config
Browse files Browse the repository at this point in the history
Use Enum for all config parameters
  • Loading branch information
JJ-Author authored Oct 20, 2024
2 parents 706f826 + 5da7694 commit c49fedc
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 88 deletions.
30 changes: 23 additions & 7 deletions ontologytimemachine/custom_proxy.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
from proxy.http.proxy import HttpProxyBasePlugin
from proxy.http import httpHeaders
from proxy.http.parser import HttpParser
from proxy.common.utils import build_http_response
from ontologytimemachine.utils.mock_responses import mock_response_403
from ontologytimemachine.utils.mock_responses import (
mock_response_403,
mock_response_500,
)
from ontologytimemachine.proxy_wrapper import HttpRequestWrapper
from ontologytimemachine.utils.proxy_logic import (
get_response_from_request,
do_block_CONNECT_request,
is_archivo_ontology_request,
evaluate_configuration,
)
from ontologytimemachine.utils.config import Config, HttpsInterception, parse_arguments
from http.client import responses
import proxy
import sys
import logging
from ontologytimemachine.utils.config import HttpsInterception, ClientConfigViaProxyAuth


IP = "0.0.0.0"
Expand All @@ -31,17 +37,23 @@ def __init__(self, *args, **kwargs):
logger.info("Init")
super().__init__(*args, **kwargs)
self.config = config
self.current_config = None

def before_upstream_connection(self, request: HttpParser) -> HttpParser | None:
print(config)
# self.client.config = None
logger.info("Before upstream connection hook")
logger.info(
f"Request method: {request.method} - Request host: {request.host} - Request path: {request.path} - Request headers: {request.headers}"
)
wrapped_request = HttpRequestWrapper(request)

# if self.config.clientConfigViaProxyAuth == ClientConfigViaProxyAuth.REQUIRED:
# self.client.config = evaluate_configuration(wrapped_request, self.config)

if wrapped_request.is_connect_request():
logger.info(f"Handling CONNECT request: configured HTTPS interception mode: {self.config.httpsInterception}")
logger.info(
f"Handling CONNECT request: configured HTTPS interception mode: {self.config.httpsInterception}"
)

# Check whether to allow CONNECT requests since they can impose a security risk
if not do_block_CONNECT_request(self.config):
Expand All @@ -56,24 +68,28 @@ def before_upstream_connection(self, request: HttpParser) -> HttpParser | None:
response = get_response_from_request(wrapped_request, self.config)
if response:
self.queue_response(response)
self.current_config = None
return None

return request

def do_intercept(self, _request: HttpParser) -> bool:
wrapped_request = HttpRequestWrapper(_request)
if self.config.httpsInterception in ["all"]:
if self.config.httpsInterception in HttpsInterception.ALL:
return True
elif self.config.httpsInterception in ["none"]:
elif self.config.httpsInterception in HttpsInterception.NONE:
return False
# elif self.config.httpsInterception == HttpsInterception.BLOCK: #this should actually be not triggered
# return False
elif self.config.httpsInterception in ["archivo"]:
elif self.config.httpsInterception in HttpsInterception.ARCHIVO:
if is_archivo_ontology_request(wrapped_request):
return True
return False
else:
logger.info("Unknown Option for httpsInterception: %s -> fallback to no interception", self.config.httpsInterception)
logger.info(
"Unknown Option for httpsInterception: %s -> fallback to no interception",
self.config.httpsInterception,
)
return False

def handle_client_request(self, request: HttpParser) -> HttpParser:
Expand Down
25 changes: 24 additions & 1 deletion ontologytimemachine/proxy_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from proxy.http.parser import HttpParser
import logging
from typing import Tuple, Dict, Any
import base64

# Configure logger
logging.basicConfig(
Expand Down Expand Up @@ -54,6 +55,10 @@ def set_request_accept_header(self, mime_type: str) -> None:
def get_request_url_host_path(self) -> Tuple[str, str, str]:
pass

@abstractmethod
def get_authentication_from_request(self) -> str:
pass


class HttpRequestWrapper(AbstractRequestWrapper):
def __init__(self, request: HttpParser) -> None:
Expand Down Expand Up @@ -95,7 +100,7 @@ def set_request_accept_header(self, mime_type: str) -> None:

def get_request_url_host_path(self) -> Tuple[str, str, str]:
logger.info("Get ontology from request")
if (self.request.method in {b"GET", b"HEAD"}) and not self.request.host:
if (self.is_get_request or self.is_head_request) and not self.request.host:
for k, v in self.request.headers.items():
if v[0].decode("utf-8") == "Host":
host = v[1].decode("utf-8")
Expand All @@ -108,3 +113,21 @@ def get_request_url_host_path(self) -> Tuple[str, str, str]:

logger.info(f"Ontology: {url}")
return url, host, path

def get_authentication_from_request(self) -> str:
is_auth = False
# if b"authorization" in self.request.headers.keys():
# auth_header = self.request.headers[b"authorization"]
# is_auth = True
if b"proxy-authorization" in self.request.headers.keys():
auth_header = self.request.headers[b"proxy-authorization"]
is_auth = True
if is_auth:
auth_header = auth_header[1]
auth_type, encoded_credentials = auth_header.split()
auth_type = auth_type.decode("utf-8")
if auth_type.lower() != "basic":
return None
decoded_credentials = base64.b64decode(encoded_credentials).decode()
return decoded_credentials
return None
104 changes: 72 additions & 32 deletions ontologytimemachine/utils/config.py
Original file line number Diff line number Diff line change
@@ -1,133 +1,174 @@
import argparse
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Any
from typing import Dict, Any, Type, TypeVar


class LogLevel(Enum):
class EnumValuePrint(
Enum
): # redefine how the enum is printed such that it will show up properly the cmd help message (choices)
def __str__(self):
return self.value


class LogLevel(EnumValuePrint):
DEBUG = "debug"
INFO = "info"
WARNING = "warning"
ERROR = "error"


class OntoFormat(Enum):
class OntoFormat(EnumValuePrint):
TURTLE = "turtle"
NTRIPLES = "ntriples"
RDFXML = "rdfxml"
HTMLDOCU = "htmldocu"


class OntoPrecedence(Enum):
class OntoPrecedence(EnumValuePrint):
DEFAULT = "default"
ENFORCED_PRIORITY = "enforcedPriority"
ALWAYS = "always"


class OntoVersion(Enum):
class OntoVersion(EnumValuePrint):
ORIGINAL = "original"
ORIGINAL_FAILOVER_LIVE_LATEST = "originalFailoverLiveLatest"
LATEST_ARCHIVED = "latestArchived"
TIMESTAMP_ARCHIVED = "timestampArchived"
DEPENDENCY_MANIFEST = "dependencyManifest"


class HttpsInterception(Enum):
class HttpsInterception(EnumValuePrint):
NONE = "none"
ALL = "all"
BLOCK = "block"
ARCHIVO = "archivo"


class ClientConfigViaProxyAuth(EnumValuePrint):
IGNORE = "ignore"
REQUIRED = "required"
OPTIONAL = "optional"


@dataclass
class OntoFormatConfig:
format: OntoFormat = OntoFormat.NTRIPLES
precedence: OntoPrecedence = OntoPrecedence.ENFORCED_PRIORITY
patchAcceptUpstream: bool = False


@dataclass
class Config:
logLevel: LogLevel = LogLevel.INFO
ontoFormat: Dict[str, Any] = None
ontoVersion: OntoVersion = (OntoVersion.ORIGINAL_FAILOVER_LIVE_LATEST,)
ontoFormatConf: OntoFormatConfig = field(default_factory=OntoFormatConfig)
ontoVersion: OntoVersion = OntoVersion.ORIGINAL_FAILOVER_LIVE_LATEST
restrictedAccess: bool = False
httpsInterception: HttpsInterception = (HttpsInterception.ALL,)
clientConfigViaProxyAuth: ClientConfigViaProxyAuth = (
ClientConfigViaProxyAuth.REQUIRED
)
httpsInterception: HttpsInterception = HttpsInterception.ALL
disableRemovingRedirects: bool = False
timestamp: str = ""
# manifest: Dict[str, Any] = None


def enum_parser(enum_class, value):
# Define a TypeVar for the enum class
E = TypeVar("E", bound=Enum)


def enum_parser(enum_class: Type[E], value: str) -> E:
value_lower = value.lower()
try:
return next(e.value for e in enum_class if e.value.lower() == value_lower)
except StopIteration:
return next(e for e in enum_class if e.value.lower() == value_lower)
except StopIteration as exc:
valid_options = ", ".join([e.value for e in enum_class])
raise ValueError(
raise argparse.ArgumentTypeError(
f"Invalid value '{value}'. Available options are: {valid_options}"
)
) from exc


def parse_arguments() -> Config:
def parse_arguments(config_str: str = "") -> Config:
default_cfg: Config = Config()
parser = argparse.ArgumentParser(description="Process ontology format and version.")

# Defining ontoFormat argument with nested options
parser.add_argument(
"--ontoFormat",
type=lambda s: enum_parser(OntoFormat, s),
default=OntoFormat.TURTLE.value,
default=default_cfg.ontoFormatConf.format,
choices=list(OntoFormat),
help="Format of the ontology: turtle, ntriples, rdfxml, htmldocu",
)

parser.add_argument(
"--ontoPrecedence",
type=lambda s: enum_parser(OntoPrecedence, s),
default=OntoPrecedence.ENFORCED_PRIORITY.value,
default=default_cfg.ontoFormatConf.precedence,
choices=list(OntoPrecedence),
help="Precedence of the ontology: default, enforcedPriority, always",
)

parser.add_argument(
"--patchAcceptUpstream",
type=bool,
default=False,
default=default_cfg.ontoFormatConf.patchAcceptUpstream,
help="Defines if the Accept Header is patched upstream in original mode.",
)

# Defining ontoVersion argument
parser.add_argument(
"--ontoVersion",
type=lambda s: enum_parser(OntoVersion, s),
default=OntoVersion.ORIGINAL_FAILOVER_LIVE_LATEST.value,
default=default_cfg.ontoVersion,
choices=list(OntoVersion),
help="Version of the ontology: original, originalFailoverLive, originalFailoverArchivoMonitor, latestArchive, timestampArchive, dependencyManifest",
)

# Enable/disable mode to only proxy requests to ontologies
parser.add_argument(
"--restrictedAccess",
type=bool,
default=False,
default=default_cfg.restrictedAccess,
help="Enable/disable mode to only proxy requests to ontologies stored in Archivo.",
)

# Enable HTTPS interception for specific domains
parser.add_argument(
"--httpsInterception",
type=lambda s: enum_parser(HttpsInterception, s),
default=HttpsInterception.ALL.value,
default=default_cfg.httpsInterception,
choices=list(HttpsInterception),
help="Enable HTTPS interception for specific domains: none, archivo, all, listfilename.",
)

# Enable/disable inspecting or removing redirects
parser.add_argument(
"--disableRemovingRedirects",
type=bool,
default=False,
default=default_cfg.disableRemovingRedirects,
help="Enable/disable inspecting or removing redirects.",
)

parser.add_argument(
"--clientConfigViaProxyAuth",
type=lambda s: enum_parser(ClientConfigViaProxyAuth, s),
default=default_cfg.clientConfigViaProxyAuth,
choices=list(ClientConfigViaProxyAuth),
help="Define the config.",
)

# Log level
parser.add_argument(
"--logLevel",
type=lambda s: enum_parser(LogLevel, s),
default=LogLevel.INFO.value,
default=default_cfg.logLevel,
choices=list(LogLevel),
help="Level of the logging: debug, info, warning, error.",
)

args = parser.parse_args()
args = parser.parse_args(config_str)

# Check the value of --ontoVersion and prompt for additional arguments if needed
if args.ontoVersion == "timestampArchived":
Expand All @@ -148,20 +189,19 @@ def parse_arguments() -> Config:
# else:
# manifest = None

# Create ontoFormat dictionary
ontoFormat = {
"format": args.ontoFormat,
"precedence": args.ontoPrecedence,
"patchAcceptUpstream": args.patchAcceptUpstream,
}
# print the default configuration with all nested members
print(default_cfg) # TODO remove

# Initialize the Config class with parsed arguments
config = Config(
logLevel=args.logLevel,
ontoFormat=ontoFormat,
ontoFormatConf=OntoFormatConfig(
args.ontoFormat, args.ontoPrecedence, args.patchAcceptUpstream
),
ontoVersion=args.ontoVersion,
restrictedAccess=args.restrictedAccess,
httpsInterception=args.httpsInterception,
clientConfigViaProxyAuth=args.clientConfigViaProxyAuth,
disableRemovingRedirects=args.disableRemovingRedirects,
timestamp=args.timestamp if hasattr(args, "timestamp") else "",
)
Expand Down
Loading

0 comments on commit c49fedc

Please sign in to comment.