diff --git a/test/ProtocolFeesImplementation.t.sol b/test/ProtocolFeesImplementation.t.sol index 05d624492..822254c07 100644 --- a/test/ProtocolFeesImplementation.t.sol +++ b/test/ProtocolFeesImplementation.t.sol @@ -123,8 +123,6 @@ contract ProtocolFeesTest is Test, GasSnapshot, Deployers { function test_fuzz_collectProtocolFees(address recipient, uint256 amount, uint256 feesAccrued) public { vm.assume(feesAccrued <= currency0.balanceOf(address(protocolFees))); - vm.assume(amount <= feesAccrued); - vm.assume(recipient != address(protocolFees)); uint256 recipientBalanceBefore = currency0.balanceOf(recipient); uint256 senderBalanceBefore = currency0.balanceOf(address(protocolFees)); @@ -138,12 +136,21 @@ contract ProtocolFeesTest is Test, GasSnapshot, Deployers { protocolFees.setProtocolFeeController(feeController); vm.prank(address(feeController)); + if (amount > feesAccrued) { + vm.expectRevert(); + } uint256 amountCollected = protocolFees.collectProtocolFees(recipient, currency0, amount); - assertEq(protocolFees.protocolFeesAccrued(currency0), feesAccrued - amount); - assertEq(currency0.balanceOf(recipient), recipientBalanceBefore + amount); - assertEq(currency0.balanceOf(address(protocolFees)), senderBalanceBefore - amount); - assertEq(amountCollected, amount); + if (amount <= feesAccrued) { + if (recipient == address(protocolFees)) { + assertEq(currency0.balanceOf(recipient), recipientBalanceBefore); + } else { + assertEq(currency0.balanceOf(recipient), recipientBalanceBefore + amount); + assertEq(currency0.balanceOf(address(protocolFees)), senderBalanceBefore - amount); + } + assertEq(protocolFees.protocolFeesAccrued(currency0), feesAccrued - amount); + assertEq(amountCollected, amount); + } } function test_updateProtocolFees_succeeds() public {