Skip to content

Commit

Permalink
refactor(contracts): make MultiPayment contract upgradable and emit…
Browse files Browse the repository at this point in the history
… payment events (#858)

* support partial payments and set gas limit

* tests

* add revert workaround: https://book.getfoundry.sh/cheatcodes/expect-revert

* make MultiPayment upgradeable

* deploy proxy

* regenerate genesis block

* update tests

* style: resolve style guide violations

* Update contracts/src/multi-payment/MultiPaymentV1.sol

Co-authored-by: Sebastijan K. <[email protected]>

* Update contracts/test/multi-payment/MultiPayment.sol

Co-authored-by: Sebastijan K. <[email protected]>

* Update contracts/test/multi-payment/MultiPayment.sol

Co-authored-by: Sebastijan K. <[email protected]>

* style: resolve style guide violations

* fix

---------

Co-authored-by: Sebastijan K. <[email protected]>
  • Loading branch information
oXtxNt9U and sebastijankuzner authored Feb 13, 2025
1 parent 43b1eeb commit 1fd3832
Show file tree
Hide file tree
Showing 25 changed files with 3,500 additions and 2,899 deletions.
29 changes: 0 additions & 29 deletions contracts/src/multi-payment/MultiPayment.sol

This file was deleted.

58 changes: 58 additions & 0 deletions contracts/src/multi-payment/MultiPaymentV1.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// SPDX-License-Identifier: GNU GENERAL PUBLIC LICENSE
pragma solidity ^0.8.27;

import {UUPSUpgradeable} from "@openzeppelin/contracts-upgradeable/proxy/utils/UUPSUpgradeable.sol";
import {OwnableUpgradeable} from "@openzeppelin/contracts-upgradeable/access/OwnableUpgradeable.sol";

contract MultiPaymentV1 is UUPSUpgradeable, OwnableUpgradeable {
error RecipientsAndAmountsMismatch();
error InvalidValue();

event Payment(address indexed recipient, uint256 amount, bool success);

// Initializers
function initialize() public initializer {
__Ownable_init(msg.sender);
}

// Overrides
function _authorizeUpgrade(address newImplementation) internal override onlyOwner {}

function version() external pure returns (uint256) {
return 1;
}

function pay(address payable[] calldata recipients, uint256[] calldata amounts) external payable {
if (recipients.length != amounts.length) {
revert RecipientsAndAmountsMismatch();
}

// Ensure value sent is equal to the total amount to send
uint256 total = 0;
for (uint256 i = 0; i < amounts.length; i++) {
total += amounts[i];
}
if (msg.value != total) {
revert InvalidValue();
}

if (recipients.length == 0) {
return;
}

for (uint256 i = 0; i < recipients.length; i++) {
(bool sent,) = recipients[i].call{value: amounts[i], gas: 5000}("");
if (sent) {
total -= amounts[i];
}

emit Payment(recipients[i], amounts[i], sent);
}

// Refund any remaining value due to partial payments
if (total > 0) {
(bool success,) = msg.sender.call{value: total}("");
require(success, "Refund failed");
}
}
}
46 changes: 46 additions & 0 deletions contracts/test/multi-payment/MultiPayment-Proxy.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// SPDX-License-Identifier: GNU GENERAL PUBLIC LICENSE
pragma solidity ^0.8.13;

import {Test, console} from "@forge-std/Test.sol";
import {MultiPaymentV1} from "@contracts/multi-payment/MultiPaymentV1.sol";
import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
import {Initializable} from "@openzeppelin/contracts/proxy/utils/Initializable.sol";

contract MultiPaymentV2Test is MultiPaymentV1 {
function versionv2() external pure returns (uint256) {
return 2;
}
}

contract ProxyTest is Test {
MultiPaymentV1 public multiPayment;

function setUp() public {
bytes memory data = abi.encode(MultiPaymentV1.initialize.selector);
address proxy = address(new ERC1967Proxy(address(new MultiPaymentV1()), data));
multiPayment = MultiPaymentV1(proxy);
}

function test_initialize_should_revert() public {
vm.expectRevert(Initializable.InvalidInitialization.selector);
multiPayment.initialize();
}

function test_should_have_valid_UPGRADE_INTERFACE_VERSION() public view {
assertEq(multiPayment.UPGRADE_INTERFACE_VERSION(), "5.0.0");
}

function test_proxy_should_update() public {
assertEq(multiPayment.version(), 1);
assertEq(multiPayment.UPGRADE_INTERFACE_VERSION(), "5.0.0");
multiPayment.upgradeToAndCall(address(new MultiPaymentV2Test()), bytes(""));

// Cast proxy to new contract
MultiPaymentV2Test multiPaymentNew = MultiPaymentV2Test(address(multiPayment));
assertEq(multiPaymentNew.versionv2(), 2);

// Should keep old data
vm.expectRevert(Initializable.InvalidInitialization.selector);
multiPaymentNew.initialize();
}
}
191 changes: 162 additions & 29 deletions contracts/test/multi-payment/MultiPayment.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,27 @@
pragma solidity ^0.8.13;

import {Test, console} from "@forge-std/Test.sol";
import {MultiPayment} from "@contracts/multi-payment/MultiPayment.sol";
import {MultiPaymentV1} from "@contracts/multi-payment/MultiPaymentV1.sol";
import {ERC1967Proxy} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
import {OwnableUpgradeable} from "@openzeppelin/contracts-upgradeable/access/OwnableUpgradeable.sol";

contract RejectPayments {
fallback() external payable {
revert("Recipient always reverts");
}

receive() external payable {
revert("Direct payments are not accepted");
revert("Recipient always reverts");
}
}

contract MultiPaymentTest is Test {
MultiPayment public multiPayment;
MultiPaymentV1 public multiPayment;

function setUp() public {
multiPayment = new MultiPayment();
bytes memory data = abi.encode(MultiPaymentV1.initialize.selector);
address proxy = address(new ERC1967Proxy(address(new MultiPaymentV1()), data));
multiPayment = MultiPaymentV1(proxy);
}

function test_pay_pass_with_zero_payment() public {
Expand Down Expand Up @@ -108,6 +116,128 @@ contract MultiPaymentTest is Test {
assertEq(sender.balance, 40 ether);
}

function test_pay_pass_with_partial_success() public {
address payable sender = payable(address(9999));
vm.deal(sender, 100 ether);
assertEq(sender.balance, 100 ether);

vm.startPrank(sender);

address payable recipient1 = payable(address(1));

RejectPayments rejectPayments = new RejectPayments();
address payable recipient2 = payable(address(rejectPayments));
assertEq(recipient2.balance, 0);

address payable recipient3 = payable(address(3));

address payable[] memory recipients = new address payable[](3);
recipients[0] = recipient1;
recipients[1] = recipient2;
recipients[2] = recipient3;

uint256[] memory amounts = new uint256[](3);
amounts[0] = 10 ether;
amounts[1] = 20 ether;
amounts[2] = 30 ether;

// Act
multiPayment.pay{value: 60 ether}(recipients, amounts);

// Assert
assertEq(recipient1.balance, 10 ether);
assertEq(recipient2.balance, 0 ether); // failed
assertEq(recipient3.balance, 30 ether);
assertEq(sender.balance, 60 ether); // refunded 20 ether

vm.stopPrank();
}

function test_pay_emitted_events() public {
address payable sender = payable(address(this));
vm.deal(sender, 100 ether);
assertEq(sender.balance, 100 ether);

address payable recipient1 = payable(address(1));
address payable recipient2 = payable(address(2));
address payable recipient3 = payable(address(3));

address payable[] memory recipients = new address payable[](3);
recipients[0] = recipient1;
recipients[1] = recipient2;
recipients[2] = recipient3;

uint256[] memory amounts = new uint256[](3);
amounts[0] = 10 ether;
amounts[1] = 20 ether;
amounts[2] = 30 ether;

// Events
vm.expectEmit();
emit MultiPaymentV1.Payment(recipient1, 10 ether, true);

vm.expectEmit();
emit MultiPaymentV1.Payment(recipient2, 20 ether, true);

vm.expectEmit();
emit MultiPaymentV1.Payment(recipient3, 30 ether, true);

// Act
multiPayment.pay{value: 60 ether}(recipients, amounts);

// Assert
assertEq(recipient1.balance, 10 ether);
assertEq(recipient2.balance, 20 ether);
assertEq(recipient3.balance, 30 ether);
assertEq(sender.balance, 40 ether);
}

function test_pay_emitted_events_with_reverts() public {
address payable sender = payable(address(9999));
vm.deal(sender, 100 ether);
assertEq(sender.balance, 100 ether);

vm.startPrank(sender);

address payable recipient1 = payable(address(1));

RejectPayments rejectPayments = new RejectPayments();
address payable recipient2 = payable(address(rejectPayments));
assertEq(recipient2.balance, 0);

address payable recipient3 = payable(address(3));

address payable[] memory recipients = new address payable[](3);
recipients[0] = recipient1;
recipients[1] = recipient2;
recipients[2] = recipient3;

uint256[] memory amounts = new uint256[](3);
amounts[0] = 10 ether;
amounts[1] = 20 ether;
amounts[2] = 30 ether;

// Force recipient2 to reject payment

vm.expectEmit();
emit MultiPaymentV1.Payment(recipient1, 10 ether, true);

vm.expectEmit();
emit MultiPaymentV1.Payment(recipient2, 20 ether, false);

vm.expectEmit();
emit MultiPaymentV1.Payment(recipient3, 30 ether, true);

// Act
multiPayment.pay{value: 60 ether}(recipients, amounts);

// Assert
assertEq(recipient1.balance, 10 ether);
assertEq(recipient2.balance, 0 ether);
assertEq(recipient3.balance, 30 ether);
assertEq(sender.balance, 60 ether);
}

function test_pay_pass_with_multiple_payments_same_address() public {
address payable sender = payable(address(this));
vm.deal(sender, 100 ether);
Expand Down Expand Up @@ -135,13 +265,14 @@ contract MultiPaymentTest is Test {
}

function test_pay_pass_with_multiple_payments_large() public {
uint256 payments = 10000;
uint256 payments = 100;
address payable[] memory recipients = new address payable[](payments);
uint256[] memory amounts = new uint256[](payments);

uint256 total = 0;
for (uint256 i = 0; i < payments; i++) {
recipients[i] = payable(address(uint160(i + 10))); // For some reason address(9) reverts // TODO: Check why
// Low addresses are reserved by foundry (Cheat Code Addresses) and cause side effects when used
recipients[i] = payable(address(uint160(1000 + i)));
amounts[i] = 1;
total += 1;
}
Expand Down Expand Up @@ -175,7 +306,7 @@ contract MultiPaymentTest is Test {
amounts[1] = 60 ether;

// Act
vm.expectRevert(MultiPayment.RecipientsAndAmountsMismatch.selector);
vm.expectRevert(MultiPaymentV1.RecipientsAndAmountsMismatch.selector);
multiPayment.pay{value: 100 ether}(recipients, amounts);
}

Expand All @@ -194,15 +325,17 @@ contract MultiPaymentTest is Test {
amounts[0] = 40 ether;

// Act
vm.expectRevert(MultiPayment.InvalidValue.selector);
vm.expectRevert(MultiPaymentV1.InvalidValue.selector);
multiPayment.pay{value: 50 ether}(recipients, amounts);
}

function test_pay_fail_with_failed_to_send_ether() public {
address payable sender = payable(address(this));
function test_pay_refund_when_failed_to_send_ether() public {
address payable sender = payable(address(999));
vm.deal(sender, 100 ether);
assertEq(sender.balance, 100 ether);

vm.startPrank(sender);

RejectPayments rejectPayments = new RejectPayments();
address payable recipient = payable(address(rejectPayments));
assertEq(recipient.balance, 0);
Expand All @@ -214,30 +347,30 @@ contract MultiPaymentTest is Test {
amounts[0] = 40 ether;

// Act
recipient = payable(address(0)); // Force recipient to be address(0)
vm.expectRevert(MultiPayment.FailedToSendEther.selector);
multiPayment.pay{value: 40 ether}(recipients, amounts);
}

// Test disabled, because of foundy updates. Check:
// https://book.getfoundry.sh/cheatcodes/expect-revert#description
assertEq(sender.balance, 100 ether);

// function test_pay_fail_if_no_enough_balance() public {
// address payable sender = payable(address(this));
// vm.deal(sender, 100 ether);
// assertEq(sender.balance, 100 ether);
vm.stopPrank();
}

// address payable recipient = payable(address(1));
// assertEq(recipient.balance, 0);
/// forge-config: default.allow_internal_expect_revert = true
function test_pay_fail_if_no_enough_balance() public {
address payable sender = payable(address(this));
vm.deal(sender, 100 ether);
assertEq(sender.balance, 100 ether);

address payable recipient = payable(address(1));
assertEq(recipient.balance, 0);

// address payable[] memory recipients = new address payable[](1);
// recipients[0] = recipient;
address payable[] memory recipients = new address payable[](1);
recipients[0] = recipient;

// uint256[] memory amounts = new uint256[](1);
// amounts[0] = 10 ether;
uint256[] memory amounts = new uint256[](1);
amounts[0] = 10 ether;

// // Act
// vm.expectRevert();
// multiPayment.pay{value: 110 ether}(recipients, amounts);
// }
// Act
vm.expectRevert();
multiPayment.pay{value: 110 ether}(recipients, amounts);
}
}
Loading

0 comments on commit 1fd3832

Please sign in to comment.