Skip to content

Commit

Permalink
CHINA-219: Check certificate chain in TLS check
Browse files Browse the repository at this point in the history
  • Loading branch information
ianton-ru committed Mar 25, 2024
1 parent 60a87b8 commit 6689eaf
Showing 1 changed file with 53 additions and 4 deletions.
57 changes: 53 additions & 4 deletions ch_tools/monrun_checks/ch_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import subprocess
from datetime import datetime
from typing import List, Optional, Tuple
from functools import lru_cache

import click
from OpenSSL import SSL
from OpenSSL.crypto import FILETYPE_PEM, dump_certificate, load_certificate

from ch_tools.common.clickhouse.client.clickhouse_client import (
Expand All @@ -31,13 +33,17 @@
default=None,
help="Comma separated list of ports. By default read from ClickHouse config",
)
@click.option(
"--chain", "chain", is_flag=True, help="Verify certificate chain."
)
@click.pass_context
def tls_command(
ctx: click.Context, crit: int, warn: int, ports: Optional[str]
ctx: click.Context, crit: int, warn: int, ports: Optional[str], chain: bool
) -> Result:
"""
Check TLS certificate for expiration and that actual cert from fs used.
"""
file_chain = read_file_cert_chain()
file_certificate, _ = read_cert_file()

for port in get_ports(ctx, ports):
Expand All @@ -50,8 +56,22 @@ def tls_command(

if certificate != file_certificate:
return Result(
2, f"certificate on {port} and {CERTIFICATE_PATH} is different"
2, f"certificates on {port} and {CERTIFICATE_PATH} are different"
)
if chain:
try:
chain = get_client_cert_chain(addr)
if len(chain) != len(file_chain):
return Result(
2, f"certificates on {port} and {CERTIFICATE_PATH} have different chain lenght"
)
for file_cert, socket_cert in zip(file_chain, chain):
if file_cert != socket_cert:
return Result(
2, f"certificates on {port} and {CERTIFICATE_PATH} have different chains"
)
except Exception as e:
return Result(1, f"Failed to get certificate chain: {repr(e)}")
if days_to_expire < crit:
return Result(2, f"certificate {port} expires in {days_to_expire} days")
if days_to_expire < warn:
Expand All @@ -70,12 +90,32 @@ def get_ports(ctx: click.Context, ports: Optional[str]) -> List[str]:
]


def read_cert_file() -> Tuple[str, int]:
@lru_cache(maxsize=None)
def read_cert_file_content() -> str:
cmd = ["sudo", "/bin/cat", CERTIFICATE_PATH]
stdout = subprocess.check_output(cmd, shell=False)
return subprocess.check_output(cmd, shell=False)


def read_cert_file() -> Tuple[str, int]:
stdout = read_cert_file_content()
return load_certificate_info(stdout)


def read_file_cert_chain():
stdout = read_cert_file_content()
return read_all_certs(stdout)


def read_all_certs(blob):
start_line = b'-----BEGIN CERTIFICATE-----'
result = []
cert_slots = blob.split(start_line)
for single_pem_cert in cert_slots[1:]:
cert = load_certificate(FILETYPE_PEM, start_line+single_pem_cert)
result.append(dump_certificate(FILETYPE_PEM, cert).decode())
return result


def load_certificate_info(certificate: bytes) -> Tuple[str, int]:
x509 = load_certificate(FILETYPE_PEM, certificate)
x509_not_after: Optional[bytes] = x509.get_notAfter()
Expand All @@ -85,3 +125,12 @@ def load_certificate_info(certificate: bytes) -> Tuple[str, int]:
dump_certificate(FILETYPE_PEM, x509).decode(),
(expire_date - datetime.now()).days,
)


def get_client_cert_chain(addr: Tuple[str, int]):
cmd = ["openssl", "s_client", "-showcerts", "-connect", f"{addr[0]}:{addr[1]}"]
proc = subprocess.Popen(
cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE # type: ignore[arg-type]
)
stdout, _ = proc.communicate(input="".encode())
return read_all_certs(stdout)

0 comments on commit 6689eaf

Please sign in to comment.