From 6689eaf94cbd1c3348dc2d5b11d042524996299c Mon Sep 17 00:00:00 2001 From: Anton Ivashkin Date: Mon, 25 Mar 2024 13:26:04 +0100 Subject: [PATCH] CHINA-219: Check certificate chain in TLS check --- ch_tools/monrun_checks/ch_tls.py | 57 +++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 4 deletions(-) diff --git a/ch_tools/monrun_checks/ch_tls.py b/ch_tools/monrun_checks/ch_tls.py index 7f56b82a..af8d1f92 100644 --- a/ch_tools/monrun_checks/ch_tls.py +++ b/ch_tools/monrun_checks/ch_tls.py @@ -3,8 +3,10 @@ import subprocess from datetime import datetime from typing import List, Optional, Tuple +from functools import lru_cache import click +from OpenSSL import SSL from OpenSSL.crypto import FILETYPE_PEM, dump_certificate, load_certificate from ch_tools.common.clickhouse.client.clickhouse_client import ( @@ -31,13 +33,17 @@ default=None, help="Comma separated list of ports. By default read from ClickHouse config", ) +@click.option( + "--chain", "chain", is_flag=True, help="Verify certificate chain." +) @click.pass_context def tls_command( - ctx: click.Context, crit: int, warn: int, ports: Optional[str] + ctx: click.Context, crit: int, warn: int, ports: Optional[str], chain: bool ) -> Result: """ Check TLS certificate for expiration and that actual cert from fs used. """ + file_chain = read_file_cert_chain() file_certificate, _ = read_cert_file() for port in get_ports(ctx, ports): @@ -50,8 +56,22 @@ def tls_command( if certificate != file_certificate: return Result( - 2, f"certificate on {port} and {CERTIFICATE_PATH} is different" + 2, f"certificates on {port} and {CERTIFICATE_PATH} are different" ) + if chain: + try: + chain = get_client_cert_chain(addr) + if len(chain) != len(file_chain): + return Result( + 2, f"certificates on {port} and {CERTIFICATE_PATH} have different chain lenght" + ) + for file_cert, socket_cert in zip(file_chain, chain): + if file_cert != socket_cert: + return Result( + 2, f"certificates on {port} and {CERTIFICATE_PATH} have different chains" + ) + except Exception as e: + return Result(1, f"Failed to get certificate chain: {repr(e)}") if days_to_expire < crit: return Result(2, f"certificate {port} expires in {days_to_expire} days") if days_to_expire < warn: @@ -70,12 +90,32 @@ def get_ports(ctx: click.Context, ports: Optional[str]) -> List[str]: ] -def read_cert_file() -> Tuple[str, int]: +@lru_cache(maxsize=None) +def read_cert_file_content() -> str: cmd = ["sudo", "/bin/cat", CERTIFICATE_PATH] - stdout = subprocess.check_output(cmd, shell=False) + return subprocess.check_output(cmd, shell=False) + + +def read_cert_file() -> Tuple[str, int]: + stdout = read_cert_file_content() return load_certificate_info(stdout) +def read_file_cert_chain(): + stdout = read_cert_file_content() + return read_all_certs(stdout) + + +def read_all_certs(blob): + start_line = b'-----BEGIN CERTIFICATE-----' + result = [] + cert_slots = blob.split(start_line) + for single_pem_cert in cert_slots[1:]: + cert = load_certificate(FILETYPE_PEM, start_line+single_pem_cert) + result.append(dump_certificate(FILETYPE_PEM, cert).decode()) + return result + + def load_certificate_info(certificate: bytes) -> Tuple[str, int]: x509 = load_certificate(FILETYPE_PEM, certificate) x509_not_after: Optional[bytes] = x509.get_notAfter() @@ -85,3 +125,12 @@ def load_certificate_info(certificate: bytes) -> Tuple[str, int]: dump_certificate(FILETYPE_PEM, x509).decode(), (expire_date - datetime.now()).days, ) + + +def get_client_cert_chain(addr: Tuple[str, int]): + cmd = ["openssl", "s_client", "-showcerts", "-connect", f"{addr[0]}:{addr[1]}"] + proc = subprocess.Popen( + cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE # type: ignore[arg-type] + ) + stdout, _ = proc.communicate(input="".encode()) + return read_all_certs(stdout)