diff --git a/database/connection.go b/database/connection.go index 276e2d23..b1bac3ef 100644 --- a/database/connection.go +++ b/database/connection.go @@ -20,7 +20,6 @@ import ( "github.com/jinzhu/gorm" _ "github.com/jinzhu/gorm/dialects/postgres" _ "github.com/jinzhu/gorm/dialects/sqlite" - "fmt" "github.com/IBM/ubiquity/utils/logs" "errors" ) @@ -40,7 +39,7 @@ type ConnectionFactory interface { } type postgresFactory struct { - host string + psql string } type sqliteFactory struct { @@ -51,7 +50,7 @@ type testErrorFactory struct { } func (f *postgresFactory) newConnection() (*gorm.DB, error) { - return gorm.Open("postgres", fmt.Sprintf("host=%s user=postgres dbname=postgres sslmode=disable", f.host)) + return gorm.Open("postgres", f.psql) } func (f *sqliteFactory) newConnection() (*gorm.DB, error) { diff --git a/database/init.go b/database/init.go index 51f03880..df647b1e 100644 --- a/database/init.go +++ b/database/init.go @@ -21,12 +21,50 @@ import ( "github.com/IBM/ubiquity/utils/logs" ) -const KeyPsqlHost = "UBIQUITY_DB_PSQL_HOST" -const KeySqlitePath = "UBIQUITY_DB_SQLITE_PATH" +const ( + KeyPsqlHost = "UBIQUITY_DB_PSQL_HOST" + KeySqlitePath = "UBIQUITY_DB_SQLITE_PATH" + KeyPsqlUser = "UBIQUITY_DB_USER" + KeyPsqlPassword = "UBIQUITY_DB_PASSWORD" + KeyPsqlDbName = "UBIQUITY_DB_NAME" + KeyPsqlPort = "UBIQUITY_DB_PORT" +) + +func GetPsqlConnectionParams(hostname string) string { + // add host + str := "host=" + hostname + // add user + psqlUser := os.Getenv(KeyPsqlUser) + if psqlUser == "" { + psqlUser = "postgres" + } + str += " user=" + psqlUser + // add dbname + psqlDbName := os.Getenv(KeyPsqlDbName) + if psqlDbName == "" { + psqlDbName = "postgres" + } + str += " dbname=" + psqlDbName + // add password + psqlPassword := os.Getenv(KeyPsqlPassword) + if psqlPassword != "" { + str += " password=" + psqlPassword + } + // add port + psqlPort := os.Getenv(KeyPsqlPort) + if psqlPort != "" { + str += " port=" + psqlPort + } + return str +} + +func GetPsqlSslParams() string { + return "sslmode=disable" +} func InitPostgres(hostname string) func() { defer logs.GetLogger().Trace(logs.DEBUG)() - return initConnectionFactory(&postgresFactory{host: hostname}) + return initConnectionFactory(&postgresFactory{psql: GetPsqlConnectionParams(hostname) + " " + GetPsqlSslParams()}) } func InitSqlite(filepath string) func() { diff --git a/database/init_test.go b/database/init_test.go new file mode 100644 index 00000000..e2f9a910 --- /dev/null +++ b/database/init_test.go @@ -0,0 +1,84 @@ +/** + * Copyright 2017 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package database_test + + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "github.com/IBM/ubiquity/database" + "fmt" + "os" +) + + +var _ = Describe("Init", func() { + var ( + hostname string = "my-hostname" + defaultUser string = "postgres" + defaultDbName string = "postgres" + newUser string = "my-user" + newDbName string = "my-dbname" + newPassword string = "my-password" + newPort string = "my-port" + ) + BeforeEach(func() { + }) + + Context(".Postgres", func() { + It("only hostname", func() { + res := database.GetPsqlConnectionParams(hostname) + Expect(res).To(Equal(fmt.Sprintf("host=%s user=%s dbname=%s", hostname, defaultUser, defaultDbName))) + }) + It("hostname user", func() { + os.Setenv(database.KeyPsqlUser, newUser) + res := database.GetPsqlConnectionParams(hostname) + os.Unsetenv(database.KeyPsqlUser) + Expect(res).To(Equal(fmt.Sprintf("host=%s user=%s dbname=%s", hostname, newUser, defaultDbName))) + }) + It("hostname user dbname", func() { + os.Setenv(database.KeyPsqlUser, newUser) + os.Setenv(database.KeyPsqlDbName, newDbName) + res := database.GetPsqlConnectionParams(hostname) + os.Unsetenv(database.KeyPsqlUser) + os.Unsetenv(database.KeyPsqlDbName) + Expect(res).To(Equal(fmt.Sprintf("host=%s user=%s dbname=%s", hostname, newUser, newDbName))) + }) + It("hostname user dbname password", func() { + os.Setenv(database.KeyPsqlUser, newUser) + os.Setenv(database.KeyPsqlDbName, newDbName) + os.Setenv(database.KeyPsqlPassword, newPassword) + res := database.GetPsqlConnectionParams(hostname) + os.Unsetenv(database.KeyPsqlUser) + os.Unsetenv(database.KeyPsqlDbName) + os.Unsetenv(database.KeyPsqlPassword) + Expect(res).To(Equal(fmt.Sprintf("host=%s user=%s dbname=%s password=%s", hostname, newUser, newDbName, newPassword))) + }) + It("hostname user dbname password port", func() { + os.Setenv(database.KeyPsqlUser, newUser) + os.Setenv(database.KeyPsqlDbName, newDbName) + os.Setenv(database.KeyPsqlPassword, newPassword) + os.Setenv(database.KeyPsqlPort, newPort) + res := database.GetPsqlConnectionParams(hostname) + os.Unsetenv(database.KeyPsqlUser) + os.Unsetenv(database.KeyPsqlDbName) + os.Unsetenv(database.KeyPsqlPassword) + os.Unsetenv(database.KeyPsqlPort) + Expect(res).To(Equal(fmt.Sprintf("host=%s user=%s dbname=%s password=%s port=%s", hostname, newUser, newDbName, newPassword, newPort))) + }) + }) +})