Added 4337 tests + brought up to 100% test coverage

This commit is contained in:
Jayden Windle
2023-04-11 12:19:08 -04:00
parent 1efdcefd2d
commit a7d68c483e
7 changed files with 599 additions and 22 deletions

View File

@@ -1,9 +1,12 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.13;
import "forge-std/console.sol";
import "erc6551/interfaces/IERC6551Account.sol";
import "erc6551/lib/ERC6551AccountByteCode.sol";
import "openzeppelin-contracts/utils/cryptography/ECDSA.sol";
import "openzeppelin-contracts/utils/introspection/IERC165.sol";
import "openzeppelin-contracts/token/ERC721/IERC721.sol";
import "openzeppelin-contracts/token/ERC721/IERC721Receiver.sol";
@@ -38,6 +41,8 @@ contract Account is
UUPSUpgradeable,
BaseERC4337Account
{
using ECDSA for bytes32;
// @dev ERC-4337 entry point
address immutable _entryPoint;
@@ -174,15 +179,16 @@ contract Account is
uint256 tokenId
)
{
address self = address(this);
uint256 length = self.code.length;
if (length < 0x60) return (0, address(0), 0);
bytes memory footer = new bytes(0x60);
return
abi.decode(
Bytecode.codeAt(self, length - 0x60, length),
(uint256, address, uint256)
);
assembly {
let size := extcodesize(address())
if gt(size, 0x60) {
extcodecopy(address(), add(footer, 0x20), sub(size, 0x60), size)
}
}
return abi.decode(footer, (uint256, address, uint256));
}
function nonce() public view override returns (uint256) {
@@ -305,11 +311,10 @@ contract Account is
UserOperation calldata userOp,
bytes32 userOpHash
) internal view override returns (uint256 validationData) {
bool isValid = SignatureChecker.isValidSignatureNow(
owner(),
userOpHash,
bool isValid = this.isValidSignature(
userOpHash.toEthSignedMessageHash(),
userOp.signature
);
) == IERC1271.isValidSignature.selector;
if (isValid) {
return 0;
@@ -389,9 +394,10 @@ contract Account is
tokenId == receivedTokenId
) revert OwnershipCycle();
// Advance up the ownership chain
currentOwner = IERC721(tokenAddress).ownerOf(tokenId);
if (currentOwner == address(this)) revert OwnershipCycle();
unchecked {
depth++;
}

View File

@@ -17,6 +17,7 @@ import "../src/AccountGuardian.sol";
import "./mocks/MockERC721.sol";
import "./mocks/MockExecutor.sol";
import "./mocks/MockReverter.sol";
import "./mocks/MockAccount.sol";
contract AccountTest is Test {
Account implementation;
@@ -302,6 +303,57 @@ contract AccountTest is Test {
MockExecutor(accountAddress).fail();
}
function testCustomOverridesSupportsInterface(uint256 tokenId) public {
address user1 = vm.addr(1);
tokenCollection.mint(user1, tokenId);
assertEq(tokenCollection.ownerOf(tokenId), user1);
address accountAddress = registry.createAccount(
address(implementation),
block.chainid,
address(tokenCollection),
tokenId,
0,
""
);
vm.deal(accountAddress, 1 ether);
Account account = Account(payable(accountAddress));
assertEq(
account.supportsInterface(type(IERC1155Receiver).interfaceId),
true
);
assertEq(account.supportsInterface(0x12345678), false);
MockExecutor mockExecutor = new MockExecutor();
// set overrides on account
bytes4[] memory selectors = new bytes4[](1);
selectors[0] = bytes4(
abi.encodeWithSignature("supportsInterface(bytes4)")
);
address[] memory implementations = new address[](1);
implementations[0] = address(mockExecutor);
vm.prank(user1);
account.setOverrides(selectors, implementations);
// override handles extra interface support
assertEq(
Account(payable(accountAddress)).supportsInterface(0x12345678),
true
);
// cannot override default interfaces
assertEq(
Account(payable(accountAddress)).supportsInterface(
type(IERC1155Receiver).interfaceId
),
true
);
}
/**/
function testCustomPermissions(uint256 tokenId) public {
address user1 = vm.addr(1);
@@ -474,5 +526,46 @@ contract AccountTest is Test {
true
);
assertEq(account.supportsInterface(type(IERC165).interfaceId), true);
assertEq(account.supportsInterface(0x12345678), true);
}
function testAccountUpgrade() public {
uint256 tokenId = 1;
address user1 = vm.addr(1);
tokenCollection.mint(user1, tokenId);
assertEq(tokenCollection.ownerOf(tokenId), user1);
address accountAddress = registry.createAccount(
address(implementation),
block.chainid,
address(tokenCollection),
tokenId,
0,
""
);
Account account = Account(payable(accountAddress));
MockAccount upgradedImplementation = new MockAccount(
address(guardian),
address(entryPoint)
);
vm.expectRevert(UntrustedImplementation.selector);
vm.prank(user1);
account.upgradeTo(address(upgradedImplementation));
guardian.setTrustedImplementation(
address(upgradedImplementation),
true
);
vm.prank(user1);
account.upgradeTo(address(upgradedImplementation));
uint256 returnValue = MockAccount(payable(accountAddress))
.customFunction();
assertEq(returnValue, 12345);
}
}

View File

@@ -16,6 +16,7 @@ import "../src/AccountGuardian.sol";
import "./mocks/MockERC721.sol";
import "./mocks/MockERC1155.sol";
import "./mocks/MockExecutor.sol";
contract AccountERC1155Test is Test {
MockERC1155 public dummyERC1155;
@@ -125,4 +126,131 @@ contract AccountERC1155Test is Test {
assertEq(dummyERC1155.balanceOf(accountAddress, 1), 0);
assertEq(dummyERC1155.balanceOf(user1, 1), 10);
}
function testBatchTransferERC1155(uint256 tokenId) public {
address user1 = vm.addr(1);
address accountAddress = registry.createAccount(
address(implementation),
block.chainid,
address(tokenCollection),
tokenId,
0,
""
);
tokenCollection.mint(user1, tokenId);
assertEq(tokenCollection.ownerOf(tokenId), user1);
dummyERC1155.mint(user1, 1, 10);
dummyERC1155.mint(user1, 2, 10);
uint256[] memory ids = new uint256[](2);
ids[0] = 1;
ids[1] = 2;
uint256[] memory amounts = new uint256[](2);
amounts[0] = 10;
amounts[1] = 10;
vm.prank(user1);
dummyERC1155.safeBatchTransferFrom(
user1,
accountAddress,
ids,
amounts,
""
);
assertEq(dummyERC1155.balanceOf(accountAddress, 1), 10);
assertEq(dummyERC1155.balanceOf(accountAddress, 2), 10);
assertEq(dummyERC1155.balanceOf(user1, 1), 0);
assertEq(dummyERC1155.balanceOf(user1, 2), 0);
}
function testOverrideERC1155Receiver(uint256 tokenId) public {
address user1 = vm.addr(1);
tokenCollection.mint(user1, tokenId);
assertEq(tokenCollection.ownerOf(tokenId), user1);
address accountAddress = registry.createAccount(
address(implementation),
block.chainid,
address(tokenCollection),
tokenId,
0,
""
);
Account account = Account(payable(accountAddress));
MockExecutor mockExecutor = new MockExecutor();
// set overrides on account
bytes4[] memory selectors = new bytes4[](1);
selectors[0] = bytes4(
abi.encodeWithSignature(
"onERC1155Received(address,address,uint256,uint256,bytes)"
)
);
address[] memory implementations = new address[](1);
implementations[0] = address(mockExecutor);
vm.prank(user1);
account.setOverrides(selectors, implementations);
vm.expectRevert("ERC1155: ERC1155Receiver rejected tokens");
dummyERC1155.mint(accountAddress, 1, 10);
}
function testOverrideERC1155BatchReceiver(uint256 tokenId) public {
address user1 = vm.addr(1);
tokenCollection.mint(user1, tokenId);
assertEq(tokenCollection.ownerOf(tokenId), user1);
address accountAddress = registry.createAccount(
address(implementation),
block.chainid,
address(tokenCollection),
tokenId,
0,
""
);
Account account = Account(payable(accountAddress));
MockExecutor mockExecutor = new MockExecutor();
// set overrides on account
bytes4[] memory selectors = new bytes4[](1);
selectors[0] = bytes4(
abi.encodeWithSignature(
"onERC1155BatchReceived(address,address,uint256[],uint256[],bytes)"
)
);
address[] memory implementations = new address[](1);
implementations[0] = address(mockExecutor);
vm.prank(user1);
account.setOverrides(selectors, implementations);
dummyERC1155.mint(user1, 1, 10);
dummyERC1155.mint(user1, 2, 10);
uint256[] memory ids = new uint256[](2);
ids[0] = 1;
ids[1] = 2;
uint256[] memory amounts = new uint256[](2);
amounts[0] = 10;
amounts[1] = 10;
vm.expectRevert("ERC1155: ERC1155Receiver rejected tokens");
vm.prank(user1);
dummyERC1155.safeBatchTransferFrom(
user1,
accountAddress,
ids,
amounts,
""
);
}
}

279
test/AccountERC4337.t.sol Normal file
View File

@@ -0,0 +1,279 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.13;
import "forge-std/Test.sol";
import "openzeppelin-contracts/utils/cryptography/ECDSA.sol";
import "openzeppelin-contracts/token/ERC20/ERC20.sol";
import "openzeppelin-contracts/proxy/Clones.sol";
import "account-abstraction/core/EntryPoint.sol";
import "erc6551/ERC6551Registry.sol";
import "erc6551/interfaces/IERC6551Account.sol";
import "../src/Account.sol";
import "../src/AccountGuardian.sol";
import "./mocks/MockERC721.sol";
contract AccountERC4337Test is Test {
using ECDSA for bytes32;
Account implementation;
AccountGuardian public guardian;
ERC6551Registry public registry;
IEntryPoint public entryPoint;
MockERC721 public tokenCollection;
function setUp() public {
entryPoint = new EntryPoint();
guardian = new AccountGuardian();
implementation = new Account(address(guardian), address(entryPoint));
registry = new ERC6551Registry();
tokenCollection = new MockERC721();
}
function testReturnsEntryPoint() public {
address accountAddress = registry.createAccount(
address(implementation),
block.chainid,
address(tokenCollection),
1,
0,
""
);
assertEq(
address(Account(payable(accountAddress)).entryPoint()),
address(entryPoint)
);
}
function testNonceIncrementsOnDirectCall(uint256 tokenId) public {
address user1 = vm.addr(1);
tokenCollection.mint(user1, tokenId);
assertEq(tokenCollection.ownerOf(tokenId), user1);
address accountAddress = registry.createAccount(
address(implementation),
block.chainid,
address(tokenCollection),
tokenId,
0,
""
);
vm.deal(accountAddress, 1 ether);
Account account = Account(payable(accountAddress));
uint256 nonce = account.nonce();
assertEq(nonce, 0);
// user1 executes transaction to send ETH from account
vm.prank(user1);
account.executeCall(payable(user1), 0.1 ether, "");
assertEq(account.nonce(), nonce + 1);
assertEq(account.nonce(), entryPoint.getNonce(accountAddress, 0));
// success!
assertEq(accountAddress.balance, 0.9 ether);
assertEq(user1.balance, 0.1 ether);
}
function test4337CallCreateAccount() public {
uint256 tokenId = 1;
address user1 = vm.addr(1);
address user2 = vm.addr(2);
tokenCollection.mint(user1, tokenId);
assertEq(tokenCollection.ownerOf(tokenId), user1);
address accountAddress = registry.account(
address(implementation),
block.chainid,
address(tokenCollection),
tokenId,
0
);
bytes memory initCode = abi.encodePacked(
address(registry),
abi.encodeWithSignature(
"createAccount(address,uint256,address,uint256,uint256,bytes)",
address(implementation),
block.chainid,
address(tokenCollection),
tokenId,
0,
""
)
);
bytes memory callData = abi.encodeWithSignature(
"executeCall(address,uint256,bytes)",
user2,
0.1 ether,
""
);
UserOperation memory op = UserOperation({
sender: accountAddress,
nonce: 0,
initCode: initCode,
callData: callData,
callGasLimit: 1000000,
verificationGasLimit: 1000000,
preVerificationGas: 1000000,
maxFeePerGas: block.basefee + 10,
maxPriorityFeePerGas: 10,
paymasterAndData: "",
signature: ""
});
bytes32 opHash = entryPoint.getUserOpHash(op);
(uint8 v, bytes32 r, bytes32 s) = vm.sign(
1,
opHash.toEthSignedMessageHash()
);
bytes memory signature = abi.encodePacked(r, s, v);
op.signature = signature;
vm.deal(accountAddress, 1 ether);
UserOperation[] memory ops = new UserOperation[](1);
ops[0] = op;
assertEq(entryPoint.getNonce(accountAddress, 0), 0);
entryPoint.handleOps(ops, payable(user1));
assertEq(entryPoint.getNonce(accountAddress, 0), 1);
assertEq(user2.balance, 0.1 ether);
assertTrue(accountAddress.balance < 0.9 ether);
}
function test4337CallExistingAccount() public {
uint256 tokenId = 1;
address user1 = vm.addr(1);
address user2 = vm.addr(2);
tokenCollection.mint(user1, tokenId);
assertEq(tokenCollection.ownerOf(tokenId), user1);
address accountAddress = registry.createAccount(
address(implementation),
block.chainid,
address(tokenCollection),
tokenId,
0,
""
);
bytes memory callData = abi.encodeWithSignature(
"executeCall(address,uint256,bytes)",
user2,
0.1 ether,
""
);
UserOperation memory op = UserOperation({
sender: accountAddress,
nonce: 0,
initCode: "",
callData: callData,
callGasLimit: 1000000,
verificationGasLimit: 1000000,
preVerificationGas: 1000000,
maxFeePerGas: block.basefee + 10,
maxPriorityFeePerGas: 10,
paymasterAndData: "",
signature: ""
});
bytes32 opHash = entryPoint.getUserOpHash(op);
(uint8 v, bytes32 r, bytes32 s) = vm.sign(
1,
opHash.toEthSignedMessageHash()
);
bytes memory signature = abi.encodePacked(r, s, v);
op.signature = signature;
vm.deal(accountAddress, 1 ether);
UserOperation[] memory ops = new UserOperation[](1);
ops[0] = op;
assertEq(entryPoint.getNonce(accountAddress, 0), 0);
entryPoint.handleOps(ops, payable(user1));
assertEq(entryPoint.getNonce(accountAddress, 0), 1);
assertEq(user2.balance, 0.1 ether);
assertTrue(accountAddress.balance < 0.9 ether);
}
function test4337CallRevertsInvalidSignature() public {
uint256 tokenId = 1;
address user1 = vm.addr(1);
address user2 = vm.addr(2);
tokenCollection.mint(user1, tokenId);
assertEq(tokenCollection.ownerOf(tokenId), user1);
address accountAddress = registry.createAccount(
address(implementation),
block.chainid,
address(tokenCollection),
tokenId,
0,
""
);
bytes memory callData = abi.encodeWithSignature(
"executeCall(address,uint256,bytes)",
user2,
0.1 ether,
""
);
UserOperation memory op = UserOperation({
sender: accountAddress,
nonce: 0,
initCode: "",
callData: callData,
callGasLimit: 1000000,
verificationGasLimit: 1000000,
preVerificationGas: 1000000,
maxFeePerGas: block.basefee + 10,
maxPriorityFeePerGas: 10,
paymasterAndData: "",
signature: ""
});
bytes32 opHash = entryPoint.getUserOpHash(op);
(uint8 v, bytes32 r, bytes32 s) = vm.sign(
1,
opHash.toEthSignedMessageHash()
);
// invalidate signature
bytes memory signature = abi.encodePacked(r, s, v + 1);
op.signature = signature;
vm.deal(accountAddress, 1 ether);
UserOperation[] memory ops = new UserOperation[](1);
ops[0] = op;
vm.expectRevert();
entryPoint.handleOps(ops, payable(user1));
assertEq(accountAddress.balance, 1 ether);
}
}

View File

@@ -15,6 +15,7 @@ import "../src/Account.sol";
import "../src/AccountGuardian.sol";
import "./mocks/MockERC721.sol";
import "./mocks/MockExecutor.sol";
contract AccountERC721Test is Test {
MockERC721 public dummyERC721;
@@ -246,4 +247,39 @@ contract AccountERC721Test is Test {
vm.expectRevert(OwnershipDepthLimitExceeded.selector);
tokenCollection.safeTransferFrom(owners[6], accounts[0], 7);
}
function testOverrideERC721Receiver(uint256 tokenId) public {
address user1 = vm.addr(1);
tokenCollection.mint(user1, tokenId);
assertEq(tokenCollection.ownerOf(tokenId), user1);
address accountAddress = registry.createAccount(
address(implementation),
block.chainid,
address(tokenCollection),
tokenId,
0,
""
);
Account account = Account(payable(accountAddress));
MockExecutor mockExecutor = new MockExecutor();
// set overrides on account
bytes4[] memory selectors = new bytes4[](1);
selectors[0] = bytes4(
abi.encodeWithSignature(
"onERC721Received(address,address,uint256,bytes)"
)
);
address[] memory implementations = new address[](1);
implementations[0] = address(mockExecutor);
vm.prank(user1);
account.setOverrides(selectors, implementations);
vm.expectRevert("ERC721: transfer to non ERC721Receiver implementer");
dummyERC721.mint(accountAddress, 1);
}
}

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.13;
import "../../src/Account.sol";
contract MockAccount is Account {
constructor(address _guardian, address entryPoint_)
Account(_guardian, entryPoint_)
{}
function customFunction() external pure returns (uint256) {
return 12345;
}
}

View File

@@ -5,14 +5,6 @@ import "./MockReverter.sol";
import "openzeppelin-contracts/interfaces/IERC1271.sol";
contract MockExecutor is MockReverter {
function isValidSignature(bytes32, bytes memory)
external
pure
returns (bytes4 magicValue)
{
return IERC1271.isValidSignature.selector;
}
function customFunction() external pure returns (uint256) {
return 12345;
}
@@ -22,6 +14,35 @@ contract MockExecutor is MockReverter {
pure
returns (bool)
{
return interfaceId == IERC1271.isValidSignature.selector;
return interfaceId == 0x12345678;
}
function onERC721Received(
address,
address,
uint256,
bytes memory
) public pure returns (bytes4) {
return bytes4("");
}
function onERC1155Received(
address,
address,
uint256,
uint256,
bytes memory
) public pure returns (bytes4) {
return bytes4("");
}
function onERC1155BatchReceived(
address,
address,
uint256[] memory,
uint256[] memory,
bytes memory
) public pure returns (bytes4) {
return bytes4("");
}
}