diff --git a/config/backend_plugin.go b/config/backend_plugin.go index b3e6d0d5712..a071b551442 100644 --- a/config/backend_plugin.go +++ b/config/backend_plugin.go @@ -10,6 +10,8 @@ import ( "github.com/pkg/errors" "github.com/Sirupsen/logrus" "github.com/ory/fosite" + "github.com/jmoiron/sqlx" + "time" ) type PluginConnection struct { @@ -17,14 +19,15 @@ type PluginConnection struct { plugin *plugin.Plugin didConnect bool Logger logrus.FieldLogger + db *sqlx.DB } func (c *PluginConnection) load() error { - cf := c.Config if c.plugin != nil { return nil } + cf := c.Config p, err := plugin.Open(cf.DatabasePlugin) if err != nil { return errors.WithStack(err) @@ -46,11 +49,18 @@ func (c *PluginConnection) Connect() error { if l, err := c.plugin.Lookup("Connect"); err != nil { return errors.Wrap(err, "Unable to look up `Connect`") - } else if c, ok := l.(func(url string) error); !ok { + } else if con, ok := l.(func(url string) (*sqlx.DB, error)); !ok { return errors.New("Unable to type assert `Connect`") } else { - if err := c(cf.DatabaseURL); err != nil { - return errors.Wrap(err, "Could not Connect to database") + if db, err := con(cf.DatabaseURL); err != nil { + return errors.Wrap(err, "Could not connect to database") + } else { + cf.GetLogger().Info("Successfully connected through database plugin") + c.db = db + cf.GetLogger().Debugf("Address of database plugin is: %s", c.db) + if err := db.Ping(); err != nil { + cf.GetLogger().WithError(err).Fatal("Could not ping database connection from plugin") + } } } return nil @@ -64,10 +74,10 @@ func (c *PluginConnection) NewClientManager() (client.Manager, error) { ctx := c.Config.Context() if l, err := c.plugin.Lookup("NewClientManager"); err != nil { return nil, errors.Wrap(err, "Unable to look up `NewClientManager`") - } else if m, ok := l.(func(fosite.Hasher) client.Manager); !ok { + } else if m, ok := l.(func(*sqlx.DB, fosite.Hasher) client.Manager); !ok { return nil, errors.New("Unable to type assert `NewClientManager`") } else { - return m(ctx.Hasher), nil + return m(c.db, ctx.Hasher), nil } } @@ -78,10 +88,10 @@ func (c *PluginConnection) NewGroupManager() (group.Manager, error) { if l, err := c.plugin.Lookup("NewGroupManager"); err != nil { return nil, errors.Wrap(err, "Unable to look up `NewGroupManager`") - } else if m, ok := l.(func() group.Manager); !ok { + } else if m, ok := l.(func(*sqlx.DB) group.Manager); !ok { return nil, errors.New("Unable to type assert `NewGroupManager`") } else { - return m(), nil + return m(c.db), nil } } @@ -92,10 +102,10 @@ func (c *PluginConnection) NewJWKManager() (jwk.Manager, error) { if l, err := c.plugin.Lookup("NewJWKManager"); err != nil { return nil, errors.Wrap(err, "Unable to look up `NewJWKManager`") - } else if m, ok := l.(func(*jwk.AEAD) jwk.Manager); !ok { + } else if m, ok := l.(func(*sqlx.DB, *jwk.AEAD) jwk.Manager); !ok { return nil, errors.New("Unable to type assert `NewJWKManager`") } else { - return m(&jwk.AEAD{ + return m(c.db, &jwk.AEAD{ Key: c.Config.GetSystemSecret(), }), nil } @@ -108,10 +118,10 @@ func (c *PluginConnection) NewOAuth2Manager(clientManager client.Manager) (pkg.F if l, err := c.plugin.Lookup("NewOAuth2Manager"); err != nil { return nil, errors.Wrap(err, "Unable to look up `NewOAuth2Manager`") - } else if m, ok := l.(func(client.Manager, logrus.FieldLogger) pkg.FositeStorer); !ok { + } else if m, ok := l.(func(*sqlx.DB, client.Manager, logrus.FieldLogger) pkg.FositeStorer); !ok { return nil, errors.New("Unable to type assert `NewOAuth2Manager`") } else { - return m(clientManager, c.Config.GetLogger()), nil + return m(c.db, clientManager, c.Config.GetLogger()), nil } } @@ -122,9 +132,9 @@ func (c *PluginConnection) NewPolicyManager() (ladon.Manager, error) { if l, err := c.plugin.Lookup("NewPolicyManager"); err != nil { return nil, errors.Wrap(err, "Unable to look up `NewPolicyManager`") - } else if m, ok := l.(func() ladon.Manager); !ok { + } else if m, ok := l.(func(*sqlx.DB) ladon.Manager); !ok { return nil, errors.Errorf("Unable to type assert `NewPolicyManager`, got %v", l) } else { - return m(), nil + return m(c.db), nil } }