Skip to content

Commit

Permalink
Make sign() and partialSigAgg() return null in case of failure
Browse files Browse the repository at this point in the history
We'll add better error managememt when we switch to a native implementation of musig2 (through secp256k1-kmp).
  • Loading branch information
sstone committed Dec 6, 2023
1 parent 89e84c2 commit c33d415
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 39 deletions.
17 changes: 9 additions & 8 deletions src/commonMain/kotlin/fr/acinq/bitcoin/musig2/Musig2.kt
Original file line number Diff line number Diff line change
Expand Up @@ -221,19 +221,19 @@ public data class SessionCtx(val aggnonce: AggregatedNonce, val pubkeys: List<Pu
/**
* @param secnonce secret nonce
* @param sk private key
* @return a Musig2 partial signature
* @return a Musig2 partial signature, or null if the nonce does not match the private key or the partial signature cannot be verified
*/
public fun sign(secnonce: SecretNonce, sk: PrivateKey): ByteVector32 {
public fun sign(secnonce: SecretNonce, sk: PrivateKey): ByteVector32? = runCatching {
val (Q, gacc, _, b, R, e) = build()
val (k1, k2) = if (R.isEven()) Pair(secnonce.p1, secnonce.p2) else Pair(-secnonce.p1, -secnonce.p2)
val P = sk.publicKey()
require(P == secnonce.pk)
require(P == secnonce.pk) { "nonce and private key mismatch" }
val a = getSessionKeyAggCoeff(P)
val d = if (Q.isEven() == gacc) sk else -sk
val s = k1 + b * k2 + e * a * d
require(partialSigVerify(s.value, secnonce.publicNonce(), sk.publicKey())) { "partial signature verification failed" }
return s.value
}
s.value
}.getOrNull()

/**
* @param psig Musig2 partial signature
Expand All @@ -254,17 +254,18 @@ public data class SessionCtx(val aggnonce: AggregatedNonce, val pubkeys: List<Pu
/**
* @param psigs list of partial signatures
* @return an aggregated signature, which is a valid Schnorr signature for the matching aggregated public key
* or null is one of the partial signatures is not valid
*/
public fun partialSigAgg(psigs: List<ByteVector32>): ByteVector64 {
public fun partialSigAgg(psigs: List<ByteVector32>): ByteVector64? = runCatching {
val (Q, _, tacc, _, R, e) = build()
for (i in psigs.indices) {
require(PrivateKey(psigs[i]).isValid()) { "invalid partial signature at index $i" }
}
val s = psigs.reduce { a, b -> add(a, b) }
val s1 = if (Q.isEven()) add(s, mul(e.value, tacc)) else minus(s, mul(e.value, tacc))
val sig = ByteVector64(R.xOnly().value + s1)
return sig
}
sig
}.getOrNull()

public companion object {
private data class SessionValues(val Q: PublicKey, val gacc: Boolean, val tacc: ByteVector32, val b: PrivateKey, val R: PublicKey, val e: PrivateKey)
Expand Down
58 changes: 27 additions & 31 deletions src/commonTest/kotlin/fr/acinq/bitcoin/Musig2TestsCommon.kt
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class Musig2TestsCommon {
listOf(),
msgs[it.jsonObject["msg_index"]!!.jsonPrimitive.int]
)
val psig = ctx.sign(secnonces[keyIndices[signerIndex]], sk)
val psig = ctx.sign(secnonces[keyIndices[signerIndex]], sk)!!
assertEquals(expected, psig)
assertTrue {
ctx.partialSigVerify(psig, pnonces[nonceIndices[signerIndex]], pubkeys[keyIndices[signerIndex]])
Expand All @@ -110,15 +110,13 @@ class Musig2TestsCommon {

tests.jsonObject["sign_error_test_cases"]!!.jsonArray.forEach {
val keyIndices = it.jsonObject["key_indices"]!!.jsonArray.map { it.jsonPrimitive.int }
assertFails {
val ctx = SessionCtx(
aggnonces[it.jsonObject["aggnonce_index"]!!.jsonPrimitive.int],
keyIndices.map { pubkeys[it] },
listOf(),
msgs[it.jsonObject["msg_index"]!!.jsonPrimitive.int]
)
ctx.sign(secnonces[it.jsonObject["secnonce_index"]!!.jsonPrimitive.int], sk)
}
val ctx = SessionCtx(
aggnonces[it.jsonObject["aggnonce_index"]!!.jsonPrimitive.int],
keyIndices.map { pubkeys[it] },
listOf(),
msgs[it.jsonObject["msg_index"]!!.jsonPrimitive.int]
)
require(ctx.sign(secnonces[it.jsonObject["secnonce_index"]!!.jsonPrimitive.int], sk) == null)
}
}

Expand Down Expand Up @@ -146,7 +144,7 @@ class Musig2TestsCommon {
tweakIndices.zip(isXonly).map { tweaks[it.first] to it.second },
msg
)
val aggsig = ctx.partialSigAgg(psigIndices.map { psigs[it] })
val aggsig = ctx.partialSigAgg(psigIndices.map { psigs[it] })!!
assertEquals(expected, aggsig)
}
tests.jsonObject["error_test_cases"]!!.jsonArray.forEach {
Expand All @@ -157,15 +155,13 @@ class Musig2TestsCommon {
val tweakIndices = it.jsonObject["tweak_indices"]!!.jsonArray.map { it.jsonPrimitive.int }
val isXonly = it.jsonObject["is_xonly"]!!.jsonArray.map { it.jsonPrimitive.boolean }
assertEquals(AggregatedNonce(it.jsonObject["aggnonce"]!!.jsonPrimitive.content), aggnonce)
assertFails {
val ctx = SessionCtx(
aggnonce,
keyIndices.map { pubkeys[it] },
tweakIndices.zip(isXonly).map { tweaks[it.first] to it.second },
msg
)
ctx.partialSigAgg(psigIndices.map { psigs[it] })
}
val ctx = SessionCtx(
aggnonce,
keyIndices.map { pubkeys[it] },
tweakIndices.zip(isXonly).map { tweaks[it.first] to it.second },
msg
)
require(ctx.partialSigAgg(psigIndices.map { psigs[it] }) == null)
}
}

Expand Down Expand Up @@ -199,7 +195,7 @@ class Musig2TestsCommon {
tweakIndices.zip(isXonly).map { tweaks[it.first] to it.second },
msg
)
val psig = ctx.sign(secnonce, sk)
val psig = ctx.sign(secnonce, sk)!!
assertEquals(expected, psig)
assertTrue { ctx.partialSigVerify(psig, pnonces[nonceIndices[signerIndex]], pubkeys[keyIndices[signerIndex]]) }
}
Expand All @@ -217,7 +213,7 @@ class Musig2TestsCommon {
tweakIndices.zip(isXonly).map { tweaks[it.first] to it.second },
msg
)
val psig = ctx.sign(secnonce, sk)
val psig = ctx.sign(secnonce, sk)!!
ctx.partialSigVerify(psig, pnonces[nonceIndices[signerIndex]], pubkeys[keyIndices[signerIndex]])
}
}
Expand Down Expand Up @@ -258,7 +254,7 @@ class Musig2TestsCommon {

// create partial signatures
val psigs = privkeys.indices.map {
ctx.sign(secnonces[it], privkeys[it])
ctx.sign(secnonces[it], privkeys[it])!!
}

// verify partial signatures
Expand All @@ -267,7 +263,7 @@ class Musig2TestsCommon {
}

// aggregate partial signatures
ctx.partialSigAgg(psigs)
ctx.partialSigAgg(psigs)!!
}

