diff --git a/programs/drift/src/error.rs b/programs/drift/src/error.rs index e576077c83..67d7a96f6a 100644 --- a/programs/drift/src/error.rs +++ b/programs/drift/src/error.rs @@ -651,6 +651,8 @@ pub enum ErrorCode { WrongNumberOfConstituents, #[msg("Oracle too stale for LP AUM update")] OracleTooStaleForLPAUMUpdate, + #[msg("Insufficient constituent token balance")] + InsufficientConstituentTokenBalance, } #[macro_export] diff --git a/programs/drift/src/instructions/admin.rs b/programs/drift/src/instructions/admin.rs index ac4352921d..9f6daad6f4 100644 --- a/programs/drift/src/instructions/admin.rs +++ b/programs/drift/src/instructions/admin.rs @@ -4545,13 +4545,17 @@ pub fn handle_initialize_constituent<'info>( .resize_with((current_len + 1) as usize, WeightDatum::default); constituent_target_weights.validate()?; + msg!("initializing constituent {}", lp_pool.constituents); + constituent.spot_market_index = spot_market_index; + constituent.constituent_index = lp_pool.constituents; constituent.decimals = decimals; constituent.max_weight_deviation = max_weight_deviation; constituent.swap_fee_min = swap_fee_min; constituent.swap_fee_max = swap_fee_max; constituent.oracle_staleness_threshold = oracle_staleness_threshold; constituent.pubkey = ctx.accounts.constituent.key(); + constituent.mint = ctx.accounts.spot_market_mint.key(); constituent.constituent_index = (constituent_target_weights.weights.len() - 1) as u16; lp_pool.constituents += 1; diff --git a/programs/drift/src/instructions/lp_pool.rs b/programs/drift/src/instructions/lp_pool.rs index f4d8cae334..3e8544e065 100644 --- a/programs/drift/src/instructions/lp_pool.rs +++ b/programs/drift/src/instructions/lp_pool.rs @@ -1,4 +1,5 @@ use anchor_lang::{prelude::*, Accounts, Key, Result}; +use anchor_spl::token_interface::{Mint, TokenAccount, TokenInterface}; use crate::{ error::ErrorCode, @@ -12,22 +13,30 @@ use crate::{ state::{ constituent_map::{ConstituentMap, ConstituentSet}, lp_pool::{ - AmmConstituentDatum, AmmConstituentMappingFixed, LPPool, WeightValidationFlags, - CONSTITUENT_PDA_SEED, + AmmConstituentDatum, AmmConstituentMappingFixed, Constituent, LPPool, + WeightValidationFlags, }, oracle::OraclePriceData, perp_market::{AmmCacheFixed, CacheInfo, AMM_POSITIONS_CACHE}, perp_market_map::MarketSet, + spot_market_map::get_writable_spot_market_set_from_many, state::State, user::MarketType, zero_copy::{AccountZeroCopy, ZeroCopyLoader}, + events::LPSwapRecord, }, validate, }; + use solana_program::sysvar::clock::Clock; use super::optional_accounts::{load_maps, AccountMaps}; -use crate::state::lp_pool::{AMM_MAP_PDA_SEED, CONSTITUENT_TARGET_WEIGHT_PDA_SEED}; +use crate::controller::spot_balance::update_spot_market_cumulative_interest; +use crate::controller::token::{receive, send_from_program_vault}; +use crate::instructions::constraints::*; +use crate::state::lp_pool::{ + AMM_MAP_PDA_SEED, CONSTITUENT_PDA_SEED, CONSTITUENT_TARGET_WEIGHT_PDA_SEED, +}; pub fn handle_update_constituent_target_weights<'c: 'info, 'info>( ctx: Context<'_, '_, 'c, 'info, UpdateConstituentTargetWeights<'info>>, @@ -191,7 +200,6 @@ pub fn handle_update_lp_pool_aum<'c: 'info, 'info>( }; if oracle_price.is_none() { - msg!("hi"); return Err(ErrorCode::OracleTooStaleForLPAUMUpdate.into()); } @@ -226,16 +234,214 @@ pub fn handle_update_lp_pool_aum<'c: 'info, 'info>( Ok(()) } +#[access_control( + fill_not_paused(&ctx.accounts.state) +)] +pub fn handle_lp_pool_swap<'c: 'info, 'info>( + ctx: Context<'_, '_, 'c, 'info, LPPoolSwap<'info>>, + in_market_index: u16, + out_market_index: u16, + in_amount: u64, + min_out_amount: u64, +) -> Result<()> { + validate!( + in_market_index != out_market_index, + ErrorCode::InvalidSpotMarketAccount, + "In and out spot market indices cannot be the same" + )?; + + let slot = Clock::get()?.slot; + let now = Clock::get()?.unix_timestamp; + let state = &ctx.accounts.state; + let lp_pool = &ctx.accounts.lp_pool.load()?; + + let mut in_constituent = ctx.accounts.in_constituent.load_mut()?; + let mut out_constituent = ctx.accounts.out_constituent.load_mut()?; + + let constituent_target_weights = ctx.accounts.constituent_target_weights.load_zc()?; + + let AccountMaps { + perp_market_map: _, + spot_market_map, + mut oracle_map, + } = load_maps( + &mut ctx.remaining_accounts.iter().peekable(), + &MarketSet::new(), + &get_writable_spot_market_set_from_many(vec![in_market_index, out_market_index]), + slot, + Some(state.oracle_guard_rails), + )?; + + let mut in_spot_market = spot_market_map.get_ref_mut(&in_market_index)?; + let mut out_spot_market = spot_market_map.get_ref_mut(&out_market_index)?; + + let in_oracle_id = in_spot_market.oracle_id(); + let out_oracle_id = out_spot_market.oracle_id(); + + let (in_oracle, in_oracle_validity) = oracle_map.get_price_data_and_validity( + MarketType::Spot, + in_spot_market.market_index, + &in_oracle_id, + in_spot_market.historical_oracle_data.last_oracle_price_twap, + in_spot_market.get_max_confidence_interval_multiplier()?, + )?; + let in_oracle = in_oracle.clone(); + + let (out_oracle, out_oracle_validity) = oracle_map.get_price_data_and_validity( + MarketType::Spot, + out_spot_market.market_index, + &out_oracle_id, + out_spot_market + .historical_oracle_data + .last_oracle_price_twap, + out_spot_market.get_max_confidence_interval_multiplier()?, + )?; + + if !is_oracle_valid_for_action(in_oracle_validity, Some(DriftAction::LpPoolSwap))? { + msg!( + "In oracle data for spot market {} is invalid for lp pool swap.", + in_spot_market.market_index, + ); + return Err(ErrorCode::InvalidOracle.into()); + } + + if !is_oracle_valid_for_action(out_oracle_validity, Some(DriftAction::LpPoolSwap))? { + msg!( + "Out oracle data for spot market {} is invalid for lp pool swap.", + out_spot_market.market_index, + ); + return Err(ErrorCode::InvalidOracle.into()); + } + + update_spot_market_cumulative_interest(&mut in_spot_market, Some(&in_oracle), now)?; + update_spot_market_cumulative_interest(&mut out_spot_market, Some(&out_oracle), now)?; + + let in_target_weight = + constituent_target_weights.get_target_weight(in_constituent.constituent_index)?; + let out_target_weight = + constituent_target_weights.get_target_weight(out_constituent.constituent_index)?; + + let (in_amount, out_amount, in_fee, out_fee) = lp_pool.get_swap_amount( + &in_oracle, + &out_oracle, + &in_constituent, + &out_constituent, + &in_spot_market, + &out_spot_market, + in_target_weight, + out_target_weight, + in_amount, + )?; + msg!( + "in_amount: {}, out_amount: {}, in_fee: {}, out_fee: {}", + in_amount, + out_amount, + in_fee, + out_fee + ); + let out_amount_net_fees = if out_fee > 0 { + out_amount.safe_sub(out_fee.unsigned_abs() as u64)? + } else { + out_amount.safe_add(out_fee.unsigned_abs() as u64)? + }; + + validate!( + out_amount_net_fees >= min_out_amount, + ErrorCode::SlippageOutsideLimit, + format!( + "Slippage outside limit: out_amount_net_fees({}) < min_out_amount({})", + out_amount_net_fees, min_out_amount + ) + .as_str() + )?; + + validate!( + out_amount_net_fees <= out_constituent.token_balance, + ErrorCode::InsufficientConstituentTokenBalance, + format!( + "Insufficient out constituent balance: out_amount_net_fees({}) > out_constituent.token_balance({})", + out_amount_net_fees, out_constituent.token_balance + ) + .as_str() + )?; + + in_constituent.record_swap_fees(in_fee)?; + out_constituent.record_swap_fees(out_fee)?; + + emit!(LPSwapRecord { + ts: now, + authority: ctx.accounts.authority.key(), + amount_out: out_amount_net_fees, + amount_in: in_amount, + fee_out: out_fee, + fee_in: in_fee, + out_spot_market_index: out_market_index, + in_spot_market_index: in_market_index, + out_constituent_index: out_constituent.constituent_index, + in_constituent_index: in_constituent.constituent_index, + out_oracle_price: out_oracle.price, + in_oracle_price: in_oracle.price, + mint_out: out_constituent.mint, + mint_in: in_constituent.mint, + }); + + receive( + &ctx.accounts.token_program, + &ctx.accounts.user_in_token_account, + &ctx.accounts.constituent_in_token_account, + &ctx.accounts.authority, + in_amount, + &Some((*ctx.accounts.in_market_mint).clone()), + )?; + + send_from_program_vault( + &ctx.accounts.token_program, + &ctx.accounts.constituent_out_token_account, + &ctx.accounts.user_out_token_account, + &ctx.accounts.drift_signer, + state.signer_nonce, + out_amount_net_fees, + &Some((*ctx.accounts.out_market_mint).clone()), + )?; + + ctx.accounts.constituent_in_token_account.reload()?; + ctx.accounts.constituent_out_token_account.reload()?; + + in_constituent.sync_token_balance(ctx.accounts.constituent_in_token_account.amount); + out_constituent.sync_token_balance(ctx.accounts.constituent_out_token_account.amount); + + Ok(()) +} + #[derive(Accounts)] #[instruction( lp_pool_name: [u8; 32], )] -pub struct UpdateLPPoolAum<'info> { +pub struct UpdateConstituentTargetWeights<'info> { pub state: Box>, #[account(mut)] pub keeper: Signer<'info>, + #[account( + seeds = [AMM_MAP_PDA_SEED.as_ref(), lp_pool.key().as_ref()], + bump, + )] + /// CHECK: checked in AmmConstituentMappingZeroCopy checks + pub amm_constituent_mapping: AccountInfo<'info>, + #[account( + mut, + seeds = [CONSTITUENT_TARGET_WEIGHT_PDA_SEED.as_ref(), lp_pool.key().as_ref()], + bump, + )] + /// CHECK: checked in ConstituentTargetWeightsZeroCopy checks + pub constituent_target_weights: AccountInfo<'info>, #[account( mut, + seeds = [AMM_POSITIONS_CACHE.as_ref()], + bump, + )] + /// CHECK: checked in ConstituentTargetWeightsZeroCopy checks + pub amm_cache: AccountInfo<'info>, + #[account( seeds = [b"lp_pool", lp_pool_name.as_ref()], bump, )] @@ -246,16 +452,30 @@ pub struct UpdateLPPoolAum<'info> { #[instruction( lp_pool_name: [u8; 32], )] -pub struct UpdateConstituentTargetWeights<'info> { +pub struct UpdateLPPoolAum<'info> { pub state: Box>, #[account(mut)] pub keeper: Signer<'info>, #[account( - seeds = [AMM_MAP_PDA_SEED.as_ref(), lp_pool.key().as_ref()], + mut, + seeds = [b"lp_pool", lp_pool_name.as_ref()], bump, )] - /// CHECK: checked in AmmConstituentMappingZeroCopy checks - pub amm_constituent_mapping: AccountInfo<'info>, + pub lp_pool: AccountLoader<'info, LPPool>, +} + +/// `in`/`out` is in the program's POV for this swap. So `user_in_token_account` is the user owned token account +/// for the `in` token for this swap. +#[derive(Accounts)] +#[instruction( + in_market_index: u16, + out_market_index: u16, +)] +pub struct LPPoolSwap<'info> { + /// CHECK: forced drift_signer + pub drift_signer: AccountInfo<'info>, + pub state: Box>, + pub lp_pool: AccountLoader<'info, LPPool>, #[account( mut, seeds = [CONSTITUENT_TARGET_WEIGHT_PDA_SEED.as_ref(), lp_pool.key().as_ref()], @@ -263,16 +483,49 @@ pub struct UpdateConstituentTargetWeights<'info> { )] /// CHECK: checked in ConstituentTargetWeightsZeroCopy checks pub constituent_target_weights: AccountInfo<'info>, + + #[account(mut)] + pub constituent_in_token_account: Box>, + #[account(mut)] + pub constituent_out_token_account: Box>, + #[account( mut, - seeds = [AMM_POSITIONS_CACHE.as_ref()], + constraint = user_in_token_account.mint.eq(&constituent_in_token_account.mint) + )] + pub user_in_token_account: Box>, + #[account( + mut, + constraint = user_out_token_account.mint.eq(&constituent_out_token_account.mint) + )] + pub user_out_token_account: Box>, + + #[account( + mut, + seeds = [CONSTITUENT_PDA_SEED.as_ref(), lp_pool.key().as_ref(), in_market_index.to_le_bytes().as_ref()], bump, + constraint = in_constituent.load()?.mint.eq(&constituent_in_token_account.mint) )] - /// CHECK: checked in ConstituentTargetWeightsZeroCopy checks - pub amm_cache: AccountInfo<'info>, + pub in_constituent: AccountLoader<'info, Constituent>, #[account( - seeds = [b"lp_pool", lp_pool_name.as_ref()], + mut, + seeds = [CONSTITUENT_PDA_SEED.as_ref(), lp_pool.key().as_ref(), out_market_index.to_le_bytes().as_ref()], bump, + constraint = out_constituent.load()?.mint.eq(&constituent_out_token_account.mint) )] - pub lp_pool: AccountLoader<'info, LPPool>, + pub out_constituent: AccountLoader<'info, Constituent>, + + #[account( + constraint = in_market_mint.key() == in_constituent.load()?.mint, + )] + pub in_market_mint: Box>, + #[account( + constraint = out_market_mint.key() == out_constituent.load()?.mint, + )] + pub out_market_mint: Box>, + + pub authority: Signer<'info>, + + // TODO: in/out token program + pub token_program: Interface<'info, TokenInterface>, } diff --git a/programs/drift/src/instructions/user.rs b/programs/drift/src/instructions/user.rs index 9ae0b462b2..59cfc18184 100644 --- a/programs/drift/src/instructions/user.rs +++ b/programs/drift/src/instructions/user.rs @@ -67,6 +67,7 @@ use crate::state::fulfillment_params::openbook_v2::OpenbookV2FulfillmentParams; use crate::state::fulfillment_params::phoenix::PhoenixFulfillmentParams; use crate::state::fulfillment_params::serum::SerumFulfillmentParams; use crate::state::high_leverage_mode_config::HighLeverageModeConfig; +use crate::state::lp_pool::{Constituent, LPPool}; use crate::state::margin_calculation::MarginContext; use crate::state::oracle::StrictOraclePrice; use crate::state::order_params::{ diff --git a/programs/drift/src/lib.rs b/programs/drift/src/lib.rs index 738b766b24..3ad633f2c2 100644 --- a/programs/drift/src/lib.rs +++ b/programs/drift/src/lib.rs @@ -41,7 +41,7 @@ declare_id!("dRiftyHA39MWEi3m9aunc5MzRF1JYuBsbn6VPcn33UH"); #[program] pub mod drift { use super::*; - use crate::{instruction::UpdateLpPoolAum, state::spot_market::SpotFulfillmentConfigStatus}; + use crate::state::spot_market::SpotFulfillmentConfigStatus; // User Instructions @@ -1780,6 +1780,22 @@ pub mod drift { ) -> Result<()> { handle_update_amm_cache(ctx) } + + pub fn lp_pool_swap<'c: 'info, 'info>( + ctx: Context<'_, '_, 'c, 'info, LPPoolSwap<'info>>, + in_market_index: u16, + out_market_index: u16, + in_amount: u64, + min_out_amount: u64, + ) -> Result<()> { + handle_lp_pool_swap( + ctx, + in_market_index, + out_market_index, + in_amount, + min_out_amount, + ) + } } #[cfg(not(feature = "no-entrypoint"))] diff --git a/programs/drift/src/math/oracle.rs b/programs/drift/src/math/oracle.rs index 07e8337eab..ba255a85ab 100644 --- a/programs/drift/src/math/oracle.rs +++ b/programs/drift/src/math/oracle.rs @@ -71,6 +71,7 @@ pub enum DriftAction { OracleOrderPrice, UpdateDlpConstituentTargetWeights, UpdateLpPoolAum, + LpPoolSwap, } pub fn is_oracle_valid_for_action( @@ -133,6 +134,12 @@ pub fn is_oracle_valid_for_action( DriftAction::UpdateDlpConstituentTargetWeights | DriftAction::UpdateLpPoolAum => { !matches!(oracle_validity, OracleValidity::NonPositive) } + DriftAction::LpPoolSwap => !matches!( + oracle_validity, + OracleValidity::NonPositive + | OracleValidity::StaleForAMM + | OracleValidity::InsufficientDataPoints + ), }, None => { matches!(oracle_validity, OracleValidity::Valid) diff --git a/programs/drift/src/state/constituent_map.rs b/programs/drift/src/state/constituent_map.rs index 8e2300bb9a..f79f97e7e5 100644 --- a/programs/drift/src/state/constituent_map.rs +++ b/programs/drift/src/state/constituent_map.rs @@ -29,7 +29,7 @@ impl<'a> ConstituentMap<'a> { None => { let caller = Location::caller(); msg!( - "Could not find costituent {} at {}:{}", + "Could not find constituent {} at {}:{}", constituent_index, caller.file(), caller.line() diff --git a/programs/drift/src/state/events.rs b/programs/drift/src/state/events.rs index aee1509910..8497ddcd43 100644 --- a/programs/drift/src/state/events.rs +++ b/programs/drift/src/state/events.rs @@ -672,3 +672,34 @@ pub fn emit_buffers( Ok(()) } + +#[event] +#[derive(Default)] +pub struct LPSwapRecord { + pub ts: i64, + pub authority: Pubkey, + /// precision: out market mint precision, gross fees + pub amount_out: u64, + /// precision: in market mint precision, gross fees + pub amount_in: u64, + /// precision: fee on amount_out, in market mint precision + pub fee_out: i64, + /// precision: fee on amount_in, out market mint precision + pub fee_in: i64, + // out spot market index + pub out_spot_market_index: u16, + // in spot market index + pub in_spot_market_index: u16, + // out constituent index + pub out_constituent_index: u16, + // in constituent index + pub in_constituent_index: u16, + /// precision: PRICE_PRECISION + pub out_oracle_price: i64, + /// precision: PRICE_PRECISION + pub in_oracle_price: i64, + /// out token mint + pub mint_out: Pubkey, + /// in token mint + pub mint_in: Pubkey, +} diff --git a/programs/drift/src/state/lp_pool.rs b/programs/drift/src/state/lp_pool.rs index 6606961576..f8c70e4821 100644 --- a/programs/drift/src/state/lp_pool.rs +++ b/programs/drift/src/state/lp_pool.rs @@ -7,8 +7,10 @@ use crate::math::safe_math::SafeMath; use crate::math::spot_balance::get_token_amount; use anchor_lang::prelude::*; use anchor_spl::token::Mint; +use anchor_spl::token_interface::TokenAccount; use borsh::{BorshDeserialize, BorshSerialize}; +use super::oracle::OraclePriceData; use super::oracle_map::OracleMap; use super::spot_market::SpotMarket; use super::zero_copy::{AccountZeroCopy, AccountZeroCopyMut, HasLen}; @@ -88,55 +90,39 @@ impl LPPool { } } - /// get the swap price between two (non-LP token) constituents + /// Get the swap price between two (non-LP token) constituents. + /// Accounts for precision differences between in and out constituents /// returns swap price in PRICE_PRECISION pub fn get_swap_price( &self, - oracle_map: &mut OracleMap, - in_spot_market: &SpotMarket, - out_spot_market: &SpotMarket, - in_amount: u64, - ) -> DriftResult { - let in_price = oracle_map - .get_price_data(&(in_spot_market.oracle, in_spot_market.oracle_source)) - .expect("failed to get price data") - .price - .cast::() - .expect("failed to cast price"); - - let out_price = oracle_map - .get_price_data(&(out_spot_market.oracle, out_spot_market.oracle_source)) - .expect("failed to get price data") - .price - .cast::() - .expect("failed to cast price"); - - let (prec_diff_numerator, prec_diff_denominator) = - if out_spot_market.decimals > in_spot_market.decimals { - ( - 10_u64.pow(out_spot_market.decimals as u32 - in_spot_market.decimals as u32), - 1, - ) - } else { - ( - 1, - 10_u64.pow(in_spot_market.decimals as u32 - out_spot_market.decimals as u32), - ) - }; + in_decimals: u32, + out_decimals: u32, + in_oracle: &OraclePriceData, + out_oracle: &OraclePriceData, + ) -> DriftResult<(u64, u64)> { + let in_price = in_oracle.price.cast::()?; + let out_price = out_oracle.price.cast::()?; + + let (prec_diff_numerator, prec_diff_denominator) = if out_decimals > in_decimals { + (10_u64.pow(out_decimals - in_decimals), 1) + } else { + (1, 10_u64.pow(in_decimals - out_decimals)) + }; - let swap_price = in_amount - .safe_mul(in_price)? - .safe_mul(prec_diff_numerator)? - .safe_div(out_price.safe_mul(prec_diff_denominator)?)?; + let swap_price_num = in_price.safe_mul(prec_diff_numerator)?; + let swap_price_denom = out_price.safe_mul(prec_diff_denominator)?; - Ok(swap_price) + Ok((swap_price_num, swap_price_denom)) } - /// - /// Returns the (out_amount, in_fee, out_fee) in the respective token units. Amounts are gross fees. + /// in the respective token units. Amounts are gross fees and in + /// token mint precision. + /// Positive fees are paid, negative fees are rebated + /// Returns (in_amount out_amount, in_fee, out_fee) pub fn get_swap_amount( &self, - oracle_map: &mut OracleMap, + in_oracle: &OraclePriceData, + out_oracle: &OraclePriceData, in_constituent: &Constituent, out_constituent: &Constituent, in_spot_market: &SpotMarket, @@ -145,51 +131,77 @@ impl LPPool { out_target_weight: i64, in_amount: u64, ) -> DriftResult<(u64, u64, i64, i64)> { - let swap_price = - self.get_swap_price(oracle_map, in_spot_market, out_spot_market, in_amount)?; + let (swap_price_num, swap_price_denom) = self.get_swap_price( + in_spot_market.decimals, + out_spot_market.decimals, + in_oracle, + out_oracle, + )?; let in_fee = self.get_swap_fees( - oracle_map, - in_constituent, in_spot_market, - in_amount, + in_oracle, + in_constituent, + in_amount.cast::()?, in_target_weight, )?; + let in_fee_amount = in_amount + .cast::()? + .safe_mul(in_fee)? + .safe_div(PERCENTAGE_PRECISION_I64.cast::()?)?; + let out_amount = in_amount .cast::()? - .safe_sub(in_fee)? - .safe_mul(swap_price.cast::()?)? - .safe_div(PRICE_PRECISION_I64)? + .safe_sub(in_fee_amount)? + .safe_mul(swap_price_num.cast::()?)? + .safe_div(swap_price_denom.cast::()?)? .cast::()?; let out_fee = self.get_swap_fees( - oracle_map, - out_constituent, out_spot_market, - out_amount, + out_oracle, + out_constituent, + out_amount + .cast::()? + .checked_neg() + .ok_or(ErrorCode::MathError.into())?, out_target_weight, )?; - // TODO: additional spot quoter logic can go here - // TODO: emit swap event + msg!("in_fee: {}, out_fee: {}", in_fee, out_fee); + let out_fee_amount = out_amount + .cast::()? + .safe_mul(out_fee)? + .safe_div(PERCENTAGE_PRECISION_I64.cast::()?)?; - Ok((in_amount, out_amount, in_fee, out_fee)) + Ok((in_amount, out_amount, in_fee_amount, out_fee_amount)) } /// returns fee in PERCENTAGE_PRECISION pub fn get_swap_fees( &self, - oracle_map: &mut OracleMap, // might not need oracle_map depending on how accounts are passed in - constituent: &Constituent, spot_market: &SpotMarket, - amount: u64, + oracle: &OraclePriceData, + constituent: &Constituent, + amount: i64, target_weight: i64, ) -> DriftResult { - let price = oracle_map - .get_price_data(&(spot_market.oracle, spot_market.oracle_source)) - .expect("failed to get price data") - .price; + // +4,976 CUs to log weight_before + let weight_before = constituent.get_weight(oracle.price, spot_market, 0, self.last_aum)?; + msg!( + "constituent {}: weight_before: {} target_weight: {}", + constituent.constituent_index, + weight_before, + target_weight + ); + let weight_after = - constituent.get_weight(price, spot_market, amount.cast::()?, self.last_aum)?; + constituent.get_weight(oracle.price, spot_market, amount, self.last_aum)?; + msg!( + "constituent {}: weight_after: {} target_weight: {}", + constituent.constituent_index, + weight_after, + target_weight + ); let fee = constituent.get_fee_to_charge(weight_after, target_weight)?; Ok(fee) @@ -255,7 +267,8 @@ impl BLPosition { pub struct Constituent { /// address of the constituent pub pubkey: Pubkey, - /// underlying drift spot market index + /// underlying drift spot market index. + /// TODO: redundant with spot_balance.market_index pub spot_market_index: u16, /// idx in LPPool.constituents pub constituent_index: u16, @@ -273,8 +286,11 @@ pub struct Constituent { /// precision: PERCENTAGE_PRECISION pub swap_fee_max: i64, + /// total fees received by the constituent. Positive = fees received, Negative = fees paid + pub total_swap_fees: i128, + /// ata token balance in token precision - pub token_balance: u128, + pub token_balance: u64, /// spot borrow-lend balance for constituent pub spot_balance: BLPosition, // should be in constituent base asset @@ -282,12 +298,14 @@ pub struct Constituent { pub last_oracle_price: i64, pub last_oracle_slot: u64, + pub mint: Pubkey, + pub oracle_staleness_threshold: u64, _padding2: [u8; 8], } impl Size for Constituent { - const SIZE: usize = 152; + const SIZE: usize = 192; } impl Constituent { @@ -308,6 +326,11 @@ impl Constituent { } } + pub fn record_swap_fees(&mut self, amount: i64) -> DriftResult { + self.total_swap_fees = self.total_swap_fees.safe_add(amount.cast::()?)?; + Ok(()) + } + /// Current weight of this constituent = price * token_balance / lp_pool_aum /// Note: lp_pool_aum is from LPPool.last_aum, which is a lagged value updated via crank pub fn get_weight( @@ -345,7 +368,9 @@ impl Constituent { let denom = target_weight.safe_sub(min_weight)?; (num, denom) }; - + if slope_denominator == 0 { + return Ok(self.swap_fee_min); + } let b = self .swap_fee_min .safe_mul(slope_denominator)? @@ -353,7 +378,12 @@ impl Constituent { Ok(post_swap_weight .safe_mul(slope_numerator)? .safe_add(b)? - .safe_div(slope_denominator)?) + .safe_div(slope_denominator)? + .clamp(self.swap_fee_min, self.swap_fee_max)) + } + + pub fn sync_token_balance(&mut self, token_account_amount: u64) { + self.token_balance = token_account_amount; } } diff --git a/programs/drift/src/state/lp_pool/tests.rs b/programs/drift/src/state/lp_pool/tests.rs index 63a8ae4451..e661018ffe 100644 --- a/programs/drift/src/state/lp_pool/tests.rs +++ b/programs/drift/src/state/lp_pool/tests.rs @@ -404,3 +404,333 @@ mod tests { assert_eq!(fee, PERCENTAGE_PRECISION_I64 / 10000); // 1 bps (min fee) } } + +#[cfg(test)] +mod swap_tests { + use crate::math::constants::PERCENTAGE_PRECISION_I64; + use crate::state::lp_pool::*; + + #[test] + fn test_get_swap_price() { + let lp_pool = LPPool::default(); + + let in_oracle = OraclePriceData { + price: 1_000_000, + ..OraclePriceData::default() + }; + let out_oracle = OraclePriceData { + price: 233_400_000, + ..OraclePriceData::default() + }; + + // same decimals + let (price_num, price_denom) = lp_pool + .get_swap_price(6, 6, &in_oracle, &out_oracle) + .unwrap(); + assert_eq!(price_num, 1_000_000); + assert_eq!(price_denom, 233_400_000); + + let (price_num, price_denom) = lp_pool + .get_swap_price(6, 6, &out_oracle, &in_oracle) + .unwrap(); + assert_eq!(price_num, 233_400_000); + assert_eq!(price_denom, 1_000_000); + } + + fn get_swap_amount_decimals_scenario( + in_decimals: u32, + out_decimals: u32, + in_amount: u64, + expected_in_amount: u64, + expected_out_amount: u64, + expected_in_fee: i64, + expected_out_fee: i64, + ) { + let lp_pool = LPPool { + last_aum: 1_000_000_000_000, + ..LPPool::default() + }; + + let oracle_0 = OraclePriceData { + price: 1_000_000, + ..OraclePriceData::default() + }; + let oracle_1 = OraclePriceData { + price: 233_400_000, + ..OraclePriceData::default() + }; + + let constituent_0 = Constituent { + decimals: in_decimals as u8, + swap_fee_min: PERCENTAGE_PRECISION_I64 / 10000, + swap_fee_max: PERCENTAGE_PRECISION_I64 / 1000, + // max_weight_deviation: PERCENTAGE_PRECISION_I64 / 10, + ..Constituent::default() + }; + let constituent_1 = Constituent { + decimals: out_decimals as u8, + swap_fee_min: PERCENTAGE_PRECISION_I64 / 10000, + swap_fee_max: PERCENTAGE_PRECISION_I64 / 1000, + // max_weight_deviation: PERCENTAGE_PRECISION_I64 / 10, + ..Constituent::default() + }; + let spot_market_0 = SpotMarket { + decimals: in_decimals, + ..SpotMarket::default() + }; + let spot_market_1 = SpotMarket { + decimals: out_decimals, + ..SpotMarket::default() + }; + + let (in_amount, out_amount, in_fee, out_fee) = lp_pool + .get_swap_amount( + &oracle_0, + &oracle_1, + &constituent_0, + &constituent_1, + &spot_market_0, + &spot_market_1, + 500_000, + 500_000, + in_amount, + ) + .unwrap(); + assert_eq!(in_amount, expected_in_amount); + assert_eq!(out_amount, expected_out_amount); + assert_eq!(in_fee, expected_in_fee); + assert_eq!(out_fee, expected_out_fee); + } + + #[test] + fn test_get_swap_amount_in_6_out_6() { + get_swap_amount_decimals_scenario( + 6, + 6, + 233_400_000, + 233_400_000, + 999900, + 23340, // 1 bps + 99, + ); + } + + #[test] + fn test_get_swap_amount_in_6_out_9() { + get_swap_amount_decimals_scenario(6, 9, 233_400_000, 233_400_000, 999900000, 23340, 99990); + } + + #[test] + fn test_get_swap_amount_in_9_out_6() { + get_swap_amount_decimals_scenario( + 9, + 6, + 233_400_000_000, + 233_400_000_000, + 999900, + 23340000, + 99, + ); + } + + #[test] + fn test_get_fee_to_charge_positive_min_fee() { + let c = Constituent { + swap_fee_min: PERCENTAGE_PRECISION_I64 / 10000, // 1 bps + swap_fee_max: PERCENTAGE_PRECISION_I64 / 100, // 100 bps + max_weight_deviation: PERCENTAGE_PRECISION_I64 / 10, // 10% + ..Constituent::default() + }; + + // swapping to target should incur minimum fee + let target_weight = PERCENTAGE_PRECISION_I64 / 2; // 50% + let post_swap_weight = target_weight; // 50% + let fee = c + .get_fee_to_charge(post_swap_weight, target_weight) + .unwrap(); + assert_eq!(fee, c.swap_fee_min); + + // positive target: swapping to max deviation above target should incur maximum fee + let post_swap_weight = target_weight + c.max_weight_deviation; + let fee = c + .get_fee_to_charge(post_swap_weight, target_weight) + .unwrap(); + assert_eq!(fee, c.swap_fee_max); + + // positive target: swapping to max deviation below target should incur minimum fee + let post_swap_weight = target_weight - c.max_weight_deviation; + let fee = c + .get_fee_to_charge(post_swap_weight, target_weight) + .unwrap(); + assert_eq!(fee, c.swap_fee_max); + + // negative target: swapping to max deviation above target should incur maximum fee + let post_swap_weight = -1 * target_weight + c.max_weight_deviation; + let fee = c + .get_fee_to_charge(post_swap_weight, -1 * target_weight) + .unwrap(); + assert_eq!(fee, c.swap_fee_max); + + // negative target: swapping to max deviation below target should incur minimum fee + let post_swap_weight = -1 * target_weight - c.max_weight_deviation; + let fee = c + .get_fee_to_charge(post_swap_weight, -1 * target_weight) + .unwrap(); + assert_eq!(fee, c.swap_fee_max); + + // positive target: swaps to +max_weight_deviation/2, should incur half of the max fee + let post_swap_weight = target_weight + c.max_weight_deviation / 2; + let fee = c + .get_fee_to_charge(post_swap_weight, target_weight) + .unwrap(); + assert_eq!(fee, (c.swap_fee_max + c.swap_fee_min) / 2); + + // positive target: swaps to -max_weight_deviation/2, should incur half of the max fee + let post_swap_weight = target_weight - c.max_weight_deviation / 2; + let fee = c + .get_fee_to_charge(post_swap_weight, target_weight) + .unwrap(); + assert_eq!(fee, (c.swap_fee_max + c.swap_fee_min) / 2); + + // negative target: swaps to +max_weight_deviation/2, should incur half of the max fee + let post_swap_weight = -1 * target_weight + c.max_weight_deviation / 2; + let fee = c + .get_fee_to_charge(post_swap_weight, -1 * target_weight) + .unwrap(); + assert_eq!(fee, (c.swap_fee_max + c.swap_fee_min) / 2); + + // negative target: swaps to -max_weight_deviation/2, should incur half of the max fee + let post_swap_weight = -1 * target_weight - c.max_weight_deviation / 2; + let fee = c + .get_fee_to_charge(post_swap_weight, -1 * target_weight) + .unwrap(); + assert_eq!(fee, (c.swap_fee_max + c.swap_fee_min) / 2); + } + + #[test] + fn test_get_fee_to_charge_negative_min_fee() { + let c = Constituent { + swap_fee_min: -1 * PERCENTAGE_PRECISION_I64 / 10000, // -1 bps (rebate) + swap_fee_max: PERCENTAGE_PRECISION_I64 / 100, // 100 bps + max_weight_deviation: PERCENTAGE_PRECISION_I64 / 10, // 10% + ..Constituent::default() + }; + + // swapping to target should incur minimum fee + let target_weight = PERCENTAGE_PRECISION_I64 / 2; // 50% + let post_swap_weight = target_weight; // 50% + let fee = c + .get_fee_to_charge(post_swap_weight, target_weight) + .unwrap(); + assert_eq!(fee, c.swap_fee_min); + + // positive target: swapping to max deviation above target should incur maximum fee + let post_swap_weight = target_weight + c.max_weight_deviation; + let fee = c + .get_fee_to_charge(post_swap_weight, target_weight) + .unwrap(); + assert_eq!(fee, c.swap_fee_max); + + // positive target: swapping to max deviation below target should incur minimum fee + let post_swap_weight = target_weight - c.max_weight_deviation; + let fee = c + .get_fee_to_charge(post_swap_weight, target_weight) + .unwrap(); + assert_eq!(fee, c.swap_fee_max); + + // negative target: swapping to max deviation above target should incur maximum fee + let post_swap_weight = -1 * target_weight + c.max_weight_deviation; + let fee = c + .get_fee_to_charge(post_swap_weight, -1 * target_weight) + .unwrap(); + assert_eq!(fee, c.swap_fee_max); + + // negative target: swapping to max deviation below target should incur minimum fee + let post_swap_weight = -1 * target_weight - c.max_weight_deviation; + let fee = c + .get_fee_to_charge(post_swap_weight, -1 * target_weight) + .unwrap(); + assert_eq!(fee, c.swap_fee_max); + + // positive target: swaps to +max_weight_deviation/2, should incur half of the max fee + let post_swap_weight = target_weight + c.max_weight_deviation / 2; + let fee = c + .get_fee_to_charge(post_swap_weight, target_weight) + .unwrap(); + assert_eq!(fee, (c.swap_fee_max + c.swap_fee_min) / 2); + + // positive target: swaps to -max_weight_deviation/2, should incur half of the max fee + let post_swap_weight = target_weight - c.max_weight_deviation / 2; + let fee = c + .get_fee_to_charge(post_swap_weight, target_weight) + .unwrap(); + assert_eq!(fee, (c.swap_fee_max + c.swap_fee_min) / 2); + + // negative target: swaps to +max_weight_deviation/2, should incur half of the max fee + let post_swap_weight = -1 * target_weight + c.max_weight_deviation / 2; + let fee = c + .get_fee_to_charge(post_swap_weight, -1 * target_weight) + .unwrap(); + assert_eq!(fee, (c.swap_fee_max + c.swap_fee_min) / 2); + + // negative target: swaps to -max_weight_deviation/2, should incur half of the max fee + let post_swap_weight = -1 * target_weight - c.max_weight_deviation / 2; + let fee = c + .get_fee_to_charge(post_swap_weight, -1 * target_weight) + .unwrap(); + assert_eq!(fee, (c.swap_fee_max + c.swap_fee_min) / 2); + } + + #[test] + fn test_get_weight() { + let c = Constituent { + swap_fee_min: -1 * PERCENTAGE_PRECISION_I64 / 10000, // -1 bps (rebate) + swap_fee_max: PERCENTAGE_PRECISION_I64 / 100, // 100 bps + max_weight_deviation: PERCENTAGE_PRECISION_I64 / 10, // 10% + spot_market_index: 0, + spot_balance: BLPosition { + scaled_balance: 500_000, + cumulative_deposits: 1_000_000, + balance_type: SpotBalanceType::Deposit, + market_index: 0, + ..BLPosition::default() + }, + token_balance: 500_000, + decimals: 6, + ..Constituent::default() + }; + + let spot_market = SpotMarket { + market_index: 0, + decimals: 6, + cumulative_deposit_interest: 10_000_000_000_000, + ..SpotMarket::default() + }; + + let full_balance = c.get_full_balance(&spot_market).unwrap(); + assert_eq!(full_balance, 1_000_000); + + // 1/10 = 10% + let weight = c + .get_weight( + 1_000_000, // $1 + &spot_market, + 0, + 10_000_000, + ) + .unwrap(); + assert_eq!(weight, 100_000); + + // (1+1)/10 = 20% + let weight = c + .get_weight(1_000_000, &spot_market, 1_000_000, 10_000_000) + .unwrap(); + assert_eq!(weight, 200_000); + + // (1-0.5)/10 = 0.5% + let weight = c + .get_weight(1_000_000, &spot_market, -500_000, 10_000_000) + .unwrap(); + assert_eq!(weight, 50_000); + } +} diff --git a/sdk/src/driftClient.ts b/sdk/src/driftClient.ts index 72c0e7951f..7b9eb7be2a 100644 --- a/sdk/src/driftClient.ts +++ b/sdk/src/driftClient.ts @@ -59,7 +59,8 @@ import { ProtectedMakerModeConfig, SignedMsgOrderParamsDelegateMessage, AmmConstituentMapping, - LPPool, + AmmConstituentDatum, + LPPoolAccount, } from './types'; import driftIDL from './idl/drift.json'; @@ -9738,7 +9739,7 @@ export class DriftClient { } public async updateDlpPoolAum( - lpPool: LPPool, + lpPool: LPPoolAccount, spotMarketIndexOfConstituents: number[], txParams?: TxParams ): Promise { @@ -9757,7 +9758,7 @@ export class DriftClient { } public async getUpdateDlpPoolAumIxs( - lpPool: LPPool, + lpPool: LPPoolAccount, spotMarketIndexOfConstituents: number[] ): Promise { const remainingAccounts = this.getRemainingAccounts({ @@ -9823,6 +9824,96 @@ export class DriftClient { }); } + public async lpPoolSwap( + inMarketIndex: number, + outMarketIndex: number, + inAmount: BN, + minOutAmount: BN, + lpPool: PublicKey, + constituentTargetWeights: PublicKey, + constituentInTokenAccount: PublicKey, + constituentOutTokenAccount: PublicKey, + userInTokenAccount: PublicKey, + userOutTokenAccount: PublicKey, + inConstituent: PublicKey, + outConstituent: PublicKey, + inMarketMint: PublicKey, + outMarketMint: PublicKey, + txParams?: TxParams + ): Promise { + const { txSig } = await this.sendTransaction( + await this.buildTransaction( + await this.getLpPoolSwapIx( + inMarketIndex, + outMarketIndex, + inAmount, + minOutAmount, + lpPool, + constituentTargetWeights, + constituentInTokenAccount, + constituentOutTokenAccount, + userInTokenAccount, + userOutTokenAccount, + inConstituent, + outConstituent, + inMarketMint, + outMarketMint + ), + txParams + ), + [], + this.opts + ); + return txSig; + } + + public async getLpPoolSwapIx( + inMarketIndex: number, + outMarketIndex: number, + inAmount: BN, + minOutAmount: BN, + lpPool: PublicKey, + constituentTargetWeights: PublicKey, + constituentInTokenAccount: PublicKey, + constituentOutTokenAccount: PublicKey, + userInTokenAccount: PublicKey, + userOutTokenAccount: PublicKey, + inConstituent: PublicKey, + outConstituent: PublicKey, + inMarketMint: PublicKey, + outMarketMint: PublicKey + ): Promise { + const remainingAccounts = this.getRemainingAccounts({ + userAccounts: [], + writableSpotMarketIndexes: [inMarketIndex, outMarketIndex], + }); + + return this.program.instruction.lpPoolSwap( + inMarketIndex, + outMarketIndex, + inAmount, + minOutAmount, + { + remainingAccounts, + accounts: { + driftSigner: this.getSignerPublicKey(), + state: await this.getStatePublicKey(), + lpPool, + constituentTargetWeights, + constituentInTokenAccount, + constituentOutTokenAccount, + userInTokenAccount, + userOutTokenAccount, + inConstituent, + outConstituent, + inMarketMint, + outMarketMint, + authority: this.wallet.publicKey, + tokenProgram: TOKEN_PROGRAM_ID, + }, + } + ); + } /** * Below here are the transaction sending functions */ diff --git a/sdk/src/idl/drift.json b/sdk/src/idl/drift.json index 03ff1ba51a..4383a5387e 100644 --- a/sdk/src/idl/drift.json +++ b/sdk/src/idl/drift.json @@ -7558,6 +7558,99 @@ } ], "args": [] + }, + { + "name": "lpPoolSwap", + "accounts": [ + { + "name": "driftSigner", + "isMut": false, + "isSigner": false + }, + { + "name": "state", + "isMut": false, + "isSigner": false + }, + { + "name": "lpPool", + "isMut": false, + "isSigner": false + }, + { + "name": "constituentTargetWeights", + "isMut": true, + "isSigner": false + }, + { + "name": "constituentInTokenAccount", + "isMut": true, + "isSigner": false + }, + { + "name": "constituentOutTokenAccount", + "isMut": true, + "isSigner": false + }, + { + "name": "userInTokenAccount", + "isMut": true, + "isSigner": false + }, + { + "name": "userOutTokenAccount", + "isMut": true, + "isSigner": false + }, + { + "name": "inConstituent", + "isMut": true, + "isSigner": false + }, + { + "name": "outConstituent", + "isMut": true, + "isSigner": false + }, + { + "name": "inMarketMint", + "isMut": false, + "isSigner": false + }, + { + "name": "outMarketMint", + "isMut": false, + "isSigner": false + }, + { + "name": "authority", + "isMut": false, + "isSigner": true + }, + { + "name": "tokenProgram", + "isMut": false, + "isSigner": false + } + ], + "args": [ + { + "name": "inMarketIndex", + "type": "u16" + }, + { + "name": "outMarketIndex", + "type": "u16" + }, + { + "name": "inAmount", + "type": "u64" + }, + { + "name": "minOutAmount", + "type": "u64" + } + ] } ], "accounts": [ @@ -8005,7 +8098,8 @@ { "name": "spotMarketIndex", "docs": [ - "underlying drift spot market index" + "underlying drift spot market index.", + "TODO: redundant with spot_balance.market_index" ], "type": "u16" }, @@ -8053,12 +8147,19 @@ ], "type": "i64" }, + { + "name": "totalSwapFees", + "docs": [ + "total fees received by the constituent. Positive = fees received, Negative = fees paid" + ], + "type": "i128" + }, { "name": "tokenBalance", "docs": [ "ata token balance in token precision" ], - "type": "u128" + "type": "u64" }, { "name": "spotBalance", @@ -8077,6 +8178,10 @@ "name": "lastOracleSlot", "type": "u64" }, + { + "name": "mint", + "type": "publicKey" + }, { "name": "oracleStalenessThreshold", "type": "u64" @@ -12387,6 +12492,9 @@ }, { "name": "UpdateLpPoolAum" + }, + { + "name": "LpPoolSwap" } ] } @@ -14575,6 +14683,81 @@ "index": false } ] + }, + { + "name": "LPSwapRecord", + "fields": [ + { + "name": "ts", + "type": "i64", + "index": false + }, + { + "name": "authority", + "type": "publicKey", + "index": false + }, + { + "name": "amountOut", + "type": "u64", + "index": false + }, + { + "name": "amountIn", + "type": "u64", + "index": false + }, + { + "name": "feeOut", + "type": "i64", + "index": false + }, + { + "name": "feeIn", + "type": "i64", + "index": false + }, + { + "name": "outSpotMarketIndex", + "type": "u16", + "index": false + }, + { + "name": "inSpotMarketIndex", + "type": "u16", + "index": false + }, + { + "name": "outConstituentIndex", + "type": "u16", + "index": false + }, + { + "name": "inConstituentIndex", + "type": "u16", + "index": false + }, + { + "name": "outOraclePrice", + "type": "i64", + "index": false + }, + { + "name": "inOraclePrice", + "type": "i64", + "index": false + }, + { + "name": "mintOut", + "type": "publicKey", + "index": false + }, + { + "name": "mintIn", + "type": "publicKey", + "index": false + } + ] } ], "errors": [ @@ -16187,6 +16370,11 @@ "code": 6321, "name": "OracleTooStaleForLPAUMUpdate", "msg": "Oracle too stale for LP AUM update" + }, + { + "code": 6322, + "name": "InsufficientConstituentTokenBalance", + "msg": "Insufficient constituent token balance" } ], "metadata": { diff --git a/sdk/src/types.ts b/sdk/src/types.ts index eb58c70c85..c04ac28519 100644 --- a/sdk/src/types.ts +++ b/sdk/src/types.ts @@ -1496,7 +1496,7 @@ export type ConstituentTargetWeights = { weights: WeightDatum[]; }; -export type LPPool = { +export type LPPoolAccount = { name: number[]; pubkey: PublicKey; mint: PublicKey; @@ -1504,6 +1504,7 @@ export type LPPool = { lastAum: BN; lastAumSlot: BN; lastAumTs: BN; + oldestOracleSlot: BN; lastRevenueRebalanceTs: BN; totalFeesReceived: BN; totalFeesPaid: BN; @@ -1517,7 +1518,7 @@ export type BLPosition = { balanceType: SpotBalanceType; }; -export type Constituent = { +export type ConstituentAccount = { pubkey: PublicKey; spotMarketIndex: number; constituentIndex: number; @@ -1525,11 +1526,12 @@ export type Constituent = { maxWeightDeviation: BN; swapFeeMin: BN; swapFeeMax: BN; + totalSwapFees: BN; tokenBalance: BN; spotBalance: BLPosition; lastOraclePrice: BN; lastOracleSlot: BN; - oracleStalenessThreshold: BN; + mint: PublicKey; }; export type CacheInfo = { diff --git a/tests/adminDeposit.ts b/tests/adminDeposit.ts new file mode 100644 index 0000000000..2eee301d8c --- /dev/null +++ b/tests/adminDeposit.ts @@ -0,0 +1,222 @@ +import * as anchor from '@coral-xyz/anchor'; +import { expect } from 'chai'; + +import { Program, Wallet } from '@coral-xyz/anchor'; + +import { Keypair } from '@solana/web3.js'; + +import { + BN, + TestClient, + getTokenAmount, + getSignedTokenAmount, +} from '../sdk/src'; + +import { + createFundedKeyPair, + initializeQuoteSpotMarket, + mockUSDCMint, + mockUserUSDCAccount, +} from './testHelpers'; +import { startAnchor } from 'solana-bankrun'; +import { TestBulkAccountLoader } from '../sdk/src/accounts/testBulkAccountLoader'; +import { BankrunContextWrapper } from '../sdk/src/bankrun/bankrunConnection'; +import dotenv from 'dotenv'; +dotenv.config(); + +describe('admin deposit', () => { + const chProgram = anchor.workspace.Drift as Program; + let bankrunContextWrapper: BankrunContextWrapper; + let bulkAccountLoader: TestBulkAccountLoader; + + let adminDriftClient: TestClient; + let adminUSDCAccount: Keypair; + + let userKeyPair: Keypair; + let userDriftClient: TestClient; + + let userKeyPair2: Keypair; + let userDriftClient2: TestClient; + let user2USDCAccount: Keypair; + + let usdcMint; + const usdcAmount = new BN(100 * 10 ** 6); + + before(async () => { + const context = await startAnchor('', [], []); + + // @ts-ignore + bankrunContextWrapper = new BankrunContextWrapper(context); + + userKeyPair = await createFundedKeyPair(bankrunContextWrapper); + userKeyPair2 = await createFundedKeyPair(bankrunContextWrapper); + + bulkAccountLoader = new TestBulkAccountLoader( + bankrunContextWrapper.connection, + 'processed', + 1 + ); + + usdcMint = await mockUSDCMint(bankrunContextWrapper); + adminUSDCAccount = await mockUserUSDCAccount( + usdcMint, + usdcAmount, + bankrunContextWrapper + ); + + user2USDCAccount = await mockUserUSDCAccount( + usdcMint, + usdcAmount, + bankrunContextWrapper, + userKeyPair2.publicKey + ); + + adminDriftClient = new TestClient({ + connection: bankrunContextWrapper.connection.toConnection(), + wallet: bankrunContextWrapper.provider.wallet, + programID: chProgram.programId, + opts: { + commitment: 'confirmed', + }, + activeSubAccountId: 0, + subAccountIds: [], + perpMarketIndexes: [], + spotMarketIndexes: [0], + oracleInfos: [], + accountSubscription: { + type: 'polling', + accountLoader: bulkAccountLoader, + }, + }); + await adminDriftClient.initialize(usdcMint.publicKey, true); + await adminDriftClient.subscribe(); + await initializeQuoteSpotMarket(adminDriftClient, usdcMint.publicKey); + // await adminDriftClient.initializeUserAccountAndDepositCollateral( + // QUOTE_PRECISION, + // adminUSDCAccount.publicKey + // ); + await adminDriftClient.initializeUserAccount(0, 'admin subacc 0'); + + userDriftClient = new TestClient({ + connection: bankrunContextWrapper.connection.toConnection(), + wallet: new Wallet(userKeyPair), + programID: chProgram.programId, + opts: { + commitment: 'confirmed', + }, + activeSubAccountId: 0, + perpMarketIndexes: [], + spotMarketIndexes: [0], + subAccountIds: [], + oracleInfos: [], + accountSubscription: { + type: 'polling', + accountLoader: bulkAccountLoader, + }, + }); + await userDriftClient.subscribe(); + await userDriftClient.initializeUserAccount(0, 'user account 0'); + + userKeyPair2 = await createFundedKeyPair(bankrunContextWrapper); + userDriftClient2 = new TestClient({ + connection: bankrunContextWrapper.connection.toConnection(), + wallet: new Wallet(userKeyPair2), + programID: chProgram.programId, + opts: { + commitment: 'confirmed', + }, + activeSubAccountId: 0, + perpMarketIndexes: [], + spotMarketIndexes: [0], + subAccountIds: [], + oracleInfos: [], + accountSubscription: { + type: 'polling', + accountLoader: bulkAccountLoader, + }, + }); + await userDriftClient2.subscribe(); + }); + + after(async () => { + await adminDriftClient.unsubscribe(); + await userDriftClient.unsubscribe(); + }); + + it('admin can deposit into user', async () => { + const userAccount = await userDriftClient.getUserAccountPublicKey(0); + console.log('user userAccount', userAccount.toBase58()); + + const state = adminDriftClient.getStateAccount().admin.toBase58(); + expect(state).to.be.equal(adminDriftClient.wallet.publicKey.toBase58()); + + // user has 0 balance + let spotPos = userDriftClient.getSpotPosition(0); + let spotMarket = userDriftClient.getSpotMarketAccount(0); + const userSpotBalBefore = getSignedTokenAmount( + getTokenAmount(spotPos.scaledBalance, spotMarket, spotPos.balanceType), + spotPos.balanceType + ); + expect(userSpotBalBefore.toString()).to.be.equal('0'); + + // admin deposits into user + await adminDriftClient.adminDeposit( + 0, + usdcAmount, + userAccount, + adminUSDCAccount.publicKey + ); + + await adminDriftClient.fetchAccounts(); + await userDriftClient.fetchAccounts(); + + // check user got the deposit + spotPos = userDriftClient.getSpotPosition(0); + spotMarket = userDriftClient.getSpotMarketAccount(0); + const userSpotBalAfter = getSignedTokenAmount( + getTokenAmount(spotPos.scaledBalance, spotMarket, spotPos.balanceType), + spotPos.balanceType + ); + const spotBalDiff = userSpotBalAfter.sub(userSpotBalBefore); + expect(spotBalDiff.toString()).to.be.equal(usdcAmount.toString()); + }); + + it('user2 cannot deposit into user', async () => { + const state = adminDriftClient.getStateAccount().admin.toBase58(); + expect(state).to.not.be.equal(userDriftClient2.wallet.publicKey.toBase58()); + + // user has 0 balance + let spotPos = userDriftClient.getSpotPosition(0); + let spotMarket = userDriftClient.getSpotMarketAccount(0); + const userSpotBalBefore = getSignedTokenAmount( + getTokenAmount(spotPos.scaledBalance, spotMarket, spotPos.balanceType), + spotPos.balanceType + ); + + // user2 attempts to deposit into user + try { + await userDriftClient2.adminDeposit( + 0, + usdcAmount, + await userDriftClient.getUserAccountPublicKey(0), + user2USDCAccount.publicKey + ); + expect.fail('should not allow non-admin to call adminDeposit'); + } catch (e) { + expect(e.message as string).to.contain('0x7d3'); + } + + await adminDriftClient.fetchAccounts(); + await userDriftClient.fetchAccounts(); + + // check user did not get the deposit + spotPos = userDriftClient.getSpotPosition(0); + spotMarket = userDriftClient.getSpotMarketAccount(0); + const userSpotBalAfter = getSignedTokenAmount( + getTokenAmount(spotPos.scaledBalance, spotMarket, spotPos.balanceType), + spotPos.balanceType + ); + const spotBalDiff = userSpotBalAfter.sub(userSpotBalBefore); + expect(spotBalDiff.toString()).to.be.equal('0'); + }); +}); diff --git a/tests/lpPool.ts b/tests/lpPool.ts index d7e4c6e91b..42d890d331 100644 --- a/tests/lpPool.ts +++ b/tests/lpPool.ts @@ -19,7 +19,7 @@ import { PEG_PRECISION, ConstituentTargetWeights, AmmConstituentMapping, - LPPool, + LPPoolAccount, getConstituentVaultPublicKey, OracleSource, SPOT_MARKET_WEIGHT_PRECISION, @@ -111,7 +111,7 @@ describe('LP Pool', () => { activeSubAccountId: 0, subAccountIds: [], perpMarketIndexes: [0, 1, 2], - spotMarketIndexes: [0], + spotMarketIndexes: [0, 1], oracleInfos: [{ publicKey: solUsd, source: OracleSource.PYTH }], accountSubscription: { type: 'polling', @@ -254,7 +254,7 @@ describe('LP Pool', () => { const lpPool = (await adminClient.program.account.lpPool.fetch( lpPoolKey - )) as LPPool; + )) as LPPoolAccount; assert(lpPool.constituents == 1); @@ -397,7 +397,7 @@ describe('LP Pool', () => { it('can update pool aum', async () => { const lpPool = (await adminClient.program.account.lpPool.fetch( lpPoolKey - )) as LPPool; + )) as LPPoolAccount; assert(lpPool.constituents == 1); await adminClient.updateDlpPoolAum(lpPool, [0]); diff --git a/tests/lpPoolSwap.ts b/tests/lpPoolSwap.ts index 7c29569ad2..7becd32217 100644 --- a/tests/lpPoolSwap.ts +++ b/tests/lpPoolSwap.ts @@ -1,33 +1,36 @@ import * as anchor from '@coral-xyz/anchor'; import { expect, assert } from 'chai'; - import { Program } from '@coral-xyz/anchor'; - import { Keypair, PublicKey } from '@solana/web3.js'; -import { TOKEN_PROGRAM_ID, getMint } from '@solana/spl-token'; - import { BN, TestClient, QUOTE_PRECISION, getLpPoolPublicKey, - getAmmConstituentMappingPublicKey, encodeName, getConstituentTargetWeightsPublicKey, PERCENTAGE_PRECISION, PRICE_PRECISION, PEG_PRECISION, ConstituentTargetWeights, - AmmConstituentMapping, - User, + OracleSource, + SPOT_MARKET_RATE_PRECISION, + SPOT_MARKET_WEIGHT_PRECISION, + LPPoolAccount, + convertToNumber, + getConstituentVaultPublicKey, + getConstituentPublicKey, + ConstituentAccount, } from '../sdk/src'; - import { - getPerpMarketDecoded, initializeQuoteSpotMarket, - mockOracleNoProgram, mockUSDCMint, mockUserUSDCAccount, + mockOracleNoProgram, + setFeedPriceNoProgram, + overWriteTokenAccountBalance, + overwriteConstituentAccount, + mockAtaTokenAccountForMint, } from './testHelpers'; import { startAnchor } from 'solana-bankrun'; import { TestBulkAccountLoader } from '../sdk/src/accounts/testBulkAccountLoader'; @@ -41,8 +44,9 @@ describe('LP Pool', () => { let bulkAccountLoader: TestBulkAccountLoader; let adminClient: TestClient; - let usdcMint; - let adminUser: User; + let usdcMint: Keypair; + let spotTokenMint: Keypair; + let spotMarketOracle: PublicKey; const mantissaSqrtScale = new BN(Math.sqrt(PRICE_PRECISION.toNumber())); const ammInitialQuoteAssetReserve = new anchor.BN(10 * 10 ** 13).mul( @@ -51,7 +55,6 @@ describe('LP Pool', () => { const ammInitialBaseAssetReserve = new anchor.BN(10 * 10 ** 13).mul( mantissaSqrtScale ); - let solUsd: PublicKey; const lpPoolName = 'test pool 1'; const tokenDecimals = 6; @@ -61,18 +64,7 @@ describe('LP Pool', () => { ); before(async () => { - const context = await startAnchor( - '', - [ - { - name: 'token_2022', - programId: new PublicKey( - 'TokenzQdBNbLqP5VEhdkAS6EPFLC1PHnBqCXEpPxuEb' - ), - }, - ], - [] - ); + const context = await startAnchor('', [], []); // @ts-ignore bankrunContextWrapper = new BankrunContextWrapper(context); @@ -84,12 +76,12 @@ describe('LP Pool', () => { ); usdcMint = await mockUSDCMint(bankrunContextWrapper); + spotTokenMint = await mockUSDCMint(bankrunContextWrapper); + spotMarketOracle = await mockOracleNoProgram(bankrunContextWrapper, 200.1); const keypair = new Keypair(); await bankrunContextWrapper.fundKeypair(keypair, 10 ** 9); - usdcMint = await mockUSDCMint(bankrunContextWrapper); - adminClient = new TestClient({ connection: bankrunContextWrapper.connection.toConnection(), wallet: new anchor.Wallet(keypair), @@ -100,8 +92,13 @@ describe('LP Pool', () => { activeSubAccountId: 0, subAccountIds: [], perpMarketIndexes: [0, 1], - spotMarketIndexes: [0], - oracleInfos: [], + spotMarketIndexes: [0, 1], + oracleInfos: [ + { + publicKey: spotMarketOracle, + source: OracleSource.PYTH, + }, + ], accountSubscription: { type: 'polling', accountLoader: bulkAccountLoader, @@ -122,21 +119,12 @@ describe('LP Pool', () => { new BN(10).mul(QUOTE_PRECISION), userUSDCAccount.publicKey ); - adminUser = new User({ - driftClient: adminClient, - userAccountPublicKey: await adminClient.getUserAccountPublicKey(), - accountSubscription: { - type: 'polling', - accountLoader: bulkAccountLoader, - }, - }); - solUsd = await mockOracleNoProgram(bankrunContextWrapper, 224.3); const periodicity = new BN(0); await adminClient.initializePerpMarket( 0, - solUsd, + spotMarketOracle, ammInitialBaseAssetReserve, ammInitialQuoteAssetReserve, periodicity, @@ -145,200 +133,260 @@ describe('LP Pool', () => { await adminClient.initializePerpMarket( 1, - solUsd, + spotMarketOracle, ammInitialBaseAssetReserve, ammInitialQuoteAssetReserve, periodicity, new BN(224 * PEG_PRECISION.toNumber()) ); + const optimalUtilization = SPOT_MARKET_RATE_PRECISION.div( + new BN(2) + ).toNumber(); // 50% utilization + const optimalRate = SPOT_MARKET_RATE_PRECISION.toNumber(); + const maxRate = SPOT_MARKET_RATE_PRECISION.toNumber(); + const initialAssetWeight = SPOT_MARKET_WEIGHT_PRECISION.toNumber(); + const maintenanceAssetWeight = SPOT_MARKET_WEIGHT_PRECISION.toNumber(); + const initialLiabilityWeight = SPOT_MARKET_WEIGHT_PRECISION.toNumber(); + const maintenanceLiabilityWeight = SPOT_MARKET_WEIGHT_PRECISION.toNumber(); + const imfFactor = 0; + + await adminClient.initializeSpotMarket( + spotTokenMint.publicKey, + optimalUtilization, + optimalRate, + maxRate, + spotMarketOracle, + OracleSource.PYTH, + initialAssetWeight, + maintenanceAssetWeight, + initialLiabilityWeight, + maintenanceLiabilityWeight, + imfFactor + ); + await adminClient.initializeLpPool( lpPoolName, new BN(100_000_000).mul(QUOTE_PRECISION), - Keypair.generate() + Keypair.generate() // dlp mint ); - }); - - after(async () => { - await adminClient.unsubscribe(); - }); - - it('can create a new LP Pool', async () => { - // check LpPool created - const lpPool = await adminClient.program.account.lpPool.fetch(lpPoolKey); - - // Check amm constituent map exists - const ammConstituentMapPublicKey = getAmmConstituentMappingPublicKey( - program.programId, - lpPoolKey - ); - const ammConstituentMap = - (await adminClient.program.account.ammConstituentMapping.fetch( - ammConstituentMapPublicKey - )) as AmmConstituentMapping; - expect(ammConstituentMap).to.not.be.null; - assert(ammConstituentMap.weights.length == 0); - - // check constituent target weights exists - const constituentTargetWeightsPublicKey = - getConstituentTargetWeightsPublicKey(program.programId, lpPoolKey); - const constituentTargetWeights = - (await adminClient.program.account.constituentTargetWeights.fetch( - constituentTargetWeightsPublicKey - )) as ConstituentTargetWeights; - expect(constituentTargetWeights).to.not.be.null; - assert(constituentTargetWeights.weights.length == 0); - - // check mint created correctly - const mintInfo = await getMint( - bankrunContextWrapper.connection.toConnection(), - lpPool.mint as PublicKey - ); - expect(mintInfo.decimals).to.equal(tokenDecimals); - expect(Number(mintInfo.supply)).to.equal(0); - expect(mintInfo.mintAuthority!.toBase58()).to.equal(lpPoolKey.toBase58()); - }); - - it('can add constituent to LP Pool', async () => { await adminClient.initializeConstituent( encodeName(lpPoolName), 0, 6, - new BN(10).mul(PERCENTAGE_PRECISION), - new BN(1).mul(PERCENTAGE_PRECISION), - new BN(2).mul(PERCENTAGE_PRECISION) + PERCENTAGE_PRECISION.divn(10), // 10% max dev + PERCENTAGE_PRECISION.divn(10000), // min fee 1 bps + PERCENTAGE_PRECISION.divn(100), // max 1% + new BN(100) + ); + await adminClient.initializeConstituent( + encodeName(lpPoolName), + 1, + 6, + PERCENTAGE_PRECISION.divn(10), // 10% max dev + PERCENTAGE_PRECISION.divn(10000), // min 1 bps + PERCENTAGE_PRECISION.divn(100), // max 1% + new BN(100) ); - const constituentTargetWeightsPublicKey = - getConstituentTargetWeightsPublicKey(program.programId, lpPoolKey); - const constituentTargetWeights = - (await adminClient.program.account.constituentTargetWeights.fetch( - constituentTargetWeightsPublicKey - )) as ConstituentTargetWeights; - expect(constituentTargetWeights).to.not.be.null; - assert(constituentTargetWeights.weights.length == 1); }); - it('can add amm mapping datum', async () => { - await adminClient.addInitAmmConstituentMappingData(encodeName(lpPoolName), [ - { - perpMarketIndex: 0, - constituentIndex: 0, - }, - { - perpMarketIndex: 1, - constituentIndex: 0, - }, - ]); - const ammConstituentMapping = getAmmConstituentMappingPublicKey( - program.programId, - lpPoolKey - ); - const ammMapping = - (await adminClient.program.account.ammConstituentMapping.fetch( - ammConstituentMapping - )) as AmmConstituentMapping; - expect(ammMapping).to.not.be.null; - assert(ammMapping.weights.length == 2); + after(async () => { + await adminClient.unsubscribe(); }); - it('fails adding datum with bad params', async () => { - // Bad perp market index + it('LP Pool init properly', async () => { + let lpPool: LPPoolAccount; try { - await adminClient.addInitAmmConstituentMappingData( - encodeName(lpPoolName), - [ - { - perpMarketIndex: 2, - constituentIndex: 0, - }, - ] - ); - expect.fail('should have failed'); + lpPool = (await adminClient.program.account.lpPool.fetch( + lpPoolKey + )) as LPPoolAccount; + expect(lpPool).to.not.be.null; } catch (e) { - expect(e.message).to.contain('0x18ab'); + expect.fail('LP Pool should have been created'); } - // Bad constituent index try { - await adminClient.addInitAmmConstituentMappingData( - encodeName(lpPoolName), - [ - { - perpMarketIndex: 0, - constituentIndex: 1, - }, - ] - ); - expect.fail('should have failed'); + const constituentTargetWeightsPublicKey = + getConstituentTargetWeightsPublicKey(program.programId, lpPoolKey); + const constituentTargetWeights = + (await adminClient.program.account.constituentTargetWeights.fetch( + constituentTargetWeightsPublicKey + )) as ConstituentTargetWeights; + expect(constituentTargetWeights).to.not.be.null; + assert(constituentTargetWeights.weights.length == 2); } catch (e) { - expect(e.message).to.contain('0x18ab'); + expect.fail('Amm constituent map should have been created'); } }); - it('can update constituent target weights', async () => { - // Override AMM to have a balance - const perpMarket = adminClient.getPerpMarketAccount(0); - const raw = await bankrunContextWrapper.connection.getAccountInfo( - perpMarket.pubkey + it('lp pool swap', async () => { + let spotOracle = adminClient.getOracleDataForSpotMarket(1); + const price1 = convertToNumber(spotOracle.price); + + await setFeedPriceNoProgram(bankrunContextWrapper, 224.3, spotMarketOracle); + + await adminClient.fetchAccounts(); + + spotOracle = adminClient.getOracleDataForSpotMarket(1); + const price2 = convertToNumber(spotOracle.price); + assert(price2 > price1); + + const const0TokenAccount = getConstituentVaultPublicKey( + program.programId, + lpPoolKey, + 0 + ); + const const1TokenAccount = getConstituentVaultPublicKey( + program.programId, + lpPoolKey, + 1 ); - const buf = raw.data; - buf.writeBigInt64LE(BigInt(1000000000), 304); + const const0Key = getConstituentPublicKey(program.programId, lpPoolKey, 0); + const const1Key = getConstituentPublicKey(program.programId, lpPoolKey, 1); - bankrunContextWrapper.context.setAccount(perpMarket.pubkey, { - executable: raw.executable, - owner: raw.owner, - lamports: raw.lamports, - rentEpoch: raw.rentEpoch, - data: buf, - }); + const c0TokenBalance = new BN(224_300_000_000); + const c1TokenBalance = new BN(1_000_000_000); - const perpMarketAccountAfter = await getPerpMarketDecoded( - adminClient, + await overWriteTokenAccountBalance( bankrunContextWrapper, - perpMarket.pubkey + const0TokenAccount, + BigInt(c0TokenBalance.toString()) + ); + await overwriteConstituentAccount( + bankrunContextWrapper, + adminClient.program, + const0Key, + [['tokenBalance', c0TokenBalance]] ); - assert(!perpMarketAccountAfter.amm.baseAssetAmountLong.isZero()); - // Override LP pool to have some aum - const lpraw = await bankrunContextWrapper.connection.getAccountInfo( - lpPoolKey + await overWriteTokenAccountBalance( + bankrunContextWrapper, + const1TokenAccount, + BigInt(c1TokenBalance.toString()) + ); + await overwriteConstituentAccount( + bankrunContextWrapper, + adminClient.program, + const1Key, + [['tokenBalance', c1TokenBalance]] ); - const lpbuf = lpraw.data; - buf.writeBigInt64LE(BigInt(1000000000), 152); + // check fields overwritten correctly + const c0 = (await adminClient.program.account.constituent.fetch( + const0Key + )) as ConstituentAccount; + expect(c0.tokenBalance.toString()).to.equal(c0TokenBalance.toString()); + console.log('c0', c0); - bankrunContextWrapper.context.setAccount(lpPoolKey, { - executable: lpraw.executable, - owner: lpraw.owner, - lamports: lpraw.lamports, - rentEpoch: lpraw.rentEpoch, - data: lpbuf, - }); + const c1 = (await adminClient.program.account.constituent.fetch( + const1Key + )) as ConstituentAccount; + expect(c1.tokenBalance.toString()).to.equal(c1TokenBalance.toString()); + console.log('c1', c1); - const ammConstituentMappingPublicKey = getAmmConstituentMappingPublicKey( - program.programId, + const prec = new BN(10).pow(new BN(tokenDecimals)); + console.log(`const0 balance: ${convertToNumber(c0.tokenBalance, prec)}`); + console.log(`const1 balance: ${convertToNumber(c1.tokenBalance, prec)}`); + + const lpPool1 = (await adminClient.program.account.lpPool.fetch( lpPoolKey - ); + )) as LPPoolAccount; + expect(lpPool1.lastAumSlot.toNumber()).to.be.equal(0); - const ammMapping = - (await adminClient.program.account.ammConstituentMapping.fetch( - ammConstituentMappingPublicKey - )) as AmmConstituentMapping; + await adminClient.updateDlpPoolAum(lpPool1, [1, 0]); + + const lpPool2 = (await adminClient.program.account.lpPool.fetch( + lpPoolKey + )) as LPPoolAccount; + + expect(lpPool2.lastAumSlot.toNumber()).to.be.greaterThan(0); + expect(lpPool2.lastAum.gt(lpPool1.lastAum)).to.be.true; + console.log(`AUM: ${convertToNumber(lpPool2.lastAum, QUOTE_PRECISION)}`); - console.log(`ok there should be ${ammMapping.weights.length} constituents`); - await adminClient.updateDlpConstituentTargetWeights( - encodeName(lpPoolName), - [0], - ammMapping - ); const constituentTargetWeightsPublicKey = getConstituentTargetWeightsPublicKey(program.programId, lpPoolKey); - const constituentTargetWeights = - (await adminClient.program.account.constituentTargetWeights.fetch( - constituentTargetWeightsPublicKey - )) as ConstituentTargetWeights; - expect(constituentTargetWeights).to.not.be.null; - assert(constituentTargetWeights.weights.length == 1); + + // swap c0 for c1 + + const adminAuth = adminClient.wallet.publicKey; + + // mint some tokens for user + const c0UserTokenAccount = await mockAtaTokenAccountForMint( + bankrunContextWrapper, + usdcMint.publicKey, + new BN(224_300_000_000), + adminAuth + ); + const c1UserTokenAccount = await mockAtaTokenAccountForMint( + bankrunContextWrapper, + spotTokenMint.publicKey, + new BN(1_000_000_000), + adminAuth + ); + + // console.log(`0 mint: ${usdcMint.publicKey.toBase58()}`) + // console.log(`const0:`, await adminClient.program.account.constituent.fetch(const0Key)) + // console.log(`1 mint: ${spotTokenMint.publicKey.toBase58()}`) + // console.log(`const1:`, await adminClient.program.account.constituent.fetch(const1Key)) + + // const m0 = await adminClient.getSpotMarketAccount(0); + // const m1 = await adminClient.getSpotMarketAccount(1); + // console.log(`m0 ${m0.pubkey.toBase58()}, ${m0.oracle.toBase58()}`) + // console.log(`m1 ${m1.pubkey.toBase58()}, ${m1.oracle.toBase58()}`) + + const inTokenBalanceBefore = + await bankrunContextWrapper.connection.getTokenAccount( + c0UserTokenAccount + ); + const outTokenBalanceBefore = + await bankrunContextWrapper.connection.getTokenAccount( + c1UserTokenAccount + ); + + // in = 0, out = 1 + await adminClient.lpPoolSwap( + 0, + 1, + new BN(224_300_000), + new BN(0), + lpPoolKey, + constituentTargetWeightsPublicKey, + const0TokenAccount, + const1TokenAccount, + c0UserTokenAccount, + c1UserTokenAccount, + const0Key, + const1Key, + usdcMint.publicKey, + spotTokenMint.publicKey + ); + + const inTokenBalanceAfter = + await bankrunContextWrapper.connection.getTokenAccount( + c0UserTokenAccount + ); + const outTokenBalanceAfter = + await bankrunContextWrapper.connection.getTokenAccount( + c1UserTokenAccount + ); + const diffInToken = + inTokenBalanceAfter.amount - inTokenBalanceBefore.amount; + const diffOutToken = + outTokenBalanceAfter.amount - outTokenBalanceBefore.amount; + + expect(Number(diffInToken)).to.be.equal(-224_300_000); + expect(Number(diffOutToken)).to.be.approximately(980100, 1); + + console.log( + `in Token: ${inTokenBalanceBefore.amount} -> ${ + inTokenBalanceAfter.amount + } (${Number(diffInToken) / 1e6})` + ); + console.log( + `out Token: ${outTokenBalanceBefore.amount} -> ${ + outTokenBalanceAfter.amount + } (${Number(diffOutToken) / 1e6})` + ); }); }); diff --git a/tests/testHelpers.ts b/tests/testHelpers.ts index 712fa6db2c..6738c476ad 100644 --- a/tests/testHelpers.ts +++ b/tests/testHelpers.ts @@ -15,6 +15,9 @@ import { createInitializePermanentDelegateInstruction, getMintLen, ExtensionType, + unpackAccount, + type RawAccount, + AccountState, } from '@solana/spl-token'; import { AccountInfo, @@ -35,7 +38,6 @@ import { OraclePriceData, OracleInfo, PerpMarketAccount, - UserAccount, } from '../sdk'; import { TestClient, @@ -45,6 +47,7 @@ import { QUOTE_PRECISION, User, OracleSource, + ConstituentAccount, } from '../sdk/src'; import { BankrunContextWrapper, @@ -224,6 +227,39 @@ export async function mockUserUSDCAccount( return userUSDCAccount; } +export async function mockAtaTokenAccountForMint( + context: BankrunContextWrapper, + tokenMint: PublicKey, + amount: BN, + owner: PublicKey +): Promise { + const userTokenAccount = getAssociatedTokenAddressSync(tokenMint, owner); + const newTx = new Transaction(); + + const tokenProgram = (await context.connection.getAccountInfo(tokenMint)) + .owner; + + newTx.add( + createAssociatedTokenAccountIdempotentInstruction( + context.context.payer.publicKey, + userTokenAccount, + owner, + tokenMint, + tokenProgram + ) + ); + + await context.sendTransaction(newTx, [context.context.payer]); + + await overWriteTokenAccountBalance( + context, + userTokenAccount, + BigInt(amount.toString()) + ); + + return userTokenAccount; +} + export function getMockUserUsdcAccountInfo( fakeUSDCMint: Keypair, usdcMintAmount: BN, @@ -1136,3 +1172,63 @@ export async function getPerpMarketDecoded( driftClient.program.coder.accounts.decode('PerpMarket', accountInfo!.data); return perpMarketAccount; } + +export async function overWriteTokenAccountBalance( + bankrunContextWrapper: BankrunContextWrapper, + tokenAccount: PublicKey, + newBalance: bigint +) { + const info = await bankrunContextWrapper.connection.getAccountInfo( + tokenAccount + ); + const account = unpackAccount(tokenAccount, info, info.owner); + account.amount = newBalance; + const data = Buffer.alloc(AccountLayout.span); + const rawAccount: RawAccount = { + mint: account.mint, + owner: account.owner, + amount: account.amount, + delegateOption: account.delegate ? 1 : 0, + delegate: account.delegate || PublicKey.default, + state: account.isFrozen ? AccountState.Frozen : AccountState.Initialized, + isNativeOption: account.isNative ? 1 : 0, + isNative: account.rentExemptReserve || BigInt(0), + delegatedAmount: account.delegatedAmount, + closeAuthorityOption: account.closeAuthority ? 1 : 0, + closeAuthority: account.closeAuthority || PublicKey.default, + }; + AccountLayout.encode(rawAccount, data); + bankrunContextWrapper.context.setAccount(tokenAccount, { + executable: info.executable, + owner: info.owner, + lamports: info.lamports, + data: data, + rentEpoch: info.rentEpoch, + }); +} + +export async function overwriteConstituentAccount( + bankrunContextWrapper: BankrunContextWrapper, + program: Program, + constituentPublicKey: PublicKey, + overwriteFields: Array<[key: keyof ConstituentAccount, value: any]> +) { + const acc = await program.account.constituent.fetch(constituentPublicKey); + if (!acc) { + throw new Error( + `Constituent account ${constituentPublicKey.toBase58()} not found` + ); + } + for (const [key, value] of overwriteFields) { + acc[key] = value; + } + bankrunContextWrapper.context.setAccount(constituentPublicKey, { + executable: false, + owner: program.programId, + lamports: LAMPORTS_PER_SOL, + data: await program.account.constituent.coder.accounts.encode( + 'Constituent', + acc + ), + }); +}