diff --git a/contracts/basic_staking/src/contract.rs b/contracts/basic_staking/src/contract.rs index e447999a..d0161fe7 100644 --- a/contracts/basic_staking/src/contract.rs +++ b/contracts/basic_staking/src/contract.rs @@ -7,7 +7,7 @@ use shade_protocol::{ StdError, StdResult, Uint128, }, query_auth::helpers::{authenticate_permit, authenticate_vk, PermitAuthentication}, - snip20::helpers::{register_receive, set_viewing_key_msg}, + snip20::helpers::{balance_query, register_receive, set_viewing_key_msg}, utils::{asset::Contract, pad_handle_result}, }; @@ -241,5 +241,37 @@ pub fn migrate(deps: DepsMut, env: Env, _msg: MigrateMsg) -> StdResult &stake_token, )?); } - Ok(Response::default().add_messages(msgs)) + + // Rectify Total Staked + // (only referencing stake token amounts) + // total_staked = total balance - unclaimed rewards + let viewing_key = VIEWING_KEY.load(deps.storage)?; + let stake_token_balance = balance_query( + &deps.querier, + env.contract.address, + viewing_key, + &stake_token, + )?; + let reward_pools = REWARD_POOLS.load(deps.storage)?; + let mut unclaimed_stake_token_rewards = Uint128::zero(); + + for reward_pool in reward_pools { + if reward_pool.token.address == stake_token.address { + unclaimed_stake_token_rewards += reward_pool.amount - reward_pool.claimed; + } + } + + let prev_total_staked = TOTAL_STAKED.load(deps.storage)?; + let total_staked = stake_token_balance - unclaimed_stake_token_rewards; + TOTAL_STAKED.save(deps.storage, &total_staked)?; + + Ok(Response::default() + .add_messages(msgs) + .add_attribute("total_staked", total_staked) + .add_attribute("prev_total_staked", prev_total_staked) + .add_attribute("stake_token_balance", stake_token_balance) + .add_attribute( + "unclaimed_stake_token_rewards", + unclaimed_stake_token_rewards, + )) } diff --git a/contracts/basic_staking/src/execute.rs b/contracts/basic_staking/src/execute.rs index a61fb1bc..6224a464 100644 --- a/contracts/basic_staking/src/execute.rs +++ b/contracts/basic_staking/src/execute.rs @@ -538,6 +538,7 @@ pub fn unbond( // Reduce by unbonding user_staked -= amount; + total_staked -= amount; TOTAL_STAKED.save(deps.storage, &total_staked)?; USER_STAKED.save(deps.storage, info.sender.clone(), &user_staked)?; diff --git a/contracts/basic_staking/tests/unbonding_withdrawals.rs b/contracts/basic_staking/tests/unbonding_withdrawals.rs index d1f962f7..3d9fd6d9 100644 --- a/contracts/basic_staking/tests/unbonding_withdrawals.rs +++ b/contracts/basic_staking/tests/unbonding_withdrawals.rs @@ -247,6 +247,26 @@ fn unbonding_withdrawals( }; } + let unbonded = unbonding_amounts.iter().sum::(); + // check total staked + match (basic_staking::QueryMsg::TotalStaked {} + .test_query(&basic_staking, &app) + .unwrap()) + { + basic_staking::QueryAnswer::TotalStaked { amount } => { + assert_eq!( + amount, + stake_amount - unbonded, + "Total staked {} != {} expected", + amount, + stake_amount - unbonded + ); + } + _ => { + panic!("Total staked query failed"); + } + }; + // Check snip20 received by user match (snip20::QueryMsg::Balance { key: viewing_key.clone(),