Skip to content

Commit

Permalink
*: Support for tidb_sm3_password authentication (#36193)
Browse files Browse the repository at this point in the history
close #36192
  • Loading branch information
CbcWestwolf authored Sep 8, 2022
1 parent abb1fd1 commit 1d482db
Show file tree
Hide file tree
Showing 24 changed files with 715 additions and 45 deletions.
1 change: 1 addition & 0 deletions executor/reload_expr_pushdown_blacklist.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ var funcName2Alias = map[string]string{
"sha1": ast.SHA1,
"sha": ast.SHA,
"sha2": ast.SHA2,
"sm3": ast.SM3,
"uncompress": ast.Uncompress,
"uncompressed_length": ast.UncompressedLength,
"validate_password_strength": ast.ValidatePasswordStrength,
Expand Down
2 changes: 1 addition & 1 deletion executor/showtest/show_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1440,7 +1440,7 @@ func TestShowBuiltin(t *testing.T) {
res := tk.MustQuery("show builtins;")
require.NotNil(t, res)
rows := res.Rows()
const builtinFuncNum = 276
const builtinFuncNum = 277
require.Equal(t, len(rows), builtinFuncNum)
require.Equal(t, rows[0][0].(string), "abs")
require.Equal(t, rows[builtinFuncNum-1][0].(string), "yearweek")
Expand Down
8 changes: 4 additions & 4 deletions executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm
}

switch authPlugin {
case mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthSocket:
case mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password, mysql.AuthSocket:
default:
return ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin)
}
Expand Down Expand Up @@ -1010,7 +1010,7 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt)
spec.AuthOpt.AuthPlugin = authplugin
}
switch spec.AuthOpt.AuthPlugin {
case mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthSocket, "":
case mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password, mysql.AuthSocket, "":
default:
return ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin)
}
Expand Down Expand Up @@ -1495,8 +1495,8 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error
}
var pwd string
switch authplugin {
case mysql.AuthCachingSha2Password:
pwd = auth.NewSha2Password(s.Password)
case mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password:
pwd = auth.NewHashPassword(s.Password, authplugin)
case mysql.AuthSocket:
e.ctx.GetSessionVars().StmtCtx.AppendNote(ErrSetPasswordAuthPlugin.GenWithStackByArgs(u, h))
pwd = ""
Expand Down
1 change: 1 addition & 0 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,7 @@ var funcs = map[string]functionClass{
ast.SHA1: &sha1FunctionClass{baseFunctionClass{ast.SHA1, 1, 1}},
ast.SHA: &sha1FunctionClass{baseFunctionClass{ast.SHA, 1, 1}},
ast.SHA2: &sha2FunctionClass{baseFunctionClass{ast.SHA2, 2, 2}},
ast.SM3: &sm3FunctionClass{baseFunctionClass{ast.SM3, 1, 1}},
ast.Uncompress: &uncompressFunctionClass{baseFunctionClass{ast.Uncompress, 1, 1}},
ast.UncompressedLength: &uncompressedLengthFunctionClass{baseFunctionClass{ast.UncompressedLength, 1, 1}},
ast.ValidatePasswordStrength: &validatePasswordStrengthFunctionClass{baseFunctionClass{ast.ValidatePasswordStrength, 1, 1}},
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_convert_charset.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ var convertActionMap = map[funcProp][]string{
ast.ASCII, ast.BitLength, ast.Hex, ast.Length, ast.OctetLength, ast.ToBase64,
/* encrypt functions */
ast.AesDecrypt, ast.Decode, ast.Encode, ast.PasswordFunc, ast.MD5, ast.SHA, ast.SHA1,
ast.SHA2, ast.Compress, ast.AesEncrypt,
ast.SHA2, ast.SM3, ast.Compress, ast.AesEncrypt,
},
funcPropAuto: {
/* string functions */ ast.Concat, ast.ConcatWS, ast.ExportSet, ast.Field, ast.FindInSet,
Expand Down
46 changes: 46 additions & 0 deletions expression/builtin_encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,52 @@ func (b *builtinSHA2Sig) Clone() builtinFunc {
return newSig
}

type sm3FunctionClass struct {
baseFunctionClass
}

func (c *sm3FunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, err
}
bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETString)
if err != nil {
return nil, err
}
charset, collate := ctx.GetSessionVars().GetCharsetInfo()
bf.tp.SetCharset(charset)
bf.tp.SetCollate(collate)
bf.tp.SetFlen(40)
sig := &builtinSM3Sig{bf}
//sig.setPbCode(tipb.ScalarFuncSig_SM3) // TODO
return sig, nil
}

type builtinSM3Sig struct {
baseBuiltinFunc
}

func (b *builtinSM3Sig) Clone() builtinFunc {
newSig := &builtinSM3Sig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

// evalString evals Sm3Hash(str).
// The value is returned as a string of 70 hexadecimal digits, or NULL if the argument was NULL.
func (b *builtinSM3Sig) evalString(row chunk.Row) (string, bool, error) {
str, isNull, err := b.args[0].EvalString(b.ctx, row)
if isNull || err != nil {
return "", isNull, err
}
hasher := auth.NewSM3()
_, err = hasher.Write([]byte(str))
if err != nil {
return "", true, err
}
return fmt.Sprintf("%x", hasher.Sum(nil)), false, nil
}

// Supported hash length of SHA-2 family
const (
SHA0 = 0
Expand Down
33 changes: 33 additions & 0 deletions expression/builtin_encryption_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,39 @@ func (b *builtinSHA2Sig) vecEvalString(input *chunk.Chunk, result *chunk.Column)
return nil
}

func (b *builtinSM3Sig) vectorized() bool {
return true
}

// vecEvalString evals Sm3Hash(str).
func (b *builtinSM3Sig) vecEvalString(input *chunk.Chunk, result *chunk.Column) error {
n := input.NumRows()
buf, err := b.bufAllocator.get()
if err != nil {
return err
}
defer b.bufAllocator.put(buf)
if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil {
return errors.Trace(err)
}
result.ReserveString(n)
hasher := auth.NewSM3()
for i := 0; i < n; i++ {
if buf.IsNull(i) {
result.AppendNull()
continue
}
str := buf.GetBytes(i)
_, err = hasher.Write(str)
if err != nil {
return err
}
result.AppendString(fmt.Sprintf("%x", hasher.Sum(nil)))
hasher.Reset()
}
return nil
}

func (b *builtinCompressSig) vectorized() bool {
return true
}
Expand Down
3 changes: 3 additions & 0 deletions expression/builtin_encryption_vec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ var vecBuiltinEncryptionCases = map[string][]vecExprBenchCase{
{retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString, types.ETInt}, geners: []dataGenerator{newRandLenStrGener(10, 20), newRangeInt64Gener(SHA384, SHA384+1)}},
{retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString, types.ETInt}, geners: []dataGenerator{newRandLenStrGener(10, 20), newRangeInt64Gener(SHA512, SHA512+1)}},
},
ast.SM3: {
{retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString}},
},
ast.Encode: {
{retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString, types.ETString}},
},
Expand Down
2 changes: 1 addition & 1 deletion expression/collation.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ func deriveCollation(ctx sessionctx.Context, funcName string, args []Expression,
case ast.Database, ast.User, ast.CurrentUser, ast.Version, ast.CurrentRole, ast.TiDBVersion:
chs, coll := charset.GetDefaultCharsetAndCollate()
return &ExprCollation{CoercibilitySysconst, UNICODE, chs, coll}, nil
case ast.Format, ast.Space, ast.ToBase64, ast.UUID, ast.Hex, ast.MD5, ast.SHA, ast.SHA2:
case ast.Format, ast.Space, ast.ToBase64, ast.UUID, ast.Hex, ast.MD5, ast.SHA, ast.SHA2, ast.SM3:
// should return ASCII repertoire, MySQL's doc says it depends on character_set_connection, but it not true from its source code.
ec = &ExprCollation{Coer: CoercibilityCoercible, Repe: ASCII}
ec.Charset, ec.Collation = ctx.GetSessionVars().GetCharsetInfo()
Expand Down
16 changes: 16 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,22 @@ func TestEncryptionBuiltin(t *testing.T) {
result = tk.MustQuery("select sha2('123', 512), sha2(123, 512), sha2('', 512), sha2('你好', 224), sha2(NULL, 256), sha2('foo', 123)")
result.Check(testkit.Rows(`3c9909afec25354d551dae21590bb26e38d53f2173b8d3dc3eee4c047e7ab1c1eb8b85103e3be7ba613b31bb5c9c36214dc9f14a42fd7a2fdb84856bca5c44c2 3c9909afec25354d551dae21590bb26e38d53f2173b8d3dc3eee4c047e7ab1c1eb8b85103e3be7ba613b31bb5c9c36214dc9f14a42fd7a2fdb84856bca5c44c2 cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e e91f006ed4e0882de2f6a3c96ec228a6a5c715f356d00091bce842b5 <nil> <nil>`))

// for sm3
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a char(10), b int, c double, d datetime, e time, f bit(4), g binary(20), h blob(10), i text(30))")
tk.MustExec(`insert into t values('2', 2, 2.3, "2017-01-01 12:01:01", "12:01:01", 0b1010, "512", "48", "tidb")`)
result = tk.MustQuery("select sm3(a), sm3(b), sm3(c), sm3(d), sm3(e), sm3(f), sm3(g), sm3(h), sm3(i) from t")
result.Check(testkit.Rows("a0dc2d74b9b0e3c87e076003dbfe472a424cb3032463cb339e351460765a822e a0dc2d74b9b0e3c87e076003dbfe472a424cb3032463cb339e351460765a822e b01f6234a2c1d98af2d8bfb79a8c95677c6e9f5750eb756890f29b33b712f804 8485b2ccde69acf41e333e8fba2f55a1b3556e1a42443095235db1d5c78b25d1 f71ab1aad211e14a47b549e8df55b627c36fa75c1aa75b9682cccae2de00babc f4051d239b766c4111e92979aa31af0b35def053646e347bc41e8b73cfd080bc d42cb1657149a8057cef0ba0ededef7f23c9a2f133bfd286ad0f4a6a8bdb5cb2 19dfccdab83e610f04c414a96edb45007b9a022af01473fccf2073b546ad092e 5e0fb8467c33dae5879fb296c9766c78b0a6fc966372f76ac000cc1fcafc2876"))
result = tk.MustQuery("select sm3('123'), sm3(123), sm3(''), sm3('你好'), sm3(NULL)")
result.Check(testkit.Rows(`6e0f9e14344c5406a0cf5a3b4dfb665f87f4a771a31f7edbb5c72874a32b2957 6e0f9e14344c5406a0cf5a3b4dfb665f87f4a771a31f7edbb5c72874a32b2957 1ab21d8355cfa17f8e61194831e81a8f22bec8c728fefb747ed035eb5082aa2b 78e5c78c5322ca174089e58dc7790acf8ce9d542bee6ae4a5a0797d5e356be61 <nil>`))
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a char(10), b int, c double, d datetime, e time, f bit(4), g binary(20), h blob(10), i text(30))")
tk.MustExec(`insert into t values('2', 2, 2.3, "2017-01-01 12:01:01", "12:01:01", 0b1010, "512", "48", "tidb")`)
result = tk.MustQuery("select sm3(a), sm3(b), sm3(c), sm3(d), sm3(e), sm3(f), sm3(g), sm3(h), sm3(i) from t")
result.Check(testkit.Rows("a0dc2d74b9b0e3c87e076003dbfe472a424cb3032463cb339e351460765a822e a0dc2d74b9b0e3c87e076003dbfe472a424cb3032463cb339e351460765a822e b01f6234a2c1d98af2d8bfb79a8c95677c6e9f5750eb756890f29b33b712f804 8485b2ccde69acf41e333e8fba2f55a1b3556e1a42443095235db1d5c78b25d1 f71ab1aad211e14a47b549e8df55b627c36fa75c1aa75b9682cccae2de00babc f4051d239b766c4111e92979aa31af0b35def053646e347bc41e8b73cfd080bc d42cb1657149a8057cef0ba0ededef7f23c9a2f133bfd286ad0f4a6a8bdb5cb2 19dfccdab83e610f04c414a96edb45007b9a022af01473fccf2073b546ad092e 5e0fb8467c33dae5879fb296c9766c78b0a6fc966372f76ac000cc1fcafc2876"))
result = tk.MustQuery("select sm3('123'), sm3(123), sm3(''), sm3('你好'), sm3(NULL)")
result.Check(testkit.Rows(`6e0f9e14344c5406a0cf5a3b4dfb665f87f4a771a31f7edbb5c72874a32b2957 6e0f9e14344c5406a0cf5a3b4dfb665f87f4a771a31f7edbb5c72874a32b2957 1ab21d8355cfa17f8e61194831e81a8f22bec8c728fefb747ed035eb5082aa2b 78e5c78c5322ca174089e58dc7790acf8ce9d542bee6ae4a5a0797d5e356be61 <nil>`))

