Skip to content

Commit

Permalink
Add test for SetDefaultDigestAlgorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
hslatman committed Feb 12, 2025
1 parent 9b402c9 commit 7003208
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 13 deletions.
29 changes: 16 additions & 13 deletions sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,30 @@ import (
"time"
)

func init() {
defaultMessageDigestAlgorithm.oid = OIDDigestAlgorithmSHA1
}

var defaultMessageDigestAlgorithm struct {
sync.RWMutex
oid asn1.ObjectIdentifier
}

// SetDefaultDigestAlgorithm sets the default digest algorithm
// to be used for signing operations on [SignedData].
//
// This must be called before creating a new instance of [SignedData]
// using [NewSignedData].
//
// When this function is not called, the default digest algorithm is SHA1.
func SetDefaultDigestAlgorithm(d asn1.ObjectIdentifier) error {
defaultMessageDigestAlgorithm.Lock()
defer defaultMessageDigestAlgorithm.Unlock()

switch {
case d.Equal(OIDDigestAlgorithmSHA1), d.Equal(OIDDigestAlgorithmSHA256), d.Equal(OIDDigestAlgorithmSHA384),
d.Equal(OIDDigestAlgorithmSHA512), d.Equal(OIDDigestAlgorithmSHA224),
case d.Equal(OIDDigestAlgorithmSHA1),
d.Equal(OIDDigestAlgorithmSHA224), d.Equal(OIDDigestAlgorithmSHA256),
d.Equal(OIDDigestAlgorithmSHA384), d.Equal(OIDDigestAlgorithmSHA512),
d.Equal(OIDDigestAlgorithmDSA), d.Equal(OIDDigestAlgorithmDSASHA1),
d.Equal(OIDDigestAlgorithmECDSASHA1), d.Equal(OIDDigestAlgorithmECDSASHA256),
d.Equal(OIDDigestAlgorithmECDSASHA384), d.Equal(OIDDigestAlgorithmECDSASHA512):
Expand All @@ -39,21 +52,11 @@ func SetDefaultDigestAlgorithm(d asn1.ObjectIdentifier) error {
return nil
}

var defaultMessageDigestAlgorithm struct {
sync.RWMutex
oid asn1.ObjectIdentifier
}

func defaultMessageDigestAlgorithmOID() asn1.ObjectIdentifier {
defaultMessageDigestAlgorithm.RLock()
defer defaultMessageDigestAlgorithm.RUnlock()

oid := defaultMessageDigestAlgorithm.oid
if oid.Equal(asn1.ObjectIdentifier{}) {
return OIDDigestAlgorithmSHA1
}

return oid
return defaultMessageDigestAlgorithm.oid
}

// SignedData is an opaque data structure for creating signed data payloads
Expand Down
70 changes: 70 additions & 0 deletions sign_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,76 @@ import (
"testing"
)

func TestSignWithGlobalDefaultDigestAlgorithm(t *testing.T) {
if err := SetDefaultDigestAlgorithm(asn1.ObjectIdentifier{}); err == nil {
t.Fatal("expected an error setting invalid digest algorithm")
}

currentDigestAlgorithm := defaultMessageDigestAlgorithmOID()
if !currentDigestAlgorithm.Equal(OIDDigestAlgorithmSHA1) {
t.Fatalf("expected digest algorithm %q, but got %q", OIDDigestAlgorithmSHA512, currentDigestAlgorithm)
}

if err := SetDefaultDigestAlgorithm(OIDDigestAlgorithmSHA512); err != nil {
t.Fatalf("failed setting default digest algorithm: %v", err)
}

defer func() {
currentDigestAlgorithm := defaultMessageDigestAlgorithmOID()
if !currentDigestAlgorithm.Equal(OIDDigestAlgorithmSHA512) {
t.Fatalf("expected digest algorithm %q, but got %q", OIDDigestAlgorithmSHA512, currentDigestAlgorithm)
}
if err := SetDefaultDigestAlgorithm(OIDDigestAlgorithmSHA1); err != nil {
t.Fatalf("failed resetting default digest algorithm: %v", err)
}
currentDigestAlgorithm = defaultMessageDigestAlgorithmOID()
if !currentDigestAlgorithm.Equal(OIDDigestAlgorithmSHA1) {
t.Fatalf("expected digest algorithm %q, but got %q", OIDDigestAlgorithmSHA1, currentDigestAlgorithm)
}
}()

cert, err := createTestCertificateByIssuer("PKCS7 Test Root CA", nil, x509.ECDSAWithSHA256, true)
if err != nil {
t.Fatalf("failed cannot generate root cert: %v", err)
}
truststore := x509.NewCertPool()
truststore.AddCert(cert.Certificate)

toBeSigned, err := NewSignedData([]byte("test"))
if err != nil {
t.Fatalf("failed creating signed data: %v", err)
}

if !toBeSigned.digestOid.Equal(OIDDigestAlgorithmSHA512) {
t.Fatalf("expected digest algorithm %q, but got %q", OIDDigestAlgorithmSHA512, toBeSigned.digestOid)
}

if err := toBeSigned.AddSignerChain(cert.Certificate, *cert.PrivateKey, nil, SignerInfoConfig{}); err != nil {
t.Fatalf("failed adding signer chain: %v", err)
}

signed, err := toBeSigned.Finish()
if err != nil {
t.Fatalf("failed signing data: %v", err)
}

pem.Encode(os.Stdout, &pem.Block{Type: "PKCS7", Bytes: signed})
p7, err := Parse(signed)
if err != nil {
t.Fatalf("failed parsing PEM encoded signed data: %v", err)
}
if err := p7.VerifyWithChain(truststore); err != nil {
t.Fatalf("failed verifying PKCS7: %v", err)
}
if !bytes.Equal([]byte("test"), p7.Content) {
t.Fatal("parsed PKCS7 content does not equal signed content")
}

if !p7.Signers[0].DigestAlgorithm.Algorithm.Equal(OIDDigestAlgorithmSHA512) {
t.Fatalf("expected digest algorithm %q, but got %q", OIDDigestAlgorithmSHA512, p7.Signers[0].DigestAlgorithm.Algorithm)
}
}

func TestSign(t *testing.T) {
content := []byte("Hello World")
sigalgs := []x509.SignatureAlgorithm{
Expand Down

0 comments on commit 7003208

Please sign in to comment.