Made v2 account pass v1 test suite

This commit is contained in:
Jayden Windle
2023-04-09 16:05:21 -04:00
parent 5056a8a879
commit 349d4716cf
3 changed files with 522 additions and 43 deletions

View File

@@ -11,7 +11,9 @@ import "openzeppelin-contracts/interfaces/IERC1271.sol";
import "openzeppelin-contracts/utils/cryptography/SignatureChecker.sol";
import "openzeppelin-contracts/proxy/utils/UUPSUpgradeable.sol";
import "openzeppelin-contracts/access/AccessControl.sol";
import "account-abstraction/core/BaseAccount.sol";
import {BaseAccount, IEntryPoint, UserOperation, IAccount as IERC4337Account} from "account-abstraction/core/BaseAccount.sol";
import "sstore2/utils/Bytecode.sol";
error NotAuthorized();
@@ -37,7 +39,7 @@ contract AccountV2 is
uint256 _nonce;
// @dev address that manages upgrade and cross-chain execution settings
address guardian;
address public guardian;
// @dev timestamp at which this account will be unlocked
uint256 public lockedUntil;
@@ -75,44 +77,31 @@ contract AccountV2 is
_;
}
constructor(address entryPoint_) {
constructor(address _guardian, address entryPoint_) {
_entryPoint = entryPoint_;
guardian = msg.sender;
guardian = _guardian;
}
receive() external payable {
_callOverride();
_handleOverride();
}
fallback(bytes calldata)
external
payable
onlyUnlocked
returns (bytes memory)
{
return _callOverride();
fallback() external payable onlyUnlocked {
_handleOverride();
}
function executeCall(
address to,
uint256 value,
bytes calldata data
)
external
payable
onlyAuthorized
onlyUnlocked
returns (bytes memory result)
{
) external payable onlyAuthorized onlyUnlocked returns (bytes memory) {
if (!isAuthorized(msg.sender, msg.sig)) revert NotAuthorized();
++_nonce;
result = _callOverride();
_handleOverride();
if (result.length != 0) {
result = _call(to, value, data);
}
return _call(to, value, data);
}
function setOverrides(
@@ -167,7 +156,7 @@ contract AccountV2 is
notDelegated
onlyGuardian
{
trustedImplementations[executor] = trusted;
trustedExecutors[executor] = trusted;
}
function setGuardian(address _guardian) external notDelegated onlyGuardian {
@@ -183,6 +172,8 @@ contract AccountV2 is
view
returns (bytes4 magicValue)
{
_handleOverrideStatic();
bool isValid = SignatureChecker.isValidSignatureNow(
owner(),
hash,
@@ -206,18 +197,22 @@ contract AccountV2 is
)
{
address self = address(this);
uint256 length = self.code.length;
if (length < 0x60) return (0, address(0), 0);
return
abi.decode(
Bytecode.codeAt(
self,
self.code.length - 0x60,
self.code.length
),
Bytecode.codeAt(self, length - 0x60, length),
(uint256, address, uint256)
);
}
function nonce() public view override returns (uint256) {
function nonce()
public
view
override(BaseAccount, IERC6551Account)
returns (uint256)
{
return _nonce;
}
@@ -251,7 +246,9 @@ contract AccountV2 is
if (caller == _entryPoint) return true;
// authorize trusted cross-chain executors if not on native chain
if (chainId != block.chainid && trustedExecutors[caller]) return true;
AccountV2 implementation = AccountV2(payable(_getImplementation()));
if (chainId != block.chainid && implementation.trustedExecutors(caller))
return true;
// authorize caller if owner has granted permissions for function call
if (permissions[_owner][caller][selector]) return true;
@@ -277,13 +274,14 @@ contract AccountV2 is
bool defaultSupport = interfaceId == type(IERC165).interfaceId ||
interfaceId == type(IERC1155Receiver).interfaceId ||
interfaceId == type(IERC6551Account).interfaceId ||
interfaceId == type(IAccount).interfaceId;
interfaceId == type(IERC4337Account).interfaceId;
if (defaultSupport) return true;
// if not supported by default, check override
bytes memory result = _callOverrideStatic();
return abi.decode(result, (bool));
_handleOverrideStatic();
return false;
}
function onERC721Received(
@@ -292,8 +290,7 @@ contract AccountV2 is
uint256,
bytes memory
) public view override returns (bytes4) {
bytes memory result = _callOverrideStatic();
if (result.length != 0) return abi.decode(result, (bytes4));
_handleOverrideStatic();
return this.onERC721Received.selector;
}
@@ -305,8 +302,7 @@ contract AccountV2 is
uint256,
bytes memory
) public view override returns (bytes4) {
bytes memory result = _callOverrideStatic();
if (result.length != 0) return abi.decode(result, (bytes4));
_handleOverrideStatic();
return this.onERC1155Received.selector;
}
@@ -318,8 +314,7 @@ contract AccountV2 is
uint256[] memory,
bytes memory
) public view override returns (bytes4) {
bytes memory result = _callOverrideStatic();
if (result.length != 0) return abi.decode(result, (bytes4));
_handleOverrideStatic();
return this.onERC1155BatchReceived.selector;
}
@@ -374,11 +369,14 @@ contract AccountV2 is
}
}
function _callOverride() internal returns (bytes memory result) {
function _handleOverride() internal returns (bytes memory result) {
address implementation = overrides[owner()][msg.sig];
if (implementation != address(0)) {
result = _call(implementation, msg.value, msg.data);
assembly {
return(add(result, 32), mload(result))
}
}
}
@@ -397,11 +395,18 @@ contract AccountV2 is
}
}
function _callOverrideStatic() internal view returns (bytes memory result) {
function _handleOverrideStatic()
internal
view
returns (bytes memory result)
{
address implementation = overrides[owner()][msg.sig];
if (implementation != address(0)) {
result = _callStatic(implementation, msg.data);
assembly {
return(add(result, 32), mload(result))
}
}
}
}

