Skip to content

Commit

Permalink
Add method to serialize/deserialize tapscript trees
Browse files Browse the repository at this point in the history
  • Loading branch information
sstone committed May 27, 2024
1 parent c3e7932 commit ab01f39
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
42 changes: 41 additions & 1 deletion src/commonMain/kotlin/fr/acinq/bitcoin/ScriptTree.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,22 @@
*/
package fr.acinq.bitcoin

import fr.acinq.bitcoin.io.ByteArrayInput
import fr.acinq.bitcoin.io.ByteArrayOutput
import fr.acinq.bitcoin.io.Input
import fr.acinq.bitcoin.io.Output
import kotlin.jvm.JvmStatic

/** Simple binary tree structure containing taproot spending scripts. */
public sealed class ScriptTree {
public abstract fun write(output: Output): Output

public fun write(): ByteArray {
val output = ByteArrayOutput()
write(output)
return output.toByteArray()
}

/**
* Multiple spending scripts can be placed in the leaves of a taproot tree. When using one of those scripts to spend
* funds, we only need to reveal that specific script and a merkle proof that it is a leaf of the tree.
Expand All @@ -30,9 +42,24 @@ public sealed class ScriptTree {
public data class Leaf(val id: Int, val script: ByteVector, val leafVersion: Int) : ScriptTree() {
public constructor(id: Int, script: List<ScriptElt>) : this(id, script, Script.TAPROOT_LEAF_TAPSCRIPT)
public constructor(id: Int, script: List<ScriptElt>, leafVersion: Int) : this(id, Script.write(script).byteVector(), leafVersion)

public override fun write(output: Output): Output {
output.write(0)
BtcSerializer.writeVarint(id, output)
BtcSerializer.writeScript(script, output)
output.write(leafVersion)
return output
}
}

public data class Branch(val left: ScriptTree, val right: ScriptTree) : ScriptTree()
public data class Branch(val left: ScriptTree, val right: ScriptTree) : ScriptTree() {
public override fun write(output: Output): Output {
output.write(1)
left.write(output)
right.write(output)
return output
}
}

/** Compute the merkle root of the script tree. */
public fun hash(): ByteVector32 = when (this) {
Expand All @@ -42,6 +69,7 @@ public sealed class ScriptTree {
BtcSerializer.writeScript(this.script, buffer)
Crypto.taggedHash(buffer.toByteArray(), "TapLeaf")
}

is Branch -> {
val h1 = this.left.hash()
val h2 = this.right.hash()
Expand All @@ -68,4 +96,16 @@ public sealed class ScriptTree {
}
return loop(this, ByteArray(0))
}

public companion object {
@JvmStatic
public fun read(input: Input): ScriptTree = when (val tag = input.read()) {
0 -> Leaf(BtcSerializer.varint(input).toInt(), BtcSerializer.script(input).byteVector(), input.read())
1 -> Branch(read(input), read(input))
else -> error("cannot deserialize script tree: invalid tag $tag")
}

@JvmStatic
public fun read(input: ByteArray): ScriptTree = read(ByteArrayInput(input))
}
}
29 changes: 29 additions & 0 deletions src/commonTest/kotlin/fr/acinq/bitcoin/TaprootTestsCommon.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package fr.acinq.bitcoin
import fr.acinq.bitcoin.Bech32.hrp
import fr.acinq.bitcoin.Bitcoin.addressToPublicKeyScript
import fr.acinq.bitcoin.Transaction.Companion.hashForSigningSchnorr
import fr.acinq.bitcoin.io.ByteArrayInput
import fr.acinq.bitcoin.io.ByteArrayOutput
import fr.acinq.bitcoin.reference.TransactionTestsCommon.Companion.resourcesDir
import fr.acinq.secp256k1.Hex
import fr.acinq.secp256k1.Secp256k1
Expand Down Expand Up @@ -420,4 +422,31 @@ class TaprootTestsCommon {
val serializedTx = Transaction.write(tx)
assertContentEquals(buffer, serializedTx)
}

@Test
fun `serialize script trees`() {
val random = kotlin.random.Random.Default

fun randomLeaf(): ScriptTree.Leaf = ScriptTree.Leaf(random.nextInt(), random.nextBytes(random.nextInt(0, 2000)).byteVector(), random.nextInt(255))

fun randomTree(maxLevel: Int): ScriptTree = when {
maxLevel == 0 -> randomLeaf()
random.nextBoolean() -> randomLeaf()
else -> ScriptTree.Branch(randomTree(maxLevel - 1), randomTree(maxLevel - 1))
}

fun serde(input: ScriptTree): ScriptTree {
val output = ByteArrayOutput()
input.write(output)
return ScriptTree.read(ByteArrayInput(output.toByteArray()))
}

val leaf = randomLeaf()
assertEquals(leaf, serde(leaf))

(0 until 1000).forEach { _ ->
val tree = randomTree(10)
assertEquals(tree, serde(tree))
}
}
}

0 comments on commit ab01f39

Please sign in to comment.