Skip to content

Commit

Permalink
SPL errors from hashes (#5169)
Browse files Browse the repository at this point in the history
* SPL errors from hashes

* hashed error code is first variant only

* add check for collision error codes

* address feedback!

* stupid `0`!
  • Loading branch information
Joe C authored Sep 1, 2023
1 parent 25381b2 commit cfaabb5
Show file tree
Hide file tree
Showing 17 changed files with 300 additions and 108 deletions.
6 changes: 2 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libraries/program-error/derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ proc-macro = true
[dependencies]
proc-macro2 = "1.0"
quote = "1.0"
solana-program = "1.16.3"
syn = { version = "2.0", features = ["full"] }
61 changes: 47 additions & 14 deletions libraries/program-error/derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,39 +14,72 @@
extern crate proc_macro;

mod macro_impl;
mod parser;

use macro_impl::MacroType;
use proc_macro::TokenStream;
use syn::{parse_macro_input, ItemEnum};
use {
crate::parser::SplProgramErrorArgs,
macro_impl::MacroType,
proc_macro::TokenStream,
syn::{parse_macro_input, ItemEnum},
};

/// Derive macro to add `Into<solana_program::program_error::ProgramError>` traits
/// Derive macro to add `Into<solana_program::program_error::ProgramError>`
/// trait
#[proc_macro_derive(IntoProgramError)]
pub fn into_program_error(input: TokenStream) -> TokenStream {
MacroType::IntoProgramError
.generate_tokens(parse_macro_input!(input as ItemEnum))
let ItemEnum { ident, .. } = parse_macro_input!(input as ItemEnum);
MacroType::IntoProgramError { ident }
.generate_tokens()
.into()
}

/// Derive macro to add `solana_program::decode_error::DecodeError` trait
#[proc_macro_derive(DecodeError)]
pub fn decode_error(input: TokenStream) -> TokenStream {
MacroType::DecodeError
.generate_tokens(parse_macro_input!(input as ItemEnum))
.into()
let ItemEnum { ident, .. } = parse_macro_input!(input as ItemEnum);
MacroType::DecodeError { ident }.generate_tokens().into()
}

/// Derive macro to add `solana_program::program_error::PrintProgramError` trait
#[proc_macro_derive(PrintProgramError)]
pub fn print_program_error(input: TokenStream) -> TokenStream {
MacroType::PrintProgramError
.generate_tokens(parse_macro_input!(input as ItemEnum))
let ItemEnum {
ident, variants, ..
} = parse_macro_input!(input as ItemEnum);
MacroType::PrintProgramError { ident, variants }
.generate_tokens()
.into()
}

/// Proc macro attribute to turn your enum into a Solana Program Error
///
/// Adds:
/// - `Clone`
/// - `Debug`
/// - `Eq`
/// - `PartialEq`
/// - `thiserror::Error`
/// - `num_derive::FromPrimitive`
/// - `Into<solana_program::program_error::ProgramError>`
/// - `solana_program::decode_error::DecodeError`
/// - `solana_program::program_error::PrintProgramError`
///
/// Optionally, you can add `hash_error_code_start: u32` argument to create
/// a unique `u32` _starting_ error codes from the names of the enum variants.
/// Notes:
/// - The _error_ variant will start at this value, and the rest will be
/// incremented by one
/// - The value provided is only for code readability, the actual error code
/// will be a hash of the input string and is checked against your input
///
/// Syntax: `#[spl_program_error(hash_error_code_start = 1275525928)]`
/// Hash Input: `spl_program_error:<enum name>:<variant name>`
/// Value: `u32::from_le_bytes(<hash of input>[13..17])`
#[proc_macro_attribute]
pub fn spl_program_error(_: TokenStream, input: TokenStream) -> TokenStream {
MacroType::SplProgramError
.generate_tokens(parse_macro_input!(input as ItemEnum))
pub fn spl_program_error(attr: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as SplProgramErrorArgs);
let item_enum = parse_macro_input!(input as ItemEnum);
MacroType::SplProgramError { args, item_enum }
.generate_tokens()
.into()
}
122 changes: 102 additions & 20 deletions libraries/program-error/derive/src/macro_impl.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,52 @@
//! The actual token generator for the macro
use quote::quote;
use syn::{punctuated::Punctuated, token::Comma, Ident, ItemEnum, LitStr, Variant};
use {
crate::parser::SplProgramErrorArgs,
proc_macro2::Span,
quote::quote,
syn::{
punctuated::Punctuated, token::Comma, Expr, ExprLit, Ident, ItemEnum, Lit, LitInt, LitStr,
Token, Variant,
},
};

const SPL_ERROR_HASH_NAMESPACE: &str = "spl_program_error";
const SPL_ERROR_HASH_MIN_VALUE: u32 = 7_000;

/// The type of macro being called, thus directing which tokens to generate
#[allow(clippy::enum_variant_names)]
pub enum MacroType {
IntoProgramError,
DecodeError,
PrintProgramError,
SplProgramError,
IntoProgramError {
ident: Ident,
},
DecodeError {
ident: Ident,
},
PrintProgramError {
ident: Ident,
variants: Punctuated<Variant, Comma>,
},
SplProgramError {
args: SplProgramErrorArgs,
item_enum: ItemEnum,
},
}

impl MacroType {
/// Generates the corresponding tokens based on variant selection
pub fn generate_tokens(&self, item_enum: ItemEnum) -> proc_macro2::TokenStream {
pub fn generate_tokens(&mut self) -> proc_macro2::TokenStream {
match self {
MacroType::IntoProgramError => into_program_error(&item_enum.ident),
MacroType::DecodeError => decode_error(&item_enum.ident),
MacroType::PrintProgramError => {
print_program_error(&item_enum.ident, &item_enum.variants)
}
MacroType::SplProgramError => spl_program_error(item_enum),
Self::IntoProgramError { ident } => into_program_error(ident),
Self::DecodeError { ident } => decode_error(ident),
Self::PrintProgramError { ident, variants } => print_program_error(ident, variants),
Self::SplProgramError { args, item_enum } => spl_program_error(args, item_enum),
}
}
}

/// Builds the implementation of `Into<solana_program::program_error::ProgramError>`
/// More specifically, implements `From<Self> for solana_program::program_error::ProgramError`
/// Builds the implementation of
/// `Into<solana_program::program_error::ProgramError>` More specifically,
/// implements `From<Self> for solana_program::program_error::ProgramError`
pub fn into_program_error(ident: &Ident) -> proc_macro2::TokenStream {
quote! {
impl From<#ident> for solana_program::program_error::ProgramError {
Expand All @@ -48,7 +68,8 @@ pub fn decode_error(ident: &Ident) -> proc_macro2::TokenStream {
}
}

/// Builds the implementation of `solana_program::program_error::PrintProgramError`
/// Builds the implementation of
/// `solana_program::program_error::PrintProgramError`
pub fn print_program_error(
ident: &Ident,
variants: &Punctuated<Variant, Comma>,
Expand Down Expand Up @@ -96,16 +117,25 @@ fn get_error_message(variant: &Variant) -> Option<String> {

/// The main function that produces the tokens required to turn your
/// error enum into a Solana Program Error
pub fn spl_program_error(input: ItemEnum) -> proc_macro2::TokenStream {
let ident = &input.ident;
let variants = &input.variants;
pub fn spl_program_error(
args: &SplProgramErrorArgs,
item_enum: &mut ItemEnum,
) -> proc_macro2::TokenStream {
if let Some(error_code_start) = args.hash_error_code_start {
set_first_discriminant(item_enum, error_code_start);
}

let ident = &item_enum.ident;
let variants = &item_enum.variants;
let into_program_error = into_program_error(ident);
let decode_error = decode_error(ident);
let print_program_error = print_program_error(ident, variants);

quote! {
#[repr(u32)]
#[derive(Clone, Debug, Eq, thiserror::Error, num_derive::FromPrimitive, PartialEq)]
#[num_traits = "num_traits"]
#input
#item_enum

#into_program_error

Expand All @@ -114,3 +144,55 @@ pub fn spl_program_error(input: ItemEnum) -> proc_macro2::TokenStream {
#print_program_error
}
}

/// This function adds a discriminant to the first enum variant based on the
/// hash of the `SPL_ERROR_HASH_NAMESPACE` constant, the enum name and variant
/// name.
/// It will then check to make sure the provided `hash_error_code_start` is
/// equal to the hash-produced `u32`.
///
/// See https://docs.rs/syn/latest/syn/struct.Variant.html
fn set_first_discriminant(item_enum: &mut ItemEnum, error_code_start: u32) {
let enum_ident = &item_enum.ident;
if item_enum.variants.is_empty() {
panic!("Enum must have at least one variant");
}
let first_variant = &mut item_enum.variants[0];
let discriminant = u32_from_hash(enum_ident);
if discriminant == error_code_start {
let eq = Token![=](Span::call_site());
let expr = Expr::Lit(ExprLit {
attrs: Vec::new(),
lit: Lit::Int(LitInt::new(&discriminant.to_string(), Span::call_site())),
});
first_variant.discriminant = Some((eq, expr));
} else {
panic!(
"Error code start value from hash must be {0}. Update your macro attribute to \
`#[spl_program_error(hash_error_code_start = {0})]`.",
discriminant
);
}
}

/// Hashes the `SPL_ERROR_HASH_NAMESPACE` constant, the enum name and variant
/// name and returns four middle bytes (13 through 16) as a u32.
fn u32_from_hash(enum_ident: &Ident) -> u32 {
let hash_input = format!("{}:{}", SPL_ERROR_HASH_NAMESPACE, enum_ident);

// We don't want our error code to start at any number below
// `SPL_ERROR_HASH_MIN_VALUE`!
let mut nonce: u32 = 0;
loop {
let hash = solana_program::hash::hashv(&[hash_input.as_bytes(), &nonce.to_le_bytes()]);
let d = u32::from_le_bytes(
hash.to_bytes()[13..17]
.try_into()
.expect("Unable to convert hash to u32"),
);
if d >= SPL_ERROR_HASH_MIN_VALUE {
return d;
}
nonce += 1;
}
}
64 changes: 64 additions & 0 deletions libraries/program-error/derive/src/parser.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//! Token parsing
use {
proc_macro2::Ident,
syn::{
parse::{Parse, ParseStream},
token::Comma,
LitInt, Token,
},
};

/// Possible arguments to the `#[spl_program_error]` attribute
pub struct SplProgramErrorArgs {
/// Whether to hash the error codes using `solana_program::hash`
/// or to use the default error code assigned by `num_traits`.
pub hash_error_code_start: Option<u32>,
}

impl Parse for SplProgramErrorArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
if input.is_empty() {
return Ok(Self {
hash_error_code_start: None,
});
}
match SplProgramErrorArgParser::parse(input)? {
SplProgramErrorArgParser::HashErrorCodes { value, .. } => Ok(Self {
hash_error_code_start: Some(value.base10_parse::<u32>()?),
}),
}
}
}

/// Parser for args to the `#[spl_program_error]` attribute
/// ie. `#[spl_program_error(hash_error_code_start = 1275525928)]`
enum SplProgramErrorArgParser {
HashErrorCodes {
_ident: Ident,
_equals_sign: Token![=],
value: LitInt,
_comma: Option<Comma>,
},
}

impl Parse for SplProgramErrorArgParser {
fn parse(input: ParseStream) -> syn::Result<Self> {
let _ident = {
let ident = input.parse::<Ident>()?;
if ident != "hash_error_code_start" {
return Err(input.error("Expected argument 'hash_error_code_start'"));
}
ident
};
let _equals_sign = input.parse::<Token![=]>()?;
let value = input.parse::<LitInt>()?;
let _comma: Option<Comma> = input.parse().unwrap_or(None);
Ok(Self::HashErrorCodes {
_ident,
_equals_sign,
value,
_comma,
})
}
}
12 changes: 6 additions & 6 deletions libraries/program-error/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ extern crate self as spl_program_error;

// Make these available downstream for the macro to work without
// additional imports
pub use num_derive;
pub use num_traits;
pub use solana_program;
pub use spl_program_error_derive::{
spl_program_error, DecodeError, IntoProgramError, PrintProgramError,
pub use {
num_derive, num_traits, solana_program,
spl_program_error_derive::{
spl_program_error, DecodeError, IntoProgramError, PrintProgramError,
},
thiserror,
};
pub use thiserror;
2 changes: 1 addition & 1 deletion libraries/program-error/tests/decode.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Tests `#[derive(DecodeError)]`
//!
use spl_program_error::*;

/// Example error
Expand Down
2 changes: 1 addition & 1 deletion libraries/program-error/tests/into.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Tests `#[derive(IntoProgramError)]`
//!
use spl_program_error::*;

/// Example error
Expand Down
14 changes: 8 additions & 6 deletions libraries/program-error/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ pub mod spl;

#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
use solana_program::{
decode_error::DecodeError,
program_error::{PrintProgramError, ProgramError},
use {
super::*,
serial_test::serial,
solana_program::{
decode_error::DecodeError,
program_error::{PrintProgramError, ProgramError},
},
std::sync::{Arc, RwLock},
};
use std::sync::{Arc, RwLock};

// Used to capture output for `PrintProgramError` for testing
lazy_static::lazy_static! {
Expand Down
Loading

0 comments on commit cfaabb5

Please sign in to comment.