From b14075dc2bbca1b422d7d4a081d371296c1a2100 Mon Sep 17 00:00:00 2001 From: Nathan Verzemnieks Date: Tue, 21 May 2024 14:40:29 +0200 Subject: [PATCH 1/2] chore: finish adding context handling throughout --- pkg/awsds/asyncDatasource.go | 8 +-- pkg/awsds/asyncDatasource_test.go | 4 +- pkg/awsds/types.go | 2 +- pkg/sql/api/api.go | 3 - pkg/sql/api/api_test.go | 2 +- pkg/sql/datasource/datasource.go | 85 +++++++++++++++------------ pkg/sql/datasource/datasource_test.go | 79 ++++++++++++++++--------- pkg/sql/datasource/utils_test.go | 4 +- 8 files changed, 111 insertions(+), 76 deletions(-) diff --git a/pkg/awsds/asyncDatasource.go b/pkg/awsds/asyncDatasource.go index 6a4e009..f48a51a 100644 --- a/pkg/awsds/asyncDatasource.go +++ b/pkg/awsds/asyncDatasource.go @@ -77,7 +77,7 @@ func isAsyncFlow(query backend.DataQuery) bool { } func (ds *AsyncAWSDatasource) NewDatasource(ctx context.Context, settings backend.DataSourceInstanceSettings) (instancemgmt.Instance, error) { - db, err := ds.driver.GetAsyncDB(settings, nil) + db, err := ds.driver.GetAsyncDB(ctx, settings, nil) if err != nil { return nil, err } @@ -153,7 +153,7 @@ func (ds *AsyncAWSDatasource) CheckHealth(ctx context.Context, req *backend.Chec }, nil } -func (ds *AsyncAWSDatasource) getAsyncDBFromQuery(q *AsyncQuery, datasourceUID string) (AsyncDB, error) { +func (ds *AsyncAWSDatasource) getAsyncDBFromQuery(ctx context.Context, q *AsyncQuery, datasourceUID string) (AsyncDB, error) { if !ds.EnableMultipleConnections && len(q.ConnectionArgs) > 0 { return nil, sqlds.ErrorMissingMultipleConnectionsConfig } @@ -174,7 +174,7 @@ func (ds *AsyncAWSDatasource) getAsyncDBFromQuery(q *AsyncQuery, datasourceUID s } var err error - db, err := ds.driver.GetAsyncDB(dbConn.settings, q.ConnectionArgs) + db, err := ds.driver.GetAsyncDB(ctx, dbConn.settings, q.ConnectionArgs) if err != nil { return nil, err } @@ -211,7 +211,7 @@ func (ds *AsyncAWSDatasource) handleAsyncQuery(ctx context.Context, req backend. fillMode = q.FillMissing } - asyncDB, err := ds.getAsyncDBFromQuery(q, datasourceUID) + asyncDB, err := ds.getAsyncDBFromQuery(ctx, q, datasourceUID) if err != nil { return getErrorFrameFromQuery(q), err } diff --git a/pkg/awsds/asyncDatasource_test.go b/pkg/awsds/asyncDatasource_test.go index cb1eeb8..1289c44 100644 --- a/pkg/awsds/asyncDatasource_test.go +++ b/pkg/awsds/asyncDatasource_test.go @@ -42,7 +42,7 @@ type fakeDriver struct { AsyncDriver } -func (d fakeDriver) GetAsyncDB(backend.DataSourceInstanceSettings, json.RawMessage) (db AsyncDB, err error) { +func (d fakeDriver) GetAsyncDB(context.Context, backend.DataSourceInstanceSettings, json.RawMessage) (db AsyncDB, err error) { return d.openDBfn() } @@ -96,7 +96,7 @@ func Test_getDBConnectionFromQuery(t *testing.T) { ds.storeDBConnection(key, dbConnection{tt.existingDB, settings}) } - dbConn, err := ds.getAsyncDBFromQuery(&AsyncQuery{Query: sqlutil.Query{ConnectionArgs: json.RawMessage(tt.args)}}, tt.dsUID) + dbConn, err := ds.getAsyncDBFromQuery(context.Background(), &AsyncQuery{Query: sqlutil.Query{ConnectionArgs: json.RawMessage(tt.args)}}, tt.dsUID) if err != nil { t.Fatalf("unexpected error %v", err) } diff --git a/pkg/awsds/types.go b/pkg/awsds/types.go index df07bda..8c76c88 100644 --- a/pkg/awsds/types.go +++ b/pkg/awsds/types.go @@ -116,5 +116,5 @@ type AsyncDB interface { // AsyncDriver extends the driver interface to also connect to async SQL datasources type AsyncDriver interface { sqlds.Driver - GetAsyncDB(settings backend.DataSourceInstanceSettings, queryArgs json.RawMessage) (AsyncDB, error) + GetAsyncDB(ctx context.Context, settings backend.DataSourceInstanceSettings, queryArgs json.RawMessage) (AsyncDB, error) } diff --git a/pkg/sql/api/api.go b/pkg/sql/api/api.go index db0c9fc..57d69ab 100644 --- a/pkg/sql/api/api.go +++ b/pkg/sql/api/api.go @@ -7,7 +7,6 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/grafana/grafana-aws-sdk/pkg/awsds" - "github.com/grafana/grafana-aws-sdk/pkg/sql/models" "github.com/grafana/grafana-plugin-sdk-go/backend/log" "github.com/grafana/sqlds/v3" "github.com/jpillora/backoff" @@ -54,8 +53,6 @@ type AWSAPI interface { Resources } -type Loader func(cache *awsds.SessionCache, settings models.Settings) (AWSAPI, error) - // WaitOnQuery polls the datasource api until the query finishes, returning an error if it failed. func WaitOnQuery(ctx context.Context, api SQL, output *ExecuteQueryOutput) error { backoffInstance := backoff.Backoff{ diff --git a/pkg/sql/api/api_test.go b/pkg/sql/api/api_test.go index f04738b..676ba09 100644 --- a/pkg/sql/api/api_test.go +++ b/pkg/sql/api/api_test.go @@ -72,7 +72,7 @@ func TestWaitOnQuery(t *testing.T) { for _, tc := range tests { t.Run(tc.description, func(t *testing.T) { - err := WaitOnQuery(context.TODO(), tc.ds, &ExecuteQueryOutput{}) + err := WaitOnQuery(context.Background(), tc.ds, &ExecuteQueryOutput{}) if tc.ds.statusCounter != len(tc.ds.status) { t.Errorf("status not called the right amount of times. Want %d got %d", len(tc.ds.status), tc.ds.statusCounter) } diff --git a/pkg/sql/datasource/datasource.go b/pkg/sql/datasource/datasource.go index 33a614e..143e7b5 100644 --- a/pkg/sql/datasource/datasource.go +++ b/pkg/sql/datasource/datasource.go @@ -1,6 +1,7 @@ package datasource import ( + "context" "database/sql" "fmt" "sync" @@ -14,28 +15,45 @@ import ( "github.com/grafana/sqlds/v3" ) -// AWSDatasource stores a cache of several instances. +// AWSClient provides creation and caching of sessions, database connections, and API clients +type AWSClient interface { + Init(config backend.DataSourceInstanceSettings) + GetDB(ctx context.Context, id int64, options sqlds.Options) (*sql.DB, error) + GetAsyncDB(ctx context.Context, id int64, options sqlds.Options) (awsds.AsyncDB, error) + GetAPI(ctx context.Context, id int64, options sqlds.Options) (api.AWSAPI, error) +} + +type Loader interface { + LoadSettings(context.Context) models.Settings + LoadAPI(context.Context, *awsds.SessionCache, models.Settings) (api.AWSAPI, error) + LoadDriver(context.Context, api.AWSAPI) (driver.Driver, error) + LoadAsyncDriver(context.Context, api.AWSAPI) (asyncDriver.Driver, error) +} + +// awsClient provides creation and caching of several types of instances. // Each Map will depend on the datasource ID (and connection options): // - sessionCache: AWS cache. This is not a Map since it does not depend on the datasource. // - config: Base configuration. It will be used as base to populate datasource settings. // It does not depend on connection options (only one per datasource) // - api: API instance with the common methods to contact the data source API. -type AWSDatasource struct { +type awsClient struct { sessionCache *awsds.SessionCache config sync.Map api sync.Map + + loader Loader } -func New() *AWSDatasource { - ds := &AWSDatasource{sessionCache: awsds.NewSessionCache()} +func New(loader Loader) AWSClient { + ds := &awsClient{sessionCache: awsds.NewSessionCache(), loader: loader} return ds } -func (ds *AWSDatasource) storeConfig(config backend.DataSourceInstanceSettings) { +func (ds *awsClient) storeConfig(config backend.DataSourceInstanceSettings) { ds.config.Store(config.ID, config) } -func (ds *AWSDatasource) createDB(dr driver.Driver) (*sql.DB, error) { +func (ds *awsClient) createDB(dr driver.Driver) (*sql.DB, error) { db, err := dr.OpenDB() if err != nil { return nil, fmt.Errorf("%w: failed to connect to database (check hostname and port?)", err) @@ -44,7 +62,7 @@ func (ds *AWSDatasource) createDB(dr driver.Driver) (*sql.DB, error) { return db, nil } -func (ds *AWSDatasource) createAsyncDB(dr asyncDriver.Driver) (awsds.AsyncDB, error) { +func (ds *awsClient) createAsyncDB(dr asyncDriver.Driver) (awsds.AsyncDB, error) { db, err := dr.GetAsyncDB() if err != nil { return nil, fmt.Errorf("%w: failed to connect to database (check hostname and port)", err) @@ -53,12 +71,12 @@ func (ds *AWSDatasource) createAsyncDB(dr asyncDriver.Driver) (awsds.AsyncDB, er return db, nil } -func (ds *AWSDatasource) storeAPI(id int64, args sqlds.Options, dsAPI api.AWSAPI) { +func (ds *awsClient) storeAPI(id int64, args sqlds.Options, dsAPI api.AWSAPI) { key := connectionKey(id, args) ds.api.Store(key, dsAPI) } -func (ds *AWSDatasource) loadAPI(id int64, args sqlds.Options) (api.AWSAPI, bool) { +func (ds *awsClient) loadAPI(id int64, args sqlds.Options) (api.AWSAPI, bool) { key := connectionKey(id, args) dsAPI, exists := ds.api.Load(key) if exists { @@ -67,8 +85,8 @@ func (ds *AWSDatasource) loadAPI(id int64, args sqlds.Options) (api.AWSAPI, bool return nil, false } -func (ds *AWSDatasource) createAPI(id int64, args sqlds.Options, settings models.Settings, loader api.Loader) (api.AWSAPI, error) { - dsAPI, err := loader(ds.sessionCache, settings) +func (ds *awsClient) createAPI(ctx context.Context, id int64, args sqlds.Options, settings models.Settings) (api.AWSAPI, error) { + dsAPI, err := ds.loader.LoadAPI(ctx, ds.sessionCache, settings) if err != nil { return nil, fmt.Errorf("%w: Failed to create client", err) } @@ -76,8 +94,8 @@ func (ds *AWSDatasource) createAPI(id int64, args sqlds.Options, settings models return dsAPI, err } -func (ds *AWSDatasource) createDriver(dsAPI api.AWSAPI, loader driver.Loader) (driver.Driver, error) { - dr, err := loader(dsAPI) +func (ds *awsClient) createDriver(ctx context.Context, dsAPI api.AWSAPI) (driver.Driver, error) { + dr, err := ds.loader.LoadDriver(ctx, dsAPI) if err != nil { return nil, fmt.Errorf("%w: Failed to create client", err) } @@ -85,8 +103,8 @@ func (ds *AWSDatasource) createDriver(dsAPI api.AWSAPI, loader driver.Loader) (d return dr, nil } -func (ds *AWSDatasource) createAsyncDriver(dsAPI api.AWSAPI, loader asyncDriver.Loader) (asyncDriver.Driver, error) { - dr, err := loader(dsAPI) +func (ds *awsClient) createAsyncDriver(ctx context.Context, dsAPI api.AWSAPI) (asyncDriver.Driver, error) { + dr, err := ds.loader.LoadAsyncDriver(ctx, dsAPI) if err != nil { return nil, fmt.Errorf("%w: Failed to create client", err) } @@ -94,7 +112,7 @@ func (ds *AWSDatasource) createAsyncDriver(dsAPI api.AWSAPI, loader asyncDriver. return dr, nil } -func (ds *AWSDatasource) parseSettings(id int64, args sqlds.Options, settings models.Settings) error { +func (ds *awsClient) parseSettings(id int64, args sqlds.Options, settings models.Settings) error { config, ok := ds.config.Load(id) if !ok { return fmt.Errorf("unable to find stored configuration for datasource %d. Initialize it first", id) @@ -108,31 +126,29 @@ func (ds *AWSDatasource) parseSettings(id int64, args sqlds.Options, settings mo } // Init stores the data source configuration. It's needed for the GetDB and GetAPI functions -func (ds *AWSDatasource) Init(config backend.DataSourceInstanceSettings) { +func (ds *awsClient) Init(config backend.DataSourceInstanceSettings) { ds.storeConfig(config) } // GetDB returns a *sql.DB. It will use the loader functions to initialize the required // settings, API and driver and finally create a DB. -func (ds *AWSDatasource) GetDB( +func (ds *awsClient) GetDB( + ctx context.Context, id int64, options sqlds.Options, - settingsLoader models.Loader, - apiLoader api.Loader, - driverLoader driver.Loader, ) (*sql.DB, error) { - settings := settingsLoader() + settings := ds.loader.LoadSettings(ctx) err := ds.parseSettings(id, options, settings) if err != nil { return nil, err } - dsAPI, err := ds.createAPI(id, options, settings, apiLoader) + dsAPI, err := ds.createAPI(ctx, id, options, settings) if err != nil { return nil, err } - dr, err := ds.createDriver(dsAPI, driverLoader) + dr, err := ds.createDriver(ctx, dsAPI) if err != nil { return nil, err } @@ -142,25 +158,23 @@ func (ds *AWSDatasource) GetDB( // GetAsyncDB returns a sqlds.AsyncDB. It will use the loader functions to initialize the required // settings, API and driver and finally create a DB. -func (ds *AWSDatasource) GetAsyncDB( +func (ds *awsClient) GetAsyncDB( + ctx context.Context, id int64, options sqlds.Options, - settingsLoader models.Loader, - apiLoader api.Loader, - driverLoader asyncDriver.Loader, ) (awsds.AsyncDB, error) { - settings := settingsLoader() + settings := ds.loader.LoadSettings(ctx) err := ds.parseSettings(id, options, settings) if err != nil { return nil, err } - dsAPI, err := ds.createAPI(id, options, settings, apiLoader) + dsAPI, err := ds.createAPI(ctx, id, options, settings) if err != nil { return nil, err } - dr, err := ds.createAsyncDriver(dsAPI, driverLoader) + dr, err := ds.createAsyncDriver(ctx, dsAPI) if err != nil { return nil, err } @@ -171,11 +185,10 @@ func (ds *AWSDatasource) GetAsyncDB( // GetAPI returns an API interface. When called multiple times with the same id and options, it // will return a cached version of the API. The first time, it will use the loader // functions to initialize the required settings and API. -func (ds *AWSDatasource) GetAPI( +func (ds *awsClient) GetAPI( + ctx context.Context, id int64, options sqlds.Options, - settingsLoader models.Loader, - apiLoader api.Loader, ) (api.AWSAPI, error) { cachedAPI, exists := ds.loadAPI(id, options) if exists { @@ -183,10 +196,10 @@ func (ds *AWSDatasource) GetAPI( } // create new api - settings := settingsLoader() + settings := ds.loader.LoadSettings(ctx) err := ds.parseSettings(id, options, settings) if err != nil { return nil, err } - return ds.createAPI(id, options, settings, apiLoader) + return ds.createAPI(ctx, id, options, settings) } diff --git a/pkg/sql/datasource/datasource_test.go b/pkg/sql/datasource/datasource_test.go index fd5be42..21f5edd 100644 --- a/pkg/sql/datasource/datasource_test.go +++ b/pkg/sql/datasource/datasource_test.go @@ -1,8 +1,10 @@ package datasource import ( + "context" "database/sql" "database/sql/driver" + asyncDriver "github.com/grafana/grafana-aws-sdk/pkg/sql/driver/async" "testing" "github.com/google/go-cmp/cmp" @@ -14,9 +16,38 @@ import ( "github.com/grafana/sqlds/v3" ) +type fakeLoader struct { + driver sqlDriver.Driver +} + +func (m fakeLoader) LoadSettings(_ context.Context) models.Settings { + return &fakeSettings{} +} + +func (m fakeLoader) LoadAPI(_ context.Context, _ *awsds.SessionCache, _ models.Settings) (sqlApi.AWSAPI, error) { + return fakeAPI{}, nil +} + +func (m fakeLoader) LoadDriver(_ context.Context, _ sqlApi.AWSAPI) (sqlDriver.Driver, error) { + return m.driver, nil +} + +func (m fakeLoader) LoadAsyncDriver(_ context.Context, _ sqlApi.AWSAPI) (asyncDriver.Driver, error) { + return nil, nil +} +func newFakeLoader(db *sql.DB) Loader { + return fakeLoader{driver: &fakeDriver{db: db}} + +} + func TestNew(t *testing.T) { - ds := New() - if ds.sessionCache == nil { + ds := New(newFakeLoader(nil)) + impl, ok := ds.(*awsClient) + if !ok { + t.Errorf("unexpected underlying type: %t", ds) + } + + if impl.sessionCache == nil { t.Errorf("missing initialization") } } @@ -25,7 +56,7 @@ func TestInit(t *testing.T) { config := backend.DataSourceInstanceSettings{ ID: 100, } - ds := &AWSDatasource{} + ds := &awsClient{loader: newFakeLoader(nil)} ds.Init(config) if _, ok := ds.config.Load(config.ID); !ok { t.Errorf("missing config") @@ -87,7 +118,7 @@ func TestLoadAPI(t *testing.T) { for _, tt := range tests { t.Run(tt.description, func(t *testing.T) { - ds := &AWSDatasource{} + ds := &awsClient{loader: newFakeLoader(nil)} key := connectionKey(tt.id, tt.args) if tt.api != nil { ds.api.Store(key, tt.api) @@ -120,7 +151,7 @@ func (f *fakeSettings) Apply(args sqlds.Options) { func TestParseSettings(t *testing.T) { id := int64(1) args := sqlds.Options{"foo": "bar"} - ds := &AWSDatasource{} + ds := &awsClient{loader: newFakeLoader(nil)} ds.config.Store(id, backend.DataSourceInstanceSettings{ID: id}) settings := &fakeSettings{} @@ -136,18 +167,15 @@ func TestParseSettings(t *testing.T) { } } -func fakeAPILoader(_ *awsds.SessionCache, _ models.Settings) (sqlApi.AWSAPI, error) { - return fakeAPI{}, nil -} - func TestCreateAPI(t *testing.T) { id := int64(1) args := sqlds.Options{"foo": "bar"} - ds := &AWSDatasource{} + ds := &awsClient{loader: newFakeLoader(nil)} key := connectionKey(id, args) settings := &fakeSettings{} + ctx := context.Background() - api, err := ds.createAPI(id, args, settings, fakeAPILoader) + api, err := ds.createAPI(ctx, id, args, settings) if err != nil { t.Errorf("unexpected error %v", err) } @@ -160,15 +188,16 @@ func TestCreateAPI(t *testing.T) { } } -func fakeDriverLoader(sqlApi.AWSAPI) (sqlDriver.Driver, error) { - return &fakeDriver{db: &sql.DB{}}, nil -} - func TestCreateDriver(t *testing.T) { - ds := &AWSDatasource{} - api := fakeAPI{} + ctx := context.Background() + loader := newFakeLoader(nil) + ds := &awsClient{loader: loader} + api, err := ds.createAPI(ctx, 0, sqlds.Options{}, loader.LoadSettings(ctx)) + if err != nil { + t.Errorf("unexpected error %v", err) + } - dr, err := ds.createDriver(api, fakeDriverLoader) + dr, err := ds.createDriver(context.Background(), api) if err != nil { t.Errorf("unexpected error %v", err) } @@ -178,9 +207,9 @@ func TestCreateDriver(t *testing.T) { } func TestCreateDB(t *testing.T) { - ds := &AWSDatasource{} db := &sql.DB{} dr := &fakeDriver{db: db} + ds := &awsClient{loader: newFakeLoader(db)} res, err := ds.createDB(dr) if err != nil { @@ -191,18 +220,14 @@ func TestCreateDB(t *testing.T) { } } -func fakeSettingsLoader() models.Settings { - return &fakeSettings{} -} - func TestGetDB(t *testing.T) { id := int64(1) args := sqlds.Options{"foo": "bar"} - ds := &AWSDatasource{} + ds := &awsClient{loader: newFakeLoader(&sql.DB{})} config := backend.DataSourceInstanceSettings{ID: id} ds.Init(config) - res, err := ds.GetDB(config.ID, args, fakeSettingsLoader, fakeAPILoader, fakeDriverLoader) + res, err := ds.GetDB(context.Background(), config.ID, args) if err != nil { t.Errorf("unexpected error %v", err) } @@ -214,12 +239,12 @@ func TestGetDB(t *testing.T) { func TestGetAPI(t *testing.T) { id := int64(1) args := sqlds.Options{"foo": "bar"} - ds := &AWSDatasource{} + ds := &awsClient{loader: fakeLoader{}} config := backend.DataSourceInstanceSettings{ID: id} ds.Init(config) key := connectionKey(id, args) - api, err := ds.GetAPI(id, args, fakeSettingsLoader, fakeAPILoader) + api, err := ds.GetAPI(context.Background(), id, args) if err != nil { t.Errorf("unexpected error %v", err) } diff --git a/pkg/sql/datasource/utils_test.go b/pkg/sql/datasource/utils_test.go index 478c789..d2b6dfb 100644 --- a/pkg/sql/datasource/utils_test.go +++ b/pkg/sql/datasource/utils_test.go @@ -9,7 +9,7 @@ func TestGetDatasourceID(t *testing.T) { // It's not possible to test that GetDatasourceID returns an actual // ID because the ctx key is not exported. This just tests the fallback // path. - if id := GetDatasourceID(context.TODO()); id != 0 { + if id := GetDatasourceID(context.Background()); id != 0 { t.Errorf("unexpected id: %d", id) } } @@ -18,7 +18,7 @@ func TestGetDatasourceLastUpdatedTime(t *testing.T) { // It's not possible to test that GetDatasourceLastUpdatedTime returns an actual // time because the ctx key is not exported. This just tests the fallback // path. - if time := GetDatasourceLastUpdatedTime(context.TODO()); time != "" { + if time := GetDatasourceLastUpdatedTime(context.Background()); time != "" { t.Errorf("unexpected time: %s", time) } } From 65047541bce5f55aa5add502200b3e2782b4a476 Mon Sep 17 00:00:00 2001 From: Nathan Verzemnieks Date: Tue, 21 May 2024 15:40:39 +0200 Subject: [PATCH 2/2] Trying to fix failing builds --- .drone.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.drone.yml b/.drone.yml index 07ed244..cad04d6 100644 --- a/.drone.yml +++ b/.drone.yml @@ -13,17 +13,17 @@ steps: - name: build image: grafana/grafana-plugin-ci:1.9.5 commands: - - mage -v build + - mage --keep -v build - name: lint image: grafana/grafana-plugin-ci:1.9.5 commands: - - mage -v lint + - mage --keep -v lint - name: test image: grafana/grafana-plugin-ci:1.9.5 commands: - - mage -v test + - mage --keep -v test - name: vuln check image: golang:1.22