refactor(psbt): support custom account index, fixes #3868

This commit is contained in:
kyranjamie
2023-06-16 13:19:10 +02:00
committed by kyranjamie
parent ec9909ac0a
commit f6e5e9d30c
9 changed files with 34 additions and 42 deletions

View File

@@ -16,8 +16,10 @@ export function PsbtRequestHeader({ origin }: PsbtRequestHeaderProps) {
Sign transaction
</Title>
{caption && (
<Flag align="middle" img={<Favicon origin={origin} />} pl="tight">
<Caption wordBreak="break-word">{caption}</Caption>
<Flag align="top" img={<Favicon origin={origin} />} pl="tight">
<Caption wordBreak="break-word" lineHeight={1.3}>
{caption}
</Caption>
</Flag>
)}
</Flex>

View File

@@ -1,6 +1,6 @@
import { WarningLabel } from '@app/components/warning-label';
export function PsbtRequestWarningLabel(props: { appName?: string }) {
export function PsbtRequestAppWarningLabel(props: { appName?: string }) {
const { appName } = props;
const title = `Do not proceed unless you trust ${appName ?? 'Unknown'}!`;

View File

@@ -20,8 +20,6 @@ export function usePsbtSigner() {
return useMemo(
() => ({
nativeSegwitSigner,
taprootSigner,
signPsbtAtIndex(allowedSighash: btc.SignatureHash[], idx: number, tx: btc.Transaction) {
try {
nativeSegwitSigner?.signIndex(tx, idx, allowedSighash);
@@ -33,6 +31,17 @@ export function usePsbtSigner() {
}
}
},
signPsbt(tx: btc.Transaction) {
try {
nativeSegwitSigner?.sign(tx);
} catch (e1) {
try {
taprootSigner?.sign(tx);
} catch (e2) {
logger.error('Error signing PSBT', e1, e2);
}
}
},
getPsbtAsTransaction(psbt: string | Uint8Array) {
const bytes = isString(psbt) ? hexToBytes(psbt) : psbt;
return btc.Transaction.fromPSBT(bytes);

View File

@@ -5,7 +5,7 @@ import { useOnOriginTabClose } from '@app/routes/hooks/use-on-tab-closed';
import { PsbtDecodedRequest } from './components/psbt-decoded-request/psbt-decoded-request';
import { PsbtRequestActions } from './components/psbt-request-actions';
import { PsbtRequestHeader } from './components/psbt-request-header';
import { PsbtRequestWarningLabel } from './components/psbt-request-warning-label';
import { PsbtRequestAppWarningLabel } from './components/psbt-request-warning-label';
import { PsbtRequestLayout } from './components/psbt-request.layout';
import { DecodedPsbt } from './hooks/use-psbt-signer';
@@ -26,7 +26,7 @@ export function PsbtSigner(props: PsbtSignerProps) {
<>
<PsbtRequestLayout>
<PsbtRequestHeader origin={appName} />
<PsbtRequestWarningLabel appName={appName} />
<PsbtRequestAppWarningLabel appName={appName} />
<PsbtDecodedRequest psbt={psbt} />
</PsbtRequestLayout>
<PsbtRequestActions isLoading={false} onCancel={onCancel} onSignPsbt={onSignPsbt} />

View File

@@ -14,13 +14,7 @@ import { usePsbtRequestSearchParams } from '@app/pages/psbt-request/psbt-request
export function usePsbtRequest() {
const { requestToken, tabId } = usePsbtRequestSearchParams();
const [isLoading, setIsLoading] = useState(false);
const {
signPsbtAtIndex,
getDecodedPsbt,
nativeSegwitSigner,
taprootSigner,
getPsbtAsTransaction,
} = usePsbtSigner();
const { signPsbt, signPsbtAtIndex, getDecodedPsbt, getPsbtAsTransaction } = usePsbtSigner();
const analytics = useAnalytics();
return useMemo(() => {
if (!requestToken) throw new Error('Cannot decode psbt without request token');
@@ -56,15 +50,7 @@ export function usePsbtRequest() {
if (!isUndefined(indexOrIndexes) && !isUndefined(allowedSighash)) {
ensureArray(indexOrIndexes).forEach(idx => signPsbtAtIndex(allowedSighash, idx, tx));
} else {
try {
nativeSegwitSigner?.sign(tx);
} catch (e1) {
try {
taprootSigner?.sign(tx);
} catch (e2) {
logger.error('Error signing tx', e1, e2);
}
}
signPsbt(tx);
}
const psbt = tx.toPSBT();
@@ -85,10 +71,9 @@ export function usePsbtRequest() {
getDecodedPsbt,
getPsbtAsTransaction,
isLoading,
nativeSegwitSigner,
requestToken,
signPsbt,
signPsbtAtIndex,
tabId,
taprootSigner,
]);
}

View File

@@ -5,7 +5,6 @@ import { RpcErrorCode } from '@btckit/types';
import { bytesToHex } from '@noble/hashes/utils';
import * as btc from '@scure/btc-signer';
import { logger } from '@shared/logger';
import { makeRpcErrorResponse, makeRpcSuccessResponse } from '@shared/rpc/rpc-methods';
import { isUndefined } from '@shared/utils';
@@ -40,13 +39,7 @@ function useRpcSignPsbtParams() {
function useRpcSignPsbt() {
const { origin, tabId, requestId, psbtHex, allowedSighash, signAtIndex } = useRpcSignPsbtParams();
const {
signPsbtAtIndex,
getDecodedPsbt,
nativeSegwitSigner,
taprootSigner,
getPsbtAsTransaction,
} = usePsbtSigner();
const { signPsbt, signPsbtAtIndex, getDecodedPsbt, getPsbtAsTransaction } = usePsbtSigner();
const tx = getPsbtAsTransaction(psbtHex);
@@ -61,15 +54,7 @@ function useRpcSignPsbt() {
signPsbtAtIndex(allowedSighash, idx, tx);
});
} else {
try {
nativeSegwitSigner?.sign(tx);
} catch (e1) {
try {
taprootSigner?.sign(tx);
} catch (e2) {
logger.error('Error signing tx', e1, e2);
}
}
signPsbt(tx);
}
const psbt = tx.toPSBT();

View File

@@ -2,6 +2,8 @@ import { useSelector } from 'react-redux';
import { createSelector } from '@reduxjs/toolkit';
import { initialSearchParams } from '@app/common/initial-search-params';
import { initBigNumber } from '@app/common/math/helpers';
import { RootState } from '@app/store';
import { selectStacksChain } from '../chains/stx-chain.selectors';
@@ -19,6 +21,10 @@ export function useCurrentKeyDetails() {
}
export const selectCurrentAccountIndex = createSelector(selectStacksChain, state => {
const customAccountIndex = initialSearchParams.get('accountIndex');
if (customAccountIndex && initBigNumber(customAccountIndex).isInteger()) {
return initBigNumber(customAccountIndex).toNumber();
}
return state[defaultKeyId].currentAccountIndex;
});

View File

@@ -54,6 +54,10 @@ export async function rpcSignPsbt(message: SignPsbtRequest, port: chrome.runtime
return;
}
if (isDefined(message.params.account)) {
params.push(['accountIndex', message.params.account.toString()]);
}
if (isDefined(message.params.allowedSighash))
ensureArray(message.params.allowedSighash).forEach(hash =>
params.push(['allowedSighash', hash.toString()])

View File

@@ -9,6 +9,7 @@ interface SignPsbtRequestParams {
hex: string;
signAtIndex?: number | number[];
network?: NetworkModes;
account?: number;
}
export type SignPsbtRequest = RpcRequest<'signPsbt', SignPsbtRequestParams>;