Skip to content

Commit

Permalink
Merge pull request #13356 from dmage/noglobals
Browse files Browse the repository at this point in the history
Merged by openshift-bot
  • Loading branch information
OpenShift Bot authored Mar 15, 2017
2 parents 0009bfa + a7ca6e1 commit 1eb44c7
Show file tree
Hide file tree
Showing 18 changed files with 200 additions and 269 deletions.
18 changes: 17 additions & 1 deletion pkg/cmd/dockerregistry/dockerregistry.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"strings"

"github.com/openshift/origin/pkg/cmd/server/crypto"
"github.com/openshift/origin/pkg/cmd/util/clientcmd"
"github.com/openshift/origin/pkg/dockerregistry/server"
"github.com/openshift/origin/pkg/dockerregistry/server/audit"
)
Expand All @@ -57,11 +58,26 @@ func Execute(configFile io.Reader) {
if err != nil {
log.Fatalf("error configuring logger: %v", err)
}

registryClient := server.NewRegistryClient(clientcmd.NewConfig().BindToFile())
ctx = server.WithRegistryClient(ctx, registryClient)

log.Infof("version=%s", version.Version)
// inject a logger into the uuid library. warns us if there is a problem
// with uuid generation under low entropy.
uuid.Loggerf = context.GetLogger(ctx).Warnf

// add parameters for the auth middleware
if config.Auth.Type() == server.OpenShiftAuth {
if config.Auth[server.OpenShiftAuth] == nil {
config.Auth[server.OpenShiftAuth] = make(configuration.Parameters)
}
config.Auth[server.OpenShiftAuth][server.AccessControllerOptionParams] = server.AccessControllerParams{
Logger: context.GetLogger(ctx),
SafeClientConfig: registryClient.SafeClientConfig(),
}
}

app := handlers.NewApp(ctx, config)

// Add a token handling endpoint
Expand All @@ -70,7 +86,7 @@ func Execute(configFile io.Reader) {
if err != nil {
log.Fatalf("error setting up token auth: %s", err)
}
err = app.NewRoute().Methods("GET").PathPrefix(tokenRealm.Path).Handler(server.NewTokenHandler(ctx, server.DefaultRegistryClient)).GetError()
err = app.NewRoute().Methods("GET").PathPrefix(tokenRealm.Path).Handler(server.NewTokenHandler(ctx, registryClient)).GetError()
if err != nil {
log.Fatalf("error setting up token endpoint at %q: %v", tokenRealm.Path, err)
}
Expand Down
63 changes: 20 additions & 43 deletions pkg/dockerregistry/server/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"net/url"
"strings"

log "github.com/Sirupsen/logrus"
context "github.com/docker/distribution/context"
registryauth "github.com/docker/distribution/registry/auth"

Expand Down Expand Up @@ -45,6 +44,10 @@ const (

RealmKey = "realm"
TokenRealmKey = "tokenrealm"

// AccessControllerOptionParams is an option name for passing
// AccessControllerParams to AccessController.
AccessControllerOptionParams = "_params"
)

// RegistryClient encapsulates getting access to the OpenShift API.
Expand All @@ -55,9 +58,6 @@ type RegistryClient interface {
SafeClientConfig() restclient.Config
}

// DefaultRegistryClient is exposed for testing the registry with fake client.
var DefaultRegistryClient = NewRegistryClient(clientcmd.NewConfig().BindToFile())

// registryClient implements RegistryClient
type registryClient struct {
config *clientcmd.Config
Expand Down Expand Up @@ -85,19 +85,6 @@ func init() {
registryauth.Register(OpenShiftAuth, registryauth.InitFunc(newAccessController))
}

type contextKey int

var userClientKey contextKey = 0

func WithUserClient(parent context.Context, userClient client.Interface) context.Context {
return context.WithValue(parent, userClientKey, userClient)
}

func UserClientFrom(ctx context.Context) (client.Interface, bool) {
userClient, ok := ctx.Value(userClientKey).(client.Interface)
return userClient, ok
}

// WithUserInfoLogger creates a new context with provided user infomation.
func WithUserInfoLogger(ctx context.Context, username, userid string) context.Context {
ctx = context.WithValue(ctx, audit.AuditUserEntry, username)
Expand All @@ -110,27 +97,6 @@ func WithUserInfoLogger(ctx context.Context, username, userid string) context.Co
))
}

const authPerformedKey = "openshift.auth.performed"

func WithAuthPerformed(parent context.Context) context.Context {
return context.WithValue(parent, authPerformedKey, true)
}

func AuthPerformed(ctx context.Context) bool {
authPerformed, ok := ctx.Value(authPerformedKey).(bool)
return ok && authPerformed
}

const deferredErrorsKey = "openshift.auth.deferredErrors"

func WithDeferredErrors(parent context.Context, errs deferredErrors) context.Context {
return context.WithValue(parent, deferredErrorsKey, errs)
}
func DeferredErrorsFrom(ctx context.Context) (deferredErrors, bool) {
errs, ok := ctx.Value(deferredErrorsKey).(deferredErrors)
return errs, ok
}

type AccessController struct {
realm string
tokenRealm *url.URL
Expand Down Expand Up @@ -198,8 +164,19 @@ func TokenRealm(options map[string]interface{}) (*url.URL, error) {
return tokenRealm, nil
}

// AccessControllerParams is the parameters for newAccessController
type AccessControllerParams struct {
Logger context.Logger
SafeClientConfig restclient.Config
}

func newAccessController(options map[string]interface{}) (registryauth.AccessController, error) {
log.Info("Using Origin Auth handler")
params, ok := options[AccessControllerOptionParams].(AccessControllerParams)
if !ok {
return nil, fmt.Errorf("no parameters provided to Origin Auth handler")
}

params.Logger.Info("Using Origin Auth handler")

realm, err := getStringOption("", RealmKey, "origin", options)
if err != nil {
Expand All @@ -214,7 +191,7 @@ func newAccessController(options map[string]interface{}) (registryauth.AccessCon
ac := &AccessController{
realm: realm,
tokenRealm: tokenRealm,
config: DefaultRegistryClient.SafeClientConfig(),
config: params.SafeClientConfig,
}

if audit, ok := options["audit"]; ok {
Expand Down Expand Up @@ -446,12 +423,12 @@ func (ac *AccessController) Authorized(ctx context.Context, accessRecords ...reg
// Conditionally add auth errors we want to handle later to the context
if !possibleCrossMountErrors.Empty() {
context.GetLogger(ctx).Debugf("Origin auth: deferring errors: %#v", possibleCrossMountErrors)
ctx = WithDeferredErrors(ctx, possibleCrossMountErrors)
ctx = withDeferredErrors(ctx, possibleCrossMountErrors)
}
// Always add a marker to the context so we know auth was run
ctx = WithAuthPerformed(ctx)
ctx = withAuthPerformed(ctx)

return WithUserClient(ctx, osClient), nil
return withUserClient(ctx, osClient), nil
}

func getOpenShiftAPIToken(ctx context.Context, req *http.Request) (string, error) {
Expand Down
16 changes: 8 additions & 8 deletions pkg/dockerregistry/server/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"github.com/docker/distribution/context"
"github.com/openshift/origin/pkg/api/latest"
"github.com/openshift/origin/pkg/authorization/api"
"github.com/openshift/origin/pkg/cmd/util/clientcmd"
userapi "github.com/openshift/origin/pkg/user/api"

// install all APIs
Expand Down Expand Up @@ -391,20 +390,21 @@ func TestAccessController(t *testing.T) {
if len(test.bearerToken) > 0 {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", test.bearerToken))
}
ctx := context.WithValue(context.Background(), "http.request", req)

server, actions := simulateOpenShiftMaster(test.openshiftResponses)
DefaultRegistryClient = NewRegistryClient(&clientcmd.Config{
CommonConfig: restclient.Config{
options[AccessControllerOptionParams] = AccessControllerParams{
Logger: context.GetLogger(context.Background()),
SafeClientConfig: restclient.Config{
Host: server.URL,
Insecure: true,
},
SkipEnv: true,
})
}
accessController, err := newAccessController(options)
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
ctx = context.WithRequest(ctx, req)
authCtx, err := accessController.Authorized(ctx, test.access...)
server.Close()

Expand All @@ -426,11 +426,11 @@ func TestAccessController(t *testing.T) {
t.Errorf("%s: expected auth context but got nil", k)
continue
}
if !AuthPerformed(authCtx) {
if !authPerformed(authCtx) {
t.Errorf("%s: expected AuthPerformed to be true", k)
continue
}
deferredErrors, hasDeferred := DeferredErrorsFrom(authCtx)
deferredErrors, hasDeferred := deferredErrorsFrom(authCtx)
if len(test.expectedRepoErr) > 0 {
if !hasDeferred || deferredErrors[test.expectedRepoErr] == nil {
t.Errorf("%s: expected deferred error for repo %s, got none", k, test.expectedRepoErr)
Expand Down
6 changes: 3 additions & 3 deletions pkg/dockerregistry/server/blobdescriptorservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ type blobDescriptorService struct {
// a proper repository object to be set on given context by upper openshift middleware wrappers.
func (bs *blobDescriptorService) Stat(ctx context.Context, dgst digest.Digest) (distribution.Descriptor, error) {
context.GetLogger(ctx).Debugf("(*blobDescriptorService).Stat: starting with digest=%s", dgst.String())
repo, found := RepositoryFrom(ctx)
repo, found := repositoryFrom(ctx)
if !found || repo == nil {
err := fmt.Errorf("failed to retrieve repository from context")
context.GetLogger(ctx).Error(err)
Expand Down Expand Up @@ -90,7 +90,7 @@ func (bs *blobDescriptorService) Stat(ctx context.Context, dgst digest.Digest) (
return desc, nil
}

if err == distribution.ErrBlobUnknown && RemoteBlobAccessCheckEnabledFrom(ctx) {
if err == distribution.ErrBlobUnknown && remoteBlobAccessCheckEnabledFrom(ctx) {
// Second attempt: looking for the blob on a remote server
desc, err = repo.remoteBlobGetter.Stat(ctx, dgst)
}
Expand All @@ -99,7 +99,7 @@ func (bs *blobDescriptorService) Stat(ctx context.Context, dgst digest.Digest) (
}

func (bs *blobDescriptorService) Clear(ctx context.Context, dgst digest.Digest) error {
repo, found := RepositoryFrom(ctx)
repo, found := repositoryFrom(ctx)
if !found || repo == nil {
err := fmt.Errorf("failed to retrieve repository from context")
context.GetLogger(ctx).Error(err)
Expand Down
18 changes: 5 additions & 13 deletions pkg/dockerregistry/server/blobdescriptorservice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ func GetTestPassThroughToUpstream(ctx context.Context) bool {
// It relies on the fact that blobDescriptorService requires higher levels to set repository object on given
// context. If the object isn't given, its method will err out.
func TestBlobDescriptorServiceIsApplied(t *testing.T) {
ctx := context.Background()

// don't do any authorization check
installFakeAccessController(t)
m := fakeBlobDescriptorService(t)
Expand All @@ -64,14 +62,8 @@ func TestBlobDescriptorServiceIsApplied(t *testing.T) {
client.AddReactor("get", "imagestreams", imagetest.GetFakeImageStreamGetHandler(t, *testImageStream))
client.AddReactor("get", "images", registrytest.GetFakeImageGetHandler(t, *testImage))

// TODO: get rid of those nasty global vars
backupRegistryClient := DefaultRegistryClient
DefaultRegistryClient = makeFakeRegistryClient(client, fake.NewSimpleClientset())
defer func() {
// set it back once this test finishes to make other unit tests working
DefaultRegistryClient = backupRegistryClient
}()

ctx := context.Background()
ctx = WithRegistryClient(ctx, makeFakeRegistryClient(client, fake.NewSimpleClientset()))
app := handlers.NewApp(ctx, &configuration.Configuration{
Loglevel: "debug",
Auth: map[string]configuration.Parameters{
Expand Down Expand Up @@ -463,7 +455,7 @@ func (bs *testBlobDescriptorService) Stat(ctx context.Context, dgst digest.Diges
bs.m.methodInvoked("Stat")
if bs.m.getUnsetRepository() {
bs.t.Logf("unsetting repository from the context")
ctx = WithRepository(ctx, nil)
ctx = withRepository(ctx, nil)
}

return bs.BlobDescriptorService.Stat(ctx, dgst)
Expand All @@ -472,7 +464,7 @@ func (bs *testBlobDescriptorService) Clear(ctx context.Context, dgst digest.Dige
bs.m.methodInvoked("Clear")
if bs.m.getUnsetRepository() {
bs.t.Logf("unsetting repository from the context")
ctx = WithRepository(ctx, nil)
ctx = withRepository(ctx, nil)
}
return bs.BlobDescriptorService.Clear(ctx, dgst)
}
Expand All @@ -499,7 +491,7 @@ func (f *fakeAccessController) Authorized(ctx context.Context, access ...registr
f.t.Logf("fake authorizer: authorizing access to %s:%s:%s", access.Resource.Type, access.Resource.Name, access.Action)
}

ctx = WithAuthPerformed(ctx)
ctx = withAuthPerformed(ctx)
return ctx, nil
}

Expand Down
Loading

0 comments on commit 1eb44c7

Please sign in to comment.