Skip to content

Commit

Permalink
[zk-sdk] Expose ElGamal and authenticated encryption ciphertext types…
Browse files Browse the repository at this point in the history
… for wasm target (anza-xyz#4171)

* add `PodElGamalPubkey::zeroed()` function

* use macro for wasm binding implementation for `PodElGamalPubkey`

* expose `ElGamalCiphertext` as wasm

* expose `AeCiphertext` and `AeKey` as wasm

* use camelCase for `new_rand()` for js

* fix order of declaration of cfgs

* use fully qualified types in `impl_wasm_bindgings` macro
  • Loading branch information
samkim-crypto authored Jan 7, 2025
1 parent 92ce0a8 commit 6f1282d
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 76 deletions.
6 changes: 6 additions & 0 deletions zk-sdk/src/encryption/auth_encryption.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
//!
//! This module is a simple wrapper of the `Aes128GcmSiv` implementation specialized for SPL
//! token-2022 where the plaintext is always `u64`.
#[cfg(target_arch = "wasm32")]
use wasm_bindgen::prelude::*;
use {
crate::{
encryption::{AE_CIPHERTEXT_LEN, AE_KEY_LEN},
Expand Down Expand Up @@ -85,12 +87,15 @@ impl AuthenticatedEncryption {
}
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
#[derive(Clone, Debug, Zeroize, Eq, PartialEq)]
pub struct AeKey([u8; AE_KEY_LEN]);
#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
impl AeKey {
/// Generates a random authenticated encryption key.
///
/// This function is randomized. It internally samples a scalar element using `OsRng`.
#[cfg_attr(target_arch = "wasm32", wasm_bindgen(js_name = newRand))]
pub fn new_rand() -> Self {
AuthenticatedEncryption::keygen()
}
Expand Down Expand Up @@ -240,6 +245,7 @@ type Nonce = [u8; NONCE_LEN];
type Ciphertext = [u8; CIPHERTEXT_LEN];

/// Authenticated encryption nonce and ciphertext
#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
#[derive(Clone, Copy, Debug, Default)]
pub struct AeCiphertext {
nonce: Nonce,
Expand Down
3 changes: 3 additions & 0 deletions zk-sdk/src/encryption/elgamal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ impl ElGamalKeypair {
/// Generates the public and secret keys for ElGamal encryption.
///
/// This function is randomized. It internally samples a scalar element using `OsRng`.
#[cfg_attr(target_arch = "wasm32", wasm_bindgen(js_name = newRand))]
pub fn new_rand() -> Self {
ElGamal::keygen()
}
Expand Down Expand Up @@ -615,6 +616,7 @@ impl ConstantTimeEq for ElGamalSecretKey {
}

/// Ciphertext for the ElGamal encryption scheme.
#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
pub struct ElGamalCiphertext {
pub commitment: PedersenCommitment,
Expand Down Expand Up @@ -750,6 +752,7 @@ define_mul_variants!(
);

/// Decryption handle for Pedersen commitment.
#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
pub struct DecryptHandle(RistrettoPoint);
impl DecryptHandle {
Expand Down
3 changes: 3 additions & 0 deletions zk-sdk/src/encryption/pedersen.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! Pedersen commitment implementation using the Ristretto prime-order group.
#[cfg(target_arch = "wasm32")]
use wasm_bindgen::prelude::*;
use {
crate::encryption::{PEDERSEN_COMMITMENT_LEN, PEDERSEN_OPENING_LEN},
core::ops::{Add, Mul, Sub},
Expand Down Expand Up @@ -165,6 +167,7 @@ define_mul_variants!(
);

/// Pedersen commitment type.
#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
pub struct PedersenCommitment(RistrettoPoint);
impl PedersenCommitment {
Expand Down
7 changes: 6 additions & 1 deletion zk-sdk/src/encryption/pod/auth_encryption.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
#[cfg(not(target_os = "solana"))]
use crate::{encryption::auth_encryption::AeCiphertext, errors::AuthenticatedEncryptionError};
#[cfg(target_arch = "wasm32")]
use wasm_bindgen::prelude::*;
use {
crate::{
encryption::AE_CIPHERTEXT_LEN,
pod::{impl_from_bytes, impl_from_str},
pod::{impl_from_bytes, impl_from_str, impl_wasm_bindings},
},
base64::{prelude::BASE64_STANDARD, Engine},
bytemuck::{Pod, Zeroable},
Expand All @@ -16,10 +18,13 @@ use {
const AE_CIPHERTEXT_MAX_BASE64_LEN: usize = 48;

/// The `AeCiphertext` type as a `Pod`.
#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
#[derive(Clone, Copy, PartialEq, Eq)]
#[repr(transparent)]
pub struct PodAeCiphertext(pub(crate) [u8; AE_CIPHERTEXT_LEN]);

impl_wasm_bindings!(POD_TYPE = PodAeCiphertext, DECODED_TYPE = AeCiphertext);

// `PodAeCiphertext` is a wrapper type for a byte array, which is both `Pod` and `Zeroable`. However,
// the marker traits `bytemuck::Pod` and `bytemuck::Zeroable` can only be derived for power-of-two
// length byte arrays. Directly implement these traits for `PodAeCiphertext`.
Expand Down
88 changes: 13 additions & 75 deletions zk-sdk/src/encryption/pod/elgamal.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
//! Plain Old Data types for the ElGamal encryption scheme.
#[cfg(not(target_arch = "wasm32"))]
use bytemuck::Zeroable;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen::prelude::*;
#[cfg(not(target_os = "solana"))]
use {
crate::{
Expand All @@ -11,17 +15,11 @@ use {
use {
crate::{
encryption::{DECRYPT_HANDLE_LEN, ELGAMAL_CIPHERTEXT_LEN, ELGAMAL_PUBKEY_LEN},
pod::{impl_from_bytes, impl_from_str},
pod::{impl_from_bytes, impl_from_str, impl_wasm_bindings},
},
base64::{prelude::BASE64_STANDARD, Engine},
bytemuck::Zeroable,
std::fmt,
};
#[cfg(target_arch = "wasm32")]
use {
js_sys::{Array, Uint8Array},
wasm_bindgen::prelude::*,
};

/// Maximum length of a base64 encoded ElGamal public key
const ELGAMAL_PUBKEY_MAX_BASE64_LEN: usize = 44;
Expand All @@ -33,10 +31,16 @@ const ELGAMAL_CIPHERTEXT_MAX_BASE64_LEN: usize = 88;
const DECRYPT_HANDLE_MAX_BASE64_LEN: usize = 44;

/// The `ElGamalCiphertext` type as a `Pod`.
#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
#[derive(Clone, Copy, bytemuck_derive::Pod, bytemuck_derive::Zeroable, PartialEq, Eq)]
#[repr(transparent)]
pub struct PodElGamalCiphertext(pub(crate) [u8; ELGAMAL_CIPHERTEXT_LEN]);

impl_wasm_bindings!(
POD_TYPE = PodElGamalCiphertext,
DECODED_TYPE = ElGamalCiphertext
);

impl fmt::Debug for PodElGamalCiphertext {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self.0)
Expand Down Expand Up @@ -83,78 +87,12 @@ impl TryFrom<PodElGamalCiphertext> for ElGamalCiphertext {
}

/// The `ElGamalPubkey` type as a `Pod`.
#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
#[derive(Clone, Copy, Default, bytemuck_derive::Pod, bytemuck_derive::Zeroable, PartialEq, Eq)]
#[repr(transparent)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
pub struct PodElGamalPubkey(pub(crate) [u8; ELGAMAL_PUBKEY_LEN]);

#[cfg(target_arch = "wasm32")]
#[allow(non_snake_case)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
impl PodElGamalPubkey {
/// Create a new `PodElGamalPubkey` object
///
/// * `value` - optional public key as a base64 encoded string, `Uint8Array`, `[number]`
#[wasm_bindgen(constructor)]
pub fn constructor(value: JsValue) -> Result<PodElGamalPubkey, JsValue> {
if let Some(base64_str) = value.as_string() {
base64_str
.parse::<PodElGamalPubkey>()
.map_err(|e| e.to_string().into())
} else if let Some(uint8_array) = value.dyn_ref::<Uint8Array>() {
bytemuck::try_from_bytes(&uint8_array.to_vec())
.map_err(|err| JsValue::from(format!("Invalid Uint8Array ElGamalPubkey: {err:?}")))
.map(|pubkey| *pubkey)
} else if let Some(array) = value.dyn_ref::<Array>() {
let mut bytes = vec![];
let iterator = js_sys::try_iter(&array.values())?.expect("array to be iterable");
for x in iterator {
let x = x?;

if let Some(n) = x.as_f64() {
if (0. ..=255.).contains(&n) {
bytes.push(n as u8);
continue;
}
}
return Err(format!("Invalid array argument: {:?}", x).into());
}

bytemuck::try_from_bytes(&bytes)
.map_err(|err| JsValue::from(format!("Invalid Array pubkey: {err:?}")))
.map(|pubkey| *pubkey)
} else if value.is_undefined() {
Ok(PodElGamalPubkey::default())
} else {
Err("Unsupported argument".into())
}
}

/// Return the base64 string representation of the public key
pub fn toString(&self) -> String {
self.to_string()
}

/// Checks if two `ElGamalPubkey`s are equal
pub fn equals(&self, other: &PodElGamalPubkey) -> bool {
self == other
}

/// Return the `Uint8Array` representation of the public key
pub fn toBytes(&self) -> Box<[u8]> {
self.0.into()
}

pub fn compressed(decoded: &ElGamalPubkey) -> PodElGamalPubkey {
(*decoded).into()
}

pub fn decompressed(&self) -> Result<ElGamalPubkey, JsValue> {
(*self)
.try_into()
.map_err(|err| JsValue::from(format!("Invalid ElGamalPubkey: {err:?}")))
}
}
impl_wasm_bindings!(POD_TYPE = PodElGamalPubkey, DECODED_TYPE = ElGamalPubkey);

impl fmt::Debug for PodElGamalPubkey {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Expand Down
78 changes: 78 additions & 0 deletions zk-sdk/src/pod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,81 @@ macro_rules! impl_from_bytes {
};
}
pub(crate) use impl_from_bytes;

macro_rules! impl_wasm_bindings {
(POD_TYPE = $pod_type:ident, DECODED_TYPE = $decoded_type: ident) => {
#[cfg(target_arch = "wasm32")]
#[allow(non_snake_case)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen::prelude::wasm_bindgen)]
impl $pod_type {
#[wasm_bindgen::prelude::wasm_bindgen(constructor)]
pub fn constructor(
value: wasm_bindgen::JsValue,
) -> Result<$pod_type, wasm_bindgen::JsValue> {
if let Some(base64_str) = value.as_string() {
base64_str
.parse::<$pod_type>()
.map_err(|e| e.to_string().into())
} else if let Some(uint8_array) = value.dyn_ref::<js_sys::Uint8Array>() {
bytemuck::try_from_bytes(&uint8_array.to_vec())
.map_err(|err| {
wasm_bindgen::JsValue::from(format!("Invalid Uint8Array: {err:?}"))
})
.map(|value| *value)
} else if let Some(array) = value.dyn_ref::<js_sys::Array>() {
let mut bytes = vec![];
let iterator =
js_sys::try_iter(&array.values())?.expect("array to be iterable");
for x in iterator {
let x = x?;

if let Some(n) = x.as_f64() {
if (0. ..=255.).contains(&n) {
bytes.push(n as u8);
continue;
}
}
return Err(format!("Invalid array argument: {:?}", x).into());
}

bytemuck::try_from_bytes(&bytes)
.map_err(|err| {
wasm_bindgen::JsValue::from(format!("Invalid Array: {err:?}"))
})
.map(|value| *value)
} else if value.is_undefined() {
Ok($pod_type::default())
} else {
Err("Unsupported argument".into())
}
}

pub fn toString(&self) -> String {
self.to_string()
}

pub fn equals(&self, other: &$pod_type) -> bool {
self == other
}

pub fn toBytes(&self) -> Box<[u8]> {
self.0.into()
}

pub fn zeroed() -> Self {
Self::default()
}

pub fn encode(decoded: &$decoded_type) -> $pod_type {
(*decoded).into()
}

pub fn decode(&self) -> Result<$decoded_type, wasm_bindgen::JsValue> {
(*self)
.try_into()
.map_err(|err| JsValue::from(format!("Invalid encoding: {err:?}")))
}
}
};
}
pub(crate) use impl_wasm_bindings;

0 comments on commit 6f1282d

Please sign in to comment.