Skip to content

Commit

Permalink
Add high-level helpers for using Musig2 with Taproot
Browse files Browse the repository at this point in the history
When using Musig2 for a taproot key path, we can provide simpler helper
functions to collaboratively build a shared signature for the spending
transaction.

This hides all of the low-level details of how the musig2 algorithm
works, by exposing a subset of what can be done that is sufficient for
spending taproot inputs.
  • Loading branch information
t-bast committed Jan 18, 2024
1 parent fa9718d commit 88f25d0
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 73 deletions.
10 changes: 10 additions & 0 deletions src/commonMain/kotlin/fr/acinq/bitcoin/ByteVector.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package fr.acinq.bitcoin

import fr.acinq.secp256k1.Hex
import kotlin.experimental.or
import kotlin.experimental.xor
import kotlin.jvm.JvmField
import kotlin.jvm.JvmStatic

Expand Down Expand Up @@ -151,6 +152,15 @@ public class ByteVector32(bytes: ByteArray, offset: Int) : ByteVector(bytes, off

@JvmStatic
public fun fromValidHex(input: String): ByteVector32 = ByteVector32(input)

@JvmStatic
public fun xor(a: ByteVector32, b: ByteVector32): ByteVector32 {
val result = ByteArray(32)
for (i in 0..31) {
result[i] = a[i].xor(b[i])
}
return result.byteVector32()
}
}
}

Expand Down
54 changes: 36 additions & 18 deletions src/commonTest/kotlin/fr/acinq/bitcoin/musig2proto/Musig2.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@ import fr.acinq.bitcoin.*
import fr.acinq.bitcoin.crypto.Pack
import fr.acinq.secp256k1.Hex
import fr.acinq.secp256k1.Secp256k1
import kotlin.experimental.xor
import kotlin.jvm.JvmStatic


