Skip to content

Commit

Permalink
Protocolfees tests (#691)
Browse files Browse the repository at this point in the history
* protocol fees tests

* more tests

* underscore for internal var

* protocol fees tests

* re-add view

* gas snapshot

* slot0 verify

* shorten

* comments

* extra comment

* fix snapshot
  • Loading branch information
dianakocsis authored May 23, 2024
1 parent 48f79bc commit 3351c80
Show file tree
Hide file tree
Showing 6 changed files with 286 additions and 91 deletions.
2 changes: 1 addition & 1 deletion .forge-snapshots/set protocol fee.snap
Original file line number Diff line number Diff line change
@@ -1 +1 @@
32415
32444
16 changes: 8 additions & 8 deletions src/PoolManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim
/// @inheritdoc IPoolManager
int24 public constant MIN_TICK_SPACING = TickMath.MIN_TICK_SPACING;

mapping(PoolId id => Pool.State) internal pools;
mapping(PoolId id => Pool.State) internal _pools;

constructor(uint256 controllerGasLimit) ProtocolFees(controllerGasLimit) {}

function _getPool(PoolId id) internal view override returns (Pool.State storage) {
return pools[id];
return _pools[id];
}

/// @notice This will revert if the contract is locked
Expand Down Expand Up @@ -125,7 +125,7 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim
PoolId id = key.toId();
(, uint24 protocolFee) = _fetchProtocolFee(key);

tick = pools[id].initialize(sqrtPriceX96, protocolFee, lpFee);
tick = _pools[id].initialize(sqrtPriceX96, protocolFee, lpFee);

key.hooks.afterInitialize(key, sqrtPriceX96, tick, hookData);

Expand Down Expand Up @@ -174,7 +174,7 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim
}

function _checkPoolInitialized(PoolId id) internal view {
if (pools[id].isNotInitialized()) revert PoolNotInitialized();
if (_pools[id].isNotInitialized()) revert PoolNotInitialized();
}

/// @inheritdoc IPoolManager
Expand All @@ -189,7 +189,7 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim
key.hooks.beforeModifyLiquidity(key, params, hookData);

