diff --git a/client/auth.go b/client/auth.go index e4fa908d3..7392f8fdd 100644 --- a/client/auth.go +++ b/client/auth.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "encoding/binary" "fmt" + "github.com/pingcap/tidb/pkg/parser/charset" . "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/packet" @@ -269,7 +270,16 @@ func (c *Conn) writeAuthHandshake() error { // Charset [1 byte] // use default collation id 33 here, is utf-8 - data[12] = DEFAULT_COLLATION_ID + collationName := c.collation + if len(collationName) == 0 { + collationName = DEFAULT_COLLATION_NAME + } + collation, err := charset.GetCollationByName(collationName) + if err != nil { + return fmt.Errorf("invalid collation name %s", collationName) + } + + data[12] = byte(collation.ID) // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest diff --git a/client/client_test.go b/client/client_test.go index c47c795ef..b27c4c669 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -31,7 +31,11 @@ func TestClientSuite(t *testing.T) { func (s *clientTestSuite) SetupSuite() { var err error addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) - s.c, err = Connect(addr, *testUser, *testPassword, "") + s.c, err = Connect(addr, *testUser, *testPassword, "", func(conn *Conn) { + // test the collation logic, but this is essentially a no-op since + // the collation set is the default value + _ = conn.SetCollation(mysql.DEFAULT_COLLATION_NAME) + }) require.NoError(s.T(), err) var result *mysql.Result @@ -228,6 +232,21 @@ func (s *clientTestSuite) TestConn_SetCharset() { require.NoError(s.T(), err) } +func (s *clientTestSuite) TestConn_SetCollationAfterConnect() { + err := s.c.SetCollation("latin1_swedish_ci") + require.Error(s.T(), err) +} + +func (s *clientTestSuite) TestConn_SetCollation() { + addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port) + _, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) { + // test the collation logic + _ = conn.SetCollation("invalid_collation") + }) + + require.Error(s.T(), err) +} + func (s *clientTestSuite) testStmt_DropTable() { str := `drop table if exists mixer_test_stmt` diff --git a/client/conn.go b/client/conn.go index b1f3e52d1..1db021762 100644 --- a/client/conn.go +++ b/client/conn.go @@ -37,6 +37,8 @@ type Conn struct { status uint16 charset string + // sets the collation to be set on the auth handshake, this does not issue a 'set names' command + collation string salt []byte authPluginName string @@ -357,6 +359,20 @@ func (c *Conn) SetCharset(charset string) error { } } +func (c *Conn) SetCollation(collation string) error { + if c.status == 0 { + c.collation = collation + } else { + return errors.Trace(errors.Errorf("cannot set collation after connection is established")) + } + + return nil +} + +func (c *Conn) GetCollation() string { + return c.collation +} + func (c *Conn) FieldList(table string, wildcard string) ([]*Field, error) { if err := c.writeCommandStrStr(COM_FIELD_LIST, table, wildcard); err != nil { return nil, errors.Trace(err)