diff --git a/Cargo.lock b/Cargo.lock index 02dd0a583e..70706b4d10 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3375,6 +3375,13 @@ dependencies = [ "thiserror 2.0.17", ] +[[package]] +name = "light-array-map" +version = "0.1.0" +dependencies = [ + "tinyvec", +] + [[package]] name = "light-batched-merkle-tree" version = "0.5.0" @@ -3885,9 +3892,11 @@ name = "light-system-program-pinocchio" version = "1.2.0" dependencies = [ "aligned-sized", + "arrayvec", "borsh 0.10.4", "bytemuck", "light-account-checks", + "light-array-map", "light-batched-merkle-tree", "light-compressed-account", "light-concurrent-merkle-tree", diff --git a/Cargo.toml b/Cargo.toml index 3641c22599..97dbdb0882 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "program-libs/account-checks", + "program-libs/array-map", "program-libs/compressed-account", "program-libs/aligned-sized", "program-libs/batched-merkle-tree", @@ -200,11 +201,13 @@ light-bounded-vec = { version = "2.0.0" } light-poseidon = { version = "0.3.0" } light-test-utils = { path = "program-tests/utils", version = "1.2.1" } light-indexed-array = { path = "program-libs/indexed-array", version = "0.2.0" } +light-array-map = { path = "program-libs/array-map", version = "0.1.0" } light-program-profiler = { version = "0.1.0" } create-address-program-test = { path = "program-tests/create-address-test-program", version = "1.0.0" } groth16-solana = { version = "0.2.0" } bytemuck = { version = "1.19.0" } arrayvec = "0.7" +tinyvec = "1.10.0" # Math and crypto num-bigint = "0.4.6" diff --git a/program-libs/array-map/Cargo.toml b/program-libs/array-map/Cargo.toml new file mode 100644 index 0000000000..51f424159c --- /dev/null +++ b/program-libs/array-map/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "light-array-map" +version = "0.1.0" +description = "Generic array-backed map with O(n) lookup for small collections" +repository = "https://github.com/Lightprotocol/light-protocol" +license = "Apache-2.0" +edition = "2021" + +[features] +default = [] +alloc = ["tinyvec/alloc"] + +[dependencies] +tinyvec = { workspace = true } diff --git a/program-libs/array-map/src/lib.rs b/program-libs/array-map/src/lib.rs new file mode 100644 index 0000000000..9f830c8702 --- /dev/null +++ b/program-libs/array-map/src/lib.rs @@ -0,0 +1,184 @@ +#![no_std] + +#[cfg(feature = "alloc")] +extern crate alloc; + +use core::ptr::read_unaligned; + +use tinyvec::ArrayVec; + +/// A generic tinyvec::ArrayVec backed map with O(n) lookup. +/// Maintains insertion order and tracks the last accessed entry index. +pub struct ArrayMap +where + K: PartialEq + Default, + V: Default, +{ + entries: ArrayVec<[(K, V); N]>, + last_accessed_index: Option, +} + +impl ArrayMap +where + K: PartialEq + Default, + V: Default, +{ + pub fn new() -> Self { + Self { + entries: ArrayVec::new(), + last_accessed_index: None, + } + } + + pub fn len(&self) -> usize { + self.entries.len() + } + + pub fn is_empty(&self) -> bool { + self.entries.is_empty() + } + + pub fn last_accessed_index(&self) -> Option { + self.last_accessed_index + } + + pub fn get(&self, index: usize) -> Option<&(K, V)> { + self.entries.get(index) + } + + pub fn get_mut(&mut self, index: usize) -> Option<&mut (K, V)> { + self.entries.get_mut(index) + } + + pub fn get_u8(&self, index: u8) -> Option<&(K, V)> { + self.get(index as usize) + } + + pub fn get_mut_u8(&mut self, index: u8) -> Option<&mut (K, V)> { + self.get_mut(index as usize) + } + + pub fn get_by_key(&self, key: &K) -> Option<&V> { + self.entries.iter().find(|(k, _)| k == key).map(|(_, v)| v) + } + + pub fn get_mut_by_key(&mut self, key: &K) -> Option<&mut V> { + self.entries + .iter_mut() + .find(|(k, _)| k == key) + .map(|(_, v)| v) + } + + pub fn find(&self, key: &K) -> Option<(usize, &(K, V))> { + self.entries.iter().enumerate().find(|(_, (k, _))| k == key) + } + + pub fn find_mut(&mut self, key: &K) -> Option<(usize, &mut (K, V))> { + self.entries + .iter_mut() + .enumerate() + .find(|(_, (k, _))| k == key) + } + + pub fn find_index(&self, key: &K) -> Option { + self.find(key).map(|(idx, _)| idx) + } + + pub fn set_last_accessed_index(&mut self, index: usize) -> Result<(), E> + where + E: From, + { + if index < self.entries.len() { + self.last_accessed_index = Some(index); + Ok(()) + } else { + Err(ArrayMapError::IndexOutOfBounds.into()) + } + } + + pub fn insert(&mut self, key: K, value: V, error: E) -> Result { + let new_idx = self.entries.len(); + // tinyvec's try_push returns Some(item) on failure, None on success + if self.entries.try_push((key, value)).is_some() { + return Err(error); + } + self.last_accessed_index = Some(new_idx); + Ok(new_idx) + } +} + +impl Default for ArrayMap +where + K: PartialEq + Default, + V: Default, +{ + fn default() -> Self { + Self::new() + } +} + +// Optimized [u8; 32] key methods (4x u64 comparison instead of 32x u8). +impl ArrayMap<[u8; 32], V, N> +where + V: Default, +{ + pub fn get_by_pubkey(&self, key: &[u8; 32]) -> Option<&V> { + self.entries + .iter() + .find(|(k, _)| pubkey_eq(k, key)) + .map(|(_, v)| v) + } + + pub fn get_mut_by_pubkey(&mut self, key: &[u8; 32]) -> Option<&mut V> { + self.entries + .iter_mut() + .find(|(k, _)| pubkey_eq(k, key)) + .map(|(_, v)| v) + } + + pub fn find_by_pubkey(&self, key: &[u8; 32]) -> Option<(usize, &([u8; 32], V))> { + self.entries + .iter() + .enumerate() + .find(|(_, (k, _))| pubkey_eq(k, key)) + } + + pub fn find_mut_by_pubkey(&mut self, key: &[u8; 32]) -> Option<(usize, &mut ([u8; 32], V))> { + self.entries + .iter_mut() + .enumerate() + .find(|(_, (k, _))| pubkey_eq(k, key)) + } + + pub fn find_pubkey_index(&self, key: &[u8; 32]) -> Option { + self.find_by_pubkey(key).map(|(idx, _)| idx) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ArrayMapError { + CapacityExceeded, + IndexOutOfBounds, +} + +impl core::fmt::Display for ArrayMapError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + ArrayMapError::CapacityExceeded => write!(f, "ArrayMap capacity exceeded"), + ArrayMapError::IndexOutOfBounds => write!(f, "ArrayMap index out of bounds"), + } + } +} + +#[inline(always)] +pub const fn pubkey_eq(p1: &[u8; 32], p2: &[u8; 32]) -> bool { + let p1_ptr = p1.as_ptr() as *const u64; + let p2_ptr = p2.as_ptr() as *const u64; + + unsafe { + read_unaligned(p1_ptr) == read_unaligned(p2_ptr) + && read_unaligned(p1_ptr.add(1)) == read_unaligned(p2_ptr.add(1)) + && read_unaligned(p1_ptr.add(2)) == read_unaligned(p2_ptr.add(2)) + && read_unaligned(p1_ptr.add(3)) == read_unaligned(p2_ptr.add(3)) + } +} diff --git a/program-libs/array-map/tests/array_map_tests.rs b/program-libs/array-map/tests/array_map_tests.rs new file mode 100644 index 0000000000..8789435a0b --- /dev/null +++ b/program-libs/array-map/tests/array_map_tests.rs @@ -0,0 +1,245 @@ +use light_array_map::{ArrayMap, ArrayMapError}; + +// Test error type for testing +#[derive(Debug, PartialEq)] +enum TestError { + ArrayMap(ArrayMapError), + Custom, +} + +impl From for TestError { + fn from(e: ArrayMapError) -> Self { + TestError::ArrayMap(e) + } +} + +#[test] +fn test_new_map() { + let map = ArrayMap::::new(); + assert_eq!(map.len(), 0); + assert!(map.is_empty()); + assert!(map.last_accessed_index().is_none()); +} + +#[test] +fn test_insert() { + let mut map = ArrayMap::::new(); + + let idx = map.insert(1, "one".to_string(), TestError::Custom).unwrap(); + + assert_eq!(idx, 0); + assert_eq!(map.len(), 1); + assert_eq!(map.last_accessed_index(), Some(0)); + assert_eq!(map.get(0).unwrap().1, "one"); +} + +#[test] +fn test_get_by_key() { + let mut map = ArrayMap::::new(); + + map.insert(1, "one".to_string(), TestError::Custom).unwrap(); + map.insert(2, "two".to_string(), TestError::Custom).unwrap(); + + assert_eq!(map.get_by_key(&1), Some(&"one".to_string())); + assert_eq!(map.get_by_key(&2), Some(&"two".to_string())); + assert_eq!(map.get_by_key(&3), None); +} + +#[test] +fn test_get_mut_by_key() { + let mut map = ArrayMap::::new(); + + map.insert(1, "one".to_string(), TestError::Custom).unwrap(); + + if let Some(val) = map.get_mut_by_key(&1) { + *val = "ONE".to_string(); + } + + assert_eq!(map.get_by_key(&1), Some(&"ONE".to_string())); +} + +#[test] +fn test_find_index() { + let mut map = ArrayMap::::new(); + + map.insert(10, "ten".to_string(), TestError::Custom) + .unwrap(); + map.insert(20, "twenty".to_string(), TestError::Custom) + .unwrap(); + + assert_eq!(map.find_index(&10), Some(0)); + assert_eq!(map.find_index(&20), Some(1)); + assert_eq!(map.find_index(&30), None); +} + +#[test] +fn test_set_last_accessed_index() { + let mut map = ArrayMap::::new(); + + map.insert(1, "one".to_string(), TestError::Custom).unwrap(); + map.insert(2, "two".to_string(), TestError::Custom).unwrap(); + + // Should be at index 1 after last insert + assert_eq!(map.last_accessed_index(), Some(1)); + + // Set to 0 + map.set_last_accessed_index::(0).unwrap(); + assert_eq!(map.last_accessed_index(), Some(0)); + + // Out of bounds should fail + let result = map.set_last_accessed_index::(10); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err(), + TestError::ArrayMap(ArrayMapError::IndexOutOfBounds) + ); +} + +#[test] +fn test_capacity_limit() { + let mut map = ArrayMap::::new(); + + // Fill to capacity + for i in 0..5 { + map.insert(i, format!("val{}", i), TestError::Custom) + .unwrap(); + } + + assert_eq!(map.len(), 5); + + // 6th entry should fail + let result = map.insert(5, "val5".to_string(), TestError::Custom); + assert!(result.is_err()); +} + +#[test] +fn test_get_mut_direct() { + let mut map = ArrayMap::::new(); + + map.insert(1, 100, TestError::Custom).unwrap(); + + if let Some(entry) = map.get_mut(0) { + entry.1 += 50; + } + + assert_eq!(map.get(0).unwrap().1, 150); +} + +#[test] +fn test_last_accessed_index_updates() { + let mut map = ArrayMap::::new(); + + // Insert first entry + map.insert(1, 100, TestError::Custom).unwrap(); + assert_eq!(map.last_accessed_index(), Some(0)); + + // Insert second entry + map.insert(2, 200, TestError::Custom).unwrap(); + assert_eq!(map.last_accessed_index(), Some(1)); +} + +#[cfg(feature = "alloc")] +#[test] +fn test_with_alloc_feature() { + extern crate alloc; + use alloc::{string::String, vec::Vec}; + + // NOTE: ArrayVec is ALWAYS fixed-capacity (stack-only), even with alloc feature. + // The alloc feature just enables using heap-allocated VALUE types like String/Vec. + // ArrayVec itself will still error when capacity is exceeded. + + let mut map = ArrayMap::::new(); + + // Fill to capacity with heap-allocated strings + for i in 0..5 { + map.insert(i, format!("string_{}", i), TestError::Custom) + .unwrap(); + } + + assert_eq!(map.len(), 5); + + // ArrayVec still has fixed capacity - 6th insert should fail + let result = map.insert(5, String::from("overflow"), TestError::Custom); + assert!( + result.is_err(), + "ArrayVec should fail when capacity exceeded, even with alloc feature" + ); + + // Test with Vec values (heap-allocated) + let mut vec_map = ArrayMap::, 3>::new(); + vec_map.insert(1, vec![1, 2, 3], TestError::Custom).unwrap(); + vec_map + .insert(2, vec![4, 5, 6, 7, 8], TestError::Custom) + .unwrap(); + vec_map.insert(3, vec![9, 10], TestError::Custom).unwrap(); + + // The Vec VALUES can be any size (heap-allocated) + assert_eq!(vec_map.get_by_key(&1).map(|v| v.len()), Some(3)); + assert_eq!(vec_map.get_by_key(&2).map(|v| v.len()), Some(5)); + + // But the ArrayVec container itself still has fixed capacity + let result = vec_map.insert(4, vec![99], TestError::Custom); + assert!( + result.is_err(), + "ArrayVec container is still fixed capacity" + ); +} + +#[test] +fn test_capacity_overflow_without_alloc() { + // Demonstrate that ArrayVec has fixed capacity regardless of alloc feature + let mut map = ArrayMap::::new(); + + // Fill to capacity + map.insert(1, 100, TestError::Custom).unwrap(); + map.insert(2, 200, TestError::Custom).unwrap(); + map.insert(3, 300, TestError::Custom).unwrap(); + + assert_eq!(map.len(), 3); + + // 4th insert should fail - fixed capacity + let result = map.insert(4, 400, TestError::Custom); + assert!(result.is_err(), "ArrayVec has fixed capacity"); +} + +#[test] +fn test_get_u8() { + let mut map = ArrayMap::::new(); + + map.insert(1, "one".to_string(), TestError::Custom).unwrap(); + map.insert(2, "two".to_string(), TestError::Custom).unwrap(); + map.insert(3, "three".to_string(), TestError::Custom) + .unwrap(); + + // Test valid indices + assert_eq!(map.get_u8(0).unwrap().1, "one"); + assert_eq!(map.get_u8(1).unwrap().1, "two"); + assert_eq!(map.get_u8(2).unwrap().1, "three"); + + // Test out of bounds + assert!(map.get_u8(3).is_none()); + assert!(map.get_u8(255).is_none()); +} + +#[test] +fn test_get_mut_u8() { + let mut map = ArrayMap::::new(); + + map.insert(1, 100, TestError::Custom).unwrap(); + map.insert(2, 200, TestError::Custom).unwrap(); + map.insert(3, 300, TestError::Custom).unwrap(); + + // Modify via get_mut_u8 + if let Some(entry) = map.get_mut_u8(1) { + entry.1 += 50; + } + + // Verify modification + assert_eq!(map.get_u8(1).unwrap().1, 250); + assert_eq!(map.get_u8(0).unwrap().1, 100); + assert_eq!(map.get_u8(2).unwrap().1, 300); + + // Test out of bounds + assert!(map.get_mut_u8(3).is_none()); + assert!(map.get_mut_u8(255).is_none()); +} diff --git a/program-libs/compressed-account/src/pubkey.rs b/program-libs/compressed-account/src/pubkey.rs index 9325f96988..7a6d3cc99b 100644 --- a/program-libs/compressed-account/src/pubkey.rs +++ b/program-libs/compressed-account/src/pubkey.rs @@ -74,6 +74,10 @@ impl Pubkey { array.copy_from_slice(slice); Self(array) } + + pub fn array_ref(&self) -> &[u8; 32] { + &self.0 + } } impl AsRef for Pubkey { diff --git a/program-tests/system-test/tests/test.rs b/program-tests/system-test/tests/test.rs index c5ee61b95f..88fc5a2deb 100644 --- a/program-tests/system-test/tests/test.rs +++ b/program-tests/system-test/tests/test.rs @@ -522,26 +522,26 @@ pub async fn failing_transaction_inputs_inner( .await .unwrap(); } - // output Merkle tree is not unique (we need at least 2 outputs for this test) - if num_outputs > 1 { - let mut inputs_struct = inputs_struct.clone(); - let mut remaining_accounts = remaining_accounts.clone(); - let remaining_mt_acc = remaining_accounts - [inputs_struct.output_compressed_accounts[1].merkle_tree_index as usize] - .clone(); - remaining_accounts.push(remaining_mt_acc); - inputs_struct.output_compressed_accounts[1].merkle_tree_index = - (remaining_accounts.len() - 1) as u8; - create_instruction_and_failing_transaction( - rpc, - payer, - inputs_struct, - remaining_accounts.clone(), - SystemProgramError::OutputMerkleTreeNotUnique.into(), - ) - .await - .unwrap(); - } + // // output Merkle tree is not unique (we need at least 2 outputs for this test) + // if num_outputs > 1 { + // let mut inputs_struct = inputs_struct.clone(); + // let mut remaining_accounts = remaining_accounts.clone(); + // let remaining_mt_acc = remaining_accounts + // [inputs_struct.output_compressed_accounts[1].merkle_tree_index as usize] + // .clone(); + // remaining_accounts.push(remaining_mt_acc); + // inputs_struct.output_compressed_accounts[1].merkle_tree_index = + // (remaining_accounts.len() - 1) as u8; + // create_instruction_and_failing_transaction( + // rpc, + // payer, + // inputs_struct, + // remaining_accounts.clone(), + // SystemProgramError::OutputMerkleTreeNotUnique.into(), + // ) + // .await + // .unwrap(); + // } Ok(()) } diff --git a/programs/system/Cargo.toml b/programs/system/Cargo.toml index 9817ee655f..52426471b0 100644 --- a/programs/system/Cargo.toml +++ b/programs/system/Cargo.toml @@ -56,6 +56,8 @@ pinocchio-pubkey = { workspace = true } solana-msg = { workspace = true } light-program-profiler = { workspace = true } light-heap = { workspace = true, optional = true } +light-array-map = { workspace = true } +arrayvec = { workspace = true } [dev-dependencies] rand = { workspace = true } light-compressed-account = { workspace = true, features = [ diff --git a/programs/system/src/errors.rs b/programs/system/src/errors.rs index b69ae8af1c..cae3e25e64 100644 --- a/programs/system/src/errors.rs +++ b/programs/system/src/errors.rs @@ -1,4 +1,5 @@ use light_account_checks::error::AccountError; +use light_array_map::ArrayMapError; use light_batched_merkle_tree::errors::BatchedMerkleTreeError; use light_concurrent_merkle_tree::errors::ConcurrentMerkleTreeError; use light_indexed_merkle_tree::errors::IndexedMerkleTreeError; @@ -140,6 +141,10 @@ pub enum SystemProgramError { PackedAccountIndexOutOfBounds, #[error("Unimplemented.")] Unimplemented, + #[error("Too many output V2 queues (max 30).")] + TooManyOutputV2Queues, + #[error("Too many output V1 trees (max 30).")] + TooManyOutputV1Trees, #[error("Batched Merkle tree error {0}")] BatchedMerkleTreeError(#[from] BatchedMerkleTreeError), #[error("Concurrent Merkle tree error {0}")] @@ -223,6 +228,8 @@ impl From for u32 { SystemProgramError::Unimplemented => 6063, SystemProgramError::CpiContextDeactivated => 6064, SystemProgramError::InputMerkleTreeIndexOutOfBounds => 6065, + SystemProgramError::TooManyOutputV2Queues => 6066, + SystemProgramError::TooManyOutputV1Trees => 6067, SystemProgramError::BatchedMerkleTreeError(e) => e.into(), SystemProgramError::IndexedMerkleTreeError(e) => e.into(), SystemProgramError::ConcurrentMerkleTreeError(e) => e.into(), @@ -238,3 +245,12 @@ impl From for ProgramError { ProgramError::Custom(e.into()) } } + +impl From for SystemProgramError { + fn from(e: ArrayMapError) -> Self { + match e { + ArrayMapError::CapacityExceeded => SystemProgramError::TooManyOutputV2Queues, + ArrayMapError::IndexOutOfBounds => SystemProgramError::OutputMerkleTreeIndexOutOfBounds, + } + } +} diff --git a/programs/system/src/processor/create_outputs_cpi_data.rs b/programs/system/src/processor/create_outputs_cpi_data.rs index 84fc7a3b88..d80f1abfa8 100644 --- a/programs/system/src/processor/create_outputs_cpi_data.rs +++ b/programs/system/src/processor/create_outputs_cpi_data.rs @@ -1,3 +1,4 @@ +use light_array_map::ArrayMap; use light_compressed_account::{ hash_to_bn254_field_size_be, instruction_data::{ @@ -10,6 +11,7 @@ use light_hasher::{Hasher, Poseidon}; use light_program_profiler::profile; use pinocchio::{account_info::AccountInfo, msg, program_error::ProgramError}; +use super::tree_leaf_tracker_ext::TreeLeafTrackerTupleExt; use crate::{ accounts::remaining_account_checks::AcpAccount, context::{SystemContext, WrappedInstructionData}, @@ -42,18 +44,15 @@ pub fn create_outputs_cpi_data<'a, 'info, T: InstructionData<'a>>( if inputs.output_len() == 0 { return Ok([0u8; 32]); } - let mut current_index: i16 = -1; - let mut num_leaves_in_tree: u32 = 0; - let mut mt_next_index: u32 = 0; let mut hashed_merkle_tree = [0u8; 32]; cpi_ix_data.start_output_appends = context.account_indices.len() as u8; - let mut index_merkle_tree_account_account = cpi_ix_data.start_output_appends; + // TODO: check correct index use and deduplicate if possible. + let mut current_index: i16 = -1; + let mut next_account_index = cpi_ix_data.start_output_appends; let mut index_merkle_tree_account = 0; - let number_of_merkle_trees = - inputs.output_accounts().last().unwrap().merkle_tree_index() as usize + 1; - let mut merkle_tree_pubkeys = - Vec::::with_capacity(number_of_merkle_trees); + // Track (tree_pubkey, (next_leaf_index, account_index)) for each unique tree + let mut tree_leaf_tracker = ArrayMap::<[u8; 32], (u64, u8), 30>::new(); let mut hash_chain = [0u8; 32]; let mut rollover_fee = 0; let mut is_batched = true; @@ -62,50 +61,95 @@ pub fn create_outputs_cpi_data<'a, 'info, T: InstructionData<'a>>( // if mt index == current index Merkle tree account info has already been added. // if mt index != current index, Merkle tree account info is new, add it. #[allow(clippy::comparison_chain)] - if account.merkle_tree_index() as i16 == current_index { - // Do nothing, but it is the most common case. - } else if account.merkle_tree_index() as i16 > current_index { + let (leaf_index, account_index) = if account.merkle_tree_index() as i16 == current_index { + // Same tree as previous iteration - just increment leaf index + tree_leaf_tracker.increment_current_tuple()? + } else { current_index = account.merkle_tree_index().into(); - - let pubkey = match &accounts + // Get tree/queue pubkey and metadata + match &accounts .get(current_index as usize) .ok_or(SystemProgramError::OutputMerkleTreeIndexOutOfBounds)? { AcpAccount::OutputQueue(output_queue) => { - context.set_network_fee( - output_queue.metadata.rollover_metadata.network_fee, - current_index as u8, - ); + let initial_leaf_index = output_queue.batch_metadata.next_index; + + // Get or insert tree entry - returns ((leaf_idx, account_idx), is_new) + let ((leaf_idx, account_idx), is_new) = tree_leaf_tracker.get_or_insert_tuple( + output_queue.pubkey().array_ref(), + (initial_leaf_index, next_account_index), + SystemProgramError::TooManyOutputV2Queues, + )?; - hashed_merkle_tree = output_queue.hashed_merkle_tree_pubkey; - rollover_fee = output_queue.metadata.rollover_metadata.rollover_fee; - mt_next_index = output_queue.batch_metadata.next_index as u32; - cpi_ix_data.output_sequence_numbers[index_merkle_tree_account as usize] = - MerkleTreeSequenceNumber { - tree_pubkey: output_queue.metadata.associated_merkle_tree, - queue_pubkey: *output_queue.pubkey(), - tree_type: (TreeType::StateV2 as u64).into(), - seq: output_queue.batch_metadata.next_index.into(), - }; - is_batched = true; - *output_queue.pubkey() + // Only set up metadata if this is a new tree (first time seeing this pubkey) + if is_new { + // TODO: depulicate logic + context.set_network_fee( + output_queue.metadata.rollover_metadata.network_fee, + current_index as u8, + ); + hashed_merkle_tree = output_queue.hashed_merkle_tree_pubkey; + rollover_fee = output_queue.metadata.rollover_metadata.rollover_fee; + is_batched = true; + + cpi_ix_data.output_sequence_numbers[index_merkle_tree_account as usize] = + MerkleTreeSequenceNumber { + tree_pubkey: output_queue.metadata.associated_merkle_tree, + queue_pubkey: *output_queue.pubkey(), + tree_type: (TreeType::StateV2 as u64).into(), + seq: initial_leaf_index.into(), + }; + + context.get_index_or_insert( + account.merkle_tree_index(), + remaining_accounts, + "Output queue for V2 state trees (Merkle tree for V1 state trees)", + )?; + + index_merkle_tree_account += 1; + next_account_index += 1; + } + + (leaf_idx, account_idx) } AcpAccount::StateTree((pubkey, tree)) => { - cpi_ix_data.output_sequence_numbers[index_merkle_tree_account as usize] = - MerkleTreeSequenceNumber { - tree_pubkey: *pubkey, - queue_pubkey: *pubkey, - tree_type: (TreeType::StateV1 as u64).into(), - seq: (tree.sequence_number() as u64 + 1).into(), - }; - let merkle_context = context - .get_legacy_merkle_context(current_index as u8) - .unwrap(); - hashed_merkle_tree = merkle_context.hashed_pubkey; - rollover_fee = merkle_context.rollover_fee; - mt_next_index = tree.next_index() as u32; - is_batched = false; - *pubkey + let initial_leaf_index = tree.next_index() as u64; + + // Get or insert tree entry - returns ((leaf_idx, account_idx), is_new) + let ((leaf_idx, account_idx), is_new) = tree_leaf_tracker.get_or_insert_tuple( + pubkey.array_ref(), + (initial_leaf_index, next_account_index), + SystemProgramError::TooManyOutputV1Trees, + )?; + + // Only set up metadata if this is a new tree (first time seeing this pubkey) + if is_new { + cpi_ix_data.output_sequence_numbers[index_merkle_tree_account as usize] = + MerkleTreeSequenceNumber { + tree_pubkey: *pubkey, + queue_pubkey: *pubkey, + tree_type: (TreeType::StateV1 as u64).into(), + seq: (tree.sequence_number() as u64 + 1).into(), + }; + + let merkle_context = context + .get_legacy_merkle_context(current_index as u8) + .unwrap(); + hashed_merkle_tree = merkle_context.hashed_pubkey; + rollover_fee = merkle_context.rollover_fee; + is_batched = false; + + context.get_index_or_insert( + account.merkle_tree_index(), + remaining_accounts, + "Output queue for V2 state trees (Merkle tree for V1 state trees)", + )?; + + index_merkle_tree_account += 1; + next_account_index += 1; + } + + (leaf_idx, account_idx) } AcpAccount::Unknown() => { msg!( @@ -144,30 +188,8 @@ pub fn create_outputs_cpi_data<'a, 'info, T: InstructionData<'a>>( SystemProgramError::StateMerkleTreeAccountDiscriminatorMismatch.into(), ); } - }; - // check Merkle tree uniqueness - if merkle_tree_pubkeys.contains(&pubkey) { - return Err(SystemProgramError::OutputMerkleTreeNotUnique.into()); - } else { - merkle_tree_pubkeys.push(pubkey); } - - context.get_index_or_insert( - account.merkle_tree_index(), - remaining_accounts, - "Output queue for V2 state trees (Merkle tree for V1 state trees)", - )?; - num_leaves_in_tree = 0; - index_merkle_tree_account += 1; - index_merkle_tree_account_account += 1; - } else { - // Check 2. - // Output Merkle tree indices must be in order since we use the - // number of leaves in a Merkle tree to determine the correct leaf - // index. Since the leaf index is part of the hash this is security - // critical. - return Err(SystemProgramError::OutputMerkleTreeIndicesNotInOrder.into()); - } + }; // Check 3. if let Some(address) = account.address() { @@ -183,9 +205,11 @@ pub fn create_outputs_cpi_data<'a, 'info, T: InstructionData<'a>>( return Err(SystemProgramError::InvalidAddress.into()); } } - cpi_ix_data.output_leaf_indices[j] = (mt_next_index + num_leaves_in_tree).into(); - num_leaves_in_tree += 1; + // Use the tracked leaf index from our ArrayVec + cpi_ix_data.output_leaf_indices[j] = u32::try_from(leaf_index) + .map_err(|_| SystemProgramError::PackedAccountIndexOutOfBounds)? + .into(); if account.has_data() && context.invoking_program_id.is_none() { msg!("Invoking program is not provided."); msg!("Only program owned compressed accounts can have data."); @@ -214,7 +238,7 @@ pub fn create_outputs_cpi_data<'a, 'info, T: InstructionData<'a>>( is_batched, ) .map_err(ProgramError::from)?; - cpi_ix_data.leaves[j].account_index = index_merkle_tree_account_account - 1; + cpi_ix_data.leaves[j].account_index = account_index; if !cpi_ix_data.nullifiers.is_empty() { if j == 0 { diff --git a/programs/system/src/processor/mod.rs b/programs/system/src/processor/mod.rs index 40e658457b..663d9cdc6c 100644 --- a/programs/system/src/processor/mod.rs +++ b/programs/system/src/processor/mod.rs @@ -7,4 +7,5 @@ pub mod read_only_account; pub mod read_only_address; pub mod sol_compression; pub mod sum_check; +pub mod tree_leaf_tracker_ext; pub mod verify_proof; diff --git a/programs/system/src/processor/tree_leaf_tracker_ext.rs b/programs/system/src/processor/tree_leaf_tracker_ext.rs new file mode 100644 index 0000000000..ca0558c449 --- /dev/null +++ b/programs/system/src/processor/tree_leaf_tracker_ext.rs @@ -0,0 +1,75 @@ +use light_array_map::ArrayMap; + +use crate::{errors::SystemProgramError, Result}; + +/// Extension trait for ArrayMap with tuple values (u64, u8). +/// Only increments the first element of the tuple. +pub trait TreeLeafTrackerTupleExt +where + K: PartialEq + Default, +{ + /// Increments the first element of the tuple for the last accessed entry. + /// Returns the tuple value before incrementing. + /// + /// # Errors + /// Returns error if no entry has been accessed (last_accessed_index is None). + fn increment_current_tuple(&mut self) -> Result<(u64, u8)>; + + /// Gets or inserts a tuple entry, incrementing only the first element. + /// If the key exists, returns its current value and increments the first element. + /// If the key doesn't exist, inserts it with the given initial value, + /// returns that value, and increments the first element for next use. + /// Sets this entry as the last accessed entry. + /// + /// # Arguments + /// * `key` - The key to look up or insert + /// * `initial_value` - The tuple (leaf_index, account_index) to use if this is a new entry + /// * `error` - The error to return if capacity is exceeded + /// + /// # Returns + /// A tuple of ((leaf_index, account_index), is_new) where is_new indicates if this was a new entry. + fn get_or_insert_tuple( + &mut self, + key: &K, + initial_value: (u64, u8), + error: SystemProgramError, + ) -> Result<((u64, u8), bool)>; +} + +impl TreeLeafTrackerTupleExt for ArrayMap +where + K: PartialEq + Copy + Default, +{ + fn increment_current_tuple(&mut self) -> Result<(u64, u8)> { + let idx = self + .last_accessed_index() + .ok_or(SystemProgramError::OutputMerkleTreeIndexOutOfBounds)?; + + let entry = self + .get_mut(idx) + .ok_or(SystemProgramError::OutputMerkleTreeIndexOutOfBounds)?; + + let prev = entry.1; + entry.1 .0 += 1; // Only increment the leaf index (first element) + Ok(prev) + } + + fn get_or_insert_tuple( + &mut self, + key: &K, + initial_value: (u64, u8), + error: SystemProgramError, + ) -> Result<((u64, u8), bool)> { + // Find existing key + if let Some((idx, entry)) = self.find_mut(key) { + let prev = entry.1; + entry.1 .0 += 1; // Only increment the leaf index (first element) + self.set_last_accessed_index::(idx)?; + return Ok((prev, false)); + } + // Insert new entry with first element incremented + self.insert(*key, (initial_value.0 + 1, initial_value.1), error)?; + + Ok((initial_value, true)) + } +}