474
test/AccountV2.t.sol Normal file
View File

@@ -0,0 +1,474 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.13;
import "forge-std/Test.sol";
import "openzeppelin-contracts/token/ERC20/ERC20.sol";
import "openzeppelin-contracts/proxy/Clones.sol";
import "erc6551/ERC6551Registry.sol";
import "erc6551/interfaces/IERC6551Account.sol";
import "../src/CrossChainExecutorList.sol";
import "../src/Account.sol";
import "../src/AccountV2.sol";
import "../src/AccountRegistry.sol";
import "./mocks/MockERC721.sol";
import "./mocks/MockExecutor.sol";
import "./mocks/MockReverter.sol";
contract AccountV2Test is Test {
AccountV2 implementation;
ERC6551Registry public registry;
MockERC721 public tokenCollection;
function setUp() public {
implementation = new AccountV2(address(this), address(0));
registry = new ERC6551Registry();
tokenCollection = new MockERC721();
}
function testNonOwnerCallsFail(uint256 tokenId) public {
address user1 = vm.addr(1);
address user2 = vm.addr(2);
tokenCollection.mint(user1, tokenId);
assertEq(tokenCollection.ownerOf(tokenId), user1);
address accountAddress = registry.createAccount(
address(implementation),
block.chainid,
address(tokenCollection),
tokenId,
0,
""
);
vm.deal(accountAddress, 1 ether);
AccountV2 account = AccountV2(payable(accountAddress));
// should fail if user2 tries to use account
vm.prank(user2);
vm.expectRevert(NotAuthorized.selector);
account.executeCall(payable(user2), 0.1 ether, "");
// should fail if user2 tries to set override
bytes4[] memory selectors = new bytes4[](1);
selectors[0] = AccountV2.executeCall.selector;
address[] memory implementations = new address[](1);
implementations[0] = vm.addr(1337);
vm.prank(user2);
vm.expectRevert(NotAuthorized.selector);
account.grantPermissions(selectors, implementations);
// should fail if user2 tries to lock account
vm.prank(user2);
vm.expectRevert(NotAuthorized.selector);
account.lock(364 days);
}
function testAccountOwnershipTransfer(uint256 tokenId) public {
address user1 = vm.addr(1);
address user2 = vm.addr(2);
tokenCollection.mint(user1, tokenId);
assertEq(tokenCollection.ownerOf(tokenId), user1);
address accountAddress = registry.createAccount(
address(implementation),
block.chainid,
address(tokenCollection),
tokenId,
0,
""
);
vm.deal(accountAddress, 1 ether);
AccountV2 account = AccountV2(payable(accountAddress));
// should fail if user2 tries to use account
vm.prank(user2);
vm.expectRevert(NotAuthorized.selector);
account.executeCall(payable(user2), 0.1 ether, "");
vm.prank(user1);
tokenCollection.safeTransferFrom(user1, user2, tokenId);
// should succeed now that user2 is owner
vm.prank(user2);
account.executeCall(payable(user2), 0.1 ether, "");
assertEq(user2.balance, 0.1 ether);
}
function testMessageVerification(uint256 tokenId) public {
address user1 = vm.addr(1);
tokenCollection.mint(user1, tokenId);
assertEq(tokenCollection.ownerOf(tokenId), user1);
address accountAddress = registry.createAccount(
address(implementation),
block.chainid,
address(tokenCollection),
tokenId,
0,
""
);
AccountV2 account = AccountV2(payable(accountAddress));
bytes32 hash = keccak256("This is a signed message");
(uint8 v1, bytes32 r1, bytes32 s1) = vm.sign(1, hash);
bytes memory signature1 = abi.encodePacked(r1, s1, v1);
bytes4 returnValue1 = account.isValidSignature(hash, signature1);
assertEq(returnValue1, IERC1271.isValidSignature.selector);
}
function testMessageVerificationForUnauthorizedUser(uint256 tokenId)
public
{
address user1 = vm.addr(1);
tokenCollection.mint(user1, tokenId);
assertEq(tokenCollection.ownerOf(tokenId), user1);
address accountAddress = registry.createAccount(
address(implementation),
block.chainid,
address(tokenCollection),
tokenId,
0,
""
);
AccountV2 account = AccountV2(payable(accountAddress));
bytes32 hash = keccak256("This is a signed message");
(uint8 v2, bytes32 r2, bytes32 s2) = vm.sign(2, hash);
bytes memory signature2 = abi.encodePacked(r2, s2, v2);
bytes4 returnValue2 = account.isValidSignature(hash, signature2);
assertEq(returnValue2, 0);
}
function testAccountLocksAndUnlocks(uint256 tokenId) public {
address user1 = vm.addr(1);
tokenCollection.mint(user1, tokenId);
assertEq(tokenCollection.ownerOf(tokenId), user1);
address accountAddress = registry.createAccount(
address(implementation),
block.chainid,
address(tokenCollection),
tokenId,
0,
""
);
vm.deal(accountAddress, 1 ether);
AccountV2 account = AccountV2(payable(accountAddress));
// cannot be locked for more than 365 days
vm.prank(user1);
vm.expectRevert(ExceedsMaxLockTime.selector);
account.lock(366 days);
// lock account for 10 days
uint256 unlockTimestamp = block.timestamp + 10 days;
vm.prank(user1);
account.lock(unlockTimestamp);
assertEq(account.isLocked(), true);
// transaction should revert if account is locked
vm.prank(user1);
vm.expectRevert(AccountLocked.selector);
account.executeCall(payable(user1), 1 ether, "");
// fallback calls should revert if account is locked
vm.prank(user1);
vm.expectRevert(AccountLocked.selector);
(bool success, bytes memory result) = accountAddress.call(
abi.encodeWithSignature("customFunction()")
);
// silence unused variable compiler warnings
success;
result;
// setOverrides calls should revert if account is locked
{
bytes4[] memory selectors = new bytes4[](1);
selectors[0] = AccountV2.executeCall.selector;
address[] memory implementations = new address[](1);
implementations[0] = vm.addr(1337);
vm.prank(user1);
vm.expectRevert(AccountLocked.selector);
account.setOverrides(selectors, implementations);
}
// lock calls should revert if account is locked
vm.prank(user1);
vm.expectRevert(Account.AccountLocked.selector);
account.lock(0);
// signing should fail if account is locked
bytes32 hash = keccak256("This is a signed message");
(uint8 v1, bytes32 r1, bytes32 s1) = vm.sign(2, hash);
bytes memory signature1 = abi.encodePacked(r1, s1, v1);
bytes4 returnValue = account.isValidSignature(hash, signature1);
assertEq(returnValue, 0);
// warp to timestamp after account is unlocked
vm.warp(unlockTimestamp + 1 days);
// transaction succeed now that account lock has expired
vm.prank(user1);
account.executeCall(payable(user1), 1 ether, "");
assertEq(user1.balance, 1 ether);
// signing should now that account lock has expired
bytes32 hashAfterUnlock = keccak256("This is a signed message");
(uint8 v2, bytes32 r2, bytes32 s2) = vm.sign(1, hashAfterUnlock);
bytes memory signature2 = abi.encodePacked(r2, s2, v2);
bytes4 returnValue1 = account.isValidSignature(
hashAfterUnlock,
signature2
);
assertEq(returnValue1, IERC1271.isValidSignature.selector);
}
function testCustomOverridesFallback(uint256 tokenId) public {
address user1 = vm.addr(1);
tokenCollection.mint(user1, tokenId);
assertEq(tokenCollection.ownerOf(tokenId), user1);
address accountAddress = registry.createAccount(
address(implementation),
block.chainid,
address(tokenCollection),
tokenId,
0,
""
);
vm.deal(accountAddress, 1 ether);
AccountV2 account = AccountV2(payable(accountAddress));
MockExecutor mockExecutor = new MockExecutor();
// calls succeed with noop if override is undefined
(bool success, bytes memory result) = accountAddress.call(
abi.encodeWithSignature("customFunction()")
);
assertEq(success, true);
assertEq(result, "");
// set overrides on account
bytes4[] memory selectors = new bytes4[](2);
selectors[0] = bytes4(abi.encodeWithSignature("customFunction()"));
selectors[1] = bytes4(abi.encodeWithSignature("fail()"));
address[] memory implementations = new address[](2);
implementations[0] = address(mockExecutor);
implementations[1] = address(mockExecutor);
vm.prank(user1);
account.setOverrides(selectors, implementations);
// execution module handles fallback calls
assertEq(MockExecutor(accountAddress).customFunction(), 12345);
// execution bubbles up errors on revert
vm.expectRevert(MockReverter.MockError.selector);
MockExecutor(accountAddress).fail();
}
/**/
function testCustomPermissions(uint256 tokenId) public {
address user1 = vm.addr(1);
address user2 = vm.addr(2);
tokenCollection.mint(user1, tokenId);
assertEq(tokenCollection.ownerOf(tokenId), user1);
address accountAddress = registry.createAccount(
address(implementation),
block.chainid,
address(tokenCollection),
tokenId,
0,
""
);
vm.deal(accountAddress, 1 ether);
AccountV2 account = AccountV2(payable(accountAddress));
bytes4 selector = bytes4(
abi.encodeWithSignature("executeCall(address,uint256,bytes)")
);
assertEq(account.isAuthorized(user2, selector), false);
bytes4[] memory selectors = new bytes4[](1);
selectors[0] = selector;
address[] memory implementations = new address[](1);
implementations[0] = address(user2);
vm.prank(user1);
account.grantPermissions(selectors, implementations);
assertEq(account.isAuthorized(user2, selector), true);
vm.prank(user2);
account.executeCall(user2, 0.1 ether, "");
assertEq(user2.balance, 0.1 ether);
}
function testCrossChainCalls() public {
uint256 tokenId = 1;
address user1 = vm.addr(1);
address crossChainExecutor = vm.addr(2);
uint256 chainId = block.chainid + 1;
tokenCollection.mint(user1, tokenId);
assertEq(tokenCollection.ownerOf(tokenId), user1);
address accountAddress = registry.createAccount(
address(implementation),
chainId,
address(tokenCollection),
tokenId,
0,
""
);
vm.deal(accountAddress, 1 ether);
AccountV2 account = AccountV2(payable(accountAddress));
bytes4 selector = bytes4(
abi.encodeWithSignature("executeCall(address,uint256,bytes)")
);
assertEq(account.isAuthorized(crossChainExecutor, selector), false);
implementation.setTrustedExecutor(crossChainExecutor, true);
assertEq(account.isAuthorized(crossChainExecutor, selector), true);
vm.prank(crossChainExecutor);
account.executeCall(user1, 0.1 ether, "");
assertEq(user1.balance, 0.1 ether);
address notCrossChainExecutor = vm.addr(3);
vm.prank(notCrossChainExecutor);
vm.expectRevert(NotAuthorized.selector);
AccountV2(payable(account)).executeCall(user1, 0.1 ether, "");
assertEq(user1.balance, 0.1 ether);
address nativeAccountAddress = registry.createAccount(
address(implementation),
block.chainid,
address(tokenCollection),
tokenId,
0,
""
);
vm.prank(crossChainExecutor);
vm.expectRevert(NotAuthorized.selector);
AccountV2(payable(nativeAccountAddress)).executeCall(
user1,
0.1 ether,
""
);
assertEq(user1.balance, 0.1 ether);
}
function testExecuteCallRevert(uint256 tokenId) public {
address user1 = vm.addr(1);
tokenCollection.mint(user1, tokenId);
assertEq(tokenCollection.ownerOf(tokenId), user1);
address accountAddress = registry.createAccount(
address(implementation),
block.chainid,
address(tokenCollection),
tokenId,
0,
""
);
vm.deal(accountAddress, 1 ether);
AccountV2 account = AccountV2(payable(accountAddress));
MockReverter mockReverter = new MockReverter();
vm.prank(user1);
vm.expectRevert(MockReverter.MockError.selector);
account.executeCall(
payable(address(mockReverter)),
0,
abi.encodeWithSignature("fail()")
);
}
function testAccountOwnerIsNullIfContextNotSet() public {
address accountClone = Clones.clone(address(implementation));
assertEq(AccountV2(payable(accountClone)).owner(), address(0));
}
function testEIP165Support() public {
uint256 tokenId = 1;
address user1 = vm.addr(1);
tokenCollection.mint(user1, tokenId);
assertEq(tokenCollection.ownerOf(tokenId), user1);
address accountAddress = registry.createAccount(
address(implementation),
block.chainid,
address(tokenCollection),
tokenId,
0,
""
);
vm.deal(accountAddress, 1 ether);
AccountV2 account = AccountV2(payable(accountAddress));
assertEq(
account.supportsInterface(type(IERC6551Account).interfaceId),
true
);
assertEq(
account.supportsInterface(type(IERC1155Receiver).interfaceId),
true
);
assertEq(account.supportsInterface(type(IERC165).interfaceId), true);
}
}