From 5114e94aa2c2bf97f27a855bdf541068ff9033fc Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 13 Nov 2023 13:35:01 +0000 Subject: [PATCH] token 2022: add support for _reading_ repeating fixed-length extensions --- token/program-2022/src/extension/mod.rs | 176 ++++++++++++++++++++++++ 1 file changed, 176 insertions(+) diff --git a/token/program-2022/src/extension/mod.rs b/token/program-2022/src/extension/mod.rs index 7361d54d62d..adc0ed05c23 100644 --- a/token/program-2022/src/extension/mod.rs +++ b/token/program-2022/src/extension/mod.rs @@ -173,6 +173,35 @@ fn get_extension_indices( Err(ProgramError::InvalidAccountData) } +fn get_all_extension_indices( + tlv_data: &[u8], +) -> Result, ProgramError> { + let mut indices_list = vec![]; + let mut start_index = 0; + let v_account_type = V::TYPE.get_account_type(); + while start_index < tlv_data.len() { + let tlv_indices = get_tlv_indices(start_index); + if tlv_data.len() < tlv_indices.value_start { + return Err(ProgramError::InvalidAccountData); + } + let extension_type = + ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?; + let account_type = extension_type.get_account_type(); + + let length = + pod_from_bytes::(&tlv_data[tlv_indices.length_start..tlv_indices.value_start])?; + let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length)); + start_index = value_end_index; + + if extension_type == V::TYPE { + indices_list.push(tlv_indices); + } else if v_account_type != account_type && extension_type != ExtensionType::Uninitialized { + return Err(TokenError::ExtensionTypeMismatch.into()); + } + } + Ok(indices_list) +} + /// Basic information about the TLV buffer, collected from iterating through all /// entries #[derive(Debug, PartialEq)] @@ -339,6 +368,30 @@ fn get_extension_bytes(tlv_data: &[u8]) -> Result<&[ Ok(&tlv_data[value_start..value_end]) } +fn get_all_extension_bytes( + tlv_data: &[u8], +) -> Result, ProgramError> { + if V::TYPE.get_account_type() != S::ACCOUNT_TYPE { + return Err(ProgramError::InvalidAccountData); + } + let all_extension_indices = get_all_extension_indices::(tlv_data)?; + let mut all_extension_bytes = vec![]; + for TlvIndices { + type_start: _, + length_start, + value_start, + } in all_extension_indices.iter() + { + let length = pod_from_bytes::(&tlv_data[*length_start..*value_start])?; + let value_end = value_start.saturating_add(usize::from(*length)); + if tlv_data.len() < value_end { + return Err(ProgramError::InvalidAccountData); + } + all_extension_bytes.push(&tlv_data[*value_start..value_end]); + } + Ok(all_extension_bytes) +} + fn get_extension_bytes_mut( tlv_data: &mut [u8], ) -> Result<&mut [u8], ProgramError> { @@ -397,11 +450,46 @@ pub trait BaseStateWithExtensions { get_extension_bytes::(self.get_tlv_data()) } + /// Fetch the bytes for a TLV entry, where repetitions are allowed + fn get_repeating_extension_bytes( + &self, + repetition: usize, + ) -> Result<&[u8], ProgramError> { + get_all_extension_bytes::(self.get_tlv_data()).map(|x| { + *x.get(repetition.saturating_sub(1)) + .ok_or::(ProgramError::InvalidAccountData) + .unwrap() + }) + } + + /// Fetch all bytes for each entry of a particular TLV entry, where + /// repetitions are allowed + fn get_all_extension_bytes(&self) -> Result, ProgramError> { + get_all_extension_bytes::(self.get_tlv_data()) + } + /// Unpack a portion of the TLV data as the desired type fn get_extension(&self) -> Result<&V, ProgramError> { pod_from_bytes::(self.get_extension_bytes::()?) } + /// Unpack a portion of the TLV data as the desired type, where repetitions + /// are allowed + fn get_repeating_extension( + &self, + repetition: usize, + ) -> Result<&V, ProgramError> { + pod_from_bytes::(self.get_repeating_extension_bytes::(repetition)?) + } + + /// Unpack all extensions for the desired type + fn get_all_extensions(&self) -> Result, ProgramError> { + self.get_all_extension_bytes::()? + .iter() + .map(|bytes| pod_from_bytes::(bytes)) + .collect() + } + /// Unpacks a portion of the TLV data as the desired variable-length type fn get_variable_len_extension( &self, @@ -1410,6 +1498,54 @@ mod test { 1, 1, // data ]; + const MINT_WITH_DUPLICATED_EXTENSION: &[u8] = &[ + 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 42, 0, 0, 0, 0, 0, 0, 0, 7, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // base mint + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // padding + 1, // account type + 3, 0, // extension type + 32, 0, // length + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, // data + 3, 0, // extension type + 32, 0, // length + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, // data + 3, 0, // extension type + 32, 0, // length + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, // data + ]; + + const MINT_WITH_DUPLICATED_EXTENSION_AND_ONE_EXTRA: &[u8] = &[ + 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 42, 0, 0, 0, 0, 0, 0, 0, 7, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // base mint + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // padding + 1, // account type + 3, 0, // extension type + 32, 0, // length + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, // data + 3, 0, // extension type + 32, 0, // length + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, // data + 3, 0, // extension type + 32, 0, // length + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, // data + 14, 0, // extension type + 64, 0, // length + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + ]; + #[test] fn unpack_opaque_buffer() { let state = StateWithExtensions::::unpack(MINT_WITH_EXTENSION).unwrap(); @@ -1435,6 +1571,46 @@ mod test { assert_eq!(state.base, TEST_MINT); } + #[test] + fn unpack_opaque_buffer_with_duplicates() { + let state = StateWithExtensions::::unpack(MINT_WITH_DUPLICATED_EXTENSION).unwrap(); + assert_eq!(state.base, TEST_MINT); + let all_extensions = state.get_all_extensions::().unwrap(); + assert_eq!(all_extensions.len(), 3); + assert_eq!( + state.get_extension::(), + Err(ProgramError::InvalidAccountData) + ); + assert_eq!( + StateWithExtensions::::unpack(MINT_WITH_DUPLICATED_EXTENSION), + Err(ProgramError::InvalidAccountData) + ); + + // test we can get a single entry + let state = StateWithExtensions::::unpack(MINT_WITH_DUPLICATED_EXTENSION).unwrap(); + let extension = state + .get_repeating_extension::(1) + .unwrap(); + let close_authority = + OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap(); + assert_eq!(extension.close_authority, close_authority); + + let state = + StateWithExtensions::::unpack(MINT_WITH_DUPLICATED_EXTENSION_AND_ONE_EXTRA) + .unwrap(); + assert_eq!(state.base, TEST_MINT); + let all_extensions = state.get_all_extensions::().unwrap(); + assert_eq!(all_extensions.len(), 3); + assert_eq!( + state.get_extension::(), + Err(ProgramError::InvalidAccountData) + ); + assert_eq!( + StateWithExtensions::::unpack(MINT_WITH_DUPLICATED_EXTENSION_AND_ONE_EXTRA), + Err(ProgramError::InvalidAccountData) + ); + } + #[test] fn fail_unpack_opaque_buffer() { // input buffer too small