-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2632 from hashicorp/cassandra-plugin
Add cassandra plugin
- Loading branch information
Showing
5 changed files
with
1,543 additions
and
3 deletions.
There are no files selected for viewing
16 changes: 16 additions & 0 deletions
16
plugins/database/cassandra/cassandra-database-plugin/main.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
package main | ||
|
||
import ( | ||
"fmt" | ||
"os" | ||
|
||
"github.com/hashicorp/vault/plugins/database/cassandra" | ||
) | ||
|
||
func main() { | ||
err := cassandra.Run() | ||
if err != nil { | ||
fmt.Println(err) | ||
os.Exit(1) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
package cassandra | ||
|
||
import ( | ||
"fmt" | ||
"strings" | ||
"time" | ||
|
||
"github.com/gocql/gocql" | ||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin" | ||
"github.com/hashicorp/vault/helper/strutil" | ||
"github.com/hashicorp/vault/plugins/helper/database/connutil" | ||
"github.com/hashicorp/vault/plugins/helper/database/credsutil" | ||
"github.com/hashicorp/vault/plugins/helper/database/dbutil" | ||
) | ||
|
||
const ( | ||
defaultCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;` | ||
defaultRollbackCQL = `DROP USER '{{username}}';` | ||
cassandraTypeName = "cassandra" | ||
) | ||
|
||
type Cassandra struct { | ||
connutil.ConnectionProducer | ||
credsutil.CredentialsProducer | ||
} | ||
|
||
func New() (interface{}, error) { | ||
connProducer := &connutil.CassandraConnectionProducer{} | ||
connProducer.Type = cassandraTypeName | ||
|
||
credsProducer := &credsutil.CassandraCredentialsProducer{} | ||
|
||
dbType := &Cassandra{ | ||
ConnectionProducer: connProducer, | ||
CredentialsProducer: credsProducer, | ||
} | ||
|
||
return dbType, nil | ||
} | ||
|
||
// Run instantiates a MySQL object, and runs the RPC server for the plugin | ||
func Run() error { | ||
dbType, err := New() | ||
if err != nil { | ||
return err | ||
} | ||
|
||
dbplugin.NewPluginServer(dbType.(*Cassandra)) | ||
|
||
return nil | ||
} | ||
|
||
func (c *Cassandra) Type() (string, error) { | ||
return cassandraTypeName, nil | ||
} | ||
|
||
func (c *Cassandra) getConnection() (*gocql.Session, error) { | ||
session, err := c.Connection() | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return session.(*gocql.Session), nil | ||
} | ||
|
||
func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { | ||
// Grab the lock | ||
c.Lock() | ||
defer c.Unlock() | ||
|
||
// Get the connection | ||
session, err := c.getConnection() | ||
if err != nil { | ||
return "", "", err | ||
} | ||
|
||
creationCQL := statements.CreationStatements | ||
if creationCQL == "" { | ||
creationCQL = defaultCreationCQL | ||
} | ||
rollbackCQL := statements.RollbackStatements | ||
if rollbackCQL == "" { | ||
rollbackCQL = defaultRollbackCQL | ||
} | ||
|
||
username, err = c.GenerateUsername(usernamePrefix) | ||
if err != nil { | ||
return "", "", err | ||
} | ||
|
||
password, err = c.GeneratePassword() | ||
if err != nil { | ||
return "", "", err | ||
} | ||
|
||
// Execute each query | ||
for _, query := range strutil.ParseArbitraryStringSlice(creationCQL, ";") { | ||
query = strings.TrimSpace(query) | ||
if len(query) == 0 { | ||
continue | ||
} | ||
|
||
err = session.Query(dbutil.QueryHelper(query, map[string]string{ | ||
"username": username, | ||
"password": password, | ||
})).Exec() | ||
if err != nil { | ||
for _, query := range strutil.ParseArbitraryStringSlice(rollbackCQL, ";") { | ||
query = strings.TrimSpace(query) | ||
if len(query) == 0 { | ||
continue | ||
} | ||
|
||
session.Query(dbutil.QueryHelper(query, map[string]string{ | ||
"username": username, | ||
"password": password, | ||
})).Exec() | ||
} | ||
return "", "", err | ||
} | ||
} | ||
|
||
return username, password, nil | ||
} | ||
|
||
func (c *Cassandra) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { | ||
// NOOP | ||
return nil | ||
} | ||
|
||
func (c *Cassandra) RevokeUser(statements dbplugin.Statements, username string) error { | ||
// Grab the lock | ||
c.Lock() | ||
defer c.Unlock() | ||
|
||
session, err := c.getConnection() | ||
if err != nil { | ||
return err | ||
} | ||
|
||
err = session.Query(fmt.Sprintf("DROP USER '%s'", username)).Exec() | ||
if err != nil { | ||
return fmt.Errorf("error removing user '%s': %s", username, err) | ||
} | ||
|
||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
package cassandra | ||
|
||
import ( | ||
"os" | ||
"strconv" | ||
"testing" | ||
"time" | ||
|
||
"fmt" | ||
|
||
"github.com/gocql/gocql" | ||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin" | ||
"github.com/hashicorp/vault/plugins/helper/database/connutil" | ||
dockertest "gopkg.in/ory-am/dockertest.v3" | ||
) | ||
|
||
func prepareCassandraTestContainer(t *testing.T) (cleanup func(), retURL string) { | ||
if os.Getenv("CASSANDRA_HOST") != "" { | ||
return func() {}, os.Getenv("CASSANDRA_HOST") | ||
} | ||
|
||
pool, err := dockertest.NewPool("") | ||
if err != nil { | ||
t.Fatalf("Failed to connect to docker: %s", err) | ||
} | ||
|
||
cwd, _ := os.Getwd() | ||
cassandraMountPath := fmt.Sprintf("%s/test-fixtures/:/etc/cassandra/", cwd) | ||
|
||
ro := &dockertest.RunOptions{ | ||
Repository: "cassandra", | ||
Tag: "latest", | ||
Mounts: []string{cassandraMountPath}, | ||
} | ||
resource, err := pool.RunWithOptions(ro) | ||
if err != nil { | ||
t.Fatalf("Could not start local cassandra docker container: %s", err) | ||
} | ||
|
||
cleanup = func() { | ||
err := pool.Purge(resource) | ||
if err != nil { | ||
t.Fatalf("Failed to cleanup local container: %s", err) | ||
} | ||
} | ||
|
||
retURL = fmt.Sprintf("localhost:%s", resource.GetPort("9042/tcp")) | ||
port, _ := strconv.Atoi(resource.GetPort("9042/tcp")) | ||
|
||
// exponential backoff-retry | ||
if err = pool.Retry(func() error { | ||
clusterConfig := gocql.NewCluster(retURL) | ||
clusterConfig.Authenticator = gocql.PasswordAuthenticator{ | ||
Username: "cassandra", | ||
Password: "cassandra", | ||
} | ||
clusterConfig.ProtoVersion = 4 | ||
clusterConfig.Port = port | ||
|
||
session, err := clusterConfig.CreateSession() | ||
if err != nil { | ||
return fmt.Errorf("error creating session: %s", err) | ||
} | ||
defer session.Close() | ||
return nil | ||
}); err != nil { | ||
t.Fatalf("Could not connect to cassandra docker container: %s", err) | ||
} | ||
return | ||
} | ||
|
||
func TestCassandra_Initialize(t *testing.T) { | ||
cleanup, connURL := prepareCassandraTestContainer(t) | ||
defer cleanup() | ||
|
||
connectionDetails := map[string]interface{}{ | ||
"hosts": connURL, | ||
"username": "cassandra", | ||
"password": "cassandra", | ||
"protocol_version": 4, | ||
} | ||
|
||
dbRaw, _ := New() | ||
db := dbRaw.(*Cassandra) | ||
connProducer := db.ConnectionProducer.(*connutil.CassandraConnectionProducer) | ||
|
||
err := db.Initialize(connectionDetails, true) | ||
if err != nil { | ||
t.Fatalf("err: %s", err) | ||
} | ||
|
||
if !connProducer.Initialized { | ||
t.Fatal("Database should be initalized") | ||
} | ||
|
||
err = db.Close() | ||
if err != nil { | ||
t.Fatalf("err: %s", err) | ||
} | ||
} | ||
|
||
func TestCassandra_CreateUser(t *testing.T) { | ||
cleanup, connURL := prepareCassandraTestContainer(t) | ||
defer cleanup() | ||
|
||
connectionDetails := map[string]interface{}{ | ||
"hosts": connURL, | ||
"username": "cassandra", | ||
"password": "cassandra", | ||
"protocol_version": 4, | ||
} | ||
|
||
dbRaw, _ := New() | ||
db := dbRaw.(*Cassandra) | ||
err := db.Initialize(connectionDetails, true) | ||
if err != nil { | ||
t.Fatalf("err: %s", err) | ||
} | ||
|
||
statements := dbplugin.Statements{ | ||
CreationStatements: testCassandraRole, | ||
} | ||
|
||
username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute)) | ||
if err != nil { | ||
t.Fatalf("err: %s", err) | ||
} | ||
|
||
if err := testCredsExist(t, connURL, username, password); err != nil { | ||
t.Fatalf("Could not connect with new credentials: %s", err) | ||
} | ||
} | ||
|
||
func TestMyCassandra_RenewUser(t *testing.T) { | ||
cleanup, connURL := prepareCassandraTestContainer(t) | ||
defer cleanup() | ||
|
||
connectionDetails := map[string]interface{}{ | ||
"hosts": connURL, | ||
"username": "cassandra", | ||
"password": "cassandra", | ||
"protocol_version": 4, | ||
} | ||
|
||
dbRaw, _ := New() | ||
db := dbRaw.(*Cassandra) | ||
err := db.Initialize(connectionDetails, true) | ||
if err != nil { | ||
t.Fatalf("err: %s", err) | ||
} | ||
|
||
statements := dbplugin.Statements{ | ||
CreationStatements: testCassandraRole, | ||
} | ||
|
||
username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute)) | ||
if err != nil { | ||
t.Fatalf("err: %s", err) | ||
} | ||
|
||
if err := testCredsExist(t, connURL, username, password); err != nil { | ||
t.Fatalf("Could not connect with new credentials: %s", err) | ||
} | ||
|
||
err = db.RenewUser(statements, username, time.Now().Add(time.Minute)) | ||
if err != nil { | ||
t.Fatalf("err: %s", err) | ||
} | ||
} | ||
|
||
func TestCassandra_RevokeUser(t *testing.T) { | ||
cleanup, connURL := prepareCassandraTestContainer(t) | ||
defer cleanup() | ||
|
||
connectionDetails := map[string]interface{}{ | ||
"hosts": connURL, | ||
"username": "cassandra", | ||
"password": "cassandra", | ||
"protocol_version": 4, | ||
} | ||
|
||
dbRaw, _ := New() | ||
db := dbRaw.(*Cassandra) | ||
err := db.Initialize(connectionDetails, true) | ||
if err != nil { | ||
t.Fatalf("err: %s", err) | ||
} | ||
|
||
statements := dbplugin.Statements{ | ||
CreationStatements: testCassandraRole, | ||
} | ||
|
||
username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute)) | ||
if err != nil { | ||
t.Fatalf("err: %s", err) | ||
} | ||
|
||
if err = testCredsExist(t, connURL, username, password); err != nil { | ||
t.Fatalf("Could not connect with new credentials: %s", err) | ||
} | ||
|
||
// Test default revoke statememts | ||
err = db.RevokeUser(statements, username) | ||
if err != nil { | ||
t.Fatalf("err: %s", err) | ||
} | ||
|
||
if err = testCredsExist(t, connURL, username, password); err == nil { | ||
t.Fatal("Credentials were not revoked") | ||
} | ||
} | ||
|
||
func testCredsExist(t testing.TB, connURL, username, password string) error { | ||
clusterConfig := gocql.NewCluster(connURL) | ||
clusterConfig.Authenticator = gocql.PasswordAuthenticator{ | ||
Username: username, | ||
Password: password, | ||
} | ||
clusterConfig.ProtoVersion = 4 | ||
|
||
session, err := clusterConfig.CreateSession() | ||
if err != nil { | ||
return fmt.Errorf("error creating session: %s", err) | ||
} | ||
defer session.Close() | ||
return nil | ||
} | ||
|
||
const testCassandraRole = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER; | ||
GRANT ALL PERMISSIONS ON ALL KEYSPACES TO {{username}};` |
Oops, something went wrong.