From 8eabb52bcd0d12fede3096c03716518e08aff843 Mon Sep 17 00:00:00 2001 From: MaxKing <1512347620@qq.com> Date: Wed, 6 Nov 2024 11:02:20 +0800 Subject: [PATCH] sm2: Fix heap allocation --- sm2/Cargo.toml | 2 +- sm2/src/pke.rs | 28 +++++++++++++-------------- sm2/src/pke/decrypting.rs | 25 ++++++++++++------------ sm2/src/pke/encrypting.rs | 40 +++++++++++++++++---------------------- 4 files changed, 44 insertions(+), 51 deletions(-) diff --git a/sm2/Cargo.toml b/sm2/Cargo.toml index 4af59da9..e00c87c1 100644 --- a/sm2/Cargo.toml +++ b/sm2/Cargo.toml @@ -24,7 +24,7 @@ elliptic-curve = { version = "0.14.0-rc.0", default-features = false, features = primeorder = { version = "=0.14.0-pre.2", optional = true, path = "../primeorder" } rfc6979 = { version = "=0.5.0-pre.4", optional = true } serdect = { version = "0.3.0-rc.0", optional = true, default-features = false } -signature = { version = "=2.3.0-pre.4", optional = true, features = ["rand_core"] } +signature = { version = "=2.3.0-pre.4", optional = true, features = ["rand_core", "digest"] } sm3 = { version = "=0.5.0-pre.4", optional = true, default-features = false } [dev-dependencies] diff --git a/sm2/src/pke.rs b/sm2/src/pke.rs index 61875511..f86dd0db 100644 --- a/sm2/src/pke.rs +++ b/sm2/src/pke.rs @@ -43,10 +43,6 @@ use core::cmp::min; -use crate::AffinePoint; - -#[cfg(feature = "alloc")] -use alloc::vec; use elliptic_curve::{ bigint::{Encoding, Uint, U256}, @@ -54,13 +50,13 @@ use elliptic_curve::{ asn1::UintRef, Decode, DecodeValue, Encode, Length, Reader, Sequence, Tag, Writer, }, }; - use elliptic_curve::{ pkcs8::der::{asn1::OctetStringRef, EncodeValue}, - sec1::ToEncodedPoint, - Result, + sec1::{ModulusSize, ToEncodedPoint}, + CurveArithmetic, FieldBytesSize, Result, }; -use sm3::digest::DynDigest; +use primeorder::{AffinePoint, PrimeCurveParams}; +use signature::digest::{FixedOutputReset, Output, Update}; #[cfg(feature = "arithmetic")] mod decrypting; @@ -131,12 +127,18 @@ impl<'a> DecodeValue<'a> for Cipher<'a> { } /// Performs key derivation using a hash function and elliptic curve point. -fn kdf(hasher: &mut dyn DynDigest, kpb: AffinePoint, c2: &mut [u8]) -> Result<()> { +fn kdf(hasher: &mut D, kpb: AffinePoint, c2: &mut [u8]) -> Result<()> +where + D: Update + FixedOutputReset, + C: CurveArithmetic + PrimeCurveParams, + FieldBytesSize: ModulusSize, + AffinePoint: ToEncodedPoint, +{ let klen = c2.len(); let mut ct: i32 = 0x00000001; let mut offset = 0; - let digest_size = hasher.output_size(); - let mut ha = vec![0u8; digest_size]; + let digest_size = D::output_size(); + let mut ha = Output::::default(); let encode_point = kpb.to_encoded_point(false); while offset < klen { @@ -144,9 +146,7 @@ fn kdf(hasher: &mut dyn DynDigest, kpb: AffinePoint, c2: &mut [u8]) -> Result<() hasher.update(encode_point.y().ok_or(elliptic_curve::Error)?); hasher.update(&ct.to_be_bytes()); - hasher - .finalize_into_reset(&mut ha) - .map_err(|_e| elliptic_curve::Error)?; + hasher.finalize_into_reset(&mut ha); let xor_len = min(digest_size, klen - offset); xor(c2, &ha, offset, xor_len); diff --git a/sm2/src/pke/decrypting.rs b/sm2/src/pke/decrypting.rs index 5a57a633..8509ea15 100644 --- a/sm2/src/pke/decrypting.rs +++ b/sm2/src/pke/decrypting.rs @@ -16,9 +16,10 @@ use elliptic_curve::{ }; use primeorder::PrimeField; -use sm3::{digest::DynDigest, Digest, Sm3}; +use signature::digest::{Digest, FixedOutputReset, Output, OutputSizeUser, Update}; +use sm3::Sm3; -use super::{encrypting::EncryptingKey, kdf, vec, Cipher, Mode}; +use super::{encrypting::EncryptingKey, kdf, Cipher, Mode}; /// Represents a decryption key used for decrypting messages using elliptic curve cryptography. #[derive(Clone)] pub struct DecryptingKey { @@ -91,7 +92,7 @@ impl DecryptingKey { /// Decrypts a ciphertext in-place using the specified digest algorithm. pub fn decrypt_digest(&self, ciphertext: &[u8]) -> Result> where - D: 'static + Digest + DynDigest + Send + Sync, + D: Digest + OutputSizeUser + Update + FixedOutputReset, { let mut digest = D::new(); decrypt(&self.secret_scalar, self.mode, &mut digest, ciphertext) @@ -105,7 +106,7 @@ impl DecryptingKey { /// Decrypts a ciphertext in-place from ASN.1 format using the specified digest algorithm. pub fn decrypt_der_digest(&self, ciphertext: &[u8]) -> Result> where - D: 'static + Digest + DynDigest + Send + Sync, + D: Digest + OutputSizeUser + Update + FixedOutputReset, { let cipher = Cipher::from_der(ciphertext).map_err(elliptic_curve::pkcs8::Error::from)?; let prefix: &[u8] = &[0x04]; @@ -153,12 +154,10 @@ impl PartialEq for DecryptingKey { } } -fn decrypt( - secret_scalar: &Scalar, - mode: Mode, - hasher: &mut dyn DynDigest, - cipher: &[u8], -) -> Result> { +fn decrypt(secret_scalar: &Scalar, mode: Mode, hasher: &mut D, cipher: &[u8]) -> Result> +where + D: Update + OutputSizeUser + FixedOutputReset, +{ let q = U256::from_be_hex(FieldElement::MODULUS); let c1_len = (q.bits() + 7) / 8 * 2 + 1; @@ -177,7 +176,7 @@ fn decrypt( // B3: compute [𝑑𝐵]𝐶1 = (𝑥2, 𝑦2) c1_point = (c1_point * secret_scalar).to_affine(); - let digest_size = hasher.output_size(); + let digest_size = D::output_size(); let (c2, c3) = match mode { Mode::C1C3C2 => { let (c3, c2) = c.split_at(digest_size); @@ -192,12 +191,12 @@ fn decrypt( kdf(hasher, c1_point, &mut c2)?; // compute 𝑢 = 𝐻𝑎𝑠ℎ(𝑥2 ∥ 𝑀′∥ 𝑦2). - let mut u = vec![0u8; digest_size]; + let mut u = Output::::default(); let encode_point = c1_point.to_encoded_point(false); hasher.update(encode_point.x().ok_or(Error)?); hasher.update(&c2); hasher.update(encode_point.y().ok_or(Error)?); - hasher.finalize_into_reset(&mut u).map_err(|_e| Error)?; + hasher.finalize_into_reset(&mut u); let checked = u .iter() .zip(c3) diff --git a/sm2/src/pke/encrypting.rs b/sm2/src/pke/encrypting.rs index a0bcb55a..c3eddfe7 100644 --- a/sm2/src/pke/encrypting.rs +++ b/sm2/src/pke/encrypting.rs @@ -1,9 +1,7 @@ use core::fmt::Debug; use crate::{ - arithmetic::field::FieldElement, - pke::{kdf, vec}, - AffinePoint, ProjectivePoint, PublicKey, Scalar, Sm2, + arithmetic::field::FieldElement, pke::kdf, AffinePoint, ProjectivePoint, PublicKey, Scalar, Sm2, }; #[cfg(feature = "alloc")] @@ -18,12 +16,10 @@ use elliptic_curve::{ }; use primeorder::PrimeField; -use sm3::{ - digest::{Digest, DynDigest}, - Sm3, -}; +use sm3::Sm3; use super::{Cipher, Mode}; +use signature::digest::{Digest, FixedOutputReset, Output, OutputSizeUser, Update}; /// Represents an encryption key used for encrypting messages using elliptic curve cryptography. #[derive(Clone, Debug)] pub struct EncryptingKey { @@ -91,7 +87,7 @@ impl EncryptingKey { /// Encrypts a message using a specified digest algorithm. pub fn encrypt_digest(&self, msg: &[u8]) -> Result> where - D: 'static + Digest + DynDigest + Send + Sync, + D: Digest + Update + FixedOutputReset, { let mut digest = D::new(); encrypt(&self.public_key, self.mode, &mut digest, msg) @@ -100,11 +96,11 @@ impl EncryptingKey { /// Encrypts a message using a specified digest algorithm and returns the result in ASN.1 format. pub fn encrypt_der_digest(&self, msg: &[u8]) -> Result> where - D: 'static + Digest + DynDigest + Send + Sync, + D: Update + OutputSizeUser + Digest + FixedOutputReset, { let mut digest = D::new(); let cipher = encrypt(&self.public_key, self.mode, &mut digest, msg)?; - let digest_size = digest.output_size(); + let digest_size = ::output_size(); let (_, cipher) = cipher.split_at(1); let (x, cipher) = cipher.split_at(32); let (y, cipher) = cipher.split_at(32); @@ -133,14 +129,13 @@ impl From for EncryptingKey { } /// Encrypts a message using the specified public key, mode, and digest algorithm. -fn encrypt( - public_key: &PublicKey, - mode: Mode, - digest: &mut dyn DynDigest, - msg: &[u8], -) -> Result> { +fn encrypt(public_key: &PublicKey, mode: Mode, digest: &mut D, msg: &[u8]) -> Result> +where + D: Update + FixedOutputReset, +{ const N_BYTES: u32 = (Sm2::ORDER.bits() + 7) / 8; - let mut c1 = vec![0; (N_BYTES * 2 + 1) as usize]; + #[allow(unused_assignments)] + let mut c1 = Default::default(); let mut c2 = msg.to_owned(); let mut hpb: AffinePoint; loop { @@ -167,24 +162,23 @@ fn encrypt( // // If 𝑡 is an all-zero bit string, go to A1. // if all of t are 0, xor(c2) == c2 if c2.iter().zip(msg).any(|(pre, cur)| pre != cur) { - let uncompress_kg = kg.to_encoded_point(false); - c1.copy_from_slice(uncompress_kg.as_bytes()); + c1 = kg.to_encoded_point(false); break; } } let encode_point = hpb.to_encoded_point(false); // A7: compute 𝐶3 = 𝐻𝑎𝑠ℎ(𝑥2||𝑀||𝑦2) - let mut c3 = vec![0; digest.output_size()]; + let mut c3 = Output::::default(); digest.update(encode_point.x().ok_or(Error)?); digest.update(msg); digest.update(encode_point.y().ok_or(Error)?); - digest.finalize_into_reset(&mut c3).map_err(|_e| Error)?; + digest.finalize_into_reset(&mut c3); // A8: output the ciphertext 𝐶 = 𝐶1||𝐶2||𝐶3. Ok(match mode { - Mode::C1C2C3 => [c1.as_slice(), &c2, &c3].concat(), - Mode::C1C3C2 => [c1.as_slice(), &c3, &c2].concat(), + Mode::C1C2C3 => [c1.as_bytes(), &c2, &c3].concat(), + Mode::C1C3C2 => [c1.as_bytes(), &c3, &c2].concat(), }) }