From cee8f7dd22c76b109ebf3bb52fe74e74fa65175b Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Sat, 26 May 2018 00:53:05 +0200 Subject: [PATCH] auth: allow using a local server pub key --- README.md | 13 ++++ auth.go | 129 ++++++++++++++++++++++++++++++-------- auth_test.go | 172 +++++++++++++++++++++++++++++++++++++++++++++++++++ dsn.go | 27 ++++++++ dsn_test.go | 48 +++++++++++++- 5 files changed, 360 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 299198d53..f5cd486c5 100644 --- a/README.md +++ b/README.md @@ -300,6 +300,19 @@ other cases. You should ensure your application will never cause an ERROR 1290 except for `read-only` mode when enabling this option. +##### `serverPubKey` + +``` +Type: string +Valid Values: +Default: none +``` + +Server public keys can be registered with [`mysql.RegisterServerPubKey`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterServerPubKey), which can then be used by the assigned name in the DSN. +Public keys are used to transmit encrypted data, e.g. for authentication. +If the server's public key is known, it should be set manually to avoid expensive and potentially insecure transmissions of the public key from the server to the client each time it is required. + + ##### `timeout` ``` diff --git a/auth.go b/auth.go index e6b45ccae..0b59f52ee 100644 --- a/auth.go +++ b/auth.go @@ -15,8 +15,72 @@ import ( "crypto/sha256" "crypto/x509" "encoding/pem" + "sync" ) +// server pub keys registry +var ( + serverPubKeyLock sync.RWMutex + serverPubKeyRegistry map[string]*rsa.PublicKey +) + +// RegisterServerPubKey registers a server RSA public key which can be used to +// send data in a secure manner to the server without receiving the public key +// in a potentially insecure way from the server first. +// Registered keys can afterwards be used adding serverPubKey= to the DSN. +// +// Note: The provided rsa.PublicKey instance is exclusively owned by the driver +// after registering it and may not be modified. +// +// data, err := ioutil.ReadFile("mykey.pem") +// if err != nil { +// log.Fatal(err) +// } +// +// block, _ := pem.Decode(data) +// if block == nil || block.Type != "PUBLIC KEY" { +// log.Fatal("failed to decode PEM block containing public key") +// } +// +// pub, err := x509.ParsePKIXPublicKey(block.Bytes) +// if err != nil { +// log.Fatal(err) +// } +// +// if rsaPubKey, ok := pub.(*rsa.PublicKey); ok { +// mysql.RegisterServerPubKey("mykey", rsaPubKey) +// } else { +// log.Fatal("not a RSA public key") +// } +// +func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) { + serverPubKeyLock.Lock() + if serverPubKeyRegistry == nil { + serverPubKeyRegistry = make(map[string]*rsa.PublicKey) + } + + serverPubKeyRegistry[name] = pubKey + serverPubKeyLock.Unlock() +} + +// DeregisterServerPubKey removes the public key registered with the given name. +func DeregisterServerPubKey(name string) { + serverPubKeyLock.Lock() + if serverPubKeyRegistry != nil { + delete(serverPubKeyRegistry, name) + } + serverPubKeyLock.Unlock() +} + +func getServerPubKey(name string) (pubKey *rsa.PublicKey) { + serverPubKeyLock.RLock() + if v, ok := serverPubKeyRegistry[name]; ok { + pubKey = v + } + serverPubKeyLock.RUnlock() + return +} + // Hash password using pre 4.1 (old password) method // https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c type myRnd struct { @@ -154,19 +218,22 @@ func scrambleSHA256Password(scramble []byte, password string) []byte { return message1 } -func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error { - plain := make([]byte, len(mc.cfg.Passwd)+1) - copy(plain, mc.cfg.Passwd) +func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { + plain := make([]byte, len(password)+1) + copy(plain, password) for i := range plain { j := i % len(seed) plain[i] ^= seed[j] } sha1 := sha1.New() - enc, err := rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil) + return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil) +} + +func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error { + enc, err := encryptPassword(mc.cfg.Passwd, seed, pub) if err != nil { return err } - return mc.writeAuthSwitchPacket(enc, false) } @@ -211,9 +278,16 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) // write cleartext auth packet return []byte(mc.cfg.Passwd), true, nil } - // request public key - // TODO: allow to specify a local file with the pub key via the DSN - return []byte{1}, false, nil + + pubKey := mc.cfg.pubKey + if pubKey == nil { + // request public key from server + return []byte{1}, false, nil + } + + // encrypted password + enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) + return enc, false, err default: errLog.Print("unknown auth plugin:", plugin) @@ -283,28 +357,29 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { return err } } else { - // TODO: allow to specify a local file with the pub key via - // the DSN - - // request public key - data := mc.buf.takeSmallBuffer(4 + 1) - data[4] = cachingSha2PasswordRequestPublicKey - mc.writePacket(data) - - // parse public key - data, err := mc.readPacket() - if err != nil { - return err - } - - block, _ := pem.Decode(data[1:]) - pub, err := x509.ParsePKIXPublicKey(block.Bytes) - if err != nil { - return err + pubKey := mc.cfg.pubKey + if pubKey == nil { + // request public key from server + data := mc.buf.takeSmallBuffer(4 + 1) + data[4] = cachingSha2PasswordRequestPublicKey + mc.writePacket(data) + + // parse public key + data, err := mc.readPacket() + if err != nil { + return err + } + + block, _ := pem.Decode(data[1:]) + pkix, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return err + } + pubKey = pkix.(*rsa.PublicKey) } // send encrypted password - err = mc.sendEncryptedPassword(oldAuthData, pub.(*rsa.PublicKey)) + err = mc.sendEncryptedPassword(oldAuthData, pubKey) if err != nil { return err } diff --git a/auth_test.go b/auth_test.go index 6197580b1..407363be4 100644 --- a/auth_test.go +++ b/auth_test.go @@ -10,7 +10,10 @@ package mysql import ( "bytes" + "crypto/rsa" "crypto/tls" + "crypto/x509" + "encoding/pem" "fmt" "testing" ) @@ -25,6 +28,17 @@ var testPubKey = []byte("-----BEGIN PUBLIC KEY-----\n" + "WQIDAQAB\n" + "-----END PUBLIC KEY-----\n") +var testPubKeyRSA *rsa.PublicKey + +func init() { + block, _ := pem.Decode(testPubKey) + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + panic(err) + } + testPubKeyRSA = pub.(*rsa.PublicKey) +} + func TestScrambleOldPass(t *testing.T) { scramble := []byte{9, 8, 7, 6, 5, 4, 3, 2} vectors := []struct { @@ -203,6 +217,59 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { } } +func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + mc.cfg.pubKey = testPubKeyRSA + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "caching_sha2_password" + + // Send Client Authentication Packet + authResp, addNUL, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165, + 49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70, + 110, 40, 139, 124, 41} + if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 2, 0, 0, 2, 1, 4, // Perform Full Authentication + } + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 2 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { conn, mc := newRWMockConn(1) mc.cfg.User = "root" @@ -558,6 +625,36 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) { } } +func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + mc.cfg.pubKey = testPubKeyRSA + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "sha256_password" + + // Send Client Authentication Packet + authResp, addNUL, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + if err != nil { + t.Fatal(err) + } + + // auth response (OK) + conn.data = []byte{7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0} + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + func TestAuthFastSHA256PasswordSecure(t *testing.T) { conn, mc := newRWMockConn(1) mc.cfg.User = "root" @@ -720,6 +817,48 @@ func TestAuthSwitchCachingSHA256PasswordFullRSA(t *testing.T) { } } +func TestAuthSwitchCachingSHA256PasswordFullRSAWithKey(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + mc.cfg.pubKey = testPubKeyRSA + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + conn.queuedReplies = [][]byte{ + // Perform Full Authentication + {2, 0, 0, 4, 1, 4}, + + // OK + {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 3 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Hash + 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, + 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, + 153, 9, 130, + + // 2. Packet: Encrypted Password + 0, 1, 0, 5, // [changing bytes] + } + if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + func TestAuthSwitchCachingSHA256PasswordFullSecure(t *testing.T) { conn, mc := newRWMockConn(2) mc.cfg.Passwd = "secret" @@ -1051,6 +1190,39 @@ func TestAuthSwitchSHA256PasswordRSA(t *testing.T) { } } +func TestAuthSwitchSHA256PasswordRSAWithKey(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + mc.cfg.pubKey = testPubKeyRSA + + // auth switch request + conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, + 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, + 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} + + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 2 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Encrypted Password + 0, 1, 0, 3, // [changing bytes] + } + if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + func TestAuthSwitchSHA256PasswordSecure(t *testing.T) { conn, mc := newRWMockConn(2) mc.cfg.Passwd = "secret" diff --git a/dsn.go b/dsn.go index 47eab6945..be014babe 100644 --- a/dsn.go +++ b/dsn.go @@ -10,6 +10,7 @@ package mysql import ( "bytes" + "crypto/rsa" "crypto/tls" "errors" "fmt" @@ -41,6 +42,8 @@ type Config struct { Collation string // Connection collation Loc *time.Location // Location for time.Time values MaxAllowedPacket int // Max packet size allowed + ServerPubKey string // Server public key name + pubKey *rsa.PublicKey // Server public key TLSConfig string // TLS configuration name tls *tls.Config // TLS configuration Timeout time.Duration // Dial timeout @@ -254,6 +257,16 @@ func (cfg *Config) FormatDSN() string { } } + if len(cfg.ServerPubKey) > 0 { + if hasParam { + buf.WriteString("&serverPubKey=") + } else { + hasParam = true + buf.WriteString("?serverPubKey=") + } + buf.WriteString(url.QueryEscape(cfg.ServerPubKey)) + } + if cfg.Timeout > 0 { if hasParam { buf.WriteString("&timeout=") @@ -512,6 +525,20 @@ func parseDSNParams(cfg *Config, params string) (err error) { return errors.New("invalid bool value: " + value) } + // Server public key + case "serverPubKey": + name, err := url.QueryUnescape(value) + if err != nil { + return fmt.Errorf("invalid value for server pub key name: %v", err) + } + + if pubKey := getServerPubKey(name); pubKey != nil { + cfg.ServerPubKey = name + cfg.pubKey = pubKey + } else { + return errors.New("invalid value / unknown server pub key name: " + name) + } + // Strict mode case "strict": panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode") diff --git a/dsn_test.go b/dsn_test.go index 7507d1201..1cd095496 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -135,11 +135,56 @@ func TestDSNReformat(t *testing.T) { } } +func TestDSNServerPubKey(t *testing.T) { + baseDSN := "User:password@tcp(localhost:5555)/dbname?serverPubKey=" + + RegisterServerPubKey("testKey", testPubKeyRSA) + defer DeregisterServerPubKey("testKey") + + tst := baseDSN + "testKey" + cfg, err := ParseDSN(tst) + if err != nil { + t.Error(err.Error()) + } + + if cfg.ServerPubKey != "testKey" { + t.Errorf("unexpected cfg.ServerPubKey value: %v", cfg.ServerPubKey) + } + if cfg.pubKey != testPubKeyRSA { + t.Error("pub key pointer doesn't match") + } + + // Key is missing + tst = baseDSN + "invalid_name" + cfg, err = ParseDSN(tst) + if err == nil { + t.Errorf("invalid name in DSN (%s) but did not error. Got config: %#v", tst, cfg) + } +} + +func TestDSNServerPubKeyQueryEscape(t *testing.T) { + const name = "&%!:" + dsn := "User:password@tcp(localhost:5555)/dbname?serverPubKey=" + url.QueryEscape(name) + + RegisterServerPubKey(name, testPubKeyRSA) + defer DeregisterServerPubKey(name) + + cfg, err := ParseDSN(dsn) + if err != nil { + t.Error(err.Error()) + } + + if cfg.pubKey != testPubKeyRSA { + t.Error("pub key pointer doesn't match") + } +} + func TestDSNWithCustomTLS(t *testing.T) { baseDSN := "User:password@tcp(localhost:5555)/dbname?tls=" tlsCfg := tls.Config{} RegisterTLSConfig("utils_test", &tlsCfg) + defer DeregisterTLSConfig("utils_test") // Custom TLS is missing tst := baseDSN + "invalid_tls" @@ -173,8 +218,6 @@ func TestDSNWithCustomTLS(t *testing.T) { } else if tlsCfg.ServerName != "" { t.Errorf("tlsCfg was mutated ServerName (%s) should be empty parsing DSN (%s).", name, tst) } - - DeregisterTLSConfig("utils_test") } func TestDSNTLSConfig(t *testing.T) { @@ -212,6 +255,7 @@ func TestDSNWithCustomTLSQueryEscape(t *testing.T) { tlsCfg := tls.Config{ServerName: name} RegisterTLSConfig(configKey, &tlsCfg) + defer DeregisterTLSConfig(configKey) cfg, err := ParseDSN(dsn)