Skip to content

Commit

Permalink
Refactor(tenants): apikey store to use varchar array (#114)
Browse files Browse the repository at this point in the history
* refactor apikey store to use varchar array

* Drop default permissons after migration
  • Loading branch information
TimVosch authored Jun 25, 2024
1 parent c051239 commit 7609362
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 135 deletions.
2 changes: 1 addition & 1 deletion services/tenants/apikeys/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (s *Service) GenerateNewApiKey(ctx context.Context, name string, tenantId i
}
existing, err := s.apiKeyStore.GetHashedAPIKeyByNameAndTenantID(name, tenantId)
if err != nil && err != ErrKeyNotFound {
return "", err
return "", fmt.Errorf("in GenerateNewApiKey, could not check for existing key due to err: %w", err)
}
if existing.ID > 0 {
return "", ErrKeyNameTenantIDCombinationNotUnique
Expand Down
238 changes: 106 additions & 132 deletions services/tenants/infrastructure/apikeys_store_psql.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
package tenantsinfra

import (
"context"
"errors"
"fmt"
"time"

sq "github.com/Masterminds/squirrel"
"github.com/jackc/pgx/v5"
pgt "github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/stdlib"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/jmoiron/sqlx"

Expand Down Expand Up @@ -35,15 +40,6 @@ func (as *ApiKeyStore) List(filter apikeys.Filter, r pagination.Request) (*pagin
return nil, fmt.Errorf("could not getcursor from pagination request: %w", err)
}

keyQuery := sq.Select("*").From("api_keys keys")
if len(filter.TenantID) > 0 {
keyQuery = keyQuery.Where(sq.Eq{"keys.tenant_id": filter.TenantID})
}
keyQuery, err = pagination.Apply(keyQuery, cursor)
if err != nil {
return nil, fmt.Errorf("could not apply pagination: %w", err)
}

q := sq.
Select(
"keys.id",
Expand All @@ -52,86 +48,83 @@ func (as *ApiKeyStore) List(filter apikeys.Filter, r pagination.Request) (*pagin
"keys.created",
"keys.tenant_id",
"tenants.name",
"permissions.permission",
"keys.created",
"keys.id",
"keys.permissions",
).
FromSelect(keyQuery, "keys").
LeftJoin("tenants on keys.tenant_id = tenants.id").
LeftJoin("api_key_permissions permissions on keys.id = permissions.api_key_id")

rows, err := q.PlaceholderFormat(sq.Dollar).RunWith(as.db).Query()
From("api_keys keys").
LeftJoin("tenants on keys.tenant_id = tenants.id")
if len(filter.TenantID) > 0 {
q = q.Where(sq.Eq{"keys.tenant_id": filter.TenantID})
}
q, err = pagination.Apply(q, cursor)
if err != nil {
return nil, fmt.Errorf("error running database query: %w", err)
return nil, fmt.Errorf("could not apply pagination: %w", err)
}
defer rows.Close()

// For each key there are multiple records so all the permissions can be listed
// Track at which API key the rows are currently so we can add the correct permissions and move to the next key
list := make([]apikeys.ApiKeyDTO, 0, cursor.Limit)
currentId := -1
lastId := -1
currentPermissions := []string{}
for rows.Next() {
key := apikeys.ApiKeyDTO{}
permission := ""
err = rows.Scan(
&key.ID,
&key.Name,
&key.ExpirationDate,
&key.Created,
&key.TenantID,
&key.TenantName,
&permission,
&cursor.Columns.Created,
&cursor.Columns.KeyID,
)
// TODO: This is a hack, it grabs the underlying PGX connection to scan the row
// it is sort of required to scan the array of permission as that column is a postgres
// special type
c, err := as.db.Conn(context.Background())
if err != nil {
return nil, fmt.Errorf("in GetTenantMember, could not get raw db conn: %w", err)
}
defer c.Close()
err = c.Raw(func(driverConn any) error {
stdlibConn, ok := driverConn.(*stdlib.Conn)
if !ok {
return errors.New("in GetTenantMember, expected driverConnection to be of type stdlib.Conn")
}
conn := stdlibConn.Conn()
sql, args, err := q.PlaceholderFormat(sq.Dollar).ToSql()
if err != nil {
return nil, fmt.Errorf("error scanning row into api key w/ permission: %w", err)
return err
}
if lastId != currentId {
// Started scanning new API key record
currentPermissions = []string{}
rows, err := conn.Query(context.TODO(), sql, args...)
if err != nil {
return err
}
currentPermissions = append(currentPermissions, permission)
lastId = currentId
// TODO: These permissions should most likely be validated
if len(list) > 0 && list[len(list)-1].ID == key.ID {
// Still at the same key, append the permission to the last key
list[len(list)-1].Permissions = append(list[len(list)-1].Permissions, auth.Permission(permission))
} else {
// Otherwise the result set arrived at a new api key
key.Permissions = auth.Permissions{auth.Permission(permission)}
defer rows.Close()
for rows.Next() {
var key apikeys.ApiKeyDTO
var permissions pgt.FlatArray[auth.Permission]
err := rows.Scan(
&key.ID, &key.Name, &key.ExpirationDate, &key.Created, &key.TenantID, &key.TenantName,
&permissions,
&cursor.Columns.Created,
&cursor.Columns.KeyID,
)
if err != nil {
return err
}
key.Permissions = auth.Permissions(permissions)
if err := key.Permissions.Validate(); err != nil {
return err
}
list = append(list, key)
}
return nil
})
if err != nil {
return nil, err
}

page := pagination.CreatePageT(list, cursor)
return &page, nil
}

func (as *ApiKeyStore) AddApiKey(tenantID int64, permissions auth.Permissions, hashedKey apikeys.HashedApiKey) error {
// Create the insert statement for the permissions which must ran with the insert API key query
apiKeyPermissionsQ := sq.Insert("api_key_permissions").
Columns("permission", "api_key_id")
for _, permission := range permissions {
apiKeyPermissionsQ = apiKeyPermissionsQ.Values(permission, sq.Select("id").From("new_api_key").Prefix("(").Suffix(")"))
}

// Create the insert API key query
q := sq.Insert("api_keys").
Columns("id", "name", "created", "tenant_id", "value", "expiration_date").
Columns("id", "name", "created", "tenant_id", "value", "expiration_date", "permissions").
Values(
hashedKey.ID,
hashedKey.Name,
time.Now().UTC(),
tenantID,
hashedKey.SecretHash,
hashedKey.ExpirationDate).
Prefix("WITH new_api_key AS (").
Suffix("RETURNING \"id\")").

// Run the API key permission query along with the insert API key query
SuffixExpr(apiKeyPermissionsQ)
hashedKey.ExpirationDate,
permissions,
)
_, err := q.PlaceholderFormat(sq.Dollar).RunWith(as.db).Exec()
if err != nil {
return err
Expand Down Expand Up @@ -159,83 +152,64 @@ func (as *ApiKeyStore) DeleteApiKey(id int64) error {
// Retrieves the hashed value of an API key, if the key is not found an ErrKeyNotFound is returned.
// Only returns the API key if the given tenant confirms to any state passed in the stateFilter
func (as *ApiKeyStore) GetHashedApiKeyById(id int64, stateFilter []tenants.State) (apikeys.HashedApiKey, error) {
q := sq.
Select(
"key.id", "key.name", "key.value", "key.expiration_date",
"key.tenant_id",
"perm.permission",
).
Where(sq.Eq{"key.id": id}).
From("api_keys key").
LeftJoin("api_key_permissions perm on perm.api_key_id = key.id")
if len(stateFilter) > 0 {
q = q.
LeftJoin("tenants on tenants.id = key.tenant_id").
Where(sq.Eq{"tenants.state": stateFilter})
}

rows, err := q.PlaceholderFormat(sq.Dollar).RunWith(as.db).Query()
if err != nil {
return apikeys.HashedApiKey{}, err
}
defer rows.Close()

key := apikeys.HashedApiKey{
Permissions: auth.Permissions{},
}
for rows.Next() {
var permission auth.Permission
err = rows.Scan(
&key.ID,
&key.Name,
&key.SecretHash,
&key.ExpirationDate,
&key.TenantID,
&permission,
)
if err != nil {
return key, err
}
key.Permissions = append(key.Permissions, permission)
}
if key.ID == 0 {
return key, apikeys.ErrKeyNotFound
}
return key, nil
return as.getAPIKey(func(q sq.SelectBuilder) sq.SelectBuilder {
return q.Where(sq.Eq{"tenants.state": stateFilter})
})
}

// Retrieves the hashed value of an API key, if the key is not found an ErrKeyNotFound is returned.
func (as *ApiKeyStore) GetHashedAPIKeyByNameAndTenantID(name string, tenantID int64) (apikeys.HashedApiKey, error) {
q := sq.
Select("id, value, tenant_id, expiration_date").
From("api_keys").
Where(sq.Eq{"name": name, "tenant_id": tenantID})
rows, err := q.PlaceholderFormat(sq.Dollar).RunWith(as.db).Query()
return as.getAPIKey(func(q sq.SelectBuilder) sq.SelectBuilder {
return q.Where(sq.Eq{"keys.name": name, "keys.tenant_id": tenantID})
})
}

func (as *ApiKeyStore) getAPIKey(mod func(q sq.SelectBuilder) sq.SelectBuilder) (apikeys.HashedApiKey, error) {
var key apikeys.HashedApiKey
var permissions pgt.FlatArray[auth.Permission]
q := sq.Select(
"keys.id", "keys.value", "keys.expiration_date", "keys.tenant_id", "keys.permissions",
).From("api_keys keys").LeftJoin("tenants on keys.tenant_id = tenants.id")
q = mod(q)
// TODO: This is a hack, it grabs the underlying PGX connection to scan the row
// it is sort of required to scan the array of permission as that column is a postgres
// special type
c, err := as.db.Conn(context.Background())
if err != nil {
return apikeys.HashedApiKey{}, err
return key, fmt.Errorf("in GetTenantMember, could not get raw db conn: %w", err)
}
defer rows.Close()
k := apikeys.HashedApiKey{
Permissions: auth.Permissions{},
}
for rows.Next() {
var permission auth.Permission
err = rows.Scan(
&k.ID,
&k.SecretHash,
&k.TenantID,
&k.ExpirationDate,
&permission,
defer c.Close()
err = c.Raw(func(driverConn any) error {
stdlibConn, ok := driverConn.(*stdlib.Conn)
if !ok {
return errors.New("in GetHashedAPIKeyByNameAndTenantID, expected driverConnection to be of type stdlib.Conn")
}
conn := stdlibConn.Conn()
query, args, err := q.PlaceholderFormat(sq.Dollar).ToSql()
if err != nil {
return fmt.Errorf("in GetHashedAPIKeyByNameAndTenantID, could not build query: %w", err)
}
row := conn.QueryRow(context.TODO(), query, args...)
err = row.Scan(
&key.ID, &key.SecretHash, &key.ExpirationDate, &key.TenantID,
&permissions,
)
if errors.Is(err, pgx.ErrNoRows) {
return apikeys.ErrKeyNotFound
}
if err != nil {
return apikeys.HashedApiKey{}, err
return fmt.Errorf("in GetHashedAPIKeyByNameAndTenantID, could not scan row: %w", err)
}
k.Permissions = append(k.Permissions, permission)
}
if k.ID == 0 {
return apikeys.HashedApiKey{}, apikeys.ErrKeyNotFound
key.Permissions = auth.Permissions(permissions)
if err := key.Permissions.Validate(); err != nil {
return fmt.Errorf("in GetHashedAPIKeyByNameAndTenantID, invalid permissions: %w", err)
}
return nil
})
if err != nil {
return key, err
}
return k, nil
return key, nil
}

type ApiKeyStore struct {
Expand Down
2 changes: 1 addition & 1 deletion services/tenants/infrastructure/tenants_store_psql.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ func (ts *PSQLTenantStore) Update(tenant *tenants.Tenant) error {
}
}

if rb := tx.Commit(); err != nil {
if rb := tx.Commit(); rb != nil {
return fmt.Errorf("commit error: %w", rb)
}
return nil
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
INSERT INTO api_key_permissions (api_key_id, permission)
SELECT ak.id, UNNEST(ak.permissions) AS permission
FROM api_keys ak;
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
BEGIN;

alter table api_keys add column permissions VARCHAR[] NOT NULL DEFAULT '{}';

UPDATE api_keys ak
SET permissions = (
SELECT ARRAY_AGG(akp.permission)
FROM api_key_permissions akp
WHERE akp.api_key_id = ak.id
);

alter table api_keys alter column permissions drop default;

drop table api_key_permissions;

COMMIT;
2 changes: 1 addition & 1 deletion tools/docker-compose/oathkeeper_config/rules.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
{
"id": "passthrough-authentication",
"match": {
"url": "http://<127.0.0.1|localhost>:3000/<(\\.ory|tenants/auth/settings|tenants/auth/login|tenants/auth/logout|tenants/static|dev)(/.+)?>",
"url": "http://<127.0.0.1|localhost>:3000/<(\\.ory|tenants/auth/settings|tenants/auth/login|tenants/auth/logout|tenants/auth/recovery|tenants/static|dev)(/.+)?>",
"methods": [
"GET","POST","PATCH","PUT","DELETE"
]
Expand Down

0 comments on commit 7609362

Please sign in to comment.