diff --git a/wolfssl/src/context.rs b/wolfssl/src/context.rs index 071e110..dc41f76 100644 --- a/wolfssl/src/context.rs +++ b/wolfssl/src/context.rs @@ -2,7 +2,7 @@ use crate::{ callback::IOCallbacks, error::{Error, Result}, ssl::{Session, SessionConfig}, - NewSessionError, Protocol, RootCertificate, Secret, + CurveGroup, NewSessionError, Protocol, RootCertificate, Secret, }; use std::ptr::NonNull; use thiserror::Error; @@ -163,6 +163,35 @@ impl ContextBuilder { } } + /// Wraps [`wolfSSL_CTX_set_groups`][0] + /// + /// [0]: https://www.wolfssl.com/documentation/manuals/wolfssl/group__Setup.html#function-wolfssl_ctx_set_groups + pub fn with_groups(self, groups: &[CurveGroup]) -> Result { + let mut ffi_curves = groups.iter().map(|g| g.as_ffi() as i32).collect::>(); + + // SAFETY: [`wolfSSL_CTX_set_groups`][0] ([also][1]) requires + // a valid `ctx` pointer from `wolfSSL_CTX_new()` and `groups` + // parameter which should be a pointer to int with length + // corresponding to the `count` argument which is guaranteed + // by our use of a `Vec` here. + // + // [0]: https://www.wolfssl.com/documentation/manuals/wolfssl/group__Setup.html#function-wolfssl_ctx_set_groups + // [1]: https://www.wolfssl.com/doxygen/group__Setup.html#ga5bab039f79486d3ac31be72bc5f4e1e8 + let result = unsafe { + wolfssl_sys::wolfSSL_CTX_set_groups( + self.ctx.as_ptr(), + ffi_curves.as_mut_ptr(), + ffi_curves.len() as i32, + ) + }; + + if result == wolfssl_sys::WOLFSSL_SUCCESS { + Ok(self) + } else { + Err(Error::fatal(result)) + } + } + /// Wraps [`wolfSSL_CTX_use_certificate_file`][0] and [`wolfSSL_CTX_use_certificate_buffer`][1] /// /// [0]: https://www.wolfssl.com/documentation/manuals/wolfssl/group__CertsKeys.html#function-wolfssl_ctx_use_certificate_file diff --git a/wolfssl/src/lib.rs b/wolfssl/src/lib.rs index 39df899..0701340 100644 --- a/wolfssl/src/lib.rs +++ b/wolfssl/src/lib.rs @@ -209,6 +209,36 @@ impl Protocol { } } +/// Corresponds to the various defined `WOLFSSL_*` curves +#[derive(Debug, Copy, Clone)] +pub enum CurveGroup { + /// `WOLFSSL_ECC_SECP256R1` + EccSecp256R1, + + /// `WOLFSSL_ECC_X25519` + EccX25519, + + /// `WOLFSSL_P256_KYBER_LEVEL1` + P256KyberLevel1, + /// `WOLFSSL_P384_KYBER_LEVEL3` + P384KyberLevel3, + /// `WOLFSSL_P521_KYBER_LEVEL5` + P521KyberLevel5, +} + +impl CurveGroup { + fn as_ffi(&self) -> std::os::raw::c_uint { + use CurveGroup::*; + match self { + EccSecp256R1 => wolfssl_sys::WOLFSSL_ECC_SECP256R1, + EccX25519 => wolfssl_sys::WOLFSSL_ECC_X25519, + P256KyberLevel1 => wolfssl_sys::WOLFSSL_P256_KYBER_LEVEL1, + P384KyberLevel3 => wolfssl_sys::WOLFSSL_P384_KYBER_LEVEL3, + P521KyberLevel5 => wolfssl_sys::WOLFSSL_P521_KYBER_LEVEL5, + } + } +} + /// Defines a CA certificate pub enum RootCertificate<'a> { /// In-memory PEM buffer diff --git a/wolfssl/src/ssl.rs b/wolfssl/src/ssl.rs index 096764a..c515358 100644 --- a/wolfssl/src/ssl.rs +++ b/wolfssl/src/ssl.rs @@ -2,7 +2,7 @@ use crate::{ callback::{IOCallbackResult, IOCallbacks}, context::Context, error::{Error, Poll, PollResult, Result}, - Protocol, ProtocolVersion, TLS_MAX_RECORD_SIZE, + CurveGroup, Protocol, ProtocolVersion, TLS_MAX_RECORD_SIZE, }; use bytes::{Buf, Bytes, BytesMut}; @@ -59,6 +59,8 @@ pub struct SessionConfig { /// If set, configures the session to check the given domain against the /// peer certificate during connection. pub checked_domain_name: Option, + /// If set, specifies a curve group to use for key share + pub keyshare_group: Option, } impl SessionConfig { @@ -71,6 +73,7 @@ impl SessionConfig { dtls_mtu: Default::default(), server_name_indicator: Default::default(), checked_domain_name: Default::default(), + keyshare_group: Default::default(), } } @@ -97,6 +100,12 @@ impl SessionConfig { self.checked_domain_name = Some(domain.to_string()); self } + + /// Sets [`Self::keyshare_group`] + pub fn with_keyshare_group(mut self, curve: CurveGroup) -> Self { + self.keyshare_group = Some(curve); + self + } } // Wrap a valid pointer to a [`wolfssl_sys::WOLFSSL`] such that we can @@ -192,6 +201,12 @@ impl Session { .map_err(|e| NewSessionError::SetupFailed("set_domain_name_to_check", e))?; } + if let Some(curve) = config.keyshare_group { + session + .use_key_share_curve(curve) + .map_err(|e| NewSessionError::SetupFailed("use_key_share_curve", e))?; + } + Ok(session) } @@ -1040,6 +1055,27 @@ impl Session { e => unreachable!("{e:?}"), } } + + /// Invokes [`wolfSSL_UseKeyShare`][0] + /// + /// [0]: https://www.wolfssl.com/documentation/manuals/wolfssl/ssl_8h.html#function-wolfssl_usekeyshare + fn use_key_share_curve(&self, curve: CurveGroup) -> Result<()> { + // SAFETY: [`wolfSSL_UseKeyShare`][0] ([also][1]) expects a valid pointer to `WOLFSSL`. Per the + // [Library design][2] access is synchronized via the containing [`Mutex`] + // + // [0]: https://www.wolfssl.com/documentation/manuals/wolfssl/ssl_8h.html#function-wolfssl_usekeyshare + // [1]: https://www.wolfssl.com/doxygen/group__Setup.html#gac2d00ac65513f10e0ccd1b67d9a99e3d + // [2]: https://www.wolfssl.com/documentation/manuals/wolfssl/chapter09.html#thread-safety + match unsafe { + let ssl = self.ssl.lock(); + wolfssl_sys::wolfSSL_UseKeyShare(ssl.as_ptr(), curve.as_ffi() as wolfssl_sys::word16) + } { + wolfssl_sys::WOLFSSL_SUCCESS => Ok(()), + wolfssl_sys::MEMORY_E => panic!("Memory Allocation Failed"), + e @ wolfssl_sys::BAD_FUNC_ARG => unreachable!("{e:?}"), + e => unreachable!("{e:?}"), + } + } } impl Drop for Session {