diff --git a/gunicorn/arbiter.py b/gunicorn/arbiter.py index 008a54efe..cee55bd1d 100644 --- a/gunicorn/arbiter.py +++ b/gunicorn/arbiter.py @@ -154,7 +154,7 @@ def start(self): self.LISTENERS = sock.create_sockets(self.cfg, self.log, fds) - listeners_str = ",".join([str(lnr) for lnr in self.LISTENERS]) + listeners_str = ",".join([sock.get_uri(lnr) for lnr in self.LISTENERS]) self.log.debug("Arbiter booted") self.log.info("Listening at: %s (%s)", listeners_str, self.pid) self.log.info("Using worker: %s", self.cfg.worker_class_str) @@ -461,7 +461,7 @@ def reload(self): lnr.close() # init new listeners self.LISTENERS = sock.create_sockets(self.cfg, self.log) - listeners_str = ",".join([str(lnr) for lnr in self.LISTENERS]) + listeners_str = ",".join([sock.get_uri(lnr) for lnr in self.LISTENERS]) self.log.info("Listening at: %s", listeners_str) # do some actions on reload diff --git a/gunicorn/config.py b/gunicorn/config.py index e7e4fac54..cfc2d83f8 100644 --- a/gunicorn/config.py +++ b/gunicorn/config.py @@ -2076,7 +2076,7 @@ class KeyFile(Setting): section = "SSL" cli = ["--keyfile"] meta = "FILE" - validator = validate_string + validator = validate_file_exists default = None desc = """\ SSL key file @@ -2088,7 +2088,7 @@ class CertFile(Setting): section = "SSL" cli = ["--certfile"] meta = "FILE" - validator = validate_string + validator = validate_file_exists default = None desc = """\ SSL certificate file diff --git a/gunicorn/sock.py b/gunicorn/sock.py index 7700146a8..8d5e60850 100644 --- a/gunicorn/sock.py +++ b/gunicorn/sock.py @@ -14,130 +14,56 @@ from gunicorn import util -class BaseSocket(object): - - def __init__(self, address, conf, log, fd=None): - self.log = log - self.conf = conf - - self.cfg_addr = address - if fd is None: - sock = socket.socket(self.FAMILY, socket.SOCK_STREAM) - bound = False +def _get_socket_family(addr): + if isinstance(addr, tuple): + if util.is_ipv6(addr[0]): + return socket.AF_INET6 else: - sock = socket.fromfd(fd, self.FAMILY, socket.SOCK_STREAM) - os.close(fd) - bound = True - - self.sock = self.set_options(sock, bound=bound) - - def __str__(self): - return "" % self.sock.fileno() - - def __getattr__(self, name): - return getattr(self.sock, name) - - def set_options(self, sock, bound=False): - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - if (self.conf.reuse_port - and hasattr(socket, 'SO_REUSEPORT')): # pragma: no cover - try: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) - except socket.error as err: - if err.errno not in (errno.ENOPROTOOPT, errno.EINVAL): - raise - if not bound: - self.bind(sock) - sock.setblocking(0) + return socket.AF_INET - # make sure that the socket can be inherited - if hasattr(sock, "set_inheritable"): - sock.set_inheritable(True) + if isinstance(addr, (str, bytes)): + return socket.AF_UNIX - sock.listen(self.conf.backlog) - return sock + raise TypeError("Unable to determine socket family for: %r" % addr) - def bind(self, sock): - sock.bind(self.cfg_addr) - def close(self): - if self.sock is None: - return +def create_socket(conf, log, addr): + family = _get_socket_family(addr) + if family is socket.AF_UNIX: + # remove any existing socket at the given path try: - self.sock.close() - except socket.error as e: - self.log.info("Error while closing socket %s", str(e)) - - self.sock = None - - -class TCPSocket(BaseSocket): - - FAMILY = socket.AF_INET - - def __str__(self): - if self.conf.is_ssl: - scheme = "https" + st = os.stat(addr) + except OSError as e: + if e.args[0] != errno.ENOENT: + raise else: - scheme = "http" - - addr = self.sock.getsockname() - return "%s://%s:%d" % (scheme, addr[0], addr[1]) - - def set_options(self, sock, bound=False): - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - return super().set_options(sock, bound=bound) - - -class TCP6Socket(TCPSocket): - - FAMILY = socket.AF_INET6 - - def __str__(self): - (host, port, _, _) = self.sock.getsockname() - return "http://[%s]:%d" % (host, port) - - -class UnixSocket(BaseSocket): - - FAMILY = socket.AF_UNIX - - def __init__(self, addr, conf, log, fd=None): - if fd is None: - try: - st = os.stat(addr) - except OSError as e: - if e.args[0] != errno.ENOENT: - raise + if stat.S_ISSOCK(st.st_mode): + os.remove(addr) else: - if stat.S_ISSOCK(st.st_mode): - os.remove(addr) - else: - raise ValueError("%r is not a socket" % addr) - super().__init__(addr, conf, log, fd=fd) - - def __str__(self): - return "unix:%s" % self.cfg_addr - - def bind(self, sock): - old_umask = os.umask(self.conf.umask) - sock.bind(self.cfg_addr) - util.chown(self.cfg_addr, self.conf.uid, self.conf.gid) - os.umask(old_umask) + raise ValueError("%r is not a socket" % addr) + for i in range(5): + try: + sock = socket.socket(family) + sock.bind(addr) + sock.listen(conf.backlog) + if family is socket.AF_UNIX: + util.chown(addr, conf.uid, conf.gid) + return sock + except socket.error as e: + if e.args[0] == errno.EADDRINUSE: + log.error("Connection in use: %s", str(addr)) + if e.args[0] == errno.EADDRNOTAVAIL: + log.error("Invalid address: %s", str(addr)) + if i < 5: + msg = "connection to {addr} failed: {error}" + log.debug(msg.format(addr=str(addr), error=str(e))) + log.error("Retrying in 1 second.") + time.sleep(1) -def _sock_type(addr): - if isinstance(addr, tuple): - if util.is_ipv6(addr[0]): - sock_type = TCP6Socket - else: - sock_type = TCPSocket - elif isinstance(addr, (str, bytes)): - sock_type = UnixSocket - else: - raise TypeError("Unable to create socket from: %r" % addr) - return sock_type + log.error("Can't connect to %s", str(addr)) + sys.exit(1) def create_sockets(conf, log, fds=None): @@ -150,67 +76,71 @@ def create_sockets(conf, log, fds=None): """ listeners = [] - # get it only once - addr = conf.address - fdaddr = [bind for bind in addr if isinstance(bind, int)] if fds: - fdaddr += list(fds) - laddr = [bind for bind in addr if not isinstance(bind, int)] - - # check ssl config early to raise the error on startup - # only the certfile is needed since it can contains the keyfile - if conf.certfile and not os.path.exists(conf.certfile): - raise ValueError('certfile "%s" does not exist' % conf.certfile) - - if conf.keyfile and not os.path.exists(conf.keyfile): - raise ValueError('keyfile "%s" does not exist' % conf.keyfile) - - # sockets are already bound - if fdaddr: - for fd in fdaddr: - sock = socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_STREAM) - sock_name = sock.getsockname() - sock_type = _sock_type(sock_name) - listener = sock_type(sock_name, conf, log, fd=fd) - listeners.append(listener) - - return listeners - - # no sockets is bound, first initialization of gunicorn in this env. - for addr in laddr: - sock_type = _sock_type(addr) - sock = None - for i in range(5): - try: - sock = sock_type(addr, conf, log) - except socket.error as e: - if e.args[0] == errno.EADDRINUSE: - log.error("Connection in use: %s", str(addr)) - if e.args[0] == errno.EADDRNOTAVAIL: - log.error("Invalid address: %s", str(addr)) - if i < 5: - msg = "connection to {addr} failed: {error}" - log.debug(msg.format(addr=str(addr), error=str(e))) - log.error("Retrying in 1 second.") - time.sleep(1) - else: - break - - if sock is None: - log.error("Can't connect to %s", str(addr)) - sys.exit(1) - - listeners.append(sock) + # sockets are already bound + listeners = [] + for fd in list(fds) + [a for a in conf.address if isinstance(a, int)]: + sock = socket.socket(fileno=fd) + set_socket_options(conf, sock) + listeners.append(sock) + else: + # first initialization of gunicorn + old_umask = os.umask(conf.umask) + try: + for addr in [bind for bind in conf.address if not isinstance(bind, int)]: + sock = create_socket(conf, log, addr) + set_socket_options(conf, sock) + listeners.append(sock) + finally: + os.umask(old_umask) return listeners def close_sockets(listeners, unlink=True): for sock in listeners: - sock_name = sock.getsockname() - sock.close() - if unlink and _sock_type(sock_name) is UnixSocket: - os.unlink(sock_name) + try: + if unlink and sock.family is socket.AF_UNIX: + sock_name = sock.getsockname() + os.unlink(sock_name) + finally: + sock.close() + + +def get_uri(listener, is_ssl=False): + addr = listener.getsockname() + family = _get_socket_family(addr) + scheme = "https" if is_ssl else "http" + + if family is socket.AF_INET: + (host, port) = listener.getsockname() + return f"{scheme}://{host}:{port}" + + if family is socket.AF_INET6: + (host, port, _, _) = listener.getsockname() + return f"{scheme}://[{host}]:{port}" + + if family is socket.AF_UNIX: + path = listener.getsockname() + return f"unix://{path}" + + +def set_socket_options(conf, sock): + sock.setblocking(False) + + # make sure that the socket can be inherited + if hasattr(sock, "set_inheritable"): + sock.set_inheritable(True) + + if sock.family in (socket.AF_INET, socket.AF_INET6): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if (conf.reuse_port and hasattr(socket, 'SO_REUSEPORT')): # pragma: no cover + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + except socket.error as err: + if err.errno not in (errno.ENOPROTOOPT, errno.EINVAL): + raise def ssl_context(conf): diff --git a/tests/test_sock.py b/tests/test_sock.py index adc348c6f..adfac1013 100644 --- a/tests/test_sock.py +++ b/tests/test_sock.py @@ -3,30 +3,41 @@ # This file is part of gunicorn released under the MIT license. # See the NOTICE for more information. +import socket from unittest import mock -from gunicorn import sock +import pytest +from gunicorn import sock -@mock.patch('os.stat') -def test_create_sockets_unix_bytes(stat): - conf = mock.Mock(address=[b'127.0.0.1:8000']) - log = mock.Mock() - with mock.patch.object(sock.UnixSocket, '__init__', lambda *args: None): - listeners = sock.create_sockets(conf, log) - assert len(listeners) == 1 - print(type(listeners[0])) - assert isinstance(listeners[0], sock.UnixSocket) +@pytest.fixture(scope='function') +def addr(request, tmp_path): + if isinstance(request.param, str): + return str(tmp_path / request.param) + return request.param -@mock.patch('os.stat') -def test_create_sockets_unix_strings(stat): - conf = mock.Mock(address=['127.0.0.1:8000']) +@pytest.mark.parametrize( + 'addr, family', + [ + ('gunicorn.sock', socket.AF_UNIX), + (('0.0.0.0', 0), socket.AF_INET), + (('::', 0), socket.AF_INET6), + ], + indirect=['addr'], +) +@mock.patch('socket.socket') +@mock.patch('gunicorn.util.chown') +def test_create_socket(chown, socket, addr, family): + conf = mock.Mock(address=[addr], umask=0o22) log = mock.Mock() - with mock.patch.object(sock.UnixSocket, '__init__', lambda *args: None): - listeners = sock.create_sockets(conf, log) - assert len(listeners) == 1 - assert isinstance(listeners[0], sock.UnixSocket) + listener = sock.create_socket(conf, log, addr) + assert listener == socket.return_value + socket.assert_called_with(family) + listener.bind.assert_called_with(addr) + listener.listen.assert_called_with(conf.backlog) + if family is socket.AF_UNIX: + chown.assert_called_with(addr, conf.uid, conf.gid) def test_socket_close(): @@ -41,7 +52,7 @@ def test_socket_close(): @mock.patch('os.unlink') def test_unix_socket_close_unlink(unlink): - listener = mock.Mock() + listener = mock.Mock(family=socket.AF_UNIX) listener.getsockname.return_value = '/var/run/test.sock' sock.close_sockets([listener]) listener.close.assert_called_with() @@ -50,7 +61,7 @@ def test_unix_socket_close_unlink(unlink): @mock.patch('os.unlink') def test_unix_socket_close_without_unlink(unlink): - listener = mock.Mock() + listener = mock.Mock(family=socket.AF_UNIX) listener.getsockname.return_value = '/var/run/test.sock' sock.close_sockets([listener], False) listener.close.assert_called_with()