From 0674ede666762cc72fccd485f42f78dfc3ec27fb Mon Sep 17 00:00:00 2001 From: Jayden Windle Date: Wed, 3 May 2023 10:22:17 -0700 Subject: [PATCH] Simplified ownership cycle protection --- src/Account.sol | 47 ++++++++-------------------- test/AccountERC721.t.sol | 66 ---------------------------------------- 2 files changed, 12 insertions(+), 101 deletions(-) diff --git a/src/Account.sol b/src/Account.sol index 0b61a0c..2f85c24 100644 --- a/src/Account.sol +++ b/src/Account.sol @@ -287,12 +287,22 @@ contract Account is function onERC721Received( address, address, - uint256 tokenId, + uint256 receivedTokenId, bytes memory ) public view override returns (bytes4) { _handleOverrideStatic(); - _revertIfOwnershipCycle(msg.sender, tokenId); + ( + uint256 chainId, + address tokenContract, + uint256 tokenId + ) = ERC6551AccountLib.token(); + + if ( + chainId == block.chainid && + tokenContract == msg.sender && + tokenId == receivedTokenId + ) revert OwnershipCycle(); return this.onERC721Received.selector; } @@ -409,37 +419,4 @@ contract Account is } } } - - /// @dev Reverts if reception of a given ERC-721 token would cause an ownership cycle - function _revertIfOwnershipCycle( - address receivedTokenAddress, - uint256 receivedTokenId - ) internal view { - address currentOwner = address(this); - uint256 depth = 0; - - do { - try IERC6551Account(payable(currentOwner)).token() returns ( - uint256 chainId, - address tokenAddress, - uint256 tokenId - ) { - if ( - chainId == block.chainid && - tokenAddress == receivedTokenAddress && - tokenId == receivedTokenId - ) revert OwnershipCycle(); - - currentOwner = IERC721(tokenAddress).ownerOf(tokenId); - - if (currentOwner == address(this)) revert OwnershipCycle(); - - unchecked { - depth++; - } - } catch { - break; - } - } while (depth < 5 && currentOwner.code.length > 0); - } } diff --git a/test/AccountERC721.t.sol b/test/AccountERC721.t.sol index 5b6ada4..e461aee 100644 --- a/test/AccountERC721.t.sol +++ b/test/AccountERC721.t.sol @@ -148,72 +148,6 @@ contract AccountERC721Test is Test { tokenCollection.safeTransferFrom(owner, account, tokenId); } - function testCannotHaveOwnershipCycle() public { - address owner1 = vm.addr(1); - address owner2 = vm.addr(2); - address owner3 = vm.addr(3); - - MockERC721 tokenCollection1 = new MockERC721(); - MockERC721 tokenCollection2 = new MockERC721(); - MockERC721 tokenCollection3 = new MockERC721(); - - uint256 tokenId1 = 100; - uint256 tokenId2 = 100; - uint256 tokenId3 = 100; - - tokenCollection1.mint(owner1, tokenId1); - tokenCollection2.mint(owner2, tokenId2); - tokenCollection3.mint(owner3, tokenId3); - - vm.prank(owner1, owner1); - address account1 = registry.createAccount( - address(implementation), - block.chainid, - address(tokenCollection1), - tokenId1, - 0, - "" - ); - vm.prank(owner2, owner2); - address account2 = registry.createAccount( - address(implementation), - block.chainid, - address(tokenCollection2), - tokenId2, - 0, - "" - ); - vm.prank(owner3, owner3); - address account3 = registry.createAccount( - address(implementation), - block.chainid, - address(tokenCollection3), - tokenId3, - 0, - "" - ); - - // Move token that holds tokenCollection1 token1 to the wallet of tokenCollection2 token2 (this is ok) - vm.prank(owner1); - tokenCollection1.safeTransferFrom(owner1, account2, tokenId1); - - // Ensure you can't loop wallet ownership by sending tokenCollection2 token2 to the wallet of tokenCollection1 token1, - // because the wallet of tokenCollection2 token2 owns tokenCollection1 token1 and doing so would create a circular loop - vm.prank(owner2); - vm.expectRevert(OwnershipCycle.selector); - tokenCollection2.safeTransferFrom(owner2, account1, tokenId2); - - // Attempt to create a 3 token loop - vm.prank(owner2); - tokenCollection2.safeTransferFrom(owner2, account3, tokenId2); - - // Now: tokenCollection2-2's wallet owns tokenCollection1-1 token. tokenCollection3-3's wallet owns tokenCollection2-2 token. - // Try to make tokenCollection1-1's wallet own tokenCollection3-3's token - vm.prank(owner3); - vm.expectRevert(OwnershipCycle.selector); - tokenCollection3.safeTransferFrom(owner3, account1, tokenId3); - } - function testExceedsOwnershipDepthLimit() public { uint256 count = 7; address[] memory owners = new address[](count);