Skip to content

Commit

Permalink
sm2: Fix heap allocation
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxKingPor committed Nov 6, 2024
1 parent 1bfea04 commit 8eabb52
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 51 deletions.
2 changes: 1 addition & 1 deletion sm2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ elliptic-curve = { version = "0.14.0-rc.0", default-features = false, features =
primeorder = { version = "=0.14.0-pre.2", optional = true, path = "../primeorder" }
rfc6979 = { version = "=0.5.0-pre.4", optional = true }
serdect = { version = "0.3.0-rc.0", optional = true, default-features = false }
signature = { version = "=2.3.0-pre.4", optional = true, features = ["rand_core"] }
signature = { version = "=2.3.0-pre.4", optional = true, features = ["rand_core", "digest"] }
sm3 = { version = "=0.5.0-pre.4", optional = true, default-features = false }

[dev-dependencies]
Expand Down
28 changes: 14 additions & 14 deletions sm2/src/pke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,20 @@
use core::cmp::min;

use crate::AffinePoint;

#[cfg(feature = "alloc")]
use alloc::vec;

use elliptic_curve::{
bigint::{Encoding, Uint, U256},
pkcs8::der::{
asn1::UintRef, Decode, DecodeValue, Encode, Length, Reader, Sequence, Tag, Writer,
},
};

use elliptic_curve::{
pkcs8::der::{asn1::OctetStringRef, EncodeValue},
sec1::ToEncodedPoint,
Result,
sec1::{ModulusSize, ToEncodedPoint},
CurveArithmetic, FieldBytesSize, Result,
};
use sm3::digest::DynDigest;
use primeorder::{AffinePoint, PrimeCurveParams};
use signature::digest::{FixedOutputReset, Output, Update};

#[cfg(feature = "arithmetic")]
mod decrypting;
Expand Down Expand Up @@ -131,22 +127,26 @@ impl<'a> DecodeValue<'a> for Cipher<'a> {
}

