diff --git a/components/esp_modem/examples/modem_tcp_client/components/extra_tcp_transports/include/mbedtls_wrap.hpp b/components/esp_modem/examples/modem_tcp_client/components/extra_tcp_transports/include/mbedtls_wrap.hpp index a6e3680220..5590801e84 100644 --- a/components/esp_modem/examples/modem_tcp_client/components/extra_tcp_transports/include/mbedtls_wrap.hpp +++ b/components/esp_modem/examples/modem_tcp_client/components/extra_tcp_transports/include/mbedtls_wrap.hpp @@ -6,6 +6,7 @@ #pragma once #include +#include #include #include "mbedtls/ssl.h" #include "mbedtls/entropy.h" @@ -22,6 +23,7 @@ class Tls { Tls(); virtual ~Tls(); bool init(is_server server, do_verify verify); + bool deinit(); int handshake(); int write(const unsigned char *buf, size_t len); int read(unsigned char *buf, size_t len); @@ -41,6 +43,11 @@ class Tls { mbedtls_entropy_context entropy_{}; virtual void delay() {} + bool set_session(); + bool get_session(); + void reset_session(); + bool is_session_loaded(); + private: static void print_error(const char *function, int error_code); static int bio_write(void *ctx, const unsigned char *buf, size_t len); @@ -48,5 +55,21 @@ class Tls { int mbedtls_pk_parse_key( mbedtls_pk_context *ctx, const unsigned char *key, size_t keylen, const unsigned char *pwd, size_t pwdlen); + struct unique_session { + unique_session() + { + ::mbedtls_ssl_session_init(&s); + } + ~unique_session() + { + ::mbedtls_ssl_session_free(&s); + } + mbedtls_ssl_session *ptr() + { + return &s; + } + mbedtls_ssl_session s; + }; + std::unique_ptr session_; }; diff --git a/components/esp_modem/examples/modem_tcp_client/components/extra_tcp_transports/mbedtls_wrap.cpp b/components/esp_modem/examples/modem_tcp_client/components/extra_tcp_transports/mbedtls_wrap.cpp index ff510880a6..026f41020d 100644 --- a/components/esp_modem/examples/modem_tcp_client/components/extra_tcp_transports/mbedtls_wrap.cpp +++ b/components/esp_modem/examples/modem_tcp_client/components/extra_tcp_transports/mbedtls_wrap.cpp @@ -35,6 +35,16 @@ bool Tls::init(is_server server, do_verify verify) return true; } +bool Tls::deinit() +{ + ::mbedtls_ssl_config_free(&conf_); + ::mbedtls_ssl_free(&ssl_); + ::mbedtls_pk_free(&pk_key_); + ::mbedtls_x509_crt_free(&public_cert_); + ::mbedtls_x509_crt_free(&ca_cert_); + return true; +} + void Tls::print_error(const char *function, int error_code) { static char error_buf[100]; @@ -132,3 +142,39 @@ Tls::~Tls() ::mbedtls_x509_crt_free(&public_cert_); ::mbedtls_x509_crt_free(&ca_cert_); } + +bool Tls::get_session() +{ + if (session_ == nullptr) { + session_ = std::make_unique(); + } + int ret = ::mbedtls_ssl_get_session(&ssl_, session_->ptr()); + if (ret != 0) { + print_error("mbedtls_ssl_get_session() failed", ret); + return false; + } + return true; +} + +bool Tls::set_session() +{ + if (session_ == nullptr) { + printf("session hasn't been initialized"); + return false; + } + int ret = mbedtls_ssl_set_session(&ssl_, session_->ptr()); + if (ret != 0) { + print_error("mbedtls_ssl_set_session() failed", ret); + return false; + } + return true; +} + +void Tls::reset_session() +{ + session_.reset(nullptr); +} +bool Tls::is_session_loaded() +{ + return session_ != nullptr; +}