diff --git a/lib/reference b/lib/reference index 7766053..c4b6c63 160000 --- a/lib/reference +++ b/lib/reference @@ -1 +1 @@ -Subproject commit 7766053f9bd75bc88f43af5559747eb5b3cb6e62 +Subproject commit c4b6c63e3e9dd2604596906000d0a8bc2f8c7230 diff --git a/src/AccountV2.sol b/src/AccountV2.sol index ffe4b09..310151d 100644 --- a/src/AccountV2.sol +++ b/src/AccountV2.sol @@ -11,7 +11,9 @@ import "openzeppelin-contracts/interfaces/IERC1271.sol"; import "openzeppelin-contracts/utils/cryptography/SignatureChecker.sol"; import "openzeppelin-contracts/proxy/utils/UUPSUpgradeable.sol"; import "openzeppelin-contracts/access/AccessControl.sol"; -import "account-abstraction/core/BaseAccount.sol"; + +import {BaseAccount, IEntryPoint, UserOperation, IAccount as IERC4337Account} from "account-abstraction/core/BaseAccount.sol"; + import "sstore2/utils/Bytecode.sol"; error NotAuthorized(); @@ -37,7 +39,7 @@ contract AccountV2 is uint256 _nonce; // @dev address that manages upgrade and cross-chain execution settings - address guardian; + address public guardian; // @dev timestamp at which this account will be unlocked uint256 public lockedUntil; @@ -75,44 +77,31 @@ contract AccountV2 is _; } - constructor(address entryPoint_) { + constructor(address _guardian, address entryPoint_) { _entryPoint = entryPoint_; - guardian = msg.sender; + guardian = _guardian; } receive() external payable { - _callOverride(); + _handleOverride(); } - fallback(bytes calldata) - external - payable - onlyUnlocked - returns (bytes memory) - { - return _callOverride(); + fallback() external payable onlyUnlocked { + _handleOverride(); } function executeCall( address to, uint256 value, bytes calldata data - ) - external - payable - onlyAuthorized - onlyUnlocked - returns (bytes memory result) - { + ) external payable onlyAuthorized onlyUnlocked returns (bytes memory) { if (!isAuthorized(msg.sender, msg.sig)) revert NotAuthorized(); ++_nonce; - result = _callOverride(); + _handleOverride(); - if (result.length != 0) { - result = _call(to, value, data); - } + return _call(to, value, data); } function setOverrides( @@ -167,7 +156,7 @@ contract AccountV2 is notDelegated onlyGuardian { - trustedImplementations[executor] = trusted; + trustedExecutors[executor] = trusted; } function setGuardian(address _guardian) external notDelegated onlyGuardian { @@ -183,6 +172,8 @@ contract AccountV2 is view returns (bytes4 magicValue) { + _handleOverrideStatic(); + bool isValid = SignatureChecker.isValidSignatureNow( owner(), hash, @@ -206,18 +197,22 @@ contract AccountV2 is ) { address self = address(this); + uint256 length = self.code.length; + if (length < 0x60) return (0, address(0), 0); + return abi.decode( - Bytecode.codeAt( - self, - self.code.length - 0x60, - self.code.length - ), + Bytecode.codeAt(self, length - 0x60, length), (uint256, address, uint256) ); } - function nonce() public view override returns (uint256) { + function nonce() + public + view + override(BaseAccount, IERC6551Account) + returns (uint256) + { return _nonce; } @@ -251,7 +246,9 @@ contract AccountV2 is if (caller == _entryPoint) return true; // authorize trusted cross-chain executors if not on native chain - if (chainId != block.chainid && trustedExecutors[caller]) return true; + AccountV2 implementation = AccountV2(payable(_getImplementation())); + if (chainId != block.chainid && implementation.trustedExecutors(caller)) + return true; // authorize caller if owner has granted permissions for function call if (permissions[_owner][caller][selector]) return true; @@ -277,13 +274,14 @@ contract AccountV2 is bool defaultSupport = interfaceId == type(IERC165).interfaceId || interfaceId == type(IERC1155Receiver).interfaceId || interfaceId == type(IERC6551Account).interfaceId || - interfaceId == type(IAccount).interfaceId; + interfaceId == type(IERC4337Account).interfaceId; if (defaultSupport) return true; // if not supported by default, check override - bytes memory result = _callOverrideStatic(); - return abi.decode(result, (bool)); + _handleOverrideStatic(); + + return false; } function onERC721Received( @@ -292,8 +290,7 @@ contract AccountV2 is uint256, bytes memory ) public view override returns (bytes4) { - bytes memory result = _callOverrideStatic(); - if (result.length != 0) return abi.decode(result, (bytes4)); + _handleOverrideStatic(); return this.onERC721Received.selector; } @@ -305,8 +302,7 @@ contract AccountV2 is uint256, bytes memory ) public view override returns (bytes4) { - bytes memory result = _callOverrideStatic(); - if (result.length != 0) return abi.decode(result, (bytes4)); + _handleOverrideStatic(); return this.onERC1155Received.selector; } @@ -318,8 +314,7 @@ contract AccountV2 is uint256[] memory, bytes memory ) public view override returns (bytes4) { - bytes memory result = _callOverrideStatic(); - if (result.length != 0) return abi.decode(result, (bytes4)); + _handleOverrideStatic(); return this.onERC1155BatchReceived.selector; } @@ -374,11 +369,14 @@ contract AccountV2 is } } - function _callOverride() internal returns (bytes memory result) { + function _handleOverride() internal returns (bytes memory result) { address implementation = overrides[owner()][msg.sig]; if (implementation != address(0)) { result = _call(implementation, msg.value, msg.data); + assembly { + return(add(result, 32), mload(result)) + } } } @@ -397,11 +395,18 @@ contract AccountV2 is } } - function _callOverrideStatic() internal view returns (bytes memory result) { + function _handleOverrideStatic() + internal + view + returns (bytes memory result) + { address implementation = overrides[owner()][msg.sig]; if (implementation != address(0)) { result = _callStatic(implementation, msg.data); + assembly { + return(add(result, 32), mload(result)) + } } } } diff --git a/test/AccountV2.t.sol b/test/AccountV2.t.sol new file mode 100644 index 0000000..4150d3c --- /dev/null +++ b/test/AccountV2.t.sol @@ -0,0 +1,474 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.13; + +import "forge-std/Test.sol"; + +import "openzeppelin-contracts/token/ERC20/ERC20.sol"; +import "openzeppelin-contracts/proxy/Clones.sol"; + +import "erc6551/ERC6551Registry.sol"; +import "erc6551/interfaces/IERC6551Account.sol"; + +import "../src/CrossChainExecutorList.sol"; +import "../src/Account.sol"; +import "../src/AccountV2.sol"; +import "../src/AccountRegistry.sol"; + +import "./mocks/MockERC721.sol"; +import "./mocks/MockExecutor.sol"; +import "./mocks/MockReverter.sol"; + +contract AccountV2Test is Test { + AccountV2 implementation; + ERC6551Registry public registry; + + MockERC721 public tokenCollection; + + function setUp() public { + implementation = new AccountV2(address(this), address(0)); + + registry = new ERC6551Registry(); + + tokenCollection = new MockERC721(); + } + + function testNonOwnerCallsFail(uint256 tokenId) public { + address user1 = vm.addr(1); + address user2 = vm.addr(2); + + tokenCollection.mint(user1, tokenId); + assertEq(tokenCollection.ownerOf(tokenId), user1); + + address accountAddress = registry.createAccount( + address(implementation), + block.chainid, + address(tokenCollection), + tokenId, + 0, + "" + ); + + vm.deal(accountAddress, 1 ether); + + AccountV2 account = AccountV2(payable(accountAddress)); + + // should fail if user2 tries to use account + vm.prank(user2); + vm.expectRevert(NotAuthorized.selector); + account.executeCall(payable(user2), 0.1 ether, ""); + + // should fail if user2 tries to set override + bytes4[] memory selectors = new bytes4[](1); + selectors[0] = AccountV2.executeCall.selector; + address[] memory implementations = new address[](1); + implementations[0] = vm.addr(1337); + vm.prank(user2); + vm.expectRevert(NotAuthorized.selector); + account.grantPermissions(selectors, implementations); + + // should fail if user2 tries to lock account + vm.prank(user2); + vm.expectRevert(NotAuthorized.selector); + account.lock(364 days); + } + + function testAccountOwnershipTransfer(uint256 tokenId) public { + address user1 = vm.addr(1); + address user2 = vm.addr(2); + + tokenCollection.mint(user1, tokenId); + assertEq(tokenCollection.ownerOf(tokenId), user1); + + address accountAddress = registry.createAccount( + address(implementation), + block.chainid, + address(tokenCollection), + tokenId, + 0, + "" + ); + + vm.deal(accountAddress, 1 ether); + + AccountV2 account = AccountV2(payable(accountAddress)); + + // should fail if user2 tries to use account + vm.prank(user2); + vm.expectRevert(NotAuthorized.selector); + account.executeCall(payable(user2), 0.1 ether, ""); + + vm.prank(user1); + tokenCollection.safeTransferFrom(user1, user2, tokenId); + + // should succeed now that user2 is owner + vm.prank(user2); + account.executeCall(payable(user2), 0.1 ether, ""); + + assertEq(user2.balance, 0.1 ether); + } + + function testMessageVerification(uint256 tokenId) public { + address user1 = vm.addr(1); + + tokenCollection.mint(user1, tokenId); + assertEq(tokenCollection.ownerOf(tokenId), user1); + + address accountAddress = registry.createAccount( + address(implementation), + block.chainid, + address(tokenCollection), + tokenId, + 0, + "" + ); + + AccountV2 account = AccountV2(payable(accountAddress)); + + bytes32 hash = keccak256("This is a signed message"); + (uint8 v1, bytes32 r1, bytes32 s1) = vm.sign(1, hash); + + bytes memory signature1 = abi.encodePacked(r1, s1, v1); + + bytes4 returnValue1 = account.isValidSignature(hash, signature1); + + assertEq(returnValue1, IERC1271.isValidSignature.selector); + } + + function testMessageVerificationForUnauthorizedUser(uint256 tokenId) + public + { + address user1 = vm.addr(1); + + tokenCollection.mint(user1, tokenId); + assertEq(tokenCollection.ownerOf(tokenId), user1); + + address accountAddress = registry.createAccount( + address(implementation), + block.chainid, + address(tokenCollection), + tokenId, + 0, + "" + ); + + AccountV2 account = AccountV2(payable(accountAddress)); + + bytes32 hash = keccak256("This is a signed message"); + + (uint8 v2, bytes32 r2, bytes32 s2) = vm.sign(2, hash); + bytes memory signature2 = abi.encodePacked(r2, s2, v2); + + bytes4 returnValue2 = account.isValidSignature(hash, signature2); + + assertEq(returnValue2, 0); + } + + function testAccountLocksAndUnlocks(uint256 tokenId) public { + address user1 = vm.addr(1); + + tokenCollection.mint(user1, tokenId); + assertEq(tokenCollection.ownerOf(tokenId), user1); + + address accountAddress = registry.createAccount( + address(implementation), + block.chainid, + address(tokenCollection), + tokenId, + 0, + "" + ); + + vm.deal(accountAddress, 1 ether); + + AccountV2 account = AccountV2(payable(accountAddress)); + + // cannot be locked for more than 365 days + vm.prank(user1); + vm.expectRevert(ExceedsMaxLockTime.selector); + account.lock(366 days); + + // lock account for 10 days + uint256 unlockTimestamp = block.timestamp + 10 days; + vm.prank(user1); + account.lock(unlockTimestamp); + + assertEq(account.isLocked(), true); + + // transaction should revert if account is locked + vm.prank(user1); + vm.expectRevert(AccountLocked.selector); + account.executeCall(payable(user1), 1 ether, ""); + + // fallback calls should revert if account is locked + vm.prank(user1); + vm.expectRevert(AccountLocked.selector); + (bool success, bytes memory result) = accountAddress.call( + abi.encodeWithSignature("customFunction()") + ); + + // silence unused variable compiler warnings + success; + result; + + // setOverrides calls should revert if account is locked + { + bytes4[] memory selectors = new bytes4[](1); + selectors[0] = AccountV2.executeCall.selector; + address[] memory implementations = new address[](1); + implementations[0] = vm.addr(1337); + vm.prank(user1); + vm.expectRevert(AccountLocked.selector); + account.setOverrides(selectors, implementations); + } + + // lock calls should revert if account is locked + vm.prank(user1); + vm.expectRevert(Account.AccountLocked.selector); + account.lock(0); + + // signing should fail if account is locked + bytes32 hash = keccak256("This is a signed message"); + (uint8 v1, bytes32 r1, bytes32 s1) = vm.sign(2, hash); + bytes memory signature1 = abi.encodePacked(r1, s1, v1); + bytes4 returnValue = account.isValidSignature(hash, signature1); + assertEq(returnValue, 0); + + // warp to timestamp after account is unlocked + vm.warp(unlockTimestamp + 1 days); + + // transaction succeed now that account lock has expired + vm.prank(user1); + account.executeCall(payable(user1), 1 ether, ""); + assertEq(user1.balance, 1 ether); + + // signing should now that account lock has expired + bytes32 hashAfterUnlock = keccak256("This is a signed message"); + (uint8 v2, bytes32 r2, bytes32 s2) = vm.sign(1, hashAfterUnlock); + bytes memory signature2 = abi.encodePacked(r2, s2, v2); + bytes4 returnValue1 = account.isValidSignature( + hashAfterUnlock, + signature2 + ); + assertEq(returnValue1, IERC1271.isValidSignature.selector); + } + + function testCustomOverridesFallback(uint256 tokenId) public { + address user1 = vm.addr(1); + + tokenCollection.mint(user1, tokenId); + assertEq(tokenCollection.ownerOf(tokenId), user1); + + address accountAddress = registry.createAccount( + address(implementation), + block.chainid, + address(tokenCollection), + tokenId, + 0, + "" + ); + + vm.deal(accountAddress, 1 ether); + + AccountV2 account = AccountV2(payable(accountAddress)); + + MockExecutor mockExecutor = new MockExecutor(); + + // calls succeed with noop if override is undefined + (bool success, bytes memory result) = accountAddress.call( + abi.encodeWithSignature("customFunction()") + ); + assertEq(success, true); + assertEq(result, ""); + + // set overrides on account + bytes4[] memory selectors = new bytes4[](2); + selectors[0] = bytes4(abi.encodeWithSignature("customFunction()")); + selectors[1] = bytes4(abi.encodeWithSignature("fail()")); + address[] memory implementations = new address[](2); + implementations[0] = address(mockExecutor); + implementations[1] = address(mockExecutor); + vm.prank(user1); + account.setOverrides(selectors, implementations); + + // execution module handles fallback calls + assertEq(MockExecutor(accountAddress).customFunction(), 12345); + + // execution bubbles up errors on revert + vm.expectRevert(MockReverter.MockError.selector); + MockExecutor(accountAddress).fail(); + } + + /**/ + function testCustomPermissions(uint256 tokenId) public { + address user1 = vm.addr(1); + address user2 = vm.addr(2); + + tokenCollection.mint(user1, tokenId); + assertEq(tokenCollection.ownerOf(tokenId), user1); + + address accountAddress = registry.createAccount( + address(implementation), + block.chainid, + address(tokenCollection), + tokenId, + 0, + "" + ); + + vm.deal(accountAddress, 1 ether); + + AccountV2 account = AccountV2(payable(accountAddress)); + + bytes4 selector = bytes4( + abi.encodeWithSignature("executeCall(address,uint256,bytes)") + ); + + assertEq(account.isAuthorized(user2, selector), false); + + bytes4[] memory selectors = new bytes4[](1); + selectors[0] = selector; + address[] memory implementations = new address[](1); + implementations[0] = address(user2); + vm.prank(user1); + account.grantPermissions(selectors, implementations); + + assertEq(account.isAuthorized(user2, selector), true); + + vm.prank(user2); + account.executeCall(user2, 0.1 ether, ""); + + assertEq(user2.balance, 0.1 ether); + } + + function testCrossChainCalls() public { + uint256 tokenId = 1; + address user1 = vm.addr(1); + address crossChainExecutor = vm.addr(2); + + uint256 chainId = block.chainid + 1; + + tokenCollection.mint(user1, tokenId); + assertEq(tokenCollection.ownerOf(tokenId), user1); + + address accountAddress = registry.createAccount( + address(implementation), + chainId, + address(tokenCollection), + tokenId, + 0, + "" + ); + + vm.deal(accountAddress, 1 ether); + + AccountV2 account = AccountV2(payable(accountAddress)); + + bytes4 selector = bytes4( + abi.encodeWithSignature("executeCall(address,uint256,bytes)") + ); + + assertEq(account.isAuthorized(crossChainExecutor, selector), false); + + implementation.setTrustedExecutor(crossChainExecutor, true); + + assertEq(account.isAuthorized(crossChainExecutor, selector), true); + + vm.prank(crossChainExecutor); + account.executeCall(user1, 0.1 ether, ""); + + assertEq(user1.balance, 0.1 ether); + + address notCrossChainExecutor = vm.addr(3); + vm.prank(notCrossChainExecutor); + vm.expectRevert(NotAuthorized.selector); + AccountV2(payable(account)).executeCall(user1, 0.1 ether, ""); + + assertEq(user1.balance, 0.1 ether); + + address nativeAccountAddress = registry.createAccount( + address(implementation), + block.chainid, + address(tokenCollection), + tokenId, + 0, + "" + ); + + vm.prank(crossChainExecutor); + vm.expectRevert(NotAuthorized.selector); + AccountV2(payable(nativeAccountAddress)).executeCall( + user1, + 0.1 ether, + "" + ); + + assertEq(user1.balance, 0.1 ether); + } + + function testExecuteCallRevert(uint256 tokenId) public { + address user1 = vm.addr(1); + + tokenCollection.mint(user1, tokenId); + assertEq(tokenCollection.ownerOf(tokenId), user1); + + address accountAddress = registry.createAccount( + address(implementation), + block.chainid, + address(tokenCollection), + tokenId, + 0, + "" + ); + + vm.deal(accountAddress, 1 ether); + + AccountV2 account = AccountV2(payable(accountAddress)); + + MockReverter mockReverter = new MockReverter(); + + vm.prank(user1); + vm.expectRevert(MockReverter.MockError.selector); + account.executeCall( + payable(address(mockReverter)), + 0, + abi.encodeWithSignature("fail()") + ); + } + + function testAccountOwnerIsNullIfContextNotSet() public { + address accountClone = Clones.clone(address(implementation)); + + assertEq(AccountV2(payable(accountClone)).owner(), address(0)); + } + + function testEIP165Support() public { + uint256 tokenId = 1; + address user1 = vm.addr(1); + + tokenCollection.mint(user1, tokenId); + assertEq(tokenCollection.ownerOf(tokenId), user1); + + address accountAddress = registry.createAccount( + address(implementation), + block.chainid, + address(tokenCollection), + tokenId, + 0, + "" + ); + + vm.deal(accountAddress, 1 ether); + + AccountV2 account = AccountV2(payable(accountAddress)); + + assertEq( + account.supportsInterface(type(IERC6551Account).interfaceId), + true + ); + assertEq( + account.supportsInterface(type(IERC1155Receiver).interfaceId), + true + ); + assertEq(account.supportsInterface(type(IERC165).interfaceId), true); + } +}