From 83641b0e28f04b58835a0d5dde5921aecdd0abe8 Mon Sep 17 00:00:00 2001 From: Ava Chow Date: Wed, 8 Jan 2025 19:58:08 -0500 Subject: [PATCH] psbt: Add sighash types to PSBT when not DEFAULT or ALL When an atypical sighash type is specified by the user, add it to the PSBT so that further signing can enforce sighash type matching. --- src/psbt.cpp | 9 ++++++++- test/functional/rpc_psbt.py | 38 +++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/src/psbt.cpp b/src/psbt.cpp index d37278a919b82d..82128c8bb7550d 100644 --- a/src/psbt.cpp +++ b/src/psbt.cpp @@ -423,6 +423,12 @@ PSBTError SignPSBTInput(const SigningProvider& provider, PartiallySignedTransact if (!sighash) sighash = utxo.scriptPubKey.IsPayToTaproot() ? SIGHASH_DEFAULT : SIGHASH_ALL; if (input.sighash_type && input.sighash_type != sighash) { return PSBTError::SIGHASH_MISMATCH; + } else { + if ((!utxo.scriptPubKey.IsPayToTaproot() && (sighash != SIGHASH_ALL && sighash != SIGHASH_DEFAULT)) || + (utxo.scriptPubKey.IsPayToTaproot() && sighash != SIGHASH_DEFAULT) + ) { + input.sighash_type = sighash; + } } Assert(sighash.has_value()); @@ -502,7 +508,8 @@ bool FinalizePSBT(PartiallySignedTransaction& psbtx) bool complete = true; const PrecomputedTransactionData txdata = PrecomputePSBTData(psbtx); for (unsigned int i = 0; i < psbtx.tx->vin.size(); ++i) { - complete &= (SignPSBTInput(DUMMY_SIGNING_PROVIDER, psbtx, i, &txdata, std::nullopt, nullptr, true) == PSBTError::OK); + PSBTInput& input = psbtx.inputs.at(i); + complete &= (SignPSBTInput(DUMMY_SIGNING_PROVIDER, psbtx, i, &txdata, input.sighash_type, nullptr, true) == PSBTError::OK); } return complete; diff --git a/test/functional/rpc_psbt.py b/test/functional/rpc_psbt.py index ebf618332c38b4..4ac1d5bdeb93f9 100755 --- a/test/functional/rpc_psbt.py +++ b/test/functional/rpc_psbt.py @@ -237,6 +237,43 @@ def test_sighash_mismatch(self): wallet.unloadwallet() + def test_sighash_adding(self): + self.log.info("Test adding of sighash type field") + self.nodes[0].createwallet("sighash_adding") + wallet = self.nodes[0].get_wallet_rpc("sighash_adding") + def_wallet = self.nodes[0].get_wallet_rpc(self.default_wallet_name) + + addr = wallet.getnewaddress(address_type="bech32") + def_wallet.sendtoaddress(addr, 5) + self.generate(self.nodes[0], 6) + + # Make a PSBT + psbt = wallet.walletcreatefundedpsbt([], [{def_wallet.getnewaddress(): 1}])["psbt"] + psbt = wallet.walletprocesspsbt(psbt=psbt, sighashtype="ALL|ANYONECANPAY", finalize=False)["psbt"] + + # Check that the PSBT has a sighash field on all inputs + dec_psbt = self.nodes[0].decodepsbt(psbt) + for input in dec_psbt["inputs"]: + assert_equal(input["sighash"], "ALL|ANYONECANPAY") + + # Make sure we can still finalize the transaction + fin_res = self.nodes[0].finalizepsbt(psbt) + assert_equal(fin_res["complete"], True) + fin_hex = fin_res["hex"] + + # Change the sighash field to a different value and make sure we still finalize to the same thing + mod_psbt = PSBT.from_base64(psbt) + mod_psbt.i[0].map[PSBT_IN_SIGHASH_TYPE] = (SIGHASH_ALL).to_bytes(4, byteorder="little") + psbt = mod_psbt.to_base64() + fin_res = self.nodes[0].finalizepsbt(psbt) + assert_equal(fin_res["complete"], True) + assert_equal(fin_res["hex"], fin_hex) + + self.nodes[0].sendrawtransaction(fin_res["hex"]) + self.generate(self.nodes[0], 1) + + wallet.unloadwallet() + def assert_change_type(self, psbtx, expected_type): """Assert that the given PSBT has a change output with the given type.""" @@ -1088,6 +1125,7 @@ def test_psbt_input_keys(psbt_input, keys): assert_raises_rpc_error(-8, "'all' is not a valid sighash parameter.", self.nodes[2].descriptorprocesspsbt, psbt, [descriptor], sighashtype="all") self.test_sighash_mismatch() + self.test_sighash_adding() if __name__ == '__main__': PSBTTest(__file__).main()