Skip to content

Commit

Permalink
allow setting the collation in auth handshake
Browse files Browse the repository at this point in the history
  • Loading branch information
dvilaverde committed Apr 26, 2024
1 parent 7c31dc4 commit 5427a8d
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 2 deletions.
12 changes: 11 additions & 1 deletion client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"encoding/binary"
"fmt"
"github.com/pingcap/tidb/pkg/parser/charset"

. "github.com/go-mysql-org/go-mysql/mysql"
"github.com/go-mysql-org/go-mysql/packet"
Expand Down Expand Up @@ -269,7 +270,16 @@ func (c *Conn) writeAuthHandshake() error {

// Charset [1 byte]
// use default collation id 33 here, is utf-8
data[12] = DEFAULT_COLLATION_ID
collationName := c.collation
if len(collationName) == 0 {
collationName = DEFAULT_COLLATION_NAME
}
collation, err := charset.GetCollationByName(collationName)
if err != nil {
return fmt.Errorf("invalid collation name %s", collationName)
}

data[12] = byte(collation.ID)

// SSL Connection Request Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
Expand Down
21 changes: 20 additions & 1 deletion client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ func TestClientSuite(t *testing.T) {
func (s *clientTestSuite) SetupSuite() {
var err error
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
s.c, err = Connect(addr, *testUser, *testPassword, "")
s.c, err = Connect(addr, *testUser, *testPassword, "", func(conn *Conn) {
// test the collation logic, but this is essentially a no-op since
// the collation set is the default value
_ = conn.SetCollation(mysql.DEFAULT_COLLATION_NAME)
})
require.NoError(s.T(), err)

var result *mysql.Result
Expand Down Expand Up @@ -228,6 +232,21 @@ func (s *clientTestSuite) TestConn_SetCharset() {
require.NoError(s.T(), err)
}

func (s *clientTestSuite) TestConn_SetCollationAfterConnect() {
err := s.c.SetCollation("latin1_swedish_ci")
require.Error(s.T(), err)
}

func (s *clientTestSuite) TestConn_SetCollation() {
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
_, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) {
// test the collation logic
_ = conn.SetCollation("invalid_collation")
})

require.Error(s.T(), err)
}

func (s *clientTestSuite) testStmt_DropTable() {
str := `drop table if exists mixer_test_stmt`

Expand Down
16 changes: 16 additions & 0 deletions client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ type Conn struct {
status uint16

charset string
// sets the collation to be set on the auth handshake, this does not issue a 'set names' command
collation string

salt []byte
authPluginName string
Expand Down Expand Up @@ -357,6 +359,20 @@ func (c *Conn) SetCharset(charset string) error {
}
}

func (c *Conn) SetCollation(collation string) error {
if c.status == 0 {
c.collation = collation
} else {
return errors.Trace(errors.Errorf("cannot set collation after connection is established"))
}

return nil
}

func (c *Conn) GetCollation() string {
return c.collation
}

func (c *Conn) FieldList(table string, wildcard string) ([]*Field, error) {
if err := c.writeCommandStrStr(COM_FIELD_LIST, table, wildcard); err != nil {
return nil, errors.Trace(err)
Expand Down

0 comments on commit 5427a8d

Please sign in to comment.