diff --git a/src/psbt.cpp b/src/psbt.cpp index d37278a919b82..82128c8bb7550 100644 --- a/src/psbt.cpp +++ b/src/psbt.cpp @@ -423,6 +423,12 @@ PSBTError SignPSBTInput(const SigningProvider& provider, PartiallySignedTransact if (!sighash) sighash = utxo.scriptPubKey.IsPayToTaproot() ? SIGHASH_DEFAULT : SIGHASH_ALL; if (input.sighash_type && input.sighash_type != sighash) { return PSBTError::SIGHASH_MISMATCH; + } else { + if ((!utxo.scriptPubKey.IsPayToTaproot() && (sighash != SIGHASH_ALL && sighash != SIGHASH_DEFAULT)) || + (utxo.scriptPubKey.IsPayToTaproot() && sighash != SIGHASH_DEFAULT) + ) { + input.sighash_type = sighash; + } } Assert(sighash.has_value()); @@ -502,7 +508,8 @@ bool FinalizePSBT(PartiallySignedTransaction& psbtx) bool complete = true; const PrecomputedTransactionData txdata = PrecomputePSBTData(psbtx); for (unsigned int i = 0; i < psbtx.tx->vin.size(); ++i) { - complete &= (SignPSBTInput(DUMMY_SIGNING_PROVIDER, psbtx, i, &txdata, std::nullopt, nullptr, true) == PSBTError::OK); + PSBTInput& input = psbtx.inputs.at(i); + complete &= (SignPSBTInput(DUMMY_SIGNING_PROVIDER, psbtx, i, &txdata, input.sighash_type, nullptr, true) == PSBTError::OK); } return complete; diff --git a/test/functional/rpc_psbt.py b/test/functional/rpc_psbt.py index ebf618332c38b..4ac1d5bdeb93f 100755 --- a/test/functional/rpc_psbt.py +++ b/test/functional/rpc_psbt.py @@ -237,6 +237,43 @@ def test_sighash_mismatch(self): wallet.unloadwallet() + def test_sighash_adding(self): + self.log.info("Test adding of sighash type field") + self.nodes[0].createwallet("sighash_adding") + wallet = self.nodes[0].get_wallet_rpc("sighash_adding") + def_wallet = self.nodes[0].get_wallet_rpc(self.default_wallet_name) + + addr = wallet.getnewaddress(address_type="bech32") + def_wallet.sendtoaddress(addr, 5) + self.generate(self.nodes[0], 6) + + # Make a PSBT + psbt = wallet.walletcreatefundedpsbt([], [{def_wallet.getnewaddress(): 1}])["psbt"] + psbt = wallet.walletprocesspsbt(psbt=psbt, sighashtype="ALL|ANYONECANPAY", finalize=False)["psbt"] + + # Check that the PSBT has a sighash field on all inputs + dec_psbt = self.nodes[0].decodepsbt(psbt) + for input in dec_psbt["inputs"]: + assert_equal(input["sighash"], "ALL|ANYONECANPAY") + + # Make sure we can still finalize the transaction + fin_res = self.nodes[0].finalizepsbt(psbt) + assert_equal(fin_res["complete"], True) + fin_hex = fin_res["hex"] + + # Change the sighash field to a different value and make sure we still finalize to the same thing + mod_psbt = PSBT.from_base64(psbt) + mod_psbt.i[0].map[PSBT_IN_SIGHASH_TYPE] = (SIGHASH_ALL).to_bytes(4, byteorder="little") + psbt = mod_psbt.to_base64() + fin_res = self.nodes[0].finalizepsbt(psbt) + assert_equal(fin_res["complete"], True) + assert_equal(fin_res["hex"], fin_hex) + + self.nodes[0].sendrawtransaction(fin_res["hex"]) + self.generate(self.nodes[0], 1) + + wallet.unloadwallet() + def assert_change_type(self, psbtx, expected_type): """Assert that the given PSBT has a change output with the given type.""" @@ -1088,6 +1125,7 @@ def test_psbt_input_keys(psbt_input, keys): assert_raises_rpc_error(-8, "'all' is not a valid sighash parameter.", self.nodes[2].descriptorprocesspsbt, psbt, [descriptor], sighashtype="all") self.test_sighash_mismatch() + self.test_sighash_adding() if __name__ == '__main__': PSBTTest(__file__).main()