diff --git a/kms/apiv1/requests.go b/kms/apiv1/requests.go index 224c1dff..0862cab8 100644 --- a/kms/apiv1/requests.go +++ b/kms/apiv1/requests.go @@ -226,7 +226,8 @@ type CreateAttestationRequest struct { // Notice: This API is EXPERIMENTAL and may be changed or removed in a later // release. type CreateAttestationResponse struct { - Certificate *x509.Certificate - CertificateChain []*x509.Certificate - PublicKey crypto.PublicKey + Certificate *x509.Certificate + CertificateChain []*x509.Certificate + PublicKey crypto.PublicKey + PermanentIdentifier string } diff --git a/kms/yubikey/yubikey.go b/kms/yubikey/yubikey.go index 13de97d5..c2018379 100644 --- a/kms/yubikey/yubikey.go +++ b/kms/yubikey/yubikey.go @@ -120,9 +120,11 @@ func New(ctx context.Context, opts apiv1.Options) (*YubiKey, error) { // Attempt to locate the yubikey with the given serial. for _, name := range cards { if k, err := pivOpen(name); err == nil { - if serialNumber, err := getSerialNumber(k); err == nil && serial == serialNumber { - yk = k - break + if cert, err := k.Attest(piv.SlotAuthentication); err == nil { + if serial == getSerialNumber(cert) { + yk = k + break + } } } } @@ -321,9 +323,10 @@ func (k *YubiKey) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1 } return &apiv1.CreateAttestationResponse{ - Certificate: cert, - CertificateChain: []*x509.Certificate{intermediate}, - PublicKey: cert.PublicKey, + Certificate: cert, + CertificateChain: []*x509.Certificate{intermediate}, + PublicKey: cert.PublicKey, + PermanentIdentifier: getSerialNumber(cert), }, nil } @@ -471,22 +474,19 @@ func getPolicies(req *apiv1.CreateKeyRequest) (piv.PINPolicy, piv.TouchPolicy) { return pin, touch } -// getSerialNumber gets an attestation certificate on the given key and returns -// the serial number on it. -func getSerialNumber(yk pivKey) (string, error) { - cert, err := yk.Attest(piv.SlotAuthentication) - if err != nil { - return "", err - } +// getSerialNumber returns the serial number from an attestation certificate. It +// will return an empty string if it the serial number extension does not exists +// or if it is malformed. +func getSerialNumber(cert *x509.Certificate) string { for _, ext := range cert.Extensions { if ext.Id.Equal(oidYubicoSerialNumber) { var serialNumber int rest, err := asn1.Unmarshal(ext.Value, &serialNumber) if err != nil || len(rest) > 0 { - return "", errors.New("error parsing YubiKey serial number") + return "" } - return strconv.Itoa(serialNumber), nil + return strconv.Itoa(serialNumber) } } - return "", errors.New("failed to find YubiKey serial number") + return "" } diff --git a/kms/yubikey/yubikey_test.go b/kms/yubikey/yubikey_test.go index cf8cead4..0dc07734 100644 --- a/kms/yubikey/yubikey_test.go +++ b/kms/yubikey/yubikey_test.go @@ -215,6 +215,22 @@ func (s *stubPivKey) Close() error { return nil } +func TestRegister(t *testing.T) { + fn, ok := apiv1.LoadKeyManagerNewFunc(apiv1.YubiKey) + if !ok { + t.Fatal("YubiKey is not registered") + } + k, err := fn(context.Background(), apiv1.Options{ + Type: "YubiKey", URI: "yubikey:", + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + if k == nil { + t.Fatalf("New() = %v, want &KeyVault{}", k) + } +} + func TestNew(t *testing.T) { ctx := context.Background() pOpen := pivOpen @@ -956,9 +972,10 @@ func TestYubiKey_CreateAttestation(t *testing.T) { {"ok", fields{yk, "123456", piv.DefaultManagementKey}, args{&apiv1.CreateAttestationRequest{ Name: "yubikey:slot-id=9a", }}, &apiv1.CreateAttestationResponse{ - Certificate: yk.attestMap[piv.SlotAuthentication], - CertificateChain: []*x509.Certificate{yk.attestCA.Intermediate}, - PublicKey: yk.attestMap[piv.SlotAuthentication].PublicKey, + Certificate: yk.attestMap[piv.SlotAuthentication], + CertificateChain: []*x509.Certificate{yk.attestCA.Intermediate}, + PublicKey: yk.attestMap[piv.SlotAuthentication].PublicKey, + PermanentIdentifier: "112233", }, false}, {"fail getSlot", fields{yk, "123456", piv.DefaultManagementKey}, args{&apiv1.CreateAttestationRequest{ Name: "yubikey://:slot-id=9a", @@ -1019,61 +1036,104 @@ func TestYubiKey_Close(t *testing.T) { } func Test_getSerialNumber(t *testing.T) { - ok := newStubPivKey(t, RSA) - - failAttest := newStubPivKey(t, RSA) - delete(failAttest.attestMap, piv.SlotAuthentication) - - failParse := newStubPivKey(t, ECDSA) - serialNumer, err := asn1.Marshal("112233") + serialNumber, err := asn1.Marshal(112233) if err != nil { t.Fatal(err) } - attCertParse, err := failParse.attestCA.Sign(&x509.Certificate{ - Subject: pkix.Name{CommonName: "attested certificate"}, - PublicKey: failParse.attestSigner.Public(), - ExtraExtensions: []pkix.Extension{ - {Id: oidYubicoSerialNumber, Value: serialNumer}, - }, - }) + printableSerialNumber, err := asn1.Marshal("112233") if err != nil { t.Fatal(err) } - failMissing := newStubPivKey(t, ECDSA) - attCertMissing, err := failMissing.attestCA.Sign(&x509.Certificate{ + + yk := newStubPivKey(t, RSA) + okCert := yk.attestMap[piv.SlotAuthentication] + printableCert := &x509.Certificate{ Subject: pkix.Name{CommonName: "attested certificate"}, - PublicKey: failMissing.attestSigner.Public(), - }) - if err != nil { - t.Fatal(err) + PublicKey: okCert.PublicKey, + Extensions: []pkix.Extension{ + {Id: oidYubicoSerialNumber, Value: printableSerialNumber}, + }, + } + restCert := &x509.Certificate{ + Subject: pkix.Name{CommonName: "attested certificate"}, + PublicKey: okCert.PublicKey, + Extensions: []pkix.Extension{ + {Id: oidYubicoSerialNumber, Value: append(serialNumber, 0)}, + }, + } + missingCert := &x509.Certificate{ + Subject: pkix.Name{CommonName: "attested certificate"}, + PublicKey: okCert.PublicKey, } - failParse.attestMap[piv.SlotAuthentication] = attCertParse - failMissing.attestMap[piv.SlotAuthentication] = attCertMissing + type args struct { + cert *x509.Certificate + } + tests := []struct { + name string + args args + want string + }{ + {"ok", args{okCert}, "112233"}, + {"fail printable", args{printableCert}, ""}, + {"fail rest", args{restCert}, ""}, + {"fail missing", args{missingCert}, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getSerialNumber(tt.args.cert); got != tt.want { + t.Errorf("getSerialNumber() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_getSignatureAlgorithm(t *testing.T) { + fake := apiv1.SignatureAlgorithm(1000) + t.Cleanup(func() { + delete(signatureAlgorithmMapping, fake) + }) + signatureAlgorithmMapping[fake] = "fake" type args struct { - yk pivKey + alg apiv1.SignatureAlgorithm + bits int } tests := []struct { name string args args - want string + want piv.Algorithm wantErr bool }{ - {"ok", args{ok}, "112233", false}, - {"fail attest", args{failAttest}, "", true}, - {"fail parse", args{failParse}, "", true}, - {"fail missing", args{failMissing}, "", true}, + {"default", args{apiv1.UnspecifiedSignAlgorithm, 0}, piv.AlgorithmEC256, false}, + {"SHA256WithRSA", args{apiv1.SHA256WithRSA, 0}, piv.AlgorithmRSA2048, false}, + {"SHA512WithRSA", args{apiv1.SHA512WithRSA, 0}, piv.AlgorithmRSA2048, false}, + {"SHA256WithRSAPSS", args{apiv1.SHA256WithRSAPSS, 0}, piv.AlgorithmRSA2048, false}, + {"SHA512WithRSAPSS", args{apiv1.SHA512WithRSAPSS, 0}, piv.AlgorithmRSA2048, false}, + {"ECDSAWithSHA256", args{apiv1.ECDSAWithSHA256, 0}, piv.AlgorithmEC256, false}, + {"ECDSAWithSHA384", args{apiv1.ECDSAWithSHA384, 0}, piv.AlgorithmEC384, false}, + {"PureEd25519", args{apiv1.PureEd25519, 0}, piv.AlgorithmEd25519, false}, + {"SHA256WithRSA 2048", args{apiv1.SHA256WithRSA, 2048}, piv.AlgorithmRSA2048, false}, + {"SHA512WithRSA 2048", args{apiv1.SHA512WithRSA, 2048}, piv.AlgorithmRSA2048, false}, + {"SHA256WithRSAPSS 2048", args{apiv1.SHA256WithRSAPSS, 2048}, piv.AlgorithmRSA2048, false}, + {"SHA512WithRSAPSS 2048", args{apiv1.SHA512WithRSAPSS, 2048}, piv.AlgorithmRSA2048, false}, + {"SHA256WithRSA 1024", args{apiv1.SHA256WithRSA, 1024}, piv.AlgorithmRSA1024, false}, + {"SHA512WithRSA 1024", args{apiv1.SHA512WithRSA, 1024}, piv.AlgorithmRSA1024, false}, + {"SHA256WithRSAPSS 1024", args{apiv1.SHA256WithRSAPSS, 1024}, piv.AlgorithmRSA1024, false}, + {"SHA512WithRSAPSS 1024", args{apiv1.SHA512WithRSAPSS, 1024}, piv.AlgorithmRSA1024, false}, + {"fail 4096", args{apiv1.SHA256WithRSA, 4096}, 0, true}, + {"fail unknown", args{apiv1.SignatureAlgorithm(100), 0}, 0, true}, + {"fail default case", args{apiv1.SignatureAlgorithm(1000), 0}, 0, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := getSerialNumber(tt.args.yk) + got, err := getSignatureAlgorithm(tt.args.alg, tt.args.bits) if (err != nil) != tt.wantErr { - t.Errorf("getSerialNumber() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("getSignatureAlgorithm() error = %v, wantErr %v", err, tt.wantErr) return } - if got != tt.want { - t.Errorf("getSerialNumber() = %v, want %v", got, tt.want) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("getSignatureAlgorithm() = %v, want %v", got, tt.want) } }) }