From 3351c80e58e6300cb263d33a4efe75b88ad7b9b2 Mon Sep 17 00:00:00 2001 From: diana Date: Thu, 23 May 2024 12:39:34 -0400 Subject: [PATCH] Protocolfees tests (#691) * protocol fees tests * more tests * underscore for internal var * protocol fees tests * re-add view * gas snapshot * slot0 verify * shorten * comments * extra comment * fix snapshot --- .forge-snapshots/set protocol fee.snap | 2 +- src/PoolManager.sol | 16 +- src/ProtocolFees.sol | 5 +- src/test/ProtocolFeesImplementation.sol | 36 ++++ test/PoolManager.t.sol | 103 +++--------- test/ProtocolFeesImplementation.t.sol | 215 ++++++++++++++++++++++++ 6 files changed, 286 insertions(+), 91 deletions(-) create mode 100644 src/test/ProtocolFeesImplementation.sol create mode 100644 test/ProtocolFeesImplementation.t.sol diff --git a/.forge-snapshots/set protocol fee.snap b/.forge-snapshots/set protocol fee.snap index ba1ce5094..4896f554b 100644 --- a/.forge-snapshots/set protocol fee.snap +++ b/.forge-snapshots/set protocol fee.snap @@ -1 +1 @@ -32415 \ No newline at end of file +32444 \ No newline at end of file diff --git a/src/PoolManager.sol b/src/PoolManager.sol index 643843432..a6732861c 100644 --- a/src/PoolManager.sol +++ b/src/PoolManager.sol @@ -91,12 +91,12 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim /// @inheritdoc IPoolManager int24 public constant MIN_TICK_SPACING = TickMath.MIN_TICK_SPACING; - mapping(PoolId id => Pool.State) internal pools; + mapping(PoolId id => Pool.State) internal _pools; constructor(uint256 controllerGasLimit) ProtocolFees(controllerGasLimit) {} function _getPool(PoolId id) internal view override returns (Pool.State storage) { - return pools[id]; + return _pools[id]; } /// @notice This will revert if the contract is locked @@ -125,7 +125,7 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim PoolId id = key.toId(); (, uint24 protocolFee) = _fetchProtocolFee(key); - tick = pools[id].initialize(sqrtPriceX96, protocolFee, lpFee); + tick = _pools[id].initialize(sqrtPriceX96, protocolFee, lpFee); key.hooks.afterInitialize(key, sqrtPriceX96, tick, hookData); @@ -174,7 +174,7 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim } function _checkPoolInitialized(PoolId id) internal view { - if (pools[id].isNotInitialized()) revert PoolNotInitialized(); + if (_pools[id].isNotInitialized()) revert PoolNotInitialized(); } /// @inheritdoc IPoolManager @@ -189,7 +189,7 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim key.hooks.beforeModifyLiquidity(key, params, hookData); BalanceDelta principalDelta; - (principalDelta, feesAccrued) = pools[id].modifyLiquidity( + (principalDelta, feesAccrued) = _pools[id].modifyLiquidity( Pool.ModifyLiquidityParams({ owner: msg.sender, tickLower: params.tickLower, @@ -257,7 +257,7 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim // Internal swap function to execute a swap, take protocol fees on input token, and emit the swap event function _swap(PoolId id, Pool.SwapParams memory params, Currency inputCurrency) internal returns (BalanceDelta) { (BalanceDelta delta, uint256 feeForProtocol, uint24 swapFee, Pool.SwapState memory state) = - pools[id].swap(params); + _pools[id].swap(params); // The fee is on the input currency. if (feeForProtocol > 0) _updateProtocolFees(inputCurrency, feeForProtocol); @@ -281,7 +281,7 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim key.hooks.beforeDonate(key, amount0, amount1, hookData); - delta = pools[id].donate(amount0, amount1); + delta = _pools[id].donate(amount0, amount1); _accountPoolBalanceDelta(key, delta, msg.sender); @@ -330,6 +330,6 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim if (!key.fee.isDynamicFee() || msg.sender != address(key.hooks)) revert UnauthorizedDynamicLPFeeUpdate(); newDynamicLPFee.validate(); PoolId id = key.toId(); - pools[id].setLPFee(newDynamicLPFee); + _pools[id].setLPFee(newDynamicLPFee); } } diff --git a/src/ProtocolFees.sol b/src/ProtocolFees.sol index 0c4d91af9..3e7263a49 100644 --- a/src/ProtocolFees.sol +++ b/src/ProtocolFees.sol @@ -59,7 +59,7 @@ abstract contract ProtocolFees is IProtocolFees, Owned { /// @dev to prevent an invalid protocol fee controller from blocking pools from being initialized /// the success of this function is NOT checked on initialize and if the call fails, the protocol fees are set to 0. /// @dev the success of this function must be checked when called in setProtocolFee - function _fetchProtocolFee(PoolKey memory key) internal returns (bool success, uint24 protocolFees) { + function _fetchProtocolFee(PoolKey memory key) internal returns (bool success, uint24 protocolFee) { if (address(protocolFeeController) != address(0)) { // note that EIP-150 mandates that calls requesting more than 63/64ths of remaining gas // will be allotted no more than this amount, so controllerGasLimit must be set with this @@ -76,8 +76,9 @@ abstract contract ProtocolFees is IProtocolFees, Owned { assembly { returnData := mload(add(_data, 0x20)) } + // Ensure return data does not overflow a uint24 and that the underlying fees are within bounds. - (success, protocolFees) = (returnData == uint24(returnData)) && uint24(returnData).isValidProtocolFee() + (success, protocolFee) = (returnData == uint24(returnData)) && uint24(returnData).isValidProtocolFee() ? (true, uint24(returnData)) : (false, 0); } diff --git a/src/test/ProtocolFeesImplementation.sol b/src/test/ProtocolFeesImplementation.sol new file mode 100644 index 000000000..c31be6140 --- /dev/null +++ b/src/test/ProtocolFeesImplementation.sol @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.20; + +import {ProtocolFees} from "../ProtocolFees.sol"; +import {IProtocolFeeController} from "../interfaces/IProtocolFeeController.sol"; +import {PoolKey} from "../types/PoolKey.sol"; +import {Currency} from "../types/Currency.sol"; +import {PoolId, PoolIdLibrary} from "../types/PoolId.sol"; +import {Pool} from "../libraries/Pool.sol"; +import {Slot0} from "../types/Slot0.sol"; + +contract ProtocolFeesImplementation is ProtocolFees { + using PoolIdLibrary for PoolKey; + + mapping(PoolId id => Pool.State) internal _pools; + + constructor(uint256 _controllerGasLimit) ProtocolFees(_controllerGasLimit) {} + + // Used to set the price of a pool to pretend that the pool has been initialized in order to successfully set a protocol fee + function setPrice(PoolKey memory key, uint160 sqrtPriceX96) public { + Pool.State storage pool = _getPool(key.toId()); + pool.slot0 = pool.slot0.setSqrtPriceX96(sqrtPriceX96); + } + + function _getPool(PoolId id) internal view override returns (Pool.State storage) { + return _pools[id]; + } + + function fetchProtocolFee(PoolKey memory key) public returns (bool, uint24) { + return ProtocolFees._fetchProtocolFee(key); + } + + function updateProtocolFees(Currency currency, uint256 amount) public { + ProtocolFees._updateProtocolFees(currency, amount); + } +} diff --git a/test/PoolManager.t.sol b/test/PoolManager.t.sol index 5cb0efd17..7994c5659 100644 --- a/test/PoolManager.t.sol +++ b/test/PoolManager.t.sol @@ -81,25 +81,6 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { snapSize("poolManager bytecode size", address(manager)); } - function test_setProtocolFeeController_succeeds() public { - deployFreshManager(); - assertEq(address(manager.protocolFeeController()), address(0)); - vm.expectEmit(false, false, false, true, address(manager)); - emit ProtocolFeeControllerUpdated(address(feeController)); - manager.setProtocolFeeController(feeController); - assertEq(address(manager.protocolFeeController()), address(feeController)); - } - - function test_setProtocolFeeController_failsIfNotOwner() public { - deployFreshManager(); - assertEq(address(manager.protocolFeeController()), address(0)); - - vm.prank(address(1)); // not the owner address - vm.expectRevert("UNAUTHORIZED"); - manager.setProtocolFeeController(feeController); - assertEq(address(manager.protocolFeeController()), address(0)); - } - function test_addLiquidity_failsIfNotInitialized() public { vm.expectRevert(Pool.PoolNotInitialized.selector); modifyLiquidityRouter.modifyLiquidity(uninitializedKey, LIQUIDITY_PARAMS, ZERO_BYTES); @@ -1059,10 +1040,13 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { uint24 protocolFee = (uint24(protocolFee1) << 12) | uint24(protocolFee0); + (,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId()); + assertEq(slot0ProtocolFee, 0); + vm.prank(address(feeController)); manager.setProtocolFee(key, protocolFee); - (,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId()); + (,, slot0ProtocolFee,) = manager.getSlot0(key.toId()); assertEq(slot0ProtocolFee, protocolFee); // Add liquidity - Fees dont accrue for positive liquidity delta. @@ -1246,69 +1230,16 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { manager.burn(address(this), key.currency0.toId(), 1); } - function test_setProtocolFee_gas() public { - vm.prank(address(feeController)); - manager.setProtocolFee(key, MAX_PROTOCOL_FEE_BOTH_TOKENS); - snapLastCall("set protocol fee"); - } - - function test_setProtocolFee_updatesProtocolFeeForInitializedPool(uint24 protocolFee) public { - (,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId()); - assertEq(slot0ProtocolFee, 0); - - uint16 fee0 = protocolFee.getZeroForOneFee(); - uint16 fee1 = protocolFee.getOneForZeroFee(); - vm.prank(address(feeController)); - if ((fee0 > 1000) || (fee1 > 1000)) { - vm.expectRevert(IProtocolFees.InvalidProtocolFee.selector); - manager.setProtocolFee(key, protocolFee); - } else { - vm.expectEmit(false, false, false, true); - emit IProtocolFees.ProtocolFeeUpdated(key.toId(), protocolFee); - manager.setProtocolFee(key, protocolFee); - - (,, slot0ProtocolFee,) = manager.getSlot0(key.toId()); - assertEq(slot0ProtocolFee, protocolFee); - } - } - - function test_setProtocolFee_failsWithInvalidFee() public { - (,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId()); - assertEq(slot0ProtocolFee, 0); - - vm.prank(address(feeController)); - vm.expectRevert(IProtocolFees.InvalidProtocolFee.selector); - manager.setProtocolFee(key, MAX_PROTOCOL_FEE_BOTH_TOKENS + 1); - } + function test_collectProtocolFees_ERC20_accumulateFees_gas() public { + uint256 expectedFees = 10; - function test_setProtocolFee_failsWithInvalidCaller() public { (,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId()); assertEq(slot0ProtocolFee, 0); - vm.expectRevert(IProtocolFees.InvalidCaller.selector); - manager.setProtocolFee(key, MAX_PROTOCOL_FEE_BOTH_TOKENS); - } - - function test_collectProtocolFees_initializesWithProtocolFeeIfCalled() public { - feeController.setProtocolFeeForPool(uninitializedKey.toId(), MAX_PROTOCOL_FEE_BOTH_TOKENS); - - manager.initialize(uninitializedKey, SQRT_PRICE_1_1, ZERO_BYTES); - (,, uint24 slot0ProtocolFee,) = manager.getSlot0(uninitializedKey.toId()); - assertEq(slot0ProtocolFee, MAX_PROTOCOL_FEE_BOTH_TOKENS); - } - - function test_collectProtocolFees_revertsIfCallerIsNotController() public { - vm.expectRevert(IProtocolFees.InvalidCaller.selector); - manager.collectProtocolFees(address(1), currency0, 0); - } - - function test_collectProtocolFees_ERC20_accumulateFees_gas() public { - uint256 expectedFees = 10; - vm.prank(address(feeController)); manager.setProtocolFee(key, MAX_PROTOCOL_FEE_BOTH_TOKENS); - (,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId()); + (,, slot0ProtocolFee,) = manager.getSlot0(key.toId()); assertEq(slot0ProtocolFee, MAX_PROTOCOL_FEE_BOTH_TOKENS); swapRouter.swap( @@ -1331,10 +1262,13 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { function test_collectProtocolFees_ERC20_accumulateFees_exactOutput() public { uint256 expectedFees = 10; + (,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId()); + assertEq(slot0ProtocolFee, 0); + vm.prank(address(feeController)); manager.setProtocolFee(key, MAX_PROTOCOL_FEE_BOTH_TOKENS); - (,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId()); + (,, slot0ProtocolFee,) = manager.getSlot0(key.toId()); assertEq(slot0ProtocolFee, MAX_PROTOCOL_FEE_BOTH_TOKENS); swapRouter.swap( @@ -1356,10 +1290,13 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { function test_collectProtocolFees_ERC20_returnsAllFeesIf0IsProvidedAsParameter() public { uint256 expectedFees = 10; + (,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId()); + assertEq(slot0ProtocolFee, 0); + vm.prank(address(feeController)); manager.setProtocolFee(key, MAX_PROTOCOL_FEE_BOTH_TOKENS); - (,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId()); + (,, slot0ProtocolFee,) = manager.getSlot0(key.toId()); assertEq(slot0ProtocolFee, MAX_PROTOCOL_FEE_BOTH_TOKENS); swapRouter.swap( @@ -1382,10 +1319,13 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { uint256 expectedFees = 10; Currency nativeCurrency = CurrencyLibrary.NATIVE; + (,, uint24 slot0ProtocolFee,) = manager.getSlot0(nativeKey.toId()); + assertEq(slot0ProtocolFee, 0); + vm.prank(address(feeController)); manager.setProtocolFee(nativeKey, MAX_PROTOCOL_FEE_BOTH_TOKENS); - (,, uint24 slot0ProtocolFee,) = manager.getSlot0(nativeKey.toId()); + (,, slot0ProtocolFee,) = manager.getSlot0(nativeKey.toId()); assertEq(slot0ProtocolFee, MAX_PROTOCOL_FEE_BOTH_TOKENS); swapRouter.swap{value: 10000}( @@ -1409,10 +1349,13 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { uint256 expectedFees = 10; Currency nativeCurrency = CurrencyLibrary.NATIVE; + (,, uint24 slot0ProtocolFee,) = manager.getSlot0(nativeKey.toId()); + assertEq(slot0ProtocolFee, 0); + vm.prank(address(feeController)); manager.setProtocolFee(nativeKey, MAX_PROTOCOL_FEE_BOTH_TOKENS); - (,, uint24 slot0ProtocolFee,) = manager.getSlot0(nativeKey.toId()); + (,, slot0ProtocolFee,) = manager.getSlot0(nativeKey.toId()); assertEq(slot0ProtocolFee, MAX_PROTOCOL_FEE_BOTH_TOKENS); swapRouter.swap{value: 10000}( diff --git a/test/ProtocolFeesImplementation.t.sol b/test/ProtocolFeesImplementation.t.sol new file mode 100644 index 000000000..05d624492 --- /dev/null +++ b/test/ProtocolFeesImplementation.t.sol @@ -0,0 +1,215 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.15; + +import {Test} from "forge-std/Test.sol"; +import {MockERC20} from "solmate/test/utils/mocks/MockERC20.sol"; +import {CurrencyLibrary, Currency} from "../src/types/Currency.sol"; +import {ProtocolFeesImplementation} from "../src/test/ProtocolFeesImplementation.sol"; +import {GasSnapshot} from "forge-gas-snapshot/GasSnapshot.sol"; +import {IProtocolFees} from "../src/interfaces/IProtocolFees.sol"; +import {ProtocolFeeLibrary} from "../src/libraries/ProtocolFeeLibrary.sol"; +import {PoolKey} from "../src/types/PoolKey.sol"; +import {Currency, CurrencyLibrary} from "../src/types/Currency.sol"; +import {Deployers} from "../test/utils/Deployers.sol"; +import {PoolId, PoolIdLibrary} from "../src/types/PoolId.sol"; +import {IHooks} from "../src/interfaces/IHooks.sol"; +import {Constants} from "../test/utils/Constants.sol"; +import { + ProtocolFeeControllerTest, + OutOfBoundsProtocolFeeControllerTest, + RevertingProtocolFeeControllerTest, + OverflowProtocolFeeControllerTest, + InvalidReturnSizeProtocolFeeControllerTest +} from "../src/test/ProtocolFeeControllerTest.sol"; + +contract ProtocolFeesTest is Test, GasSnapshot, Deployers { + using CurrencyLibrary for Currency; + using PoolIdLibrary for PoolKey; + using ProtocolFeeLibrary for uint24; + + event ProtocolFeeControllerUpdated(address feeController); + event ProtocolFeeUpdated(PoolId indexed id, uint24 protocolFee); + + uint24 constant MAX_PROTOCOL_FEE_BOTH_TOKENS = (1000 << 12) | 1000; // 1000 1000 + + ProtocolFeesImplementation protocolFees; + + function setUp() public { + protocolFees = new ProtocolFeesImplementation(5000); + feeController = new ProtocolFeeControllerTest(); + (currency0, currency1) = deployAndMint2Currencies(); + MockERC20(Currency.unwrap(currency0)).transfer(address(protocolFees), 2 ** 255); + } + + function test_setProtocolFeeController_succeedsNoRevert() public { + assertEq(address(protocolFees.protocolFeeController()), address(0)); + vm.expectEmit(false, false, false, true, address(protocolFees)); + emit ProtocolFeeControllerUpdated(address(feeController)); + protocolFees.setProtocolFeeController(feeController); + assertEq(address(protocolFees.protocolFeeController()), address(feeController)); + } + + function test_setProtocolFeeController_revertsWithNotAuthorized() public { + assertEq(address(protocolFees.protocolFeeController()), address(0)); + + vm.prank(address(1)); // not the owner address + vm.expectRevert("UNAUTHORIZED"); + protocolFees.setProtocolFeeController(feeController); + assertEq(address(protocolFees.protocolFeeController()), address(0)); + } + + function test_setProtocolFee_succeeds_gas() public { + PoolKey memory key = PoolKey(currency0, currency1, 3000, 60, IHooks(address(0))); + protocolFees.setProtocolFeeController(feeController); + // Set price to pretend that the pool is initialized + protocolFees.setPrice(key, Constants.SQRT_PRICE_1_1); + vm.prank(address(feeController)); + vm.expectEmit(true, false, false, true, address(protocolFees)); + emit ProtocolFeeUpdated(key.toId(), MAX_PROTOCOL_FEE_BOTH_TOKENS); + protocolFees.setProtocolFee(key, MAX_PROTOCOL_FEE_BOTH_TOKENS); + snapLastCall("set protocol fee"); + } + + function test_setProtocolFee_revertsWithInvalidCaller() public { + protocolFees.setProtocolFeeController(feeController); + vm.expectRevert(IProtocolFees.InvalidCaller.selector); + protocolFees.setProtocolFee(key, 1); + } + + function test_setProtocolFee_revertsWithInvalidFee() public { + protocolFees.setProtocolFeeController(feeController); + vm.prank(address(feeController)); + vm.expectRevert(IProtocolFees.InvalidProtocolFee.selector); + protocolFees.setProtocolFee(key, MAX_PROTOCOL_FEE_BOTH_TOKENS + 1); + + vm.prank(address(feeController)); + vm.expectRevert(IProtocolFees.InvalidProtocolFee.selector); + protocolFees.setProtocolFee(key, MAX_PROTOCOL_FEE_BOTH_TOKENS + (1 << 12)); + } + + function test_fuzz_setProtocolFee(PoolKey memory key, uint24 protocolFee) public { + protocolFees.setProtocolFeeController(feeController); + // Set price to pretend that the pool is initialized + protocolFees.setPrice(key, Constants.SQRT_PRICE_1_1); + uint16 fee0 = protocolFee.getZeroForOneFee(); + uint16 fee1 = protocolFee.getOneForZeroFee(); + vm.prank(address(feeController)); + if ((fee0 > 1000) || (fee1 > 1000)) { + vm.expectRevert(IProtocolFees.InvalidProtocolFee.selector); + protocolFees.setProtocolFee(key, protocolFee); + } else { + vm.expectEmit(true, false, false, true, address(protocolFees)); + emit IProtocolFees.ProtocolFeeUpdated(key.toId(), protocolFee); + protocolFees.setProtocolFee(key, protocolFee); + } + } + + function test_collectProtocolFees_revertsWithInvalidCaller() public { + vm.expectRevert(IProtocolFees.InvalidCaller.selector); + protocolFees.collectProtocolFees(address(1), currency0, 0); + } + + function test_collectProtocolFees_succeeds() public { + // set a balance of protocol fees that can be collected + protocolFees.updateProtocolFees(currency0, 100); + assertEq(protocolFees.protocolFeesAccrued(currency0), 100); + + protocolFees.setProtocolFeeController(feeController); + vm.prank(address(feeController)); + protocolFees.collectProtocolFees(address(this), currency0, 100); + assertEq(protocolFees.protocolFeesAccrued(currency0), 0); + assertEq(currency0.balanceOf(address(this)), 100); + } + + function test_fuzz_collectProtocolFees(address recipient, uint256 amount, uint256 feesAccrued) public { + vm.assume(feesAccrued <= currency0.balanceOf(address(protocolFees))); + vm.assume(amount <= feesAccrued); + vm.assume(recipient != address(protocolFees)); + + uint256 recipientBalanceBefore = currency0.balanceOf(recipient); + uint256 senderBalanceBefore = currency0.balanceOf(address(protocolFees)); + + // set a balance of protocol fees that can be collected + protocolFees.updateProtocolFees(currency0, feesAccrued); + assertEq(protocolFees.protocolFeesAccrued(currency0), feesAccrued); + if (amount == 0) { + amount = protocolFees.protocolFeesAccrued(currency0); + } + + protocolFees.setProtocolFeeController(feeController); + vm.prank(address(feeController)); + uint256 amountCollected = protocolFees.collectProtocolFees(recipient, currency0, amount); + + assertEq(protocolFees.protocolFeesAccrued(currency0), feesAccrued - amount); + assertEq(currency0.balanceOf(recipient), recipientBalanceBefore + amount); + assertEq(currency0.balanceOf(address(protocolFees)), senderBalanceBefore - amount); + assertEq(amountCollected, amount); + } + + function test_updateProtocolFees_succeeds() public { + // set a starting balance of protocol fees + protocolFees.updateProtocolFees(currency0, 100); + assertEq(protocolFees.protocolFeesAccrued(currency0), 100); + + protocolFees.updateProtocolFees(currency0, 200); + assertEq(protocolFees.protocolFeesAccrued(currency0), 300); + } + + function test_fuzz_updateProtocolFees(uint256 amount, uint256 startingAmount) public { + // set a starting balance of protocol fees + protocolFees.updateProtocolFees(currency0, startingAmount); + assertEq(protocolFees.protocolFeesAccrued(currency0), startingAmount); + + uint256 newAmount; + unchecked { + newAmount = startingAmount + amount; + } + + protocolFees.updateProtocolFees(currency0, amount); + assertEq(protocolFees.protocolFeesAccrued(currency0), newAmount); + } + + function test_fetchProtocolFee_succeeds() public { + protocolFees.setProtocolFeeController(feeController); + vm.prank(address(feeController)); + (bool success, uint24 protocolFee) = protocolFees.fetchProtocolFee(key); + assertTrue(success); + assertEq(protocolFee, 0); + } + + function test_fetchProtocolFee_outOfBounds() public { + outOfBoundsFeeController = new OutOfBoundsProtocolFeeControllerTest(); + protocolFees.setProtocolFeeController(outOfBoundsFeeController); + vm.prank(address(outOfBoundsFeeController)); + (bool success, uint24 protocolFee) = protocolFees.fetchProtocolFee(key); + assertFalse(success); + assertEq(protocolFee, 0); + } + + function test_fetchProtocolFee_overflowFee() public { + overflowFeeController = new OverflowProtocolFeeControllerTest(); + protocolFees.setProtocolFeeController(overflowFeeController); + vm.prank(address(overflowFeeController)); + (bool success, uint24 protocolFee) = protocolFees.fetchProtocolFee(key); + assertFalse(success); + assertEq(protocolFee, 0); + } + + function test_fetchProtocolFee_invalidReturnSize() public { + invalidReturnSizeFeeController = new InvalidReturnSizeProtocolFeeControllerTest(); + protocolFees.setProtocolFeeController(invalidReturnSizeFeeController); + vm.prank(address(invalidReturnSizeFeeController)); + (bool success, uint24 protocolFee) = protocolFees.fetchProtocolFee(key); + assertFalse(success); + assertEq(protocolFee, 0); + } + + function test_fetchProtocolFee_revert() public { + revertingFeeController = new RevertingProtocolFeeControllerTest(); + protocolFees.setProtocolFeeController(revertingFeeController); + vm.prank(address(revertingFeeController)); + (bool success, uint24 protocolFee) = protocolFees.fetchProtocolFee(key); + assertFalse(success); + assertEq(protocolFee, 0); + } +}