From 700320821bf7ec94608d9126fec9d66cddcedc2c Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Thu, 13 Feb 2025 00:16:33 +0100 Subject: [PATCH] Add test for `SetDefaultDigestAlgorithm` --- sign.go | 29 ++++++++++++---------- sign_test.go | 70 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 13 deletions(-) diff --git a/sign.go b/sign.go index 84fc8c7..74ce50d 100644 --- a/sign.go +++ b/sign.go @@ -15,17 +15,30 @@ import ( "time" ) +func init() { + defaultMessageDigestAlgorithm.oid = OIDDigestAlgorithmSHA1 +} + +var defaultMessageDigestAlgorithm struct { + sync.RWMutex + oid asn1.ObjectIdentifier +} + // SetDefaultDigestAlgorithm sets the default digest algorithm // to be used for signing operations on [SignedData]. // // This must be called before creating a new instance of [SignedData] // using [NewSignedData]. +// +// When this function is not called, the default digest algorithm is SHA1. func SetDefaultDigestAlgorithm(d asn1.ObjectIdentifier) error { defaultMessageDigestAlgorithm.Lock() defer defaultMessageDigestAlgorithm.Unlock() + switch { - case d.Equal(OIDDigestAlgorithmSHA1), d.Equal(OIDDigestAlgorithmSHA256), d.Equal(OIDDigestAlgorithmSHA384), - d.Equal(OIDDigestAlgorithmSHA512), d.Equal(OIDDigestAlgorithmSHA224), + case d.Equal(OIDDigestAlgorithmSHA1), + d.Equal(OIDDigestAlgorithmSHA224), d.Equal(OIDDigestAlgorithmSHA256), + d.Equal(OIDDigestAlgorithmSHA384), d.Equal(OIDDigestAlgorithmSHA512), d.Equal(OIDDigestAlgorithmDSA), d.Equal(OIDDigestAlgorithmDSASHA1), d.Equal(OIDDigestAlgorithmECDSASHA1), d.Equal(OIDDigestAlgorithmECDSASHA256), d.Equal(OIDDigestAlgorithmECDSASHA384), d.Equal(OIDDigestAlgorithmECDSASHA512): @@ -39,21 +52,11 @@ func SetDefaultDigestAlgorithm(d asn1.ObjectIdentifier) error { return nil } -var defaultMessageDigestAlgorithm struct { - sync.RWMutex - oid asn1.ObjectIdentifier -} - func defaultMessageDigestAlgorithmOID() asn1.ObjectIdentifier { defaultMessageDigestAlgorithm.RLock() defer defaultMessageDigestAlgorithm.RUnlock() - oid := defaultMessageDigestAlgorithm.oid - if oid.Equal(asn1.ObjectIdentifier{}) { - return OIDDigestAlgorithmSHA1 - } - - return oid + return defaultMessageDigestAlgorithm.oid } // SignedData is an opaque data structure for creating signed data payloads diff --git a/sign_test.go b/sign_test.go index 59666d4..cc8d44d 100644 --- a/sign_test.go +++ b/sign_test.go @@ -15,6 +15,76 @@ import ( "testing" ) +func TestSignWithGlobalDefaultDigestAlgorithm(t *testing.T) { + if err := SetDefaultDigestAlgorithm(asn1.ObjectIdentifier{}); err == nil { + t.Fatal("expected an error setting invalid digest algorithm") + } + + currentDigestAlgorithm := defaultMessageDigestAlgorithmOID() + if !currentDigestAlgorithm.Equal(OIDDigestAlgorithmSHA1) { + t.Fatalf("expected digest algorithm %q, but got %q", OIDDigestAlgorithmSHA512, currentDigestAlgorithm) + } + + if err := SetDefaultDigestAlgorithm(OIDDigestAlgorithmSHA512); err != nil { + t.Fatalf("failed setting default digest algorithm: %v", err) + } + + defer func() { + currentDigestAlgorithm := defaultMessageDigestAlgorithmOID() + if !currentDigestAlgorithm.Equal(OIDDigestAlgorithmSHA512) { + t.Fatalf("expected digest algorithm %q, but got %q", OIDDigestAlgorithmSHA512, currentDigestAlgorithm) + } + if err := SetDefaultDigestAlgorithm(OIDDigestAlgorithmSHA1); err != nil { + t.Fatalf("failed resetting default digest algorithm: %v", err) + } + currentDigestAlgorithm = defaultMessageDigestAlgorithmOID() + if !currentDigestAlgorithm.Equal(OIDDigestAlgorithmSHA1) { + t.Fatalf("expected digest algorithm %q, but got %q", OIDDigestAlgorithmSHA1, currentDigestAlgorithm) + } + }() + + cert, err := createTestCertificateByIssuer("PKCS7 Test Root CA", nil, x509.ECDSAWithSHA256, true) + if err != nil { + t.Fatalf("failed cannot generate root cert: %v", err) + } + truststore := x509.NewCertPool() + truststore.AddCert(cert.Certificate) + + toBeSigned, err := NewSignedData([]byte("test")) + if err != nil { + t.Fatalf("failed creating signed data: %v", err) + } + + if !toBeSigned.digestOid.Equal(OIDDigestAlgorithmSHA512) { + t.Fatalf("expected digest algorithm %q, but got %q", OIDDigestAlgorithmSHA512, toBeSigned.digestOid) + } + + if err := toBeSigned.AddSignerChain(cert.Certificate, *cert.PrivateKey, nil, SignerInfoConfig{}); err != nil { + t.Fatalf("failed adding signer chain: %v", err) + } + + signed, err := toBeSigned.Finish() + if err != nil { + t.Fatalf("failed signing data: %v", err) + } + + pem.Encode(os.Stdout, &pem.Block{Type: "PKCS7", Bytes: signed}) + p7, err := Parse(signed) + if err != nil { + t.Fatalf("failed parsing PEM encoded signed data: %v", err) + } + if err := p7.VerifyWithChain(truststore); err != nil { + t.Fatalf("failed verifying PKCS7: %v", err) + } + if !bytes.Equal([]byte("test"), p7.Content) { + t.Fatal("parsed PKCS7 content does not equal signed content") + } + + if !p7.Signers[0].DigestAlgorithm.Algorithm.Equal(OIDDigestAlgorithmSHA512) { + t.Fatalf("expected digest algorithm %q, but got %q", OIDDigestAlgorithmSHA512, p7.Signers[0].DigestAlgorithm.Algorithm) + } +} + func TestSign(t *testing.T) { content := []byte("Hello World") sigalgs := []x509.SignatureAlgorithm{