Skip to content

Commit 48f0943

Browse files
authored
feat: add a host-based auth cache as a fallback (#651)
This PR adds a new exported method called `NewRobustCache()` and changes the `DefaultCache` to use the robust cache. The robust cache uses scoped based auth but falls back to host based auth to better handle the situations described in #650 but retain the benefits of scoped based auth (where the token for a repo might be different than another repo in the same registry). Closes #650 Signed-off-by: Kyle M. Tarplee <[email protected]>
1 parent faaa1dd commit 48f0943

File tree

2 files changed

+206
-0
lines changed

2 files changed

+206
-0
lines changed

registry/remote/auth/cache.go

+73
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,76 @@ func (noCache) GetToken(ctx context.Context, registry string, scheme Scheme, key
157157
func (noCache) Set(ctx context.Context, registry string, scheme Scheme, key string, fetch func(context.Context) (string, error)) (string, error) {
158158
return fetch(ctx)
159159
}
160+
161+
// hostCache is an auth cache that ignores scopes. Uses only the registry's hostname to find a token.
162+
type hostCache struct {
163+
Cache
164+
}
165+
166+
// GetToken implements Cache.
167+
func (c *hostCache) GetToken(ctx context.Context, registry string, scheme Scheme, key string) (string, error) {
168+
return c.Cache.GetToken(ctx, registry, scheme, "")
169+
}
170+
171+
// Set implements Cache.
172+
func (c *hostCache) Set(ctx context.Context, registry string, scheme Scheme, key string, fetch func(context.Context) (string, error)) (string, error) {
173+
return c.Cache.Set(ctx, registry, scheme, "", fetch)
174+
}
175+
176+
// fallbackCache tries the primary cache then falls back to the secondary cache.
177+
type fallbackCache struct {
178+
primary Cache
179+
secondary Cache
180+
}
181+
182+
// GetScheme implements Cache.
183+
func (fc *fallbackCache) GetScheme(ctx context.Context, registry string) (Scheme, error) {
184+
scheme, err := fc.primary.GetScheme(ctx, registry)
185+
if err == nil {
186+
return scheme, nil
187+
}
188+
189+
// fallback
190+
return fc.secondary.GetScheme(ctx, registry)
191+
}
192+
193+
// GetToken implements Cache.
194+
func (fc *fallbackCache) GetToken(ctx context.Context, registry string, scheme Scheme, key string) (string, error) {
195+
token, err := fc.primary.GetToken(ctx, registry, scheme, key)
196+
if err == nil {
197+
return token, nil
198+
}
199+
200+
// fallback
201+
return fc.secondary.GetToken(ctx, registry, scheme, key)
202+
}
203+
204+
// Set implements Cache.
205+
func (fc *fallbackCache) Set(ctx context.Context, registry string, scheme Scheme, key string, fetch func(context.Context) (string, error)) (string, error) {
206+
token, err := fc.primary.Set(ctx, registry, scheme, key, fetch)
207+
if err != nil {
208+
return "", err
209+
}
210+
211+
return fc.secondary.Set(ctx, registry, scheme, key, func(ctx context.Context) (string, error) {
212+
return token, nil
213+
})
214+
}
215+
216+
// NewSingleContextCache creates a host-based cache for optimizing the auth flow for non-compliant registries.
217+
// It is intended to be used in a single context, such as pulling from a single repository.
218+
// This cache should not be shared.
219+
//
220+
// Note: [NewCache] should be used for compliant registries as it can be shared
221+
// across context and will generally make less re-authentication requests.
222+
func NewSingleContextCache() Cache {
223+
cache := NewCache()
224+
return &fallbackCache{
225+
primary: cache,
226+
// We can re-use the came concurrentCache here because the key space is different
227+
// (keys are always empty for the hostCache) so there is no collision.
228+
// Even if there is a collision it is not an issue.
229+
// Re-using saves a little memory.
230+
secondary: &hostCache{cache},
231+
}
232+
}

registry/remote/auth/cache_test.go

+133
Original file line numberDiff line numberDiff line change
@@ -540,3 +540,136 @@ func Test_concurrentCache_Set_Fetch_Failure(t *testing.T) {
540540
}
541541
}
542542
}
543+
544+
func Test_hostCache(t *testing.T) {
545+
base := NewCache()
546+
547+
// no entry in the cache
548+
ctx := context.Background()
549+
550+
hc := hostCache{base}
551+
552+
fetch := func(i int) func(context.Context) (string, error) {
553+
return func(context.Context) (string, error) {
554+
return strconv.Itoa(i), nil
555+
}
556+
}
557+
558+
// The key is ignored in the hostCache implementation.
559+
560+
{ // Set the token to 100
561+
gotToken, err := hc.Set(ctx, "reg.example.com", SchemeBearer, "key1", fetch(100))
562+
if err != nil {
563+
t.Fatalf("hostCache.Set() error = %v", err)
564+
}
565+
if want := strconv.Itoa(100); gotToken != want {
566+
t.Errorf("hostCache.Set() = %v, want %v", gotToken, want)
567+
}
568+
}
569+
570+
{ // Overwrite the token entry to 101
571+
gotToken, err := hc.Set(ctx, "reg.example.com", SchemeBearer, "key2", fetch(101))
572+
if err != nil {
573+
t.Fatalf("hostCache.Set() error = %v", err)
574+
}
575+
if want := strconv.Itoa(101); gotToken != want {
576+
t.Errorf("hostCache.Set() = %v, want %v", gotToken, want)
577+
}
578+
}
579+
580+
{ // Add entry for another host
581+
gotToken, err := hc.Set(ctx, "reg2.example.com", SchemeBearer, "key3", fetch(102))
582+
if err != nil {
583+
t.Fatalf("hostCache.Set() error = %v", err)
584+
}
585+
if want := strconv.Itoa(102); gotToken != want {
586+
t.Errorf("hostCache.Set() = %v, want %v", gotToken, want)
587+
}
588+
}
589+
590+
{ // Ensure the token for key1 is 101 now
591+
gotToken, err := hc.GetToken(ctx, "reg.example.com", SchemeBearer, "key1")
592+
if err != nil {
593+
t.Fatalf("hostCache.GetToken() error = %v", err)
594+
}
595+
if want := strconv.Itoa(101); gotToken != want {
596+
t.Errorf("hostCache.GetToken() = %v, want %v", gotToken, want)
597+
}
598+
}
599+
600+
{ // Make sure GetScheme still works
601+
gotScheme, err := hc.GetScheme(ctx, "reg.example.com")
602+
if err != nil {
603+
t.Fatalf("hostCache.GetScheme() error = %v", err)
604+
}
605+
if want := SchemeBearer; gotScheme != want {
606+
t.Errorf("hostCache.GetScheme() = %v, want %v", gotScheme, want)
607+
}
608+
}
609+
}
610+
611+
func Test_fallbackCache(t *testing.T) {
612+
// no entry in the cache
613+
ctx := context.Background()
614+
615+
scc := NewSingleContextCache()
616+
617+
fetch := func(i int) func(context.Context) (string, error) {
618+
return func(context.Context) (string, error) {
619+
return strconv.Itoa(i), nil
620+
}
621+
}
622+
623+
// Test that fallback works
624+
625+
{ // Set the token to 100
626+
gotToken, err := scc.Set(ctx, "reg.example.com", SchemeBearer, "key1", fetch(100))
627+
if err != nil {
628+
t.Fatalf("hostCache.Set() error = %v", err)
629+
}
630+
if want := strconv.Itoa(100); gotToken != want {
631+
t.Errorf("hostCache.Set() = %v, want %v", gotToken, want)
632+
}
633+
}
634+
635+
{ // Ensure the token for key2 falls back to 100
636+
gotToken, err := scc.GetToken(ctx, "reg.example.com", SchemeBearer, "key2")
637+
if err != nil {
638+
t.Fatalf("hostCache.GetToken() error = %v", err)
639+
}
640+
if want := strconv.Itoa(100); gotToken != want {
641+
t.Errorf("hostCache.GetToken() = %v, want %v", gotToken, want)
642+
}
643+
}
644+
645+
{ // Make sure GetScheme works as expected
646+
gotScheme, err := scc.GetScheme(ctx, "reg.example.com")
647+
if err != nil {
648+
t.Fatalf("hostCache.GetScheme() error = %v", err)
649+
}
650+
if want := SchemeBearer; gotScheme != want {
651+
t.Errorf("hostCache.GetScheme() = %v, want %v", gotScheme, want)
652+
}
653+
}
654+
655+
{ // Make sure GetScheme falls back
656+
gotScheme, err := scc.GetScheme(ctx, "reg.example.com")
657+
if err != nil {
658+
t.Fatalf("hostCache.GetScheme() error = %v", err)
659+
}
660+
if want := SchemeBearer; gotScheme != want {
661+
t.Errorf("hostCache.GetScheme() = %v, want %v", gotScheme, want)
662+
}
663+
}
664+
665+
{ // Check GetScheme fallback
666+
// scc.(*fallbackCache).primary = NewCache()
667+
gotScheme, err := scc.GetScheme(ctx, "reg2.example.com")
668+
if !errors.Is(err, errdef.ErrNotFound) {
669+
t.Fatalf("hostCache.GetScheme() error = %v", err)
670+
}
671+
if want := SchemeUnknown; gotScheme != want {
672+
t.Errorf("hostCache.GetScheme() = %v, want %v", gotScheme, want)
673+
}
674+
}
675+
}

0 commit comments

Comments
 (0)