// for AES_ENCRYPT
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a char(10), b int, c double, d datetime, e time, f bit(4), g binary(20), h blob(10), i text(30))")
Expand Down
19 changes: 19 additions & 0 deletions expression/typeinfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,25 @@ func (s *InferTypeSuite) createTestCase4EncryptionFuncs() []typeInferTestCase {
{"sha2('1234' , '256')", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 128, types.UnspecifiedLength},
{"sha2(1234 , '256')", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 128, types.UnspecifiedLength},

{"sm3(c_int_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength},
{"sm3(c_bigint_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength},
{"sm3(c_float_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength},
{"sm3(c_double_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength},
{"sm3(c_decimal )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength},
{"sm3(c_datetime )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength},
{"sm3(c_time_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength},
{"sm3(c_timestamp_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength},
{"sm3(c_char )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength},
{"sm3(c_varchar )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength},
{"sm3(c_text_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength},
{"sm3(c_binary )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength},
{"sm3(c_varbinary )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength},
{"sm3(c_blob_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength},
{"sm3(c_set )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength},
{"sm3(c_enum )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength},
{"sm3('1234' )", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 40, types.UnspecifiedLength},
{"sm3(1234 )", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 40, types.UnspecifiedLength},

{"AES_ENCRYPT(c_int_d, 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 32, types.UnspecifiedLength},
{"AES_ENCRYPT(c_char, 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 32, types.UnspecifiedLength},
{"AES_ENCRYPT(c_varchar, 'key')", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 32, types.UnspecifiedLength},
Expand Down
1 change: 1 addition & 0 deletions parser/ast/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ const (
SHA1 = "sha1"
SHA = "sha"
SHA2 = "sha2"
SM3 = "sm3"
Uncompress = "uncompress"
UncompressedLength = "uncompressed_length"
ValidatePasswordStrength = "validate_password_strength"
Expand Down
8 changes: 6 additions & 2 deletions parser/ast/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -1332,8 +1332,8 @@ func (n *UserSpec) EncodedPassword() (string, bool) {
opt := n.AuthOpt
if opt.ByAuthString {
switch opt.AuthPlugin {
case mysql.AuthCachingSha2Password:
return auth.NewSha2Password(opt.AuthString), true
case mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password:
return auth.NewHashPassword(opt.AuthString, opt.AuthPlugin), true
case mysql.AuthSocket:
return "", true
default:
Expand All @@ -1352,6 +1352,10 @@ func (n *UserSpec) EncodedPassword() (string, bool) {
if len(opt.HashString) != mysql.SHAPWDHashLen {
return "", false
}
case mysql.AuthTiDBSM3Password:
if len(opt.HashString) != mysql.SM3PWDHashLen {
return "", false
}
case "", mysql.AuthNativePassword:
if len(opt.HashString) != (mysql.PWDHashLen+1) || !strings.HasPrefix(opt.HashString, "*") {
return "", false
Expand Down
2 changes: 2 additions & 0 deletions parser/auth/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ go_library(
"auth.go",
"caching_sha2.go",
"mysql_native_password.go",
"sm3.go",
],
importpath = "github.com/pingcap/tidb/parser/auth",
visibility = ["//visibility:public"],
Expand All @@ -22,6 +23,7 @@ go_test(
srcs = [
"caching_sha2_test.go",
"mysql_native_password_test.go",
"sm3_test.go",
],
embed = [":auth"],
flaky = True,
Expand Down
48 changes: 35 additions & 13 deletions parser/auth/caching_sha2.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ import (
"errors"
"fmt"
"strconv"

"github.com/pingcap/tidb/parser/mysql"
)

const (
Expand All @@ -60,7 +62,14 @@ func b64From24bit(b []byte, n int, buf *bytes.Buffer) {
}
}

func sha256crypt(plaintext string, salt []byte, iterations int) string {
// Sha256Hash is an util function to calculate sha256 hash.
func Sha256Hash(input []byte) []byte {
res := sha256.Sum256(input)
return res[:]
}

// 'hash' function should return an array with 32 bytes, the same as SHA-256
func hashCrypt(plaintext string, salt []byte, iterations int, hash func([]byte) []byte) string {
// Numbers in the comments refer to the description of the algorithm on https://www.akkadia.org/drepper/SHA-crypt.txt

// 1, 2, 3
Expand All @@ -73,7 +82,7 @@ func sha256crypt(plaintext string, salt []byte, iterations int) string {
bufB.Write([]byte(plaintext))
bufB.Write(salt)
bufB.Write([]byte(plaintext))
sumB := sha256.Sum256(bufB.Bytes())
sumB := hash(bufB.Bytes())
bufB.Reset()

// 9, 10
Expand All @@ -93,15 +102,15 @@ func sha256crypt(plaintext string, salt []byte, iterations int) string {
}

// 12
sumA := sha256.Sum256(bufA.Bytes())
sumA := hash(bufA.Bytes())
bufA.Reset()

// 13, 14, 15
bufDP := bufA
for range []byte(plaintext) {
bufDP.Write([]byte(plaintext))
}
sumDP := sha256.Sum256(bufDP.Bytes())
sumDP := hash(bufDP.Bytes())
bufDP.Reset()

// 16
Expand All @@ -119,7 +128,7 @@ func sha256crypt(plaintext string, salt []byte, iterations int) string {
for i = 0; i < 16+int(sumA[0]); i++ {
bufDS.Write(salt)
}
sumDS := sha256.Sum256(bufDS.Bytes())
sumDS := hash(bufDS.Bytes())
bufDS.Reset()

// 20
Expand All @@ -134,7 +143,7 @@ func sha256crypt(plaintext string, salt []byte, iterations int) string {

// 21
bufC := bufA
var sumC [32]byte
var sumC []byte
for i = 0; i < iterations; i++ {
bufC.Reset()
if i&1 != 0 {
Expand All @@ -153,7 +162,7 @@ func sha256crypt(plaintext string, salt []byte, iterations int) string {
} else {
bufC.Write(p)
}
sumC = sha256.Sum256(bufC.Bytes())
sumC = hash(bufC.Bytes())
sumA = sumC
}

Expand All @@ -180,8 +189,8 @@ func sha256crypt(plaintext string, salt []byte, iterations int) string {
return buf.String()
}

// CheckShaPassword is to check if a MySQL style caching_sha2 authentication string matches a password
func CheckShaPassword(pwhash []byte, password string) (bool, error) {
// CheckHashingPassword checks if a caching_sha2_password or tidb_sm3_password authentication string matches a password
func CheckHashingPassword(pwhash []byte, password string, hash string) (bool, error) {
pwhashParts := bytes.Split(pwhash, []byte("$"))
if len(pwhashParts) != 4 {
return false, errors.New("failed to decode hash parts")
Expand All @@ -199,13 +208,19 @@ func CheckShaPassword(pwhash []byte, password string) (bool, error) {
iterations = iterations * ITERATION_MULTIPLIER
salt := pwhashParts[3][:SALT_LENGTH]

newHash := sha256crypt(password, salt, iterations)
var newHash string
switch hash {
case mysql.AuthCachingSha2Password:
newHash = hashCrypt(password, salt, iterations, Sha256Hash)
case mysql.AuthTiDBSM3Password:
newHash = hashCrypt(password, salt, iterations, Sm3Hash)
}

return bytes.Equal(pwhash, []byte(newHash)), nil
}

// NewSha2Password creates a new MySQL style caching_sha2 password hash
func NewSha2Password(pwd string) string {
// NewHashPassword creates a new password for caching_sha2_password or tidb_sm3_password
func NewHashPassword(pwd string, hash string) string {
salt := make([]byte, SALT_LENGTH)
rand.Read(salt)

Expand All @@ -219,5 +234,12 @@ func NewSha2Password(pwd string) string {
}
}

return sha256crypt(pwd, salt, 5*ITERATION_MULTIPLIER)
switch hash {
case mysql.AuthCachingSha2Password:
return hashCrypt(pwd, salt, 5*ITERATION_MULTIPLIER, Sha256Hash)
case mysql.AuthTiDBSM3Password:
return hashCrypt(pwd, salt, 5*ITERATION_MULTIPLIER, Sm3Hash)
default:
return ""
}
}
Loading

0 comments on commit 1d482db

Please sign in to comment.