diff --git a/token/program-2022/src/extension/mod.rs b/token/program-2022/src/extension/mod.rs index adc0ed05c23..ddbc827e188 100644 --- a/token/program-2022/src/extension/mod.rs +++ b/token/program-2022/src/extension/mod.rs @@ -413,6 +413,37 @@ fn get_extension_bytes_mut( Ok(&mut tlv_data[value_start..value_end]) } +fn get_repeating_extension_bytes_mut( + tlv_data: &mut [u8], + repetition: usize, +) -> Result<&mut [u8], ProgramError> { + if V::TYPE.get_account_type() != S::ACCOUNT_TYPE { + return Err(ProgramError::InvalidAccountData); + } + + let mut start_index = 0; + let mut value_start = 0; + let mut value_end = 0; + + for _ in 0..repetition { + let indices = get_extension_indices::(&tlv_data[start_index..], false)?; + + let global_length_start = indices.length_start.saturating_add(start_index); + value_start = indices.value_start.saturating_add(start_index); + + let length = pod_from_bytes::(&tlv_data[global_length_start..value_start])?; + value_end = value_start.saturating_add(usize::from(*length)); + + if tlv_data.len() < value_end { + return Err(ProgramError::InvalidAccountData); + } + + start_index = value_end; + } + + Ok(&mut tlv_data[value_start..value_end]) +} + /// Calculate the new expected size if the state allocates the given number /// of bytes for the given extension type. /// @@ -719,6 +750,39 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> { pod_from_bytes_mut::(self.get_extension_bytes_mut::()?) } + /// Unpack a portion of the TLV data as the base mutable bytes, + /// for a repeating extension + fn get_repeating_extension_bytes_mut( + &mut self, + repetition: usize, + ) -> Result<&mut [u8], ProgramError> { + get_repeating_extension_bytes_mut::(self.tlv_data, repetition) + } + + /// Unpack a portion of the TLV data as the desired type that allows + /// modifying the type, for a repeating extension + pub fn get_repeating_extension_mut( + &mut self, + repetition: usize, + ) -> Result<&mut V, ProgramError> { + pod_from_bytes_mut::(self.get_repeating_extension_bytes_mut::(repetition)?) + } + + /// Returns an unpacked portion of TLV data that allows modifying the type, + /// based on the specified match criteria + pub fn get_first_matched_repeating_extension_mut( + &mut self, + match_critera: impl Fn(&V) -> bool, + ) -> Result<&mut V, ProgramError> { + for (index, extension) in self.get_all_extensions::()?.iter().enumerate() { + let repetition = index + 1; + if match_critera(extension) { + return self.get_repeating_extension_mut(repetition); + } + } + Err(TokenError::ExtensionNotFound.into()) + } + /// Packs a variable-length extension into its appropriate data segment. /// Fails if space hasn't already been allocated for the given extension pub fn pack_variable_len_extension( @@ -752,6 +816,18 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> { Ok(extension_ref) } + /// Packs the default extension data into an open slot, disregarding if + /// the extension has already been found in the data buffer. + pub fn init_extension_allow_repeating( + &mut self, + ) -> Result<&mut V, ProgramError> { + let length = pod_get_packed_len::(); + let buffer = self.alloc_allow_repeating::(length)?; + let extension_ref = pod_from_bytes_mut::(buffer)?; + *extension_ref = V::default(); + Ok(extension_ref) + } + /// Reallocate and overwite the TLV entry for the given variable-length /// extension. /// @@ -880,6 +956,53 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> { } } + fn alloc_allow_repeating( + &mut self, + length: usize, + ) -> Result<&mut [u8], ProgramError> { + if V::TYPE.get_account_type() != S::ACCOUNT_TYPE { + return Err(ProgramError::InvalidAccountData); + } + + let mut start_index = 0; + let mut type_start = 0; + let mut length_start = 0; + let mut value_start = 0; + let mut extension_type = V::TYPE; + let required_len = add_type_and_length_to_len(length); + + while extension_type != ExtensionType::Uninitialized { + let indices = get_extension_indices::(&self.tlv_data[start_index..], true)?; + (type_start, length_start, value_start) = ( + indices.type_start.saturating_add(start_index), + indices.length_start.saturating_add(start_index), + indices.value_start.saturating_add(start_index), + ); + + if self.tlv_data[type_start..].len() < required_len { + return Err(ProgramError::InvalidAccountData); + } + + extension_type = ExtensionType::try_from(&self.tlv_data[type_start..length_start])?; + start_index = value_start.saturating_add(usize::from(*pod_from_bytes::( + &self.tlv_data[length_start..value_start], + )?)); + } + + // write extension type + let extension_type_array: [u8; 2] = V::TYPE.into(); + let extension_type_ref = &mut self.tlv_data[type_start..length_start]; + extension_type_ref.copy_from_slice(&extension_type_array); + + // write length + let length_ref = + pod_from_bytes_mut::(&mut self.tlv_data[length_start..value_start])?; + *length_ref = Length::try_from(length)?; + let value_end = value_start.saturating_add(length); + + Ok(&mut self.tlv_data[value_start..value_end]) + } + /// If `extension_type` is an Account-associated ExtensionType that requires /// initialization on InitializeAccount, this method packs the default /// relevant Extension of an ExtensionType into an open slot if not @@ -1883,6 +2006,116 @@ mod test { ); } + #[test] + fn mint_with_repeating_extensions_pack_unpack() { + // Have to manually add the other two repeating entries, since + // `try_calculate_account_len` will skip duplicates. + let mint_size = + ExtensionType::try_calculate_account_len::(&[ExtensionType::MintCloseAuthority]) + .unwrap() + .saturating_add(36) + .saturating_add(36); + let mut buffer = vec![0; mint_size]; + + let close_authority1 = + OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap(); + let close_authority2 = + OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([2; 32]))).unwrap(); + let close_authority3 = + OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([3; 32]))).unwrap(); + + let mut state = StateWithExtensionsMut::::unpack_uninitialized(&mut buffer).unwrap(); + let extension = state + .init_extension_allow_repeating::() + .unwrap(); + extension.close_authority = close_authority1; + let extension = state + .init_extension_allow_repeating::() + .unwrap(); + extension.close_authority = close_authority2; + let extension = state + .init_extension_allow_repeating::() + .unwrap(); + extension.close_authority = close_authority3; + + assert_eq!( + &state.get_extension_types().unwrap(), + &[ + ExtensionType::MintCloseAuthority, + ExtensionType::MintCloseAuthority, + ExtensionType::MintCloseAuthority, + ] + ); + + let mut state = StateWithExtensionsMut::::unpack_uninitialized(&mut buffer).unwrap(); + state.base = TEST_MINT; + state.pack_base(); + state.init_account_type().unwrap(); + + let mint_close_auth_type_bytes = (ExtensionType::MintCloseAuthority as u16).to_le_bytes(); + let mint_close_auth_len_bytes = + (pod_get_packed_len::() as u16).to_le_bytes(); + + let mut expect = TEST_MINT_SLICE.to_vec(); + expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - Mint::LEN]); // padding + expect.push(AccountType::Mint.into()); + expect.extend_from_slice(&mint_close_auth_type_bytes); + expect.extend_from_slice(&mint_close_auth_len_bytes); + expect.extend_from_slice(&[1; 32]); + expect.extend_from_slice(&mint_close_auth_type_bytes); + expect.extend_from_slice(&mint_close_auth_len_bytes); + expect.extend_from_slice(&[2; 32]); + expect.extend_from_slice(&mint_close_auth_type_bytes); + expect.extend_from_slice(&mint_close_auth_len_bytes); + expect.extend_from_slice(&[3; 32]); + assert_eq!(expect, buffer); + + // check unpacking + let mut state = StateWithExtensionsMut::::unpack(&mut buffer).unwrap(); + let unpacked_extension = state + .get_repeating_extension_mut::(1) + .unwrap(); + assert_eq!( + *unpacked_extension, + MintCloseAuthority { + close_authority: close_authority1 + } + ); + + // update extension + let close_authority = OptionalNonZeroPubkey::try_from(None).unwrap(); + unpacked_extension.close_authority = close_authority; + + // check updates are propagated + let base = state.base; + let state = StateWithExtensions::::unpack(&buffer).unwrap(); + assert_eq!(state.base, base); + let unpacked_extension = state + .get_repeating_extension::(1) + .unwrap(); + assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority }); + + // check the rest + let unpacked_extension = state + .get_repeating_extension::(2) + .unwrap(); + assert_eq!( + *unpacked_extension, + MintCloseAuthority { + close_authority: close_authority2 + } + ); + let unpacked_extension = state + .get_repeating_extension::(3) + .unwrap(); + assert_eq!( + *unpacked_extension, + MintCloseAuthority { + close_authority: close_authority3 + } + ); + } + #[test] fn mint_extension_any_order() { let mint_size = ExtensionType::try_calculate_account_len::(&[