BalanceDelta principalDelta;
(principalDelta, feesAccrued) = pools[id].modifyLiquidity(
(principalDelta, feesAccrued) = _pools[id].modifyLiquidity(
Pool.ModifyLiquidityParams({
owner: msg.sender,
tickLower: params.tickLower,
Expand Down Expand Up @@ -257,7 +257,7 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim
// Internal swap function to execute a swap, take protocol fees on input token, and emit the swap event
function _swap(PoolId id, Pool.SwapParams memory params, Currency inputCurrency) internal returns (BalanceDelta) {
(BalanceDelta delta, uint256 feeForProtocol, uint24 swapFee, Pool.SwapState memory state) =
pools[id].swap(params);
_pools[id].swap(params);

// The fee is on the input currency.
if (feeForProtocol > 0) _updateProtocolFees(inputCurrency, feeForProtocol);
Expand All @@ -281,7 +281,7 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim

key.hooks.beforeDonate(key, amount0, amount1, hookData);

delta = pools[id].donate(amount0, amount1);
delta = _pools[id].donate(amount0, amount1);

_accountPoolBalanceDelta(key, delta, msg.sender);

Expand Down Expand Up @@ -330,6 +330,6 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim
if (!key.fee.isDynamicFee() || msg.sender != address(key.hooks)) revert UnauthorizedDynamicLPFeeUpdate();
newDynamicLPFee.validate();
PoolId id = key.toId();
pools[id].setLPFee(newDynamicLPFee);
_pools[id].setLPFee(newDynamicLPFee);
}
}
5 changes: 3 additions & 2 deletions src/ProtocolFees.sol
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ abstract contract ProtocolFees is IProtocolFees, Owned {
/// @dev to prevent an invalid protocol fee controller from blocking pools from being initialized
/// the success of this function is NOT checked on initialize and if the call fails, the protocol fees are set to 0.
/// @dev the success of this function must be checked when called in setProtocolFee
function _fetchProtocolFee(PoolKey memory key) internal returns (bool success, uint24 protocolFees) {
function _fetchProtocolFee(PoolKey memory key) internal returns (bool success, uint24 protocolFee) {
if (address(protocolFeeController) != address(0)) {
// note that EIP-150 mandates that calls requesting more than 63/64ths of remaining gas
// will be allotted no more than this amount, so controllerGasLimit must be set with this
Expand All @@ -76,8 +76,9 @@ abstract contract ProtocolFees is IProtocolFees, Owned {
assembly {
returnData := mload(add(_data, 0x20))
}

// Ensure return data does not overflow a uint24 and that the underlying fees are within bounds.
(success, protocolFees) = (returnData == uint24(returnData)) && uint24(returnData).isValidProtocolFee()
(success, protocolFee) = (returnData == uint24(returnData)) && uint24(returnData).isValidProtocolFee()
? (true, uint24(returnData))
: (false, 0);
}
Expand Down
36 changes: 36 additions & 0 deletions src/test/ProtocolFeesImplementation.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.20;

import {ProtocolFees} from "../ProtocolFees.sol";
import {IProtocolFeeController} from "../interfaces/IProtocolFeeController.sol";
import {PoolKey} from "../types/PoolKey.sol";
import {Currency} from "../types/Currency.sol";
import {PoolId, PoolIdLibrary} from "../types/PoolId.sol";
import {Pool} from "../libraries/Pool.sol";
import {Slot0} from "../types/Slot0.sol";

contract ProtocolFeesImplementation is ProtocolFees {
using PoolIdLibrary for PoolKey;

mapping(PoolId id => Pool.State) internal _pools;

constructor(uint256 _controllerGasLimit) ProtocolFees(_controllerGasLimit) {}

// Used to set the price of a pool to pretend that the pool has been initialized in order to successfully set a protocol fee
function setPrice(PoolKey memory key, uint160 sqrtPriceX96) public {
Pool.State storage pool = _getPool(key.toId());
pool.slot0 = pool.slot0.setSqrtPriceX96(sqrtPriceX96);
}

function _getPool(PoolId id) internal view override returns (Pool.State storage) {
return _pools[id];
}

function fetchProtocolFee(PoolKey memory key) public returns (bool, uint24) {
return ProtocolFees._fetchProtocolFee(key);
}

function updateProtocolFees(Currency currency, uint256 amount) public {
ProtocolFees._updateProtocolFees(currency, amount);
}
}
103 changes: 23 additions & 80 deletions test/PoolManager.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -81,25 +81,6 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot {
snapSize("poolManager bytecode size", address(manager));
}

function test_setProtocolFeeController_succeeds() public {
deployFreshManager();
assertEq(address(manager.protocolFeeController()), address(0));
vm.expectEmit(false, false, false, true, address(manager));
emit ProtocolFeeControllerUpdated(address(feeController));
manager.setProtocolFeeController(feeController);
assertEq(address(manager.protocolFeeController()), address(feeController));
}

function test_setProtocolFeeController_failsIfNotOwner() public {
deployFreshManager();
assertEq(address(manager.protocolFeeController()), address(0));

vm.prank(address(1)); // not the owner address
vm.expectRevert("UNAUTHORIZED");
manager.setProtocolFeeController(feeController);
assertEq(address(manager.protocolFeeController()), address(0));
}

function test_addLiquidity_failsIfNotInitialized() public {
vm.expectRevert(Pool.PoolNotInitialized.selector);
modifyLiquidityRouter.modifyLiquidity(uninitializedKey, LIQUIDITY_PARAMS, ZERO_BYTES);
Expand Down Expand Up @@ -1059,10 +1040,13 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot {

uint24 protocolFee = (uint24(protocolFee1) << 12) | uint24(protocolFee0);

(,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId());
assertEq(slot0ProtocolFee, 0);

vm.prank(address(feeController));
manager.setProtocolFee(key, protocolFee);

(,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId());
(,, slot0ProtocolFee,) = manager.getSlot0(key.toId());
assertEq(slot0ProtocolFee, protocolFee);

// Add liquidity - Fees dont accrue for positive liquidity delta.
Expand Down Expand Up @@ -1246,69 +1230,16 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot {
manager.burn(address(this), key.currency0.toId(), 1);
}

function test_setProtocolFee_gas() public {
vm.prank(address(feeController));
manager.setProtocolFee(key, MAX_PROTOCOL_FEE_BOTH_TOKENS);
snapLastCall("set protocol fee");
}

function test_setProtocolFee_updatesProtocolFeeForInitializedPool(uint24 protocolFee) public {
(,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId());
assertEq(slot0ProtocolFee, 0);

uint16 fee0 = protocolFee.getZeroForOneFee();
uint16 fee1 = protocolFee.getOneForZeroFee();
vm.prank(address(feeController));
if ((fee0 > 1000) || (fee1 > 1000)) {
vm.expectRevert(IProtocolFees.InvalidProtocolFee.selector);
manager.setProtocolFee(key, protocolFee);
} else {
vm.expectEmit(false, false, false, true);
emit IProtocolFees.ProtocolFeeUpdated(key.toId(), protocolFee);
manager.setProtocolFee(key, protocolFee);

(,, slot0ProtocolFee,) = manager.getSlot0(key.toId());
assertEq(slot0ProtocolFee, protocolFee);
}
}

function test_setProtocolFee_failsWithInvalidFee() public {
(,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId());
assertEq(slot0ProtocolFee, 0);

vm.prank(address(feeController));
vm.expectRevert(IProtocolFees.InvalidProtocolFee.selector);
manager.setProtocolFee(key, MAX_PROTOCOL_FEE_BOTH_TOKENS + 1);
}
function test_collectProtocolFees_ERC20_accumulateFees_gas() public {
uint256 expectedFees = 10;

function test_setProtocolFee_failsWithInvalidCaller() public {
(,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId());
assertEq(slot0ProtocolFee, 0);

vm.expectRevert(IProtocolFees.InvalidCaller.selector);
manager.setProtocolFee(key, MAX_PROTOCOL_FEE_BOTH_TOKENS);
}

function test_collectProtocolFees_initializesWithProtocolFeeIfCalled() public {
feeController.setProtocolFeeForPool(uninitializedKey.toId(), MAX_PROTOCOL_FEE_BOTH_TOKENS);

manager.initialize(uninitializedKey, SQRT_PRICE_1_1, ZERO_BYTES);
(,, uint24 slot0ProtocolFee,) = manager.getSlot0(uninitializedKey.toId());
assertEq(slot0ProtocolFee, MAX_PROTOCOL_FEE_BOTH_TOKENS);
}

function test_collectProtocolFees_revertsIfCallerIsNotController() public {
vm.expectRevert(IProtocolFees.InvalidCaller.selector);
manager.collectProtocolFees(address(1), currency0, 0);
}

function test_collectProtocolFees_ERC20_accumulateFees_gas() public {
uint256 expectedFees = 10;

vm.prank(address(feeController));
manager.setProtocolFee(key, MAX_PROTOCOL_FEE_BOTH_TOKENS);

(,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId());
(,, slot0ProtocolFee,) = manager.getSlot0(key.toId());
assertEq(slot0ProtocolFee, MAX_PROTOCOL_FEE_BOTH_TOKENS);

swapRouter.swap(
Expand All @@ -1331,10 +1262,13 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot {
function test_collectProtocolFees_ERC20_accumulateFees_exactOutput() public {
uint256 expectedFees = 10;

(,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId());
assertEq(slot0ProtocolFee, 0);

vm.prank(address(feeController));
manager.setProtocolFee(key, MAX_PROTOCOL_FEE_BOTH_TOKENS);

(,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId());
(,, slot0ProtocolFee,) = manager.getSlot0(key.toId());
assertEq(slot0ProtocolFee, MAX_PROTOCOL_FEE_BOTH_TOKENS);

swapRouter.swap(
Expand All @@ -1356,10 +1290,13 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot {
function test_collectProtocolFees_ERC20_returnsAllFeesIf0IsProvidedAsParameter() public {
uint256 expectedFees = 10;

(,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId());
assertEq(slot0ProtocolFee, 0);

vm.prank(address(feeController));
manager.setProtocolFee(key, MAX_PROTOCOL_FEE_BOTH_TOKENS);

(,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId());
(,, slot0ProtocolFee,) = manager.getSlot0(key.toId());
assertEq(slot0ProtocolFee, MAX_PROTOCOL_FEE_BOTH_TOKENS);

swapRouter.swap(
Expand All @@ -1382,10 +1319,13 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot {
uint256 expectedFees = 10;
Currency nativeCurrency = CurrencyLibrary.NATIVE;

(,, uint24 slot0ProtocolFee,) = manager.getSlot0(nativeKey.toId());
assertEq(slot0ProtocolFee, 0);

vm.prank(address(feeController));
manager.setProtocolFee(nativeKey, MAX_PROTOCOL_FEE_BOTH_TOKENS);

(,, uint24 slot0ProtocolFee,) = manager.getSlot0(nativeKey.toId());
(,, slot0ProtocolFee,) = manager.getSlot0(nativeKey.toId());
assertEq(slot0ProtocolFee, MAX_PROTOCOL_FEE_BOTH_TOKENS);

swapRouter.swap{value: 10000}(
Expand All @@ -1409,10 +1349,13 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot {
uint256 expectedFees = 10;
Currency nativeCurrency = CurrencyLibrary.NATIVE;

(,, uint24 slot0ProtocolFee,) = manager.getSlot0(nativeKey.toId());
assertEq(slot0ProtocolFee, 0);

vm.prank(address(feeController));
manager.setProtocolFee(nativeKey, MAX_PROTOCOL_FEE_BOTH_TOKENS);

(,, uint24 slot0ProtocolFee,) = manager.getSlot0(nativeKey.toId());
(,, slot0ProtocolFee,) = manager.getSlot0(nativeKey.toId());
assertEq(slot0ProtocolFee, MAX_PROTOCOL_FEE_BOTH_TOKENS);

swapRouter.swap{value: 10000}(
Expand Down
Loading

0 comments on commit 3351c80

Please sign in to comment.