From 7047b9ae3f2eb2fe54ff590e85798c6e5cc12ac4 Mon Sep 17 00:00:00 2001 From: Ruoyu Ying Date: Thu, 29 Feb 2024 11:03:07 +0800 Subject: [PATCH] vmsdk: update vsock implementation for tdx quote fetching * add more error handling * generalize tdx configuration loading Signed-off-by: Ruoyu Ying --- src/python/cctrusted_vm/cvm.py | 142 +++++++++++++++++++++------------ 1 file changed, 89 insertions(+), 53 deletions(-) diff --git a/src/python/cctrusted_vm/cvm.py b/src/python/cctrusted_vm/cvm.py index e225f86..120b3ed 100644 --- a/src/python/cctrusted_vm/cvm.py +++ b/src/python/cctrusted_vm/cvm.py @@ -20,7 +20,7 @@ from cctrusted_base.tcg import TcgAlgorithmRegistry from cctrusted_base.tdx.common import TDX_VERSION_1_0, TDX_VERSION_1_5 from cctrusted_base.tdx.rtmr import TdxRTMR -from cctrusted_base.tdx.quote import TdxQuoteReq10, TdxQuoteReq15, TdxQuote +from cctrusted_base.tdx.quote import TdxQuoteReq10, TdxQuoteReq15, TdxQuote, TdxQuoteReq from cctrusted_base.tdx.report import TdxReportReq10, TdxReportReq15 LOG = logging.getLogger(__name__) @@ -223,6 +223,7 @@ def __init__(self): ConfidentialVM.__init__(self, CCTrustedApi.TYPE_CC_TDX) self._version:str = None self._tdreport = None + self._config:dict = self._load_config() @property def version(self): @@ -241,6 +242,34 @@ def tdreport(self): """TDREPORT structure""" return self._tdreport + def _load_config(self): + """Process TDX attest config file and fetch params within the config.""" + tdx_config_dict = {} + if os.path.exists(TdxVM.CFG_FILE_PATH): + LOG.debug("Found TDX Config file at %s", TdxVM.CFG_FILE_PATH) + try: + with open(TdxVM.CFG_FILE_PATH, 'rb') as cfg_file: + cfg_info = [line.rstrip() for line in cfg_file] + for line in cfg_info: + # remove spaces in each line + # save all configs into tdx_config_dict + line = line.decode("utf-8").replace(" ", "") + param = line.partition("=") + tdx_config_dict[param[0]] = param[2] + except(PermissionError, OSError): + LOG.error("Need root permission to open file %s for params.", + TdxVM.CFG_FILE_PATH) + return None + + # convert port param into integer and check its validity + if "port" in tdx_config_dict: + tdx_config_dict["port"] = int(tdx_config_dict["port"]) + if tdx_config_dict["port"] < 0 or tdx_config_dict["port"] > 65535: + LOG.debug("Invalid vsock port specified in the config.") + del tdx_config_dict["port"] + + return tdx_config_dict + def process_cc_report(self, report_data=None) -> bool: """Process the confidential computing REPORT.""" dev_path = self.DEVICE_NODE_PATH[self.version] @@ -398,61 +427,20 @@ def get_cc_report(self, nonce: bytearray, data: bytearray, extraArgs) -> CcRepor elif self.version is TDX_VERSION_1_5: quote_req = TdxQuoteReq15() - # Use tdvmcall to get TD Quote by default - tdvmcall_flag = True - - # Check if vsock port specified in TDX attest config - # If specified, use vsock to get quote - if os.path.exists(TdxVM.CFG_FILE_PATH): - LOG.info("Found TDX Config file at %s", TdxVM.CFG_FILE_PATH) - try: - with open(TdxVM.CFG_FILE_PATH, 'rb') as cfg_file: - cfg_info = [line.rstrip() for line in cfg_file] - for line in cfg_info: - line = line.decode("utf-8").replace(" ", "") - if "port=" in line: - LOG.info("Vsock port number specified. Use vsock for quote fetching.") - tdvmcall_flag = False - port = int(line.partition("port=")[2]) - if port <= 0 or port > 65535: - LOG.error( - "Invalid vsock port number specified. Fallback to tdvmcall.") - tdvmcall_flag = True - break - except(PermissionError, OSError): - LOG.error("Need root permission to open file %s", TdxVM.CFG_FILE_PATH) - - if not tdvmcall_flag: - # Setup socket to connect qgs socket on host - with socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM, 0) as sock: - sock.settimeout(30) - sock.connect((socket.VMADDR_CID_HOST, port)) - - header_size = 4 - # Generate p_blob_payload buffer - qgs_msg = quote_req.qgs_msg_quote_req(report_bytes) - msg_size = len(qgs_msg) - - p_blob_payload = bytearray(msg_size.to_bytes(header_size, "big")) - p_blob_payload[header_size:] = qgs_msg[:msg_size] - - # Send quote request - nsent = sock.send(p_blob_payload) - LOG.debug("Sent %d bytes for Quote request.", nsent) - - # Receive quote - header = sock.recv(header_size) - in_msg_size = 0 - for i in range(header_size): - in_msg_size = (in_msg_size << 8) + (header[i] & 0xFF) - qgs_resp = sock.recv(in_msg_size) - LOG.debug("Received %d bytes as Quote response", in_msg_size) + # Check if appropriate qgs vsock port specified in TDX attest config + # If specified, use vsock to get quote and return TdxQuote object + if self._config and "port" in self._config: + LOG.info("Use vsock for TDX quote fetching.") + td_report = self._invoke_quote_fetching_on_vsock( + report_bytes, quote_req, self._config["port"]) - sock.close() - tdquote = quote_req.qgs_msg_quote_resp(qgs_resp) - return TdxQuote(tdquote) + # Check if quote fetching by vsock has been done successfully + # If yes, return result and skip following steps + if td_report: + return td_report # Fetch quote through tdvmcall + LOG.info("Use tdvmcall for TDX quote fetching.") # pylint: disable=E1111 req_buf = quote_req.prepare_reqbuf(report_bytes) @@ -477,3 +465,51 @@ def get_cc_report(self, nonce: bytearray, data: bytearray, extraArgs) -> CcRepor # Get TD Quote from ioctl command output return quote_req.process_output(req_buf) + + def _invoke_quote_fetching_on_vsock( + self, + report_bytes:bytes, + quote_req:TdxQuoteReq, + port:int=None + ) -> TdxQuote: + """Invoke TDX quote fetching through vsock. + + Args: + report_bytes(bytes): report data included in quote request + quote_req(TdxQuoteReq): the TDX quote request instance to call QGS + port(integer): the port number of QGS vsock + + Returns: + A TdxQuote object fetched through vsock + """ + # Setup socket to connect qgs socket on host + try: + with socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM, 0) as sock: + sock.settimeout(30) + sock.connect((socket.VMADDR_CID_HOST, port)) + + header_size = 4 + # Generate p_blob_payload buffer + qgs_msg = quote_req.qgs_msg_quote_req(report_bytes) + msg_size = len(qgs_msg) + p_blob_payload = bytearray(msg_size.to_bytes(header_size, "big")) + p_blob_payload[header_size:] = qgs_msg[:msg_size] + + # Send quote request + nsent = sock.send(p_blob_payload) + LOG.debug("Sent %d bytes for Quote request.", nsent) + + # Receive quote response + header = sock.recv(header_size) + in_msg_size = 0 + for i in range(header_size): + in_msg_size = (in_msg_size << 8) + (header[i] & 0xFF) + qgs_resp = sock.recv(in_msg_size) + LOG.debug("Received %d bytes as Quote response", in_msg_size) + + sock.close() + except socket.error as msg: + LOG.error("Socket Error: %s", msg) + return None + tdquote = quote_req.qgs_msg_quote_resp(qgs_resp) + return TdxQuote(tdquote)