/// Performs key derivation using a hash function and elliptic curve point.
fn kdf(hasher: &mut dyn DynDigest, kpb: AffinePoint, c2: &mut [u8]) -> Result<()> {
fn kdf<D, C>(hasher: &mut D, kpb: AffinePoint<C>, c2: &mut [u8]) -> Result<()>
where
D: Update + FixedOutputReset,
C: CurveArithmetic + PrimeCurveParams,
FieldBytesSize<C>: ModulusSize,
AffinePoint<C>: ToEncodedPoint<C>,
{
let klen = c2.len();
let mut ct: i32 = 0x00000001;
let mut offset = 0;
let digest_size = hasher.output_size();
let mut ha = vec![0u8; digest_size];
let digest_size = D::output_size();
let mut ha = Output::<D>::default();
let encode_point = kpb.to_encoded_point(false);

while offset < klen {
hasher.update(encode_point.x().ok_or(elliptic_curve::Error)?);
hasher.update(encode_point.y().ok_or(elliptic_curve::Error)?);
hasher.update(&ct.to_be_bytes());

hasher
.finalize_into_reset(&mut ha)
.map_err(|_e| elliptic_curve::Error)?;
hasher.finalize_into_reset(&mut ha);

let xor_len = min(digest_size, klen - offset);
xor(c2, &ha, offset, xor_len);
Expand Down
25 changes: 12 additions & 13 deletions sm2/src/pke/decrypting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ use elliptic_curve::{
};
use primeorder::PrimeField;

use sm3::{digest::DynDigest, Digest, Sm3};
use signature::digest::{Digest, FixedOutputReset, Output, OutputSizeUser, Update};
use sm3::Sm3;

use super::{encrypting::EncryptingKey, kdf, vec, Cipher, Mode};
use super::{encrypting::EncryptingKey, kdf, Cipher, Mode};
/// Represents a decryption key used for decrypting messages using elliptic curve cryptography.
#[derive(Clone)]
pub struct DecryptingKey {
Expand Down Expand Up @@ -91,7 +92,7 @@ impl DecryptingKey {
/// Decrypts a ciphertext in-place using the specified digest algorithm.
pub fn decrypt_digest<D>(&self, ciphertext: &[u8]) -> Result<Vec<u8>>
where
D: 'static + Digest + DynDigest + Send + Sync,
D: Digest + OutputSizeUser + Update + FixedOutputReset,
{
let mut digest = D::new();
decrypt(&self.secret_scalar, self.mode, &mut digest, ciphertext)
Expand All @@ -105,7 +106,7 @@ impl DecryptingKey {
/// Decrypts a ciphertext in-place from ASN.1 format using the specified digest algorithm.
pub fn decrypt_der_digest<D>(&self, ciphertext: &[u8]) -> Result<Vec<u8>>
where
D: 'static + Digest + DynDigest + Send + Sync,
D: Digest + OutputSizeUser + Update + FixedOutputReset,
{
let cipher = Cipher::from_der(ciphertext).map_err(elliptic_curve::pkcs8::Error::from)?;
let prefix: &[u8] = &[0x04];
Expand Down Expand Up @@ -153,12 +154,10 @@ impl PartialEq for DecryptingKey {
}
}

fn decrypt(
secret_scalar: &Scalar,
mode: Mode,
hasher: &mut dyn DynDigest,
cipher: &[u8],
) -> Result<Vec<u8>> {
fn decrypt<D>(secret_scalar: &Scalar, mode: Mode, hasher: &mut D, cipher: &[u8]) -> Result<Vec<u8>>
where
D: Update + OutputSizeUser + FixedOutputReset,
{
let q = U256::from_be_hex(FieldElement::MODULUS);
let c1_len = (q.bits() + 7) / 8 * 2 + 1;

Expand All @@ -177,7 +176,7 @@ fn decrypt(

// B3: compute [𝑑𝐵]𝐶1 = (𝑥2, 𝑦2)
c1_point = (c1_point * secret_scalar).to_affine();
let digest_size = hasher.output_size();
let digest_size = D::output_size();
let (c2, c3) = match mode {
Mode::C1C3C2 => {
let (c3, c2) = c.split_at(digest_size);
Expand All @@ -192,12 +191,12 @@ fn decrypt(
kdf(hasher, c1_point, &mut c2)?;

// compute 𝑢 = 𝐻𝑎𝑠ℎ(𝑥2 ∥ 𝑀′∥ 𝑦2).
let mut u = vec![0u8; digest_size];
let mut u = Output::<D>::default();
let encode_point = c1_point.to_encoded_point(false);
hasher.update(encode_point.x().ok_or(Error)?);
hasher.update(&c2);
hasher.update(encode_point.y().ok_or(Error)?);
hasher.finalize_into_reset(&mut u).map_err(|_e| Error)?;
hasher.finalize_into_reset(&mut u);
let checked = u
.iter()
.zip(c3)
Expand Down
40 changes: 17 additions & 23 deletions sm2/src/pke/encrypting.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use core::fmt::Debug;

use crate::{
arithmetic::field::FieldElement,
pke::{kdf, vec},
AffinePoint, ProjectivePoint, PublicKey, Scalar, Sm2,
arithmetic::field::FieldElement, pke::kdf, AffinePoint, ProjectivePoint, PublicKey, Scalar, Sm2,
};

#[cfg(feature = "alloc")]
Expand All @@ -18,12 +16,10 @@ use elliptic_curve::{
};

use primeorder::PrimeField;
use sm3::{
digest::{Digest, DynDigest},
Sm3,
};
use sm3::Sm3;

use super::{Cipher, Mode};
use signature::digest::{Digest, FixedOutputReset, Output, OutputSizeUser, Update};
/// Represents an encryption key used for encrypting messages using elliptic curve cryptography.
#[derive(Clone, Debug)]
pub struct EncryptingKey {
Expand Down Expand Up @@ -91,7 +87,7 @@ impl EncryptingKey {
/// Encrypts a message using a specified digest algorithm.
pub fn encrypt_digest<D>(&self, msg: &[u8]) -> Result<Vec<u8>>
where
D: 'static + Digest + DynDigest + Send + Sync,
D: Digest + Update + FixedOutputReset,
{
let mut digest = D::new();
encrypt(&self.public_key, self.mode, &mut digest, msg)
Expand All @@ -100,11 +96,11 @@ impl EncryptingKey {
/// Encrypts a message using a specified digest algorithm and returns the result in ASN.1 format.
pub fn encrypt_der_digest<D>(&self, msg: &[u8]) -> Result<Vec<u8>>
where
D: 'static + Digest + DynDigest + Send + Sync,
D: Update + OutputSizeUser + Digest + FixedOutputReset,
{
let mut digest = D::new();
let cipher = encrypt(&self.public_key, self.mode, &mut digest, msg)?;
let digest_size = digest.output_size();
let digest_size = <D as OutputSizeUser>::output_size();
let (_, cipher) = cipher.split_at(1);
let (x, cipher) = cipher.split_at(32);
let (y, cipher) = cipher.split_at(32);
Expand Down Expand Up @@ -133,14 +129,13 @@ impl From<PublicKey> for EncryptingKey {
}

/// Encrypts a message using the specified public key, mode, and digest algorithm.
fn encrypt(
public_key: &PublicKey,
mode: Mode,
digest: &mut dyn DynDigest,
msg: &[u8],
) -> Result<Vec<u8>> {
fn encrypt<D>(public_key: &PublicKey, mode: Mode, digest: &mut D, msg: &[u8]) -> Result<Vec<u8>>
where
D: Update + FixedOutputReset,
{
const N_BYTES: u32 = (Sm2::ORDER.bits() + 7) / 8;
let mut c1 = vec![0; (N_BYTES * 2 + 1) as usize];
#[allow(unused_assignments)]
let mut c1 = Default::default();
let mut c2 = msg.to_owned();
let mut hpb: AffinePoint;
loop {
Expand All @@ -167,24 +162,23 @@ fn encrypt(
// // If 𝑡 is an all-zero bit string, go to A1.
// if all of t are 0, xor(c2) == c2
if c2.iter().zip(msg).any(|(pre, cur)| pre != cur) {
let uncompress_kg = kg.to_encoded_point(false);
c1.copy_from_slice(uncompress_kg.as_bytes());
c1 = kg.to_encoded_point(false);
break;
}
}
let encode_point = hpb.to_encoded_point(false);

// A7: compute 𝐶3 = 𝐻𝑎𝑠ℎ(𝑥2||𝑀||𝑦2)
let mut c3 = vec![0; digest.output_size()];
let mut c3 = Output::<D>::default();
digest.update(encode_point.x().ok_or(Error)?);
digest.update(msg);
digest.update(encode_point.y().ok_or(Error)?);
digest.finalize_into_reset(&mut c3).map_err(|_e| Error)?;
digest.finalize_into_reset(&mut c3);

// A8: output the ciphertext 𝐶 = 𝐶1||𝐶2||𝐶3.
Ok(match mode {
Mode::C1C2C3 => [c1.as_slice(), &c2, &c3].concat(),
Mode::C1C3C2 => [c1.as_slice(), &c3, &c2].concat(),
Mode::C1C2C3 => [c1.as_bytes(), &c2, &c3].concat(),
Mode::C1C3C2 => [c1.as_bytes(), &c3, &c2].concat(),
})
}

Expand Down

0 comments on commit 8eabb52

Please sign in to comment.