Skip to content

Commit

Permalink
Improve readability of musig2 test (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
altafan authored Jan 30, 2025
1 parent ed81573 commit 17564f8
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 130 deletions.
4 changes: 2 additions & 2 deletions common/bitcointree/musig2.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ type SignerSession interface {

type CoordinatorSession interface {
AddNonce(*btcec.PublicKey, TreeNonces)
AddSig(*btcec.PublicKey, TreePartialSigs)
AddSignatures(*btcec.PublicKey, TreePartialSigs)
AggregateNonces() (TreeNonces, error)
// SignTree combines the signatures and add them to the tree's psbts
SignTree() (tree.VtxoTree, error)
Expand Down Expand Up @@ -420,7 +420,7 @@ func (t *treeCoordinatorSession) AddNonce(pubkey *btcec.PublicKey, nonce TreeNon
t.nonces[hex.EncodeToString(schnorr.SerializePubKey(pubkey))] = nonce
}

func (t *treeCoordinatorSession) AddSig(pubkey *btcec.PublicKey, sig TreePartialSigs) {
func (t *treeCoordinatorSession) AddSignatures(pubkey *btcec.PublicKey, sig TreePartialSigs) {
t.sigs[hex.EncodeToString(schnorr.SerializePubKey(pubkey))] = sig
}

Expand Down
310 changes: 184 additions & 126 deletions common/bitcointree/musig2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ import (
"github.com/ark-network/ark/common/bitcointree"
"github.com/ark-network/ark/common/tree"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/stretchr/testify/require"
)

Expand All @@ -24,10 +22,10 @@ const (

var (
vtxoTreeExpiry = common.RelativeLocktime{Type: common.LocktimeTypeBlock, Value: 144}
testTxid, _ = chainhash.NewHashFromStr("49f8664acc899be91902f8ade781b7eeb9cbe22bdd9efbc36e56195de21bcd12")
serverPrivKey, _ = secp256k1.GeneratePrivateKey()
rootInput, _ = wire.NewOutPointFromString("49f8664acc899be91902f8ade781b7eeb9cbe22bdd9efbc36e56195de21bcd12:0")
serverPrivKey, _ = btcec.NewPrivateKey()
sweepScript, _ = (&tree.CSVMultisigClosure{
MultisigClosure: tree.MultisigClosure{PubKeys: []*secp256k1.PublicKey{serverPrivKey.PubKey()}},
MultisigClosure: tree.MultisigClosure{PubKeys: []*btcec.PublicKey{serverPrivKey.PubKey()}},
Locktime: vtxoTreeExpiry,
}).Script()
sweepRoot = txscript.NewBaseTapLeaf(sweepScript).TapHash()
Expand All @@ -37,125 +35,214 @@ var (
func TestBuildAndSignVtxoTree(t *testing.T) {
t.Parallel()

for _, tc := range generateTestCases(t) {
t.Run(tc.name, func(t *testing.T) {
sharedOutputScript, sharedOutputAmount, err := bitcointree.CraftSharedOutput(
tc.receivers,
minRelayFee,
sweepRoot[:],
testVectors, err := makeTestVectors()
require.NoError(t, err)
require.NotEmpty(t, testVectors)

for _, v := range testVectors {
t.Run(v.name, func(t *testing.T) {
sharedOutScript, sharedOutAmount, err := bitcointree.CraftSharedOutput(
v.receivers, minRelayFee, sweepRoot[:],
)
require.NoError(t, err)
require.NotNil(t, sharedOutputScript)
require.NotNil(t, sharedOutScript)
require.NotZero(t, sharedOutAmount)

vtxoTree, err := bitcointree.BuildVtxoTree(
&wire.OutPoint{
Hash: *testTxid,
Index: 0,
},
tc.receivers,
minRelayFee,
sweepRoot[:],
vtxoTreeExpiry,
rootInput, v.receivers, minRelayFee, sweepRoot[:], vtxoTreeExpiry,
)
require.NoError(t, err)
require.NotNil(t, vtxoTree)

serverCoordinator, err := bitcointree.NewTreeCoordinatorSession(
sharedOutputAmount,
vtxoTree,
sweepRoot[:],
coordinator, err := bitcointree.NewTreeCoordinatorSession(
sharedOutAmount, vtxoTree, sweepRoot[:],
)
require.NoError(t, err)
require.NotNil(t, coordinator)

// Cceate signer sessions for each receivers
signerSessions := make(map[*btcec.PublicKey]bitcointree.SignerSession)
for _, prvkey := range tc.privKeys {
session := bitcointree.NewTreeSignerSession(prvkey)
err := session.Init(sweepRoot[:], sharedOutputAmount, vtxoTree)
require.NoError(t, err)
signerSessions[prvkey.PubKey()] = session
}
signers, err := makeCosigners(v.privKeys, sharedOutAmount, vtxoTree)
require.NoError(t, err)
require.NotNil(t, signers)

// Create server's signer session
serverSession := bitcointree.NewTreeSignerSession(serverPrivKey)
err = serverSession.Init(sweepRoot[:], sharedOutputAmount, vtxoTree)
err = makeAggregatedNonces(signers, coordinator, checkNoncesRoundtrip(t))
require.NoError(t, err)
signerSessions[serverPrivKey.PubKey()] = serverSession

// generate nonces from all signers
for pubkey, session := range signerSessions {
nonces, err := session.GetNonces()
require.NoError(t, err)
var encodedNonces bytes.Buffer
err = nonces.Encode(&encodedNonces)
require.NoError(t, err)
decodedNonces, err := bitcointree.DecodeNonces(&encodedNonces)
require.NoError(t, err)
for i, nonceRow := range nonces {
for j, nonce := range nonceRow {
require.Equal(t, nonce, decodedNonces[i][j])
}
}

serverCoordinator.AddNonce(pubkey, nonces)
}
signedTree, err := makeAggregatedSignatures(signers, coordinator, checkSigsRoundtrip(t))
require.NoError(t, err)
require.NotNil(t, signedTree)

aggregatedNonce, err := serverCoordinator.AggregateNonces()
// validate signatures
err = bitcointree.ValidateTreeSigs(sweepRoot[:], sharedOutAmount, signedTree)
require.NoError(t, err)
})
}
}

func checkNoncesRoundtrip(t *testing.T) func(nonces bitcointree.TreeNonces) {
return func(nonces bitcointree.TreeNonces) {
var encodedNonces bytes.Buffer
err := nonces.Encode(&encodedNonces)
require.NoError(t, err)

// set the aggregated nonces for all signers sessions
for _, session := range signerSessions {
session.SetAggregatedNonces(aggregatedNonce)
decodedNonces, err := bitcointree.DecodeNonces(&encodedNonces)
require.NoError(t, err)
for i, nonceRow := range nonces {
for j, nonce := range nonceRow {
require.Equal(t, nonce, decodedNonces[i][j])
}
}
}
}

// get signatures from all signers sessions
for pubkey, session := range signerSessions {
sig, err := session.Sign()
require.NoError(t, err)
require.NotNil(t, sig)
var encodedSig bytes.Buffer
err = sig.Encode(&encodedSig)
require.NoError(t, err)
decodedSig, err := bitcointree.DecodeSignatures(&encodedSig)
require.NoError(t, err)
for i, sigRow := range sig {
for j, sig := range sigRow {
if sig == nil {
require.Nil(t, decodedSig[i][j])
} else {
require.Equal(t, sig.S, decodedSig[i][j].S)
}
}
func checkSigsRoundtrip(t *testing.T) func(sigs bitcointree.TreePartialSigs) {
return func(sigs bitcointree.TreePartialSigs) {
var encodedSig bytes.Buffer
err := sigs.Encode(&encodedSig)
require.NoError(t, err)
decodedSig, err := bitcointree.DecodeSignatures(&encodedSig)
require.NoError(t, err)
for i, sigRow := range sigs {
for j, sig := range sigRow {
if sig == nil {
require.Nil(t, decodedSig[i][j])
} else {
require.Equal(t, sig.S, decodedSig[i][j].S)
}

serverCoordinator.AddSig(pubkey, sig)
}
}
}
}

// aggregate signatures
signedTree, err := serverCoordinator.SignTree()
require.NoError(t, err)
require.NotNil(t, signedTree)
// validate signatures
err = bitcointree.ValidateTreeSigs(
sweepRoot[:],
sharedOutputAmount,
signedTree,
)
require.NoError(t, err)
})
func makeCosigners(
keys []*btcec.PrivateKey, sharedOutAmount int64, vtxoTree tree.VtxoTree,
) (map[string]bitcointree.SignerSession, error) {
signers := make(map[string]bitcointree.SignerSession)
for _, prvkey := range keys {
session := bitcointree.NewTreeSignerSession(prvkey)
if err := session.Init(sweepRoot[:], sharedOutAmount, vtxoTree); err != nil {
return nil, err
}
signers[keyToStr(prvkey)] = session
}

// create signer session for the server itself
serverSession := bitcointree.NewTreeSignerSession(serverPrivKey)
if err := serverSession.Init(sweepRoot[:], sharedOutAmount, vtxoTree); err != nil {
return nil, err
}
signers[keyToStr(serverPrivKey)] = serverSession
return signers, nil
}

func makeAggregatedNonces(
signers map[string]bitcointree.SignerSession, coordinator bitcointree.CoordinatorSession,
checkNoncesRoundtrip func(bitcointree.TreeNonces),
) error {
for pk, session := range signers {
buf, err := hex.DecodeString(pk)
if err != nil {
return err
}
pubkey, err := btcec.ParsePubKey(buf)
if err != nil {
return err
}

nonces, err := session.GetNonces()
if err != nil {
return err
}
checkNoncesRoundtrip(nonces)

coordinator.AddNonce(pubkey, nonces)
}

aggregatedNonce, err := coordinator.AggregateNonces()
if err != nil {
return err
}

// set the aggregated nonces for all signers sessions
for _, session := range signers {
session.SetAggregatedNonces(aggregatedNonce)
}
return nil
}

func makeAggregatedSignatures(
signers map[string]bitcointree.SignerSession, coordinator bitcointree.CoordinatorSession,
checkSigsRoundtrip func(bitcointree.TreePartialSigs),
) (tree.VtxoTree, error) {
for pk, session := range signers {
buf, err := hex.DecodeString(pk)
if err != nil {
return nil, err
}
pubkey, err := btcec.ParsePubKey(buf)
if err != nil {
return nil, err
}

sigs, err := session.Sign()
if err != nil {
return nil, err
}
checkSigsRoundtrip(sigs)

coordinator.AddSignatures(pubkey, sigs)
}

// aggregate signatures
return coordinator.SignTree()
}

type testCase struct {
name string
receivers []tree.VtxoLeaf
privKeys []*secp256k1.PrivateKey
privKeys []*btcec.PrivateKey
}

func makeTestVectors() ([]testCase, error) {
vectors := make([]testCase, 0, len(receiverCounts))
for _, count := range receiverCounts {
receivers, privKeys, err := generateMockedReceivers(count)
if err != nil {
return nil, err
}

// add mixed types test case if count is between 2 and 32
if count > 1 && count < 32 {
vectors = append(vectors, testCase{
name: fmt.Sprintf("%d receivers Mixed Signing Types", len(receivers)),
receivers: withMixedSigningTypes(receivers),
privKeys: privKeys,
})
}

// add SignAll test case if count is less than 32
if count < 32 {
vectors = append(vectors, testCase{
name: fmt.Sprintf("%d receivers SignAll", len(receivers)),
receivers: withSigningType(tree.SignAll, receivers),
privKeys: privKeys,
})
}

// always add SignBranch test case
vectors = append(vectors, testCase{
name: fmt.Sprintf("%d receivers SignBranch", len(receivers)),
receivers: withSigningType(tree.SignBranch, receivers),
privKeys: privKeys,
})
}
return vectors, nil
}

func generateReceiversFixture(count int) ([]tree.VtxoLeaf, []*secp256k1.PrivateKey, error) {
receivers := make([]tree.VtxoLeaf, 0, count)
privKeys := make([]*secp256k1.PrivateKey, 0, count)
for i := 0; i < count; i++ {
prvkey, err := secp256k1.GeneratePrivateKey()
func generateMockedReceivers(num int) ([]tree.VtxoLeaf, []*btcec.PrivateKey, error) {
receivers := make([]tree.VtxoLeaf, 0, num)
privKeys := make([]*btcec.PrivateKey, 0, num)
for i := 0; i < num; i++ {
prvkey, err := btcec.NewPrivateKey()
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -196,35 +283,6 @@ func withMixedSigningTypes(receivers []tree.VtxoLeaf) []tree.VtxoLeaf {
return append(first, second...)
}

func generateTestCases(t *testing.T) []testCase {
testCases := make([]testCase, 0)
for _, count := range receiverCounts {
receivers, privKeys, err := generateReceiversFixture(count)
require.NoError(t, err)
// add mixed types test case if count is between 2 and 32
if count > 1 && count < 32 {
testCases = append(testCases, testCase{
name: fmt.Sprintf("%d receivers Mixed Signing Types", len(receivers)),
receivers: withMixedSigningTypes(receivers),
privKeys: privKeys,
})
}

// add SignAll test case if count is less than 32
if count < 32 {
testCases = append(testCases, testCase{
name: fmt.Sprintf("%d receivers SignAll", len(receivers)),
receivers: withSigningType(tree.SignAll, receivers),
privKeys: privKeys,
})
}

// always add SignBranch test case
testCases = append(testCases, testCase{
name: fmt.Sprintf("%d receivers SignBranch", len(receivers)),
receivers: withSigningType(tree.SignBranch, receivers),
privKeys: privKeys,
})
}
return testCases
func keyToStr(key *btcec.PrivateKey) string {
return hex.EncodeToString(key.PubKey().SerializeCompressed())
}
Loading

0 comments on commit 17564f8

Please sign in to comment.