// aggregate public keys
Expand Down Expand Up @@ -312,9 +308,9 @@ class Musig2TestsCommon {
listOf(Pair(internalPubKey.tweak(Crypto.TaprootTweak.NoScriptTweak), true)),
msg
)
val aliceSig = ctx.sign(aliceNonce, alicePrivKey)
val bobSig = ctx.sign(bobNonce, bobPrivKey)
ctx.partialSigAgg(listOf(aliceSig, bobSig))
val aliceSig = ctx.sign(aliceNonce, alicePrivKey)!!
val bobSig = ctx.sign(bobNonce, bobPrivKey)!!
ctx.partialSigAgg(listOf(aliceSig, bobSig))!!
}

// this tx looks like any other tx that spends a p2tr output, with a single signature
Expand Down Expand Up @@ -373,9 +369,9 @@ class Musig2TestsCommon {
txHash
)

val userSig = ctx.sign(userNonce, userPrivateKey)
val serverSig = ctx.sign(serverNonce, serverPrivateKey)
val commonSig = ctx.partialSigAgg(listOf(userSig, serverSig))
val userSig = ctx.sign(userNonce, userPrivateKey)!!
val serverSig = ctx.sign(serverNonce, serverPrivateKey)!!
val commonSig = ctx.partialSigAgg(listOf(userSig, serverSig))!!
val signedTx = tx.updateWitness(0, ScriptWitness(listOf(commonSig)))
Transaction.correctlySpends(signedTx, swapInTx, ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS)
}
Expand All @@ -394,7 +390,7 @@ class Musig2TestsCommon {
val txHash = Transaction.hashForSigningSchnorr(tx, 0, swapInTx.txOut, SigHash.SIGHASH_DEFAULT, SigVersion.SIGVERSION_TAPSCRIPT, executionData)

val sig = Crypto.signSchnorr(txHash, userRefundPrivateKey, Crypto.SchnorrTweak.NoTweak)
val signedTx = tx.updateWitness(0, ScriptWitness.empty.push(sig).push(redeemScript).push(controlBlock))
val signedTx = tx.updateWitness(0, ScriptWitness.empty.push(sig).push(redeemScript).push(controlBlock))
Transaction.correctlySpends(signedTx, swapInTx, ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS)
}
}
Expand Down

0 comments on commit c33d415

Please sign in to comment.