/**
* Key Aggregation Context
* Holds a public key aggregate that can optionally be tweaked
Expand All @@ -16,6 +14,9 @@ import kotlin.jvm.JvmStatic
* @param tacc tweak accumulator
*/
public data class KeyAggCtx(val Q: PublicKey, val gacc: Boolean, val tacc: ByteVector32) {
public constructor(Q: PublicKey) : this(Q, true, ByteVector32.Zeroes)
public constructor(pubkeys: List<PublicKey>) : this(Musig2.keyAgg(pubkeys))

public fun tweak(tweak: ByteVector32, isXonly: Boolean): KeyAggCtx {
require(tweak == ByteVector32.Zeroes || PrivateKey(tweak).isValid()) { "invalid tweak" }
return if (isXonly && !Q.isEven()) {
Expand All @@ -30,15 +31,37 @@ public data class KeyAggCtx(val Q: PublicKey, val gacc: Boolean, val tacc: ByteV

public object Musig2 {
@JvmStatic
public fun keyAgg(pubkeys: List<PublicKey>): KeyAggCtx {
public fun keyAgg(pubkeys: List<PublicKey>): PublicKey {
val pk2 = getSecondKey(pubkeys)
val a = pubkeys.map { keyAggCoeffInternal(pubkeys, it, pk2) }
val Q = pubkeys.zip(a).map { it.first.times(PrivateKey(it.second)) }.reduce { p1, p2 -> p1 + p2 }
return KeyAggCtx(Q, true, ByteVector32.Zeroes)
return pubkeys.zip(a).map { it.first.times(PrivateKey(it.second)) }.reduce { p1, p2 -> p1 + p2 }
}

@JvmStatic
public fun keySort(pubkeys: List<PublicKey>): List<PublicKey> = pubkeys.sortedWith { a, b -> LexicographicalOrdering.compare(a, b) }

private fun taprootSessionCtx(tx: Transaction, inputIndex: Int, inputs: List<TxOut>, pubkeys: List<PublicKey>, publicNonces: List<IndividualNonce>, scriptTree: ScriptTree?): SessionCtx {
val aggregatedNonce = IndividualNonce.aggregate(publicNonces)
val aggregatedKey = keyAgg(pubkeys).xOnly()
val tweak = when (scriptTree) {
null -> Pair(aggregatedKey.tweak(Crypto.TaprootTweak.NoScriptTweak), true)
else -> Pair(aggregatedKey.tweak(Crypto.TaprootTweak.ScriptTweak(scriptTree)), true)
}
val txHash = Transaction.hashForSigningTaprootKeyPath(tx, inputIndex, inputs, SigHash.SIGHASH_DEFAULT)
return SessionCtx(aggregatedNonce, pubkeys, listOf(tweak), txHash)
}

@JvmStatic
public fun signTaprootInput(privateKey: PrivateKey, tx: Transaction, inputIndex: Int, inputs: List<TxOut>, pubkeys: List<PublicKey>, secretNonce: SecretNonce, publicNonces: List<IndividualNonce>, scriptTree: ScriptTree?): ByteVector32? {
val ctx = taprootSessionCtx(tx, inputIndex, inputs, pubkeys, publicNonces, scriptTree)
return ctx.sign(secretNonce, privateKey)
}

@JvmStatic
public fun aggregateTaprootSignatures(partialSigs: List<ByteVector32>, tx: Transaction, inputIndex: Int, inputs: List<TxOut>, pubkeys: List<PublicKey>, publicNonces: List<IndividualNonce>, scriptTree: ScriptTree?): ByteVector64? {
val ctx = taprootSessionCtx(tx, inputIndex, inputs, pubkeys, publicNonces, scriptTree)
return ctx.partialSigAgg(partialSigs)
}
}

/**
Expand Down Expand Up @@ -70,17 +93,8 @@ public data class SecretNonce(val data: ByteVector) {
*/
@JvmStatic
public fun generate(sk: PrivateKey?, pk: PublicKey, aggpk: XonlyPublicKey?, msg: ByteArray?, extraInput: ByteArray?, randprime: ByteVector32): SecretNonce {

fun xor(a: ByteVector32, b: ByteVector32): ByteVector32 {
val result = ByteArray(32)
for (i in 0..31) {
result[i] = a[i].xor(b[i])
}
return result.byteVector32()
}

val rand = if (sk != null) {
xor(sk.value, Crypto.taggedHash(randprime.toByteArray(), "MuSig/aux"))
ByteVector32.xor(sk.value, Crypto.taggedHash(randprime.toByteArray(), "MuSig/aux"))
} else {
randprime
}
Expand All @@ -102,6 +116,11 @@ public data class SecretNonce(val data: ByteVector) {
val secnonce = SecretNonce(PrivateKey(k1).value + PrivateKey(k2).value + pk.value)
return secnonce
}

@JvmStatic
public fun generate(sk: PrivateKey, aggregatedKey: XonlyPublicKey, rand: ByteVector32): SecretNonce {
return generate(sk, sk.publicKey(), aggregatedKey, null, null, rand)
}
}
}

Expand Down Expand Up @@ -192,7 +211,6 @@ internal fun add(a: PublicKey?, b: PublicKey?): PublicKey? = when {
else -> a + b
}


internal fun mul(a: PublicKey?, b: PrivateKey): PublicKey? = a?.times(b)

/**
Expand All @@ -204,7 +222,7 @@ internal fun mul(a: PublicKey?, b: PrivateKey): PublicKey? = a?.times(b)
*/
public data class SessionCtx(val aggnonce: AggregatedNonce, val pubkeys: List<PublicKey>, val tweaks: List<Pair<ByteVector32, Boolean>>, val message: ByteVector) {
private fun build(): SessionValues {
val keyAggCtx0 = Musig2.keyAgg(pubkeys)
val keyAggCtx0 = KeyAggCtx(pubkeys)
val keyAggCtx = tweaks.fold(keyAggCtx0) { ctx, tweak -> ctx.tweak(tweak.first, tweak.second) }
val (Q, gacc, tacc) = keyAggCtx
val b = PrivateKey(Crypto.taggedHash((aggnonce.toByteArray().byteVector() + Q.xOnly().value + message).toByteArray(), "MuSig/noncecoef"))
Expand All @@ -221,7 +239,7 @@ public data class SessionCtx(val aggnonce: AggregatedNonce, val pubkeys: List<Pu
/**
* @param secnonce secret nonce
* @param sk private key
* @return a Musig2 partial signature, or null if the nonce does not match the private key or the partial signature cannot be verified
* @return a Musig2 partial signature, or null if the nonce does not match the private key or the partial signature cannot be verified
*/
public fun sign(secnonce: SecretNonce, sk: PrivateKey): ByteVector32? = runCatching {
val (Q, gacc, _, b, R, e) = build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ class Musig2TestsCommon {
val keyIndices = it.jsonObject["key_indices"]!!.jsonArray.map { it.jsonPrimitive.int }
val expected = XonlyPublicKey(ByteVector32.fromValidHex(it.jsonObject["expected"]!!.jsonPrimitive.content))
val ctx = Musig2.keyAgg(keyIndices.map { pubkeys[it] })
assertEquals(expected, ctx.Q.xOnly())
assertEquals(expected, ctx.xOnly())
}
tests.jsonObject["error_test_cases"]!!.jsonArray.forEach {
val keyIndices = it.jsonObject["key_indices"]!!.jsonArray.map { it.jsonPrimitive.int }
val tweakIndices = it.jsonObject["tweak_indices"]!!.jsonArray.map { it.jsonPrimitive.int }
val isXonly = it.jsonObject["is_xonly"]!!.jsonArray.map { it.jsonPrimitive.boolean }
assertFails {
var ctx = Musig2.keyAgg(keyIndices.map { pubkeys[it] })
var ctx = KeyAggCtx(keyIndices.map { pubkeys[it] })
tweakIndices.zip(isXonly).forEach { ctx = ctx.tweak(tweaks[it.first], it.second) }
}
}
Expand Down Expand Up @@ -267,7 +267,7 @@ class Musig2TestsCommon {
}

// aggregate public keys
val aggpub = Musig2.keyAgg(pubkeys)
val aggpub = KeyAggCtx(pubkeys)
.tweak(plainTweak, false)
.tweak(xonlyTweak, true)

Expand All @@ -277,64 +277,58 @@ class Musig2TestsCommon {

@Test
fun `use musig2 to replace multisig 2-of-2`() {
val random = Random.Default
val alicePrivKey = PrivateKey(ByteArray(32) { 1 })
val alicePubKey = alicePrivKey.publicKey()
val bobPrivKey = PrivateKey(ByteArray(32) { 2 })
val bobPubKey = bobPrivKey.publicKey()

// Alice and Bob exchange public keys and agree on a common aggregated key
val internalPubKey = Musig2.keyAgg(listOf(alicePubKey, bobPubKey)).Q.xOnly()
// we use the standard BIP86 tweak
val commonPubKey = internalPubKey.outputKey(Crypto.TaprootTweak.NoScriptTweak).first

// this tx sends to a standard p2tr(commonPubKey) script
val tx = Transaction(2, listOf(), listOf(TxOut(Satoshi(10000), Script.pay2tr(commonPubKey))), 0)
// Alice and Bob exchange public keys and agree on a common aggregated key.
val aggregatedKey = Musig2.keyAgg(listOf(alicePubKey, bobPubKey)).xOnly()
// This tx sends to a taproot script that doesn't contain any script path.
val tx = Transaction(2, listOf(), listOf(TxOut(Satoshi(10000), Script.pay2tr(aggregatedKey, scripts = null))), 0)

// this is how Alice and Bob would spend that tx
val spendingTx = Transaction(2, listOf(TxIn(OutPoint(tx, 0), sequence = 0)), listOf(TxOut(Satoshi(10000), Script.pay2wpkh(alicePubKey))), 0)

val commonSig = run {
val random = Random.Default
val aliceNonce = SecretNonce.generate(alicePrivKey, alicePubKey, commonPubKey, null, null, random.nextBytes(32).byteVector32())
val bobNonce = SecretNonce.generate(bobPrivKey, bobPubKey, commonPubKey, null, null, random.nextBytes(32).byteVector32())

val aggnonce = IndividualNonce.aggregate(listOf(aliceNonce.publicNonce(), bobNonce.publicNonce()))
val msg = Transaction.hashForSigningTaprootKeyPath(spendingTx, 0, listOf(tx.txOut[0]), SigHash.SIGHASH_DEFAULT)

// we use the same ctx for Alice and Bob, they both know all the public keys that are used here
val ctx = SessionCtx(
aggnonce,
listOf(alicePubKey, bobPubKey),
listOf(Pair(internalPubKey.tweak(Crypto.TaprootTweak.NoScriptTweak), true)),
msg
)
val aliceSig = ctx.sign(aliceNonce, alicePrivKey)!!
val bobSig = ctx.sign(bobNonce, bobPrivKey)!!
ctx.partialSigAgg(listOf(aliceSig, bobSig))!!
val sig = run {
// The first step of a musig2 signing session is to exchange nonces.
// If participants are disconnected before the end of the signing session, they must start again with fresh nonces.
val aliceNonce = SecretNonce.generate(alicePrivKey, aggregatedKey, random.nextBytes(32).byteVector32())
val bobNonce = SecretNonce.generate(bobPrivKey, aggregatedKey, random.nextBytes(32).byteVector32())

// Once they have each other's public nonce, they can produce partial signatures.
val publicNonces = listOf(aliceNonce.publicNonce(), bobNonce.publicNonce())
val aliceSig = Musig2.signTaprootInput(alicePrivKey, spendingTx, 0, tx.txOut, listOf(alicePubKey, bobPubKey), aliceNonce, publicNonces, scriptTree = null)!!
val bobSig = Musig2.signTaprootInput(bobPrivKey, spendingTx, 0, tx.txOut, listOf(alicePubKey, bobPubKey), bobNonce, publicNonces, scriptTree = null)!!

// Once they have each other's partial signature, they can aggregate them into a valid signature.
Musig2.aggregateTaprootSignatures(listOf(aliceSig, bobSig), spendingTx, 0, tx.txOut, listOf(alicePubKey, bobPubKey), publicNonces, scriptTree = null)!!
}

// this tx looks like any other tx that spends a p2tr output, with a single signature
val signedSpendingTx = spendingTx.updateWitness(0, ScriptWitness(listOf(commonSig)))
// This tx looks like any other tx that spends a p2tr output, with a single signature.
val signedSpendingTx = spendingTx.updateWitness(0, Script.witnessKeyPathPay2tr(sig))
Transaction.correctlySpends(signedSpendingTx, tx, ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS)
}

@Test
fun `swap-in-potentiam example with musig2 and taproot`() {
val userPrivateKey = PrivateKey(ByteArray(32) { 1 })
val userPublicKey = userPrivateKey.publicKey()
val serverPrivateKey = PrivateKey(ByteArray(32) { 2 })
val serverPublicKey = serverPrivateKey.publicKey()
val userRefundPrivateKey = PrivateKey(ByteArray(32) { 3 })
val refundDelay = 25920

val random = Random.Default

// the redeem script is just the refund script. it is generated from this policy: and_v(v:pk(user),older(refundDelay))
// it does not depend upon the user's or server's key, just the user's refund key and the refund delay
val redeemScript = listOf(OP_PUSHDATA(userRefundPrivateKey.publicKey().xOnly()), OP_CHECKSIGVERIFY, OP_PUSHDATA(Script.encodeNumber(refundDelay)), OP_CHECKSEQUENCEVERIFY)
val redeemScript = listOf(OP_PUSHDATA(userRefundPrivateKey.xOnlyPublicKey()), OP_CHECKSIGVERIFY, OP_PUSHDATA(Script.encodeNumber(refundDelay)), OP_CHECKSEQUENCEVERIFY)
val scriptTree = ScriptTree.Leaf(0, redeemScript)

// the internal pubkey is the musig2 aggregation of the user's and server's public keys: it does not depend upon the user's refund's key
val internalPubKey = Musig2.keyAgg(listOf(userPrivateKey.publicKey(), serverPrivateKey.publicKey())).Q.xOnly()
val pubkeyScript = Script.pay2tr(internalPubKey, scriptTree)
val aggregatedKey = Musig2.keyAgg(listOf(userPublicKey, serverPublicKey)).xOnly()
val pubkeyScript = Script.pay2tr(aggregatedKey, scriptTree)

val swapInTx = Transaction(
version = 2,
Expand All @@ -348,27 +342,22 @@ class Musig2TestsCommon {
val tx = Transaction(
version = 2,
txIn = listOf(TxIn(OutPoint(swapInTx, 0), sequence = TxIn.SEQUENCE_FINAL)),
txOut = listOf(TxOut(Satoshi(10000), Script.pay2wpkh(userPrivateKey.publicKey()))),
txOut = listOf(TxOut(Satoshi(10000), Script.pay2wpkh(userPublicKey))),
lockTime = 0
)
// this is the beginning of an interactive musig2 signing session. if user and server are disconnected before they have exchanged partial
// signatures they will have to start again with fresh nonces
val userNonce = SecretNonce.generate(userPrivateKey, userPrivateKey.publicKey(), internalPubKey, null, null, random.nextBytes(32).byteVector32())
val serverNonce = SecretNonce.generate(serverPrivateKey, serverPrivateKey.publicKey(), internalPubKey, null, null, random.nextBytes(32).byteVector32())

val txHash = Transaction.hashForSigningTaprootKeyPath(tx, 0, swapInTx.txOut, SigHash.SIGHASH_DEFAULT)
val commonNonce = IndividualNonce.aggregate(listOf(userNonce.publicNonce(), serverNonce.publicNonce()))
val ctx = SessionCtx(
commonNonce,
listOf(userPrivateKey.publicKey(), serverPrivateKey.publicKey()),
listOf(Pair(internalPubKey.tweak(Crypto.TaprootTweak.ScriptTweak(scriptTree)), true)),
txHash
)

val userSig = ctx.sign(userNonce, userPrivateKey)!!
val serverSig = ctx.sign(serverNonce, serverPrivateKey)!!
val commonSig = ctx.partialSigAgg(listOf(userSig, serverSig))!!
val signedTx = tx.updateWitness(0, Script.witnessKeyPathPay2tr(commonSig))
// The first step of a musig2 signing session is to exchange nonces.
// If participants are disconnected before the end of the signing session, they must start again with fresh nonces.
val userNonce = SecretNonce.generate(userPrivateKey, aggregatedKey, random.nextBytes(32).byteVector32())
val serverNonce = SecretNonce.generate(serverPrivateKey, aggregatedKey, random.nextBytes(32).byteVector32())

// Once they have each other's public nonce, they can produce partial signatures.
val publicNonces = listOf(userNonce.publicNonce(), serverNonce.publicNonce())
val userSig = Musig2.signTaprootInput(userPrivateKey, tx, 0, swapInTx.txOut, listOf(userPublicKey, serverPublicKey), userNonce, publicNonces, scriptTree)!!
val serverSig = Musig2.signTaprootInput(serverPrivateKey, tx, 0, swapInTx.txOut, listOf(userPublicKey, serverPublicKey), serverNonce, publicNonces, scriptTree)!!

// Once they have each other's partial signature, they can aggregate them into a valid signature.
val sig = Musig2.aggregateTaprootSignatures(listOf(userSig, serverSig), tx, 0, swapInTx.txOut, listOf(userPublicKey, serverPublicKey), publicNonces, scriptTree)!!
val signedTx = tx.updateWitness(0, Script.witnessKeyPathPay2tr(sig))
Transaction.correctlySpends(signedTx, swapInTx, ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS)
}

Expand All @@ -377,11 +366,11 @@ class Musig2TestsCommon {
val tx = Transaction(
version = 2,
txIn = listOf(TxIn(OutPoint(swapInTx, 0), sequence = refundDelay.toLong())),
txOut = listOf(TxOut(Satoshi(10000), Script.pay2wpkh(userPrivateKey.publicKey()))),
txOut = listOf(TxOut(Satoshi(10000), Script.pay2wpkh(userPublicKey))),
lockTime = 0
)
val sig = Crypto.signTaprootScriptPath(userRefundPrivateKey, tx, 0, swapInTx.txOut, SigHash.SIGHASH_DEFAULT, scriptTree.hash())
val witness = Script.witnessScriptPathPay2tr(internalPubKey, scriptTree, ScriptWitness(listOf(sig)), scriptTree)
val witness = Script.witnessScriptPathPay2tr(aggregatedKey, scriptTree, ScriptWitness(listOf(sig)), scriptTree)
val signedTx = tx.updateWitness(0, witness)
Transaction.correctlySpends(signedTx, swapInTx, ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS)
}
Expand Down

0 comments on commit 88f25d0

Please sign in to comment.