diff --git a/.github/actions/action.yml b/.github/actions/action.yml index 153cabaf6..db1a753ea 100644 --- a/.github/actions/action.yml +++ b/.github/actions/action.yml @@ -30,6 +30,10 @@ runs: run: forge compile --contracts proposals/ shell: bash + - name: Compile MultiRewarder + run: forge compile --contracts crv-rewards/ + shell: bash + - name: Compile Contracts run: forge build shell: bash diff --git a/.github/workflows/unit.yml b/.github/workflows/unit.yml index acb332be8..457655097 100644 --- a/.github/workflows/unit.yml +++ b/.github/workflows/unit.yml @@ -28,3 +28,12 @@ jobs: max_attempts: 3 command: time forge test -vvv --match-contract UnitTest + - name: Run MultiRewards Tests + uses: nick-fields/retry@v3 + with: + polling_interval_seconds: 30 + retry_wait_seconds: 60 + timeout_minutes: 20 + max_attempts: 3 + command: time forge test --match-path test/unit/MultiRewards.t.sol -vvv + diff --git a/crv-rewards/IMultiRewards.sol b/crv-rewards/IMultiRewards.sol new file mode 100644 index 000000000..fb0ce76af --- /dev/null +++ b/crv-rewards/IMultiRewards.sol @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: GPL-3.0-or-later +pragma solidity 0.8.19; + +interface IMultiRewards { + struct Reward { + address rewardsDistributor; + uint256 rewardsDuration; + uint256 periodFinish; + uint256 rewardRate; + uint256 lastUpdateTime; + uint256 rewardPerTokenStored; + } + + // Events + event RewardAdded(uint256 reward); + event Staked(address indexed user, uint256 amount); + event Withdrawn(address indexed user, uint256 amount); + event RewardPaid( + address indexed user, + address indexed rewardsToken, + uint256 reward + ); + event RewardsDurationUpdated(address token, uint256 newDuration); + event Recovered(address token, uint256 amount); + event OwnerNominated(address newOwner); + event OwnerChanged(address oldOwner, address newOwner); + event PauseChanged(bool isPaused); + + // View functions + function owner() external view returns (address); + function nominatedOwner() external view returns (address); + function stakingToken() external view returns (address); + function rewardData( + address + ) + external + view + returns ( + address rewardsDistributor, + uint256 rewardsDuration, + uint256 periodFinish, + uint256 rewardRate, + uint256 lastUpdateTime, + uint256 rewardPerTokenStored + ); + function rewardTokens(uint256) external view returns (address); + function totalSupply() external view returns (uint256); + function balanceOf(address account) external view returns (uint256); + function lastTimeRewardApplicable( + address _rewardsToken + ) external view returns (uint256); + function rewardPerToken( + address _rewardsToken + ) external view returns (uint256); + function earned( + address account, + address _rewardsToken + ) external view returns (uint256); + function getRewardForDuration( + address _rewardsToken + ) external view returns (uint256); + + // Mutative functions + function nominateNewOwner(address _owner) external; + function acceptOwnership() external; + function setPaused(bool _paused) external; + function addReward( + address _rewardsToken, + address _rewardsDistributor, + uint256 _rewardsDuration + ) external; + function setRewardsDistributor( + address _rewardsToken, + address _rewardsDistributor + ) external; + function stake(uint256 amount) external; + function withdraw(uint256 amount) external; + function getReward() external; + function exit() external; + function notifyRewardAmount(address _rewardsToken, uint256 reward) external; + function recoverERC20(address tokenAddress, uint256 tokenAmount) external; + function setRewardsDuration( + address _rewardsToken, + uint256 _rewardsDuration + ) external; +} diff --git a/crv-rewards/MultiRewards.sol b/crv-rewards/MultiRewards.sol new file mode 100644 index 000000000..f5e6baf31 --- /dev/null +++ b/crv-rewards/MultiRewards.sol @@ -0,0 +1,698 @@ +pragma solidity 0.5.17; + +library Address { + /** + * @dev Returns true if `account` is a contract. + * + * This test is non-exhaustive, and there may be false-negatives: during the + * execution of a contract's constructor, its address will be reported as + * not containing a contract. + * + * > It is unsafe to assume that an address for which this function returns + * false is an externally-owned account (EOA) and not a contract. + */ + function isContract(address account) internal view returns (bool) { + // This method relies in extcodesize, which returns 0 for contracts in + // construction, since the code is only stored at the end of the + // constructor execution. + + uint256 size; + // solhint-disable-next-line no-inline-assembly + assembly { + size := extcodesize(account) + } + return size > 0; + } +} + +interface IERC20 { + /** + * @dev Returns the amount of tokens in existence. + */ + function totalSupply() external view returns (uint256); + + /** + * @dev Returns the amount of tokens owned by `account`. + */ + function balanceOf(address account) external view returns (uint256); + + /** + * @dev Moves `amount` tokens from the caller's account to `recipient`. + * + * Returns a boolean value indicating whether the operation succeeded. + * + * Emits a `Transfer` event. + */ + function transfer( + address recipient, + uint256 amount + ) external returns (bool); + + /** + * @dev Returns the remaining number of tokens that `spender` will be + * allowed to spend on behalf of `owner` through `transferFrom`. This is + * zero by default. + * + * This value changes when `approve` or `transferFrom` are called. + */ + function allowance( + address owner, + address spender + ) external view returns (uint256); + + /** + * @dev Sets `amount` as the allowance of `spender` over the caller's tokens. + * + * Returns a boolean value indicating whether the operation succeeded. + * + * > Beware that changing an allowance with this method brings the risk + * that someone may use both the old and the new allowance by unfortunate + * transaction ordering. One possible solution to mitigate this race + * condition is to first reduce the spender's allowance to 0 and set the + * desired value afterwards: + * https://github.com/ethereum/EIPs/issues/20#issuecomment-263524729 + * + * Emits an `Approval` event. + */ + function approve(address spender, uint256 amount) external returns (bool); + + /** + * @dev Moves `amount` tokens from `sender` to `recipient` using the + * allowance mechanism. `amount` is then deducted from the caller's + * allowance. + * + * Returns a boolean value indicating whether the operation succeeded. + * + * Emits a `Transfer` event. + */ + function transferFrom( + address sender, + address recipient, + uint256 amount + ) external returns (bool); + + /** + * @dev Emitted when `value` tokens are moved from one account (`from`) to + * another (`to`). + * + * Note that `value` may be zero. + */ + event Transfer(address indexed from, address indexed to, uint256 value); + + /** + * @dev Emitted when the allowance of a `spender` for an `owner` is set by + * a call to `approve`. `value` is the new allowance. + */ + event Approval( + address indexed owner, + address indexed spender, + uint256 value + ); +} + +library Math { + /** + * @dev Returns the largest of two numbers. + */ + function max(uint256 a, uint256 b) internal pure returns (uint256) { + return a >= b ? a : b; + } + + /** + * @dev Returns the smallest of two numbers. + */ + function min(uint256 a, uint256 b) internal pure returns (uint256) { + return a < b ? a : b; + } + + /** + * @dev Returns the average of two numbers. The result is rounded towards + * zero. + */ + function average(uint256 a, uint256 b) internal pure returns (uint256) { + // (a + b) / 2 can overflow, so we distribute + return (a / 2) + (b / 2) + (((a % 2) + (b % 2)) / 2); + } +} + +contract Owned { + address public owner; + address public nominatedOwner; + + constructor(address _owner) public { + require(_owner != address(0), "Owner address cannot be 0"); + owner = _owner; + emit OwnerChanged(address(0), _owner); + } + + function nominateNewOwner(address _owner) external onlyOwner { + nominatedOwner = _owner; + emit OwnerNominated(_owner); + } + + function acceptOwnership() external { + require( + msg.sender == nominatedOwner, + "You must be nominated before you can accept ownership" + ); + emit OwnerChanged(owner, nominatedOwner); + owner = nominatedOwner; + nominatedOwner = address(0); + } + + modifier onlyOwner() { + _onlyOwner(); + _; + } + + function _onlyOwner() private view { + require( + msg.sender == owner, + "Only the contract owner may perform this action" + ); + } + + event OwnerNominated(address newOwner); + event OwnerChanged(address oldOwner, address newOwner); +} + +contract Pausable is Owned { + uint public lastPauseTime; + bool public paused; + + constructor() internal { + // This contract is abstract, and thus cannot be instantiated directly + require(owner != address(0), "Owner must be set"); + // Paused will be false, and lastPauseTime will be 0 upon initialisation + } + + /** + * @notice Change the paused state of the contract + * @dev Only the contract owner may call this. + */ + function setPaused(bool _paused) external onlyOwner { + // Ensure we're actually changing the state before we do anything + if (_paused == paused) { + return; + } + + // Set our paused state. + paused = _paused; + + // If applicable, set the last pause time. + if (paused) { + lastPauseTime = now; + } + + // Let everyone know that our pause state has changed. + emit PauseChanged(paused); + } + + event PauseChanged(bool isPaused); + + modifier notPaused() { + require( + !paused, + "This action cannot be performed while the contract is paused" + ); + _; + } +} + +contract ReentrancyGuard { + /// @dev counter to allow mutex lock with only one SSTORE operation + uint256 private _guardCounter; + + constructor() internal { + // The counter starts at one to prevent changing it from zero to a non-zero + // value, which is a more expensive operation. + _guardCounter = 1; + } + + /** + * @dev Prevents a contract from calling itself, directly or indirectly. + * Calling a `nonReentrant` function from another `nonReentrant` + * function is not supported. It is possible to prevent this from happening + * by making the `nonReentrant` function external, and make it call a + * `private` function that does the actual work. + */ + modifier nonReentrant() { + _guardCounter += 1; + uint256 localCounter = _guardCounter; + _; + require( + localCounter == _guardCounter, + "ReentrancyGuard: reentrant call" + ); + } +} + +library SafeERC20 { + using SafeMath for uint256; + using Address for address; + + function safeTransfer(IERC20 token, address to, uint256 value) internal { + callOptionalReturn( + token, + abi.encodeWithSelector(token.transfer.selector, to, value) + ); + } + + function safeTransferFrom( + IERC20 token, + address from, + address to, + uint256 value + ) internal { + callOptionalReturn( + token, + abi.encodeWithSelector(token.transferFrom.selector, from, to, value) + ); + } + + function safeApprove( + IERC20 token, + address spender, + uint256 value + ) internal { + // safeApprove should only be called when setting an initial allowance, + // or when resetting it to zero. To increase and decrease it, use + // 'safeIncreaseAllowance' and 'safeDecreaseAllowance' + // solhint-disable-next-line max-line-length + require( + (value == 0) || (token.allowance(address(this), spender) == 0), + "SafeERC20: approve from non-zero to non-zero allowance" + ); + callOptionalReturn( + token, + abi.encodeWithSelector(token.approve.selector, spender, value) + ); + } + + function safeIncreaseAllowance( + IERC20 token, + address spender, + uint256 value + ) internal { + uint256 newAllowance = token.allowance(address(this), spender).add( + value + ); + callOptionalReturn( + token, + abi.encodeWithSelector( + token.approve.selector, + spender, + newAllowance + ) + ); + } + + function safeDecreaseAllowance( + IERC20 token, + address spender, + uint256 value + ) internal { + uint256 newAllowance = token.allowance(address(this), spender).sub( + value + ); + callOptionalReturn( + token, + abi.encodeWithSelector( + token.approve.selector, + spender, + newAllowance + ) + ); + } + + /** + * @dev Imitates a Solidity high-level call (i.e. a regular function call to a contract), relaxing the requirement + * on the return value: the return value is optional (but if data is returned, it must not be false). + * @param token The token targeted by the call. + * @param data The call data (encoded using abi.encode or one of its variants). + */ + function callOptionalReturn(IERC20 token, bytes memory data) private { + // We need to perform a low level call here, to bypass Solidity's return data size checking mechanism, since + // we're implementing it ourselves. + + // A Solidity high level call has three parts: + // 1. The target address is checked to verify it contains contract code + // 2. The call itself is made, and success asserted + // 3. The return value is decoded, which in turn checks the size of the returned data. + // solhint-disable-next-line max-line-length + require(address(token).isContract(), "SafeERC20: call to non-contract"); + + // solhint-disable-next-line avoid-low-level-calls + (bool success, bytes memory returndata) = address(token).call(data); + require(success, "SafeERC20: low-level call failed"); + + if (returndata.length > 0) { + // Return data is optional + // solhint-disable-next-line max-line-length + require( + abi.decode(returndata, (bool)), + "SafeERC20: ERC20 operation did not succeed" + ); + } + } +} + +library SafeMath { + /** + * @dev Returns the addition of two unsigned integers, reverting on + * overflow. + * + * Counterpart to Solidity's `+` operator. + * + * Requirements: + * - Addition cannot overflow. + */ + function add(uint256 a, uint256 b) internal pure returns (uint256) { + uint256 c = a + b; + require(c >= a, "SafeMath: addition overflow"); + + return c; + } + + /** + * @dev Returns the subtraction of two unsigned integers, reverting on + * overflow (when the result is negative). + * + * Counterpart to Solidity's `-` operator. + * + * Requirements: + * - Subtraction cannot overflow. + */ + function sub(uint256 a, uint256 b) internal pure returns (uint256) { + require(b <= a, "SafeMath: subtraction overflow"); + uint256 c = a - b; + + return c; + } + + /** + * @dev Returns the multiplication of two unsigned integers, reverting on + * overflow. + * + * Counterpart to Solidity's `*` operator. + * + * Requirements: + * - Multiplication cannot overflow. + */ + function mul(uint256 a, uint256 b) internal pure returns (uint256) { + // Gas optimization: this is cheaper than requiring 'a' not being zero, but the + // benefit is lost if 'b' is also tested. + // See: https://github.com/OpenZeppelin/openzeppelin-solidity/pull/522 + if (a == 0) { + return 0; + } + + uint256 c = a * b; + require(c / a == b, "SafeMath: multiplication overflow"); + + return c; + } + + /** + * @dev Returns the integer division of two unsigned integers. Reverts on + * division by zero. The result is rounded towards zero. + * + * Counterpart to Solidity's `/` operator. Note: this function uses a + * `revert` opcode (which leaves remaining gas untouched) while Solidity + * uses an invalid opcode to revert (consuming all remaining gas). + * + * Requirements: + * - The divisor cannot be zero. + */ + function div(uint256 a, uint256 b) internal pure returns (uint256) { + // Solidity only automatically asserts when dividing by 0 + require(b > 0, "SafeMath: division by zero"); + uint256 c = a / b; + // assert(a == b * c + a % b); // There is no case in which this doesn't hold + + return c; + } + + /** + * @dev Returns the remainder of dividing two unsigned integers. (unsigned integer modulo), + * Reverts when dividing by zero. + * + * Counterpart to Solidity's `%` operator. This function uses a `revert` + * opcode (which leaves remaining gas untouched) while Solidity uses an + * invalid opcode to revert (consuming all remaining gas). + * + * Requirements: + * - The divisor cannot be zero. + */ + function mod(uint256 a, uint256 b) internal pure returns (uint256) { + require(b != 0, "SafeMath: modulo by zero"); + return a % b; + } +} + +contract MultiRewards is ReentrancyGuard, Pausable { + using SafeMath for uint256; + using SafeERC20 for IERC20; + + /* ========== STATE VARIABLES ========== */ + + struct Reward { + address rewardsDistributor; + uint256 rewardsDuration; + uint256 periodFinish; + uint256 rewardRate; + uint256 lastUpdateTime; + uint256 rewardPerTokenStored; + } + IERC20 public stakingToken; + mapping(address => Reward) public rewardData; + address[] public rewardTokens; + + // user -> reward token -> amount + mapping(address => mapping(address => uint256)) + public userRewardPerTokenPaid; + mapping(address => mapping(address => uint256)) public rewards; + + uint256 private _totalSupply; + mapping(address => uint256) private _balances; + + /* ========== CONSTRUCTOR ========== */ + + constructor(address _owner, address _stakingToken) public Owned(_owner) { + stakingToken = IERC20(_stakingToken); + } + + function addReward( + address _rewardsToken, + address _rewardsDistributor, + uint256 _rewardsDuration + ) public onlyOwner { + require(rewardData[_rewardsToken].rewardsDuration == 0); + rewardTokens.push(_rewardsToken); + rewardData[_rewardsToken].rewardsDistributor = _rewardsDistributor; + rewardData[_rewardsToken].rewardsDuration = _rewardsDuration; + } + + /* ========== VIEWS ========== */ + + function totalSupply() external view returns (uint256) { + return _totalSupply; + } + + function balanceOf(address account) external view returns (uint256) { + return _balances[account]; + } + + function lastTimeRewardApplicable( + address _rewardsToken + ) public view returns (uint256) { + return + Math.min(block.timestamp, rewardData[_rewardsToken].periodFinish); + } + + function rewardPerToken( + address _rewardsToken + ) public view returns (uint256) { + if (_totalSupply == 0) { + return rewardData[_rewardsToken].rewardPerTokenStored; + } + return + rewardData[_rewardsToken].rewardPerTokenStored.add( + lastTimeRewardApplicable(_rewardsToken) + .sub(rewardData[_rewardsToken].lastUpdateTime) + .mul(rewardData[_rewardsToken].rewardRate) + .mul(1e18) + .div(_totalSupply) + ); + } + + function earned( + address account, + address _rewardsToken + ) public view returns (uint256) { + return + _balances[account] + .mul( + rewardPerToken(_rewardsToken).sub( + userRewardPerTokenPaid[account][_rewardsToken] + ) + ) + .div(1e18) + .add(rewards[account][_rewardsToken]); + } + + function getRewardForDuration( + address _rewardsToken + ) external view returns (uint256) { + return + rewardData[_rewardsToken].rewardRate.mul( + rewardData[_rewardsToken].rewardsDuration + ); + } + + /* ========== MUTATIVE FUNCTIONS ========== */ + + function setRewardsDistributor( + address _rewardsToken, + address _rewardsDistributor + ) external onlyOwner { + rewardData[_rewardsToken].rewardsDistributor = _rewardsDistributor; + } + + function stake( + uint256 amount + ) external nonReentrant notPaused updateReward(msg.sender) { + require(amount > 0, "Cannot stake 0"); + _totalSupply = _totalSupply.add(amount); + _balances[msg.sender] = _balances[msg.sender].add(amount); + stakingToken.safeTransferFrom(msg.sender, address(this), amount); + emit Staked(msg.sender, amount); + } + + function withdraw( + uint256 amount + ) public nonReentrant updateReward(msg.sender) { + require(amount > 0, "Cannot withdraw 0"); + _totalSupply = _totalSupply.sub(amount); + _balances[msg.sender] = _balances[msg.sender].sub(amount); + stakingToken.safeTransfer(msg.sender, amount); + emit Withdrawn(msg.sender, amount); + } + + function getReward() public nonReentrant updateReward(msg.sender) { + for (uint i; i < rewardTokens.length; i++) { + address _rewardsToken = rewardTokens[i]; + uint256 reward = rewards[msg.sender][_rewardsToken]; + if (reward > 0) { + rewards[msg.sender][_rewardsToken] = 0; + IERC20(_rewardsToken).safeTransfer(msg.sender, reward); + emit RewardPaid(msg.sender, _rewardsToken, reward); + } + } + } + + function exit() external { + withdraw(_balances[msg.sender]); + getReward(); + } + + /* ========== RESTRICTED FUNCTIONS ========== */ + + function notifyRewardAmount( + address _rewardsToken, + uint256 reward + ) external updateReward(address(0)) { + require(rewardData[_rewardsToken].rewardsDistributor == msg.sender); + // handle the transfer of reward tokens via `transferFrom` to reduce the number + // of transactions required and ensure correctness of the reward amount + IERC20(_rewardsToken).safeTransferFrom( + msg.sender, + address(this), + reward + ); + + if (block.timestamp >= rewardData[_rewardsToken].periodFinish) { + rewardData[_rewardsToken].rewardRate = reward.div( + rewardData[_rewardsToken].rewardsDuration + ); + } else { + uint256 remaining = rewardData[_rewardsToken].periodFinish.sub( + block.timestamp + ); + uint256 leftover = remaining.mul( + rewardData[_rewardsToken].rewardRate + ); + rewardData[_rewardsToken].rewardRate = reward.add(leftover).div( + rewardData[_rewardsToken].rewardsDuration + ); + } + + rewardData[_rewardsToken].lastUpdateTime = block.timestamp; + rewardData[_rewardsToken].periodFinish = block.timestamp.add( + rewardData[_rewardsToken].rewardsDuration + ); + emit RewardAdded(reward); + } + + // Added to support recovering LP Rewards from other systems such as BAL to be distributed to holders + function recoverERC20( + address tokenAddress, + uint256 tokenAmount + ) external onlyOwner { + require( + tokenAddress != address(stakingToken), + "Cannot withdraw staking token" + ); + // note: the admin (Temporal Governor) can withdraw reward tokens. + // the only time this should happen is for reward proposals that + // occur after Morpho rewards have been claimed by an autotask. + IERC20(tokenAddress).safeTransfer(owner, tokenAmount); + emit Recovered(tokenAddress, tokenAmount); + } + + function setRewardsDuration( + address _rewardsToken, + uint256 _rewardsDuration + ) external { + require( + block.timestamp > rewardData[_rewardsToken].periodFinish, + "Reward period still active" + ); + require(rewardData[_rewardsToken].rewardsDistributor == msg.sender); + require(_rewardsDuration > 0, "Reward duration must be non-zero"); + rewardData[_rewardsToken].rewardsDuration = _rewardsDuration; + emit RewardsDurationUpdated( + _rewardsToken, + rewardData[_rewardsToken].rewardsDuration + ); + } + + /* ========== MODIFIERS ========== */ + + modifier updateReward(address account) { + for (uint i; i < rewardTokens.length; i++) { + address token = rewardTokens[i]; + rewardData[token].rewardPerTokenStored = rewardPerToken(token); + rewardData[token].lastUpdateTime = lastTimeRewardApplicable(token); + if (account != address(0)) { + rewards[account][token] = earned(account, token); + userRewardPerTokenPaid[account][token] = rewardData[token] + .rewardPerTokenStored; + } + } + _; + } + + /* ========== EVENTS ========== */ + + event RewardAdded(uint256 reward); + event Staked(address indexed user, uint256 amount); + event Withdrawn(address indexed user, uint256 amount); + event RewardPaid( + address indexed user, + address indexed rewardsToken, + uint256 reward + ); + event RewardsDurationUpdated(address token, uint256 newDuration); + event Recovered(address token, uint256 amount); +} diff --git a/script/DeployMultiRewards.s.sol b/script/DeployMultiRewards.s.sol new file mode 100644 index 000000000..685cfa84e --- /dev/null +++ b/script/DeployMultiRewards.s.sol @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: GPL-3.0-or-later +pragma solidity 0.8.19; + +import {Script} from "@forge-std/Script.sol"; +import {console} from "@forge-std/console.sol"; +import {AllChainAddresses as Addresses} from "@proposals/Addresses.sol"; +import {MultiRewardsDeploy} from "src/rewards/MultiRewardsDeploy.sol"; + +/// test commands: +/// +/// forge script DeployMultiRewards -vvv --fork-url base +/// +/// forge script DeployMultiRewards -vvv --fork-url optimism +/// +contract DeployMultiRewards is Script, MultiRewardsDeploy { + function run() public { + Addresses addresses = new Addresses(); + + address owner = addresses.getAddress("TEMPORAL_GOVERNOR"); + address stakingToken = addresses.getAddress("USDC_METAMORPHO_VAULT"); + + vm.startBroadcast(); + + address multiRewards = deployMultiRewards(owner, stakingToken); + validateMultiRewards(multiRewards, owner, stakingToken); + + vm.stopBroadcast(); + + addresses.addAddress("MULTI_REWARDS", multiRewards); + addresses.printAddresses(); + } +} diff --git a/src/rewards/MultiRewardsDeploy.sol b/src/rewards/MultiRewardsDeploy.sol new file mode 100644 index 000000000..c50170500 --- /dev/null +++ b/src/rewards/MultiRewardsDeploy.sol @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: GPL-3.0-or-later +pragma solidity 0.8.19; + +import {Test} from "@forge-std/Test.sol"; +import {IMultiRewards} from "crv-rewards/IMultiRewards.sol"; + +contract MultiRewardsDeploy is Test { + function deployMultiRewards( + address owner, + address stakingToken + ) public returns (address multiRewards) { + multiRewards = deployCode( + "artifacts/foundry/MultiRewards.sol/MultiRewards.json", + abi.encode(owner, stakingToken) + ); + } + + function validateMultiRewards( + address multiRewards, + address expectedOwner, + address expectedStakingToken + ) public view { + IMultiRewards rewards = IMultiRewards(multiRewards); + assertEq(rewards.owner(), expectedOwner, "Owner not set correctly"); + assertEq( + address(rewards.stakingToken()), + expectedStakingToken, + "Staking token not set correctly" + ); + } +} diff --git a/test/unit/MultiRewards.t.sol b/test/unit/MultiRewards.t.sol new file mode 100644 index 000000000..cbe2f5136 --- /dev/null +++ b/test/unit/MultiRewards.t.sol @@ -0,0 +1,986 @@ +pragma solidity 0.5.17; + +import "../../crv-rewards/MultiRewards.sol"; + +// Simple mock ERC20 token compatible with Solidity 0.5.17 +contract MockERC20 { + string public name; + string public symbol; + uint8 public decimals; + uint256 public totalSupply; + + mapping(address => uint256) public balanceOf; + mapping(address => mapping(address => uint256)) public allowance; + + event Transfer(address indexed from, address indexed to, uint256 value); + event Approval( + address indexed owner, + address indexed spender, + uint256 value + ); + + constructor( + string memory _name, + string memory _symbol, + uint8 _decimals + ) public { + name = _name; + symbol = _symbol; + decimals = _decimals; + } + + function mint(address to, uint256 amount) public { + balanceOf[to] += amount; + totalSupply += amount; + emit Transfer(address(0), to, amount); + } + + function transfer(address to, uint256 value) public returns (bool) { + require( + balanceOf[msg.sender] >= value, + "ERC20: transfer amount exceeds balance" + ); + balanceOf[msg.sender] -= value; + balanceOf[to] += value; + emit Transfer(msg.sender, to, value); + return true; + } + + function approve(address spender, uint256 value) public returns (bool) { + allowance[msg.sender][spender] = value; + emit Approval(msg.sender, spender, value); + return true; + } + + function transferFrom( + address from, + address to, + uint256 value + ) public returns (bool) { + require( + balanceOf[from] >= value, + "ERC20: transfer amount exceeds balance" + ); + require( + allowance[from][msg.sender] >= value, + "ERC20: transfer amount exceeds allowance" + ); + balanceOf[from] -= value; + balanceOf[to] += value; + allowance[from][msg.sender] -= value; + emit Transfer(from, to, value); + return true; + } +} + +// Simple testing contract that doesn't rely on forge-std +contract MultiRewardsTest { + // Contracts + MultiRewards public multiRewards; + MockERC20 public stakingToken; + MockERC20 public rewardTokenA; + MockERC20 public rewardTokenB; + + // Addresses + address public owner; + address public user; + address public user2; + address public rewardDistributorA; + address public rewardDistributorB; + + // Constants + uint256 public constant INITIAL_STAKE_AMOUNT = 100 ether; + uint256 public constant REWARD_AMOUNT = 1000 ether; + uint256 public constant REWARDS_DURATION = 7 days; + + // Events for logging test results + event LogAssertEq(bool passed, string message); + event LogAssertEqUint(uint256 a, uint256 b, string message); + event LogAssertGt(bool passed, string message); + event LogAssertLt(bool passed, string message); + event LogAssertTrue(bool passed, string message); + event LogAssertFalse(bool passed, string message); + + address public vm = address(uint160(uint256(keccak256("hevm cheat code")))); + + constructor() public { + // Set up addresses + owner = address(this); + user = address(0x1); + user2 = address(0x2); + rewardDistributorA = address(0x3); + rewardDistributorB = address(0x4); + + // Deploy mock tokens + stakingToken = new MockERC20("Staking Token", "STK", 18); + rewardTokenA = new MockERC20("Reward Token A", "RWDA", 18); + rewardTokenB = new MockERC20("Reward Token B", "RWDB", 18); + + // Deploy MultiRewards contract + multiRewards = new MultiRewards(owner, address(stakingToken)); + + // Add first reward token + multiRewards.addReward( + address(rewardTokenA), + rewardDistributorA, + REWARDS_DURATION + ); + + // Mint tokens to user and reward distributors + stakingToken.mint(user, INITIAL_STAKE_AMOUNT); + rewardTokenA.mint(rewardDistributorA, REWARD_AMOUNT); + rewardTokenB.mint(rewardDistributorB, REWARD_AMOUNT); + + // Approve spending of reward tokens by the MultiRewards contract + prank(rewardDistributorA); + rewardTokenA.approve(address(multiRewards), REWARD_AMOUNT); + + prank(rewardDistributorB); + rewardTokenB.approve(address(multiRewards), REWARD_AMOUNT); + } + + // ------------------------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------------------------ + // Simple testing utilities since forge-std minimum solidity version is 0.6.20 and MultiRewards is 0.5.17 + // ------------------------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------------------------ + + function prank(address sender) internal { + (bool success, ) = vm.call( + abi.encodeWithSignature("prank(address)", sender) + ); + require(success, "call to prank failed"); + } + + function warp(uint256 timestamp) internal { + (bool success, ) = vm.call( + abi.encodeWithSignature("warp(uint256)", timestamp) + ); + require(success, "call to warp failed"); + } + + function assertEq(uint256 a, uint256 b, string memory message) internal { + emit LogAssertEqUint(a, b, message); + require(a == b, message); + } + + function assertEq(address a, address b, string memory message) internal { + emit LogAssertEq(a == b, message); + require(a == b, message); + } + + function assertApproxEq( + uint256 a, + uint256 b, + uint256 tolerance, + string memory message + ) internal pure { + bool withinTolerance = (a >= b ? a - b : b - a) <= tolerance; + require(withinTolerance, message); + } + + // Test function + function testStakeAndClaimNewRewardStream() public { + // 1. User stakes tokens + prank(user); + stakingToken.approve(address(multiRewards), INITIAL_STAKE_AMOUNT); + + // Check initial state + assertEq( + multiRewards.totalSupply(), + 0, + "Initial total supply should be 0" + ); + assertEq( + multiRewards.balanceOf(user), + 0, + "Initial user balance should be 0" + ); + assertEq( + stakingToken.balanceOf(user), + INITIAL_STAKE_AMOUNT, + "User should have initial tokens" + ); + + // Perform stake + prank(user); + multiRewards.stake(INITIAL_STAKE_AMOUNT); + + // Check state after staking + assertEq( + multiRewards.totalSupply(), + INITIAL_STAKE_AMOUNT, + "Total supply should equal staked amount" + ); + assertEq( + multiRewards.balanceOf(user), + INITIAL_STAKE_AMOUNT, + "User balance should equal staked amount" + ); + assertEq( + stakingToken.balanceOf(user), + 0, + "User should have 0 tokens after staking" + ); + assertEq( + stakingToken.balanceOf(address(multiRewards)), + INITIAL_STAKE_AMOUNT, + "Contract should have staked tokens" + ); + + // 2. Notify reward amount for first reward token + prank(rewardDistributorA); + multiRewards.notifyRewardAmount(address(rewardTokenA), REWARD_AMOUNT); + + assertEq( + rewardTokenA.balanceOf(address(multiRewards)), + REWARD_AMOUNT, + "reward token balance multi rewards incorrect token a" + ); + // Check reward state after notification + ( + address rewardsDistributorA, + uint256 rewardsDurationA, + uint256 periodFinishA, + uint256 rewardRateA, + uint256 lastUpdateTimeA, + + ) = multiRewards.rewardData(address(rewardTokenA)); + + assertEq( + rewardsDistributorA, + rewardDistributorA, + "Rewards distributor should be set correctly" + ); + assertEq( + rewardsDurationA, + REWARDS_DURATION, + "Rewards duration should be set correctly" + ); + assertEq( + periodFinishA, + block.timestamp + REWARDS_DURATION, + "Period finish should be set correctly" + ); + assertEq( + rewardRateA, + REWARD_AMOUNT / REWARDS_DURATION, + "Reward rate should be set correctly" + ); + assertEq( + lastUpdateTimeA, + block.timestamp, + "Last update time should be set correctly" + ); + + // 3. Add a new reward token AFTER user has staked + multiRewards.addReward( + address(rewardTokenB), + rewardDistributorB, + REWARDS_DURATION + ); + + // Check reward token was added correctly + ( + address rewardsDistributorB, + uint256 rewardsDurationB, + , + , + , + + ) = multiRewards.rewardData(address(rewardTokenB)); + assertEq( + rewardsDistributorB, + rewardDistributorB, + "New rewards distributor should be set correctly" + ); + assertEq( + rewardsDurationB, + REWARDS_DURATION, + "New rewards duration should be set correctly" + ); + + // 4. Notify reward amount for the new reward token + prank(rewardDistributorB); + multiRewards.notifyRewardAmount(address(rewardTokenB), REWARD_AMOUNT); + + assertEq( + rewardTokenB.balanceOf(address(multiRewards)), + REWARD_AMOUNT, + "reward token balance multi rewards incorrect token b" + ); + + // Check reward state after notification + ( + , + , + uint256 periodFinishB, + uint256 rewardRateB, + uint256 lastUpdateTimeB, + + ) = multiRewards.rewardData(address(rewardTokenB)); + + assertEq( + periodFinishB, + block.timestamp + REWARDS_DURATION, + "Period finish should be set correctly for token B" + ); + assertEq( + rewardRateB, + REWARD_AMOUNT / REWARDS_DURATION, + "Reward rate should be set correctly for token B" + ); + assertEq( + lastUpdateTimeB, + block.timestamp, + "Last update time should be set correctly for token B" + ); + + // 5. Fast forward time to accrue rewards (half the duration) + warp(block.timestamp + REWARDS_DURATION / 2); + + // 6. Check earned rewards + uint256 earnedA = multiRewards.earned(user, address(rewardTokenA)); + uint256 earnedB = multiRewards.earned(user, address(rewardTokenB)); + + // Should have earned approximately half the rewards (slight precision loss is expected) + uint256 expectedRewardA = REWARD_AMOUNT / 2; + uint256 expectedRewardB = REWARD_AMOUNT / 2; + uint256 tolerance = REWARD_AMOUNT / 10000; // 0.01% tolerance + + assertApproxEq( + earnedA, + expectedRewardA, + tolerance, + "Should have earned ~half of reward A" + ); + assertApproxEq( + earnedB, + expectedRewardB, + tolerance, + "Should have earned ~half of reward B" + ); + + // 7. User claims rewards + uint256 userRewardBalanceA_Before = rewardTokenA.balanceOf(user); + uint256 userRewardBalanceB_Before = rewardTokenB.balanceOf(user); + + prank(user); + multiRewards.getReward(); + + // 8. Check state after claiming rewards + uint256 userRewardBalanceA_After = rewardTokenA.balanceOf(user); + uint256 userRewardBalanceB_After = rewardTokenB.balanceOf(user); + + // Verify user received rewards + assertEq( + userRewardBalanceA_After - userRewardBalanceA_Before, + earnedA, + "User should have received earned rewards for token A" + ); + assertEq( + userRewardBalanceB_After - userRewardBalanceB_Before, + earnedB, + "User should have received earned rewards for token B" + ); + + // Verify rewards state was updated + assertEq( + multiRewards.rewards(user, address(rewardTokenA)), + 0, + "User rewards for token A should be reset to 0" + ); + assertEq( + multiRewards.rewards(user, address(rewardTokenB)), + 0, + "User rewards for token B should be reset to 0" + ); + + // 9. Fast forward to the end of the reward period + warp(block.timestamp + REWARDS_DURATION / 2); + + // 10. User claims remaining rewards + userRewardBalanceA_Before = rewardTokenA.balanceOf(user); + userRewardBalanceB_Before = rewardTokenB.balanceOf(user); + + prank(user); + multiRewards.getReward(); + + userRewardBalanceA_After = rewardTokenA.balanceOf(user); + userRewardBalanceB_After = rewardTokenB.balanceOf(user); + + // Verify user received remaining rewards + uint256 remainingRewardsA = userRewardBalanceA_After - + userRewardBalanceA_Before; + uint256 remainingRewardsB = userRewardBalanceB_After - + userRewardBalanceB_Before; + + // Should have received the remaining ~half of rewards + assertApproxEq( + remainingRewardsA, + expectedRewardA, + tolerance, + "Should have received remaining rewards for token A" + ); + assertApproxEq( + remainingRewardsB, + expectedRewardB, + tolerance, + "Should have received remaining rewards for token B" + ); + + // 11. Verify total rewards received + uint256 totalRewardsA = userRewardBalanceA_After; + uint256 totalRewardsB = userRewardBalanceB_After; + + // Should have received approximately all rewards + assertApproxEq( + totalRewardsA, + REWARD_AMOUNT, + tolerance, + "Should have received ~all rewards for token A" + ); + assertApproxEq( + totalRewardsB, + REWARD_AMOUNT, + tolerance, + "Should have received ~all rewards for token B" + ); + } + + // Define structs to group related variables and reduce stack usage + struct UserStakeInfo { + uint256 stakeAmount; + uint256 expectedReward; + } + + struct RewardBalances { + uint256 tokenA_Before; + uint256 tokenA_After; + uint256 tokenB_Before; + uint256 tokenB_After; + } + + struct RewardAmounts { + uint256 earnedA; + uint256 earnedB; + uint256 receivedA; + uint256 receivedB; + uint256 remainingA; + uint256 remainingB; + uint256 totalA; + uint256 totalB; + } + + // Test function with two users staking + function testMultipleUsersStakeAndClaimNewRewardStream() public { + // Constants for this test + uint256 tolerance = REWARD_AMOUNT / 10000; // 0.01% tolerance + + // Use memory structs to group related variables + UserStakeInfo memory user1Info = UserStakeInfo({ + stakeAmount: 75 ether, + expectedReward: 0 // Will set this later + }); + + UserStakeInfo memory user2Info = UserStakeInfo({ + stakeAmount: 25 ether, + expectedReward: 0 // Will set this later + }); + + // Setup phase - stake tokens + { + uint256 totalStakeAmount = user1Info.stakeAmount + + user2Info.stakeAmount; + + // Mint tokens to users + stakingToken.mint(user, user1Info.stakeAmount); + stakingToken.mint(user2, user2Info.stakeAmount); + + // Check initial state + assertEq( + multiRewards.totalSupply(), + 0, + "Initial total supply should be 0" + ); + assertEq( + multiRewards.balanceOf(user), + 0, + "Initial user1 balance should be 0" + ); + assertEq( + multiRewards.balanceOf(user2), + 0, + "Initial user2 balance should be 0" + ); + assertEq( + stakingToken.balanceOf(user), + user1Info.stakeAmount + INITIAL_STAKE_AMOUNT, + "User1 should have initial tokens" + ); + assertEq( + stakingToken.balanceOf(user2), + user2Info.stakeAmount, + "User2 should have initial tokens" + ); + + // 1. First user stakes tokens + prank(user); + stakingToken.approve(address(multiRewards), user1Info.stakeAmount); + prank(user); + multiRewards.stake(user1Info.stakeAmount); + + // Check state after first user staking + assertEq( + multiRewards.totalSupply(), + user1Info.stakeAmount, + "Total supply should equal user1 staked amount" + ); + assertEq( + multiRewards.balanceOf(user), + user1Info.stakeAmount, + "User1 balance should equal staked amount" + ); + assertEq( + stakingToken.balanceOf(user), + INITIAL_STAKE_AMOUNT, + "User1 should have INITIAL_STAKE_AMOUNT after staking" + ); + assertEq( + stakingToken.balanceOf(address(multiRewards)), + user1Info.stakeAmount, + "Contract should have user1 staked tokens" + ); + + // 2. Second user stakes tokens + prank(user2); + stakingToken.approve(address(multiRewards), user2Info.stakeAmount); + prank(user2); + multiRewards.stake(user2Info.stakeAmount); + + // Check state after second user staking + assertEq( + multiRewards.totalSupply(), + totalStakeAmount, + "Total supply should equal total staked amount" + ); + assertEq( + multiRewards.balanceOf(user2), + user2Info.stakeAmount, + "User2 balance should equal staked amount" + ); + assertEq( + stakingToken.balanceOf(user2), + 0, + "User2 should have 0 tokens after staking" + ); + assertEq( + stakingToken.balanceOf(address(multiRewards)), + totalStakeAmount, + "Contract should have total staked tokens" + ); + } + + // 3. Notify reward amount for first reward token + prank(rewardDistributorA); + multiRewards.notifyRewardAmount(address(rewardTokenA), REWARD_AMOUNT); + + // Check reward state after notification + { + address rewardsDistributorA; + uint256 rewardsDurationA; + uint256 periodFinishA; + uint256 rewardRateA; + uint256 lastUpdateTimeA; + + ( + rewardsDistributorA, + rewardsDurationA, + periodFinishA, + rewardRateA, + lastUpdateTimeA, + + ) = multiRewards.rewardData(address(rewardTokenA)); + + assertEq( + rewardsDistributorA, + rewardDistributorA, + "Rewards distributor should be set correctly" + ); + assertEq( + rewardsDurationA, + REWARDS_DURATION, + "Rewards duration should be set correctly" + ); + assertEq( + periodFinishA, + block.timestamp + REWARDS_DURATION, + "Period finish should be set correctly" + ); + assertEq( + rewardRateA, + REWARD_AMOUNT / REWARDS_DURATION, + "Reward rate should be set correctly" + ); + assertEq( + lastUpdateTimeA, + block.timestamp, + "Last update time should be set correctly" + ); + } + + // 4. Add a new reward token AFTER users have staked + multiRewards.addReward( + address(rewardTokenB), + rewardDistributorB, + REWARDS_DURATION + ); + + // Check reward token was added correctly + { + address rewardsDistributorB; + uint256 rewardsDurationB; + + (rewardsDistributorB, rewardsDurationB, , , , ) = multiRewards + .rewardData(address(rewardTokenB)); + + assertEq( + rewardsDistributorB, + rewardDistributorB, + "New rewards distributor should be set correctly" + ); + assertEq( + rewardsDurationB, + REWARDS_DURATION, + "New rewards duration should be set correctly" + ); + } + + // 5. Notify reward amount for the new reward token + prank(rewardDistributorB); + multiRewards.notifyRewardAmount(address(rewardTokenB), REWARD_AMOUNT); + + // Check reward state after notification + { + uint256 periodFinishB; + uint256 rewardRateB; + uint256 lastUpdateTimeB; + + (, , periodFinishB, rewardRateB, lastUpdateTimeB, ) = multiRewards + .rewardData(address(rewardTokenB)); + + assertEq( + periodFinishB, + block.timestamp + REWARDS_DURATION, + "Period finish should be set correctly for token B" + ); + assertEq( + rewardRateB, + REWARD_AMOUNT / REWARDS_DURATION, + "Reward rate should be set correctly for token B" + ); + assertEq( + lastUpdateTimeB, + block.timestamp, + "Last update time should be set correctly for token B" + ); + } + + // 6. Fast forward time to accrue rewards (half the duration) + warp(block.timestamp + REWARDS_DURATION / 2); + + // 7. Check earned rewards for both users + RewardAmounts memory user1Rewards; + RewardAmounts memory user2Rewards; + + user1Rewards.earnedA = multiRewards.earned(user, address(rewardTokenA)); + user1Rewards.earnedB = multiRewards.earned(user, address(rewardTokenB)); + user2Rewards.earnedA = multiRewards.earned( + user2, + address(rewardTokenA) + ); + user2Rewards.earnedB = multiRewards.earned( + user2, + address(rewardTokenB) + ); + + // Calculate expected rewards based on stake proportions + // User1 has 75% of the stake, User2 has 25% + user1Info.expectedReward = ((REWARD_AMOUNT / 2) * 75) / 100; + user2Info.expectedReward = ((REWARD_AMOUNT / 2) * 25) / 100; + + // Verify earned rewards are proportional to stake + assertApproxEq( + user1Rewards.earnedA, + user1Info.expectedReward, + tolerance, + "User1 should have earned ~75% of half reward A" + ); + assertApproxEq( + user1Rewards.earnedB, + user1Info.expectedReward, + tolerance, + "User1 should have earned ~75% of half reward B" + ); + assertApproxEq( + user2Rewards.earnedA, + user2Info.expectedReward, + tolerance, + "User2 should have earned ~25% of half reward A" + ); + assertApproxEq( + user2Rewards.earnedB, + user2Info.expectedReward, + tolerance, + "User2 should have earned ~25% of half reward B" + ); + + // 8. Users claim rewards + { + // User1 claims + RewardBalances memory user1Balances; + user1Balances.tokenA_Before = rewardTokenA.balanceOf(user); + user1Balances.tokenB_Before = rewardTokenB.balanceOf(user); + + prank(user); + multiRewards.getReward(); + + user1Balances.tokenA_After = rewardTokenA.balanceOf(user); + user1Balances.tokenB_After = rewardTokenB.balanceOf(user); + + // User2 claims + RewardBalances memory user2Balances; + user2Balances.tokenA_Before = rewardTokenA.balanceOf(user2); + user2Balances.tokenB_Before = rewardTokenB.balanceOf(user2); + + prank(user2); + multiRewards.getReward(); + + user2Balances.tokenA_After = rewardTokenA.balanceOf(user2); + user2Balances.tokenB_After = rewardTokenB.balanceOf(user2); + + // 9. Verify users received correct rewards + user1Rewards.receivedA = + user1Balances.tokenA_After - + user1Balances.tokenA_Before; + user1Rewards.receivedB = + user1Balances.tokenB_After - + user1Balances.tokenB_Before; + user2Rewards.receivedA = + user2Balances.tokenA_After - + user2Balances.tokenA_Before; + user2Rewards.receivedB = + user2Balances.tokenB_After - + user2Balances.tokenB_Before; + + assertEq( + user1Rewards.receivedA, + user1Rewards.earnedA, + "User1 should have received earned rewards for token A" + ); + assertEq( + user1Rewards.receivedB, + user1Rewards.earnedB, + "User1 should have received earned rewards for token B" + ); + assertEq( + user2Rewards.receivedA, + user2Rewards.earnedA, + "User2 should have received earned rewards for token A" + ); + assertEq( + user2Rewards.receivedB, + user2Rewards.earnedB, + "User2 should have received earned rewards for token B" + ); + + // Verify rewards state was updated + assertEq( + multiRewards.rewards(user, address(rewardTokenA)), + 0, + "User1 rewards for token A should be reset to 0" + ); + assertEq( + multiRewards.rewards(user, address(rewardTokenB)), + 0, + "User1 rewards for token B should be reset to 0" + ); + assertEq( + multiRewards.rewards(user2, address(rewardTokenA)), + 0, + "User2 rewards for token A should be reset to 0" + ); + assertEq( + multiRewards.rewards(user2, address(rewardTokenB)), + 0, + "User2 rewards for token B should be reset to 0" + ); + } + + // 10. Fast forward to the end of the reward period + warp(block.timestamp + REWARDS_DURATION / 2); + + // 11. Users claim remaining rewards + { + // User1 claims + RewardBalances memory user1Balances; + user1Balances.tokenA_Before = rewardTokenA.balanceOf(user); + user1Balances.tokenB_Before = rewardTokenB.balanceOf(user); + + prank(user); + multiRewards.getReward(); + + user1Balances.tokenA_After = rewardTokenA.balanceOf(user); + user1Balances.tokenB_After = rewardTokenB.balanceOf(user); + + // User2 claims + RewardBalances memory user2Balances; + user2Balances.tokenA_Before = rewardTokenA.balanceOf(user2); + user2Balances.tokenB_Before = rewardTokenB.balanceOf(user2); + + prank(user2); + multiRewards.getReward(); + + user2Balances.tokenA_After = rewardTokenA.balanceOf(user2); + user2Balances.tokenB_After = rewardTokenB.balanceOf(user2); + + // 12. Verify users received remaining rewards + user1Rewards.remainingA = + user1Balances.tokenA_After - + user1Balances.tokenA_Before; + user1Rewards.remainingB = + user1Balances.tokenB_After - + user1Balances.tokenB_Before; + user2Rewards.remainingA = + user2Balances.tokenA_After - + user2Balances.tokenA_Before; + user2Rewards.remainingB = + user2Balances.tokenB_After - + user2Balances.tokenB_Before; + + assertApproxEq( + user1Rewards.remainingA, + user1Info.expectedReward, + tolerance, + "User1 should have received remaining rewards for token A" + ); + assertApproxEq( + user1Rewards.remainingB, + user1Info.expectedReward, + tolerance, + "User1 should have received remaining rewards for token B" + ); + assertApproxEq( + user2Rewards.remainingA, + user2Info.expectedReward, + tolerance, + "User2 should have received remaining rewards for token A" + ); + assertApproxEq( + user2Rewards.remainingB, + user2Info.expectedReward, + tolerance, + "User2 should have received remaining rewards for token B" + ); + + // Store total rewards for verification + user1Rewards.totalA = user1Balances.tokenA_After; + user1Rewards.totalB = user1Balances.tokenB_After; + user2Rewards.totalA = user2Balances.tokenA_After; + user2Rewards.totalB = user2Balances.tokenB_After; + } + + // 13. Verify total rewards received by both users + { + // User1 should have received ~75% of total rewards + uint256 expectedUser1TotalA = (REWARD_AMOUNT * 75) / 100; + uint256 expectedUser1TotalB = (REWARD_AMOUNT * 75) / 100; + + // User2 should have received ~25% of total rewards + uint256 expectedUser2TotalA = (REWARD_AMOUNT * 25) / 100; + uint256 expectedUser2TotalB = (REWARD_AMOUNT * 25) / 100; + + assertApproxEq( + user1Rewards.totalA, + expectedUser1TotalA, + tolerance, + "User1 should have received ~75% of total rewards for token A" + ); + assertApproxEq( + user1Rewards.totalB, + expectedUser1TotalB, + tolerance, + "User1 should have received ~75% of total rewards for token B" + ); + assertApproxEq( + user2Rewards.totalA, + expectedUser2TotalA, + tolerance, + "User2 should have received ~25% of total rewards for token A" + ); + assertApproxEq( + user2Rewards.totalB, + expectedUser2TotalB, + tolerance, + "User2 should have received ~25% of total rewards for token B" + ); + } + + // 14. Verify that the sum of rewards equals the total rewards + { + uint256 totalRewardsDistributedA = user1Rewards.totalA + + user2Rewards.totalA; + uint256 totalRewardsDistributedB = user1Rewards.totalB + + user2Rewards.totalB; + + assertApproxEq( + totalRewardsDistributedA, + REWARD_AMOUNT, + tolerance, + "Total distributed rewards for token A should equal REWARD_AMOUNT" + ); + assertApproxEq( + totalRewardsDistributedB, + REWARD_AMOUNT, + tolerance, + "Total distributed rewards for token B should equal REWARD_AMOUNT" + ); + } + } + + // Test function to verify recoverERC20 works for reward tokens + function testRecoverRewardToken() public { + // 1. Setup - Add reward token and notify reward amount + prank(rewardDistributorA); + multiRewards.notifyRewardAmount(address(rewardTokenA), REWARD_AMOUNT); + + // Verify reward token balance in the contract + assertEq( + rewardTokenA.balanceOf(address(multiRewards)), + REWARD_AMOUNT, + "Contract should have the reward token amount" + ); + + // 2. Attempt to recover half of the reward tokens + uint256 amountToRecover = REWARD_AMOUNT / 2; + uint256 ownerBalanceBefore = rewardTokenA.balanceOf(owner); + + // Call recoverERC20 as the owner + multiRewards.recoverERC20(address(rewardTokenA), amountToRecover); + + // 3. Verify tokens were successfully transferred to the owner + uint256 ownerBalanceAfter = rewardTokenA.balanceOf(owner); + assertEq( + ownerBalanceAfter - ownerBalanceBefore, + amountToRecover, + "Owner should have received the recovered tokens" + ); + + // 4. Verify remaining balance in the contract + assertEq( + rewardTokenA.balanceOf(address(multiRewards)), + REWARD_AMOUNT - amountToRecover, + "Contract should have the remaining reward tokens" + ); + } +}