Skip to content

Commit

Permalink
Merge branch main into authz-profile
Browse files Browse the repository at this point in the history
  • Loading branch information
aarongable committed Jan 27, 2025
2 parents af8cbe0 + 55b8cbe commit 62acf88
Show file tree
Hide file tree
Showing 39 changed files with 2,153 additions and 1,077 deletions.
7 changes: 6 additions & 1 deletion allowlist/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,13 @@ func NewList[T comparable](members []T) *List[T] {
}

// NewFromYAML reads a YAML sequence of values of type T and returns a *List[T]
// containing those values. If the data cannot be parsed, an error is returned.
// containing those values. If data is empty, an empty (deny all) list is
// returned. If data cannot be parsed, an error is returned.
func NewFromYAML[T comparable](data []byte) (*List[T], error) {
if len(data) == 0 {
return NewList([]T{}), nil
}

var entries []T
err := strictyaml.Unmarshal(data, &entries)
if err != nil {
Expand Down
60 changes: 57 additions & 3 deletions allowlist/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
)

func TestNewFromYAML(t *testing.T) {
t.Parallel()

tests := []struct {
name string
yamlData string
Expand All @@ -22,9 +24,9 @@ func TestNewFromYAML(t *testing.T) {
{
name: "empty YAML",
yamlData: "",
check: nil,
expectAnswers: nil,
expectErr: true,
check: []string{"oak", "walnut", "maple", "cherry"},
expectAnswers: []bool{false, false, false, false},
expectErr: false,
},
{
name: "invalid YAML",
Expand All @@ -37,6 +39,8 @@ func TestNewFromYAML(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

list, err := NewFromYAML[string]([]byte(tt.yamlData))
if (err != nil) != tt.expectErr {
t.Fatalf("NewFromYAML() error = %v, expectErr = %v", err, tt.expectErr)
Expand All @@ -53,3 +57,53 @@ func TestNewFromYAML(t *testing.T) {
})
}
}

func TestNewList(t *testing.T) {
t.Parallel()

tests := []struct {
name string
members []string
check []string
expectAnswers []bool
}{
{
name: "unique members",
members: []string{"oak", "maple", "cherry"},
check: []string{"oak", "walnut", "maple", "cherry"},
expectAnswers: []bool{true, false, true, true},
},
{
name: "duplicate members",
members: []string{"oak", "maple", "cherry", "oak"},
check: []string{"oak", "walnut", "maple", "cherry"},
expectAnswers: []bool{true, false, true, true},
},
{
name: "nil list",
members: nil,
check: []string{"oak", "walnut", "maple", "cherry"},
expectAnswers: []bool{false, false, false, false},
},
{
name: "empty list",
members: []string{},
check: []string{"oak", "walnut", "maple", "cherry"},
expectAnswers: []bool{false, false, false, false},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

list := NewList[string](tt.members)
for i, item := range tt.check {
got := list.Contains(item)
if got != tt.expectAnswers[i] {
t.Errorf("Contains(%q) got %v, want %v", item, got, tt.expectAnswers[i])
}
}
})
}
}
70 changes: 56 additions & 14 deletions cmd/admin/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ type subcommandRevokeCert struct {
privKey string
regID uint
certFile string
crlShard int64
}

var _ subcommand = (*subcommandRevokeCert)(nil)
Expand All @@ -58,6 +59,7 @@ func (s *subcommandRevokeCert) Flags(flag *flag.FlagSet) {
flag.StringVar(&s.reasonStr, "reason", "unspecified", "Revocation reason (unspecified, keyCompromise, superseded, cessationOfOperation, or privilegeWithdrawn)")
flag.BoolVar(&s.skipBlock, "skip-block-key", false, "Skip blocking the key, if revoked for keyCompromise - use with extreme caution")
flag.BoolVar(&s.malformed, "malformed", false, "Indicates that the cert cannot be parsed - use with caution")
flag.Int64Var(&s.crlShard, "crl-shard", 0, "For malformed certs, the CRL shard the certificate belongs to")

// Flags specifying the input method for the certificates to be revoked.
flag.StringVar(&s.serial, "serial", "", "Revoke the certificate with this hex serial")
Expand Down Expand Up @@ -134,19 +136,54 @@ func (s *subcommandRevokeCert) Run(ctx context.Context, a *admin) error {
return fmt.Errorf("collecting serials to revoke: %w", err)
}

serials, err = cleanSerials(serials)
if err != nil {
return err
}

if len(serials) == 0 {
return errors.New("no serials to revoke found")
}

a.log.Infof("Found %d certificates to revoke", len(serials))

err = a.revokeSerials(ctx, serials, reasonCode, s.malformed, s.skipBlock, s.parallelism)
if s.malformed {
return s.revokeMalformed(ctx, a, serials, reasonCode)
}

err = a.revokeSerials(ctx, serials, reasonCode, s.skipBlock, s.parallelism)
if err != nil {
return fmt.Errorf("revoking serials: %w", err)
}

return nil
}

func (s *subcommandRevokeCert) revokeMalformed(ctx context.Context, a *admin, serials []string, reasonCode revocation.Reason) error {
u, err := user.Current()
if err != nil {
return fmt.Errorf("getting admin username: %w", err)
}
if s.crlShard == 0 {
return errors.New("when revoking malformed certificates, a nonzero CRL shard must be specified")
}
if len(serials) > 1 {
return errors.New("when revoking malformed certificates, only one cert at a time is allowed")
}
_, err = a.rac.AdministrativelyRevokeCertificate(
ctx,
&rapb.AdministrativelyRevokeCertificateRequest{
Serial: serials[0],
Code: int64(reasonCode),
AdminName: u.Username,
SkipBlockKey: s.skipBlock,
Malformed: true,
CrlShard: s.crlShard,
},
)
return err
}

func (a *admin) serialsFromIncidentTable(ctx context.Context, tableName string) ([]string, error) {
stream, err := a.saroc.SerialsForIncident(ctx, &sapb.SerialsForIncidentRequest{IncidentTable: tableName})
if err != nil {
Expand Down Expand Up @@ -248,7 +285,9 @@ func (a *admin) serialsFromCertPEM(_ context.Context, filename string) ([]string
return []string{core.SerialToString(cert.SerialNumber)}, nil
}

func cleanSerial(serial string) (string, error) {
// cleanSerials removes non-alphanumeric characters from the serials and checks
// that all resulting serials are valid (hex encoded, and the correct length).
func cleanSerials(serials []string) ([]string, error) {
serialStrip := func(r rune) rune {
switch {
case unicode.IsLetter(r):
Expand All @@ -258,14 +297,19 @@ func cleanSerial(serial string) (string, error) {
}
return rune(-1)
}
strippedSerial := strings.Map(serialStrip, serial)
if !core.ValidSerial(strippedSerial) {
return "", fmt.Errorf("cleaned serial %q is not valid", strippedSerial)

var ret []string
for _, s := range serials {
cleaned := strings.Map(serialStrip, s)
if !core.ValidSerial(cleaned) {
return nil, fmt.Errorf("cleaned serial %q is not valid", cleaned)
}
ret = append(ret, cleaned)
}
return strippedSerial, nil
return ret, nil
}

func (a *admin) revokeSerials(ctx context.Context, serials []string, reason revocation.Reason, malformed bool, skipBlockKey bool, parallelism uint) error {
func (a *admin) revokeSerials(ctx context.Context, serials []string, reason revocation.Reason, skipBlockKey bool, parallelism uint) error {
u, err := user.Current()
if err != nil {
return fmt.Errorf("getting admin username: %w", err)
Expand All @@ -279,19 +323,17 @@ func (a *admin) revokeSerials(ctx context.Context, serials []string, reason revo
go func() {
defer wg.Done()
for serial := range work {
cleanedSerial, err := cleanSerial(serial)
if err != nil {
a.log.Errf("skipping serial %q: %s", serial, err)
continue
}
_, err = a.rac.AdministrativelyRevokeCertificate(
ctx,
&rapb.AdministrativelyRevokeCertificateRequest{
Serial: cleanedSerial,
Serial: serial,
Code: int64(reason),
AdminName: u.Username,
SkipBlockKey: skipBlockKey,
Malformed: malformed,
// This is a well-formed certificate so send CrlShard 0
// to let the RA figure out the right shard from the cert.
Malformed: false,
CrlShard: 0,
},
)
if err != nil {
Expand Down
95 changes: 80 additions & 15 deletions cmd/admin/cert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"errors"
"os"
"path"
"reflect"
"slices"
"strings"
"sync"
Expand Down Expand Up @@ -198,70 +199,134 @@ func (mra *mockRARecordingRevocations) reset() {
func TestRevokeSerials(t *testing.T) {
t.Parallel()
serials := []string{
"2a:18:59:2b:7f:4b:f5:96:fb:1a:1d:f1:35:56:7a:cd:82:5a",
"03:8c:3f:63:88:af:b7:69:5d:d4:d6:bb:e3:d2:64:f1:e4:e2",
"048c3f6388afb7695dd4d6bbe3d264f1e5e5!",
"2a18592b7f4bf596fb1a1df135567acd825a",
"038c3f6388afb7695dd4d6bbe3d264f1e4e2",
"048c3f6388afb7695dd4d6bbe3d264f1e5e5",
}
mra := mockRARecordingRevocations{}
log := blog.NewMock()
a := admin{rac: &mra, log: log}

assertRequestsContain := func(reqs []*rapb.AdministrativelyRevokeCertificateRequest, code revocation.Reason, skipBlockKey bool, malformed bool) {
assertRequestsContain := func(reqs []*rapb.AdministrativelyRevokeCertificateRequest, code revocation.Reason, skipBlockKey bool) {
t.Helper()
for _, req := range reqs {
test.AssertEquals(t, len(req.Cert), 0)
test.AssertEquals(t, req.Code, int64(code))
test.AssertEquals(t, req.SkipBlockKey, skipBlockKey)
test.AssertEquals(t, req.Malformed, malformed)
}
}

// Revoking should result in 3 gRPC requests and quiet execution.
mra.reset()
log.Clear()
a.dryRun = false
err := a.revokeSerials(context.Background(), serials, 0, false, false, 1)
err := a.revokeSerials(context.Background(), serials, 0, false, 1)
test.AssertEquals(t, len(log.GetAllMatching("invalid serial format")), 0)
test.AssertNotError(t, err, "")
test.AssertEquals(t, len(log.GetAll()), 0)
test.AssertEquals(t, len(mra.revocationRequests), 3)
assertRequestsContain(mra.revocationRequests, 0, false, false)
assertRequestsContain(mra.revocationRequests, 0, false)

// Revoking an already-revoked serial should result in one log line.
mra.reset()
log.Clear()
mra.alreadyRevoked = []string{"048c3f6388afb7695dd4d6bbe3d264f1e5e5"}
err = a.revokeSerials(context.Background(), serials, 0, false, false, 1)
err = a.revokeSerials(context.Background(), serials, 0, false, 1)
t.Logf("error: %s", err)
t.Logf("logs: %s", strings.Join(log.GetAll(), ""))
test.AssertError(t, err, "already-revoked should result in error")
test.AssertEquals(t, len(log.GetAllMatching("not revoking")), 1)
test.AssertEquals(t, len(mra.revocationRequests), 3)
assertRequestsContain(mra.revocationRequests, 0, false, false)
assertRequestsContain(mra.revocationRequests, 0, false)

// Revoking a doomed-to-fail serial should also result in one log line.
mra.reset()
log.Clear()
mra.doomedToFail = []string{"048c3f6388afb7695dd4d6bbe3d264f1e5e5"}
err = a.revokeSerials(context.Background(), serials, 0, false, false, 1)
err = a.revokeSerials(context.Background(), serials, 0, false, 1)
test.AssertError(t, err, "gRPC error should result in error")
test.AssertEquals(t, len(log.GetAllMatching("failed to revoke")), 1)
test.AssertEquals(t, len(mra.revocationRequests), 3)
assertRequestsContain(mra.revocationRequests, 0, false, false)
assertRequestsContain(mra.revocationRequests, 0, false)

// Revoking with other parameters should get carried through.
mra.reset()
log.Clear()
err = a.revokeSerials(context.Background(), serials, 1, true, true, 3)
err = a.revokeSerials(context.Background(), serials, 1, true, 3)
test.AssertNotError(t, err, "")
test.AssertEquals(t, len(mra.revocationRequests), 3)
assertRequestsContain(mra.revocationRequests, 1, true, true)
assertRequestsContain(mra.revocationRequests, 1, true)

// Revoking in dry-run mode should result in no gRPC requests and three logs.
mra.reset()
log.Clear()
a.dryRun = true
a.rac = dryRunRAC{log: log}
err = a.revokeSerials(context.Background(), serials, 0, false, false, 1)
err = a.revokeSerials(context.Background(), serials, 0, false, 1)
test.AssertNotError(t, err, "")
test.AssertEquals(t, len(log.GetAllMatching("dry-run:")), 3)
test.AssertEquals(t, len(mra.revocationRequests), 0)
assertRequestsContain(mra.revocationRequests, 0, false, false)
assertRequestsContain(mra.revocationRequests, 0, false)
}

func TestRevokeMalformed(t *testing.T) {
t.Parallel()
mra := mockRARecordingRevocations{}
log := blog.NewMock()
a := &admin{
rac: &mra,
log: log,
dryRun: false,
}

s := subcommandRevokeCert{
crlShard: 623,
}
serial := "0379c3dfdd518be45948f2dbfa6ea3e9b209"
err := s.revokeMalformed(context.Background(), a, []string{serial}, 1)
if err != nil {
t.Errorf("revokedMalformed with crlShard 623: want success, got %s", err)
}
if len(mra.revocationRequests) != 1 {
t.Errorf("revokeMalformed: want 1 revocation request to SA, got %v", mra.revocationRequests)
}
if mra.revocationRequests[0].Serial != serial {
t.Errorf("revokeMalformed: want %s to be revoked, got %s", serial, mra.revocationRequests[0])
}

s = subcommandRevokeCert{
crlShard: 0,
}
err = s.revokeMalformed(context.Background(), a, []string{"038c3f6388afb7695dd4d6bbe3d264f1e4e2"}, 1)
if err == nil {
t.Errorf("revokedMalformed with crlShard 0: want error, got none")
}

s = subcommandRevokeCert{
crlShard: 623,
}
err = s.revokeMalformed(context.Background(), a, []string{"038c3f6388afb7695dd4d6bbe3d264f1e4e2", "28a94f966eae14e525777188512ddf5a0a3b"}, 1)
if err == nil {
t.Errorf("revokedMalformed with multiple serials: want error, got none")
}
}

func TestCleanSerials(t *testing.T) {
input := []string{
"2a:18:59:2b:7f:4b:f5:96:fb:1a:1d:f1:35:56:7a:cd:82:5a",
"03:8c:3f:63:88:af:b7:69:5d:d4:d6:bb:e3:d2:64:f1:e4:e2",
"038c3f6388afb7695dd4d6bbe3d264f1e4e2",
}
expected := []string{
"2a18592b7f4bf596fb1a1df135567acd825a",
"038c3f6388afb7695dd4d6bbe3d264f1e4e2",
"038c3f6388afb7695dd4d6bbe3d264f1e4e2",
}
output, err := cleanSerials(input)
if err != nil {
t.Errorf("cleanSerials(%s): %s, want %s", input, err, expected)
}
if !reflect.DeepEqual(output, expected) {
t.Errorf("cleanSerials(%s)=%s, want %s", input, output, expected)
}
}
Loading

0 comments on commit 62acf88

Please sign in to comment.