Skip to content

Commit

Permalink
fix: check cache if we have seen ip for domain before
Browse files Browse the repository at this point in the history
  • Loading branch information
emhagman committed Jun 27, 2022
1 parent 07cb073 commit 9725da7
Showing 1 changed file with 65 additions and 44 deletions.
109 changes: 65 additions & 44 deletions pytest_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

_true_socket = socket.socket
_true_connect = socket.socket.connect
_cached_domain_lookups = {}


class SocketBlockedError(RuntimeError):
Expand All @@ -17,31 +18,33 @@ def __init__(self, *args, **kwargs):
class SocketConnectBlockedError(RuntimeError):
def __init__(self, allowed, host, *args, **kwargs):
if allowed:
allowed = ','.join(allowed)
allowed = ",".join(allowed)
super(SocketConnectBlockedError, self).__init__(
'A test tried to use socket.socket.connect() with host "{0}" (allowed: "{1}").'.format(host, allowed)
'A test tried to use socket.socket.connect() with host "{0}" (allowed: "{1}").'.format(
host, allowed
)
)


def pytest_addoption(parser):
group = parser.getgroup('socket')
group = parser.getgroup("socket")
group.addoption(
'--disable-socket',
action='store_true',
dest='disable_socket',
help='Disable socket.socket by default to block network calls.'
"--disable-socket",
action="store_true",
dest="disable_socket",
help="Disable socket.socket by default to block network calls.",
)
group.addoption(
'--allow-hosts',
dest='allow_hosts',
metavar='ALLOWED_HOSTS_CSV',
help='Only allow specified hosts through socket.socket.connect((host, port)).'
"--allow-hosts",
dest="allow_hosts",
metavar="ALLOWED_HOSTS_CSV",
help="Only allow specified hosts through socket.socket.connect((host, port)).",
)
group.addoption(
'--allow-unix-socket',
action='store_true',
dest='allow_unix_socket',
help='Allow calls if they are to Unix domain sockets'
"--allow-unix-socket",
action="store_true",
dest="allow_unix_socket",
help="Allow calls if they are to Unix domain sockets",
)


Expand All @@ -54,39 +57,39 @@ def _socket_marker(request):
The expected behavior is that higher granularity options should override
lower granularity options.
"""
if request.config.getoption('--disable-socket'):
request.getfixturevalue('socket_disabled')
if request.config.getoption("--disable-socket"):
request.getfixturevalue("socket_disabled")

if request.node.get_closest_marker('disable_socket'):
request.getfixturevalue('socket_disabled')
if request.node.get_closest_marker('enable_socket'):
request.getfixturevalue('socket_enabled')
if request.node.get_closest_marker("disable_socket"):
request.getfixturevalue("socket_disabled")
if request.node.get_closest_marker("enable_socket"):
request.getfixturevalue("socket_enabled")


@pytest.fixture
def socket_disabled(pytestconfig):
""" disable socket.socket for duration of this test function """
allow_unix_socket = pytestconfig.getoption('--allow-unix-socket')
"""disable socket.socket for duration of this test function"""
allow_unix_socket = pytestconfig.getoption("--allow-unix-socket")
disable_socket(allow_unix_socket)
yield
enable_socket()


@pytest.fixture
def socket_enabled(pytestconfig):
""" enable socket.socket for duration of this test function """
"""enable socket.socket for duration of this test function"""
enable_socket()
yield
allow_unix_socket = pytestconfig.getoption('--allow-unix-socket')
allow_unix_socket = pytestconfig.getoption("--allow-unix-socket")
disable_socket(allow_unix_socket)


def disable_socket(allow_unix_socket=False):
""" disable socket.socket to disable the Internet. useful in testing.
"""
"""disable socket.socket to disable the Internet. useful in testing."""

class GuardedSocket(socket.socket):
""" socket guard to disable socket creation (from pytest-socket) """
"""socket guard to disable socket creation (from pytest-socket)"""

def __new__(cls, *args, **kwargs):
try:
is_unix_socket = args[0] == socket.AF_UNIX
Expand All @@ -103,20 +106,26 @@ def __new__(cls, *args, **kwargs):


def enable_socket():
""" re-enable socket.socket to enable the Internet. useful in testing.
"""
"""re-enable socket.socket to enable the Internet. useful in testing."""
socket.socket = _true_socket


def pytest_configure(config):
config.addinivalue_line("markers", "disable_socket(): Disable socket connections for a specific test")
config.addinivalue_line("markers", "enable_socket(): Enable socket connections for a specific test")
config.addinivalue_line("markers", "allow_hosts([hosts]): Restrict socket connection to defined list of hosts")
config.addinivalue_line(
"markers", "disable_socket(): Disable socket connections for a specific test"
)
config.addinivalue_line(
"markers", "enable_socket(): Enable socket connections for a specific test"
)
config.addinivalue_line(
"markers",
"allow_hosts([hosts]): Restrict socket connection to defined list of hosts",
)


def pytest_runtest_setup(item):
mark_restrictions = item.get_closest_marker('allow_hosts')
cli_restrictions = item.config.getoption('--allow-hosts')
mark_restrictions = item.get_closest_marker("allow_hosts")
cli_restrictions = item.config.getoption("--allow-hosts")
hosts = None
if mark_restrictions:
hosts = mark_restrictions.args[0]
Expand All @@ -143,10 +152,9 @@ def host_from_connect_args(args):


def socket_allow_hosts(allowed=None):
""" disable socket.socket.connect() to disable the Internet. useful in testing.
"""
"""disable socket.socket.connect() to disable the Internet. useful in testing."""
if isinstance(allowed, str):
allowed = allowed.split(',')
allowed = allowed.split(",")
if not isinstance(allowed, list):
return

Expand Down Expand Up @@ -205,12 +213,12 @@ def host_is_domain(host, domains):


def is_valid_domain(dn):
if dn.endswith('.'):
if dn.endswith("."):
dn = dn[:-1]
if len(dn) < 1 or len(dn) > 253:
return False
ldh_re = re.compile('^[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?$', re.IGNORECASE)
return all(ldh_re.match(x) for x in dn.split('.'))
ldh_re = re.compile("^[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?$", re.IGNORECASE)
return all(ldh_re.match(x) for x in dn.split("."))


def parse_cidrs_from_allowed(allowed):
Expand All @@ -225,11 +233,24 @@ def address_in_network(ip, net):
return ipaddress.ip_address(ip) in ipaddress.ip_network(net)


def cache_ip_for_domain(ip, domain):
if domain not in _cached_domain_lookups:
_cached_domain_lookups[domain] = set()
_cached_domain_lookups[domain].add(ip)


def ip_is_cached_for_domain(ip, domain):
if domain in _cached_domain_lookups:
return ip in _cached_domain_lookups[domain]
return False


def address_is_domain(ip, domain):
return socket.gethostbyname(domain) == ip
ip_for_domain = socket.gethostbyname(domain)
cache_ip_for_domain(ip_for_domain, domain)
return ip_for_domain == ip or ip_is_cached_for_domain(ip, domain)


def remove_host_restrictions():
""" restore socket.socket.connect() to allow access to the Internet. useful in testing.
"""
"""restore socket.socket.connect() to allow access to the Internet. useful in testing."""
socket.socket.connect = _true_connect

0 comments on commit 9725da7

Please sign in to comment.