Skip to content

feat: add support for md5-digest sasl auth. #152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 100 additions & 26 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,40 @@ func (c *Conn) connect() error {
}
}

// sendRequestEx sends request directly, and handles the closed or quit scenarios.
func (c *Conn) sendRequestEx(
ctx context.Context,
opcode int32,
req interface{},
res interface{},
recvFunc func(*request, *responseHeader, error)) (bool, error) {

resChan, err := c.sendRequest(opcode, req, res, recvFunc)

if err != nil {
return true, fmt.Errorf("failed to send auth request: %v", err)
}

var resp response

select {
case resp = <-resChan:
case <-c.closeChan:
c.logger.Printf("recv routine closed")
return false, nil
case <-c.shouldQuit:
c.logger.Printf("should quit")
return false, nil
case <-ctx.Done():
return false, ctx.Err()
}

if resp.err != nil {
return true, fmt.Errorf("failed for op: %d, error: %v", opcode, resp.err)
}
return true, nil
}

func (c *Conn) sendRequest(
opcode int32,
req interface{},
Expand Down Expand Up @@ -932,8 +966,34 @@ func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc
}
}

func (c *Conn) doAddSaslAuth(auth []byte) (int64, error) {
// step 1 Ask for server informations.
resp := setSaslResponse{}

zxid, err := c.request(opSetSasl, &setSaslRequest{}, &resp, nil)
if err != nil {
return zxid, err
}

challenge, err := resp.GenSaslChallenge(auth, "")

if err != nil {
return 0, err
}

// step 2 Do the authentication.
return c.request(opSetSasl, &setSaslRequest{challenge}, &resp, nil)
}

// AddAuth adds an auth specified by <scheme> and <auth>, supported schemes
// includes "digest", "sasl", <auth> usually comes as "user:pasword".
func (c *Conn) AddAuth(scheme string, auth []byte) error {
_, err := c.request(opSetAuth, &setAuthRequest{Type: 0, Scheme: scheme, Auth: auth}, &setAuthResponse{}, nil)
var err error
if scheme == "sasl" {
_, err = c.doAddSaslAuth(auth)
} else {
_, err = c.request(opSetAuth, &setAuthRequest{Type: 0, Scheme: scheme, Auth: auth}, &setAuthResponse{}, nil)
}

if err != nil {
return err
Expand Down Expand Up @@ -1299,6 +1359,24 @@ func (c *Conn) Server() string {
return c.server
}

// FIXME(linsite) unify it with doAddSasl.
// resendZkSasl resends SASL auth, when the 1st return value is true indicates the connection is invalid.
func resendZkSasl(ctx context.Context, c *Conn, auth []byte) (bool, error) {
resp := setSaslResponse{}
shouldContinue, err := c.sendRequestEx(ctx, opSetSasl, &setSaslRequest{}, &resp, nil)
if err != nil {
return shouldContinue, err
}

challenge, err := resp.GenSaslChallenge(auth, "")

if err != nil {
return true, err
}

return c.sendRequestEx(ctx, opSetSasl, &setSaslRequest{challenge}, &resp, nil)
}

func resendZkAuth(ctx context.Context, c *Conn) error {
shouldCancel := func() bool {
select {
Expand All @@ -1318,40 +1396,36 @@ func resendZkAuth(ctx context.Context, c *Conn) error {
c.logger.Printf("re-submitting `%d` credentials after reconnect", len(c.creds))
}

var shouldContinue bool
var err error

for _, cred := range c.creds {
// return early before attempting to send request.
if shouldCancel() {
return nil
}
// do not use the public API for auth since it depends on the send/recv loops
// that are waiting for this to return
resChan, err := c.sendRequest(
opSetAuth,
&setAuthRequest{Type: 0,
Scheme: cred.scheme,
Auth: cred.auth,
},
&setAuthResponse{},
nil, /* recvFunc*/
)
if err != nil {
return fmt.Errorf("failed to send auth request: %v", err)
}

var res response
select {
case res = <-resChan:
case <-c.closeChan:
c.logger.Printf("recv closed, cancel re-submitting credentials")
return nil
case <-c.shouldQuit:
c.logger.Printf("should quit, cancel re-submitting credentials")
return nil
case <-ctx.Done():
return ctx.Err()
if cred.scheme == "sasl" {
shouldContinue, err = resendZkSasl(ctx, c, cred.auth)
} else {
shouldContinue, err = c.sendRequestEx(
ctx,
opSetAuth,
&setAuthRequest{Type: 0,
Scheme: cred.scheme,
Auth: cred.auth,
},
&setAuthResponse{},
nil, /* recvFunc*/
)
}
if res.err != nil {
return fmt.Errorf("failed conneciton setAuth request: %v", res.err)
if err != nil {
if shouldContinue {
continue
}
return err
}
}

Expand Down
1 change: 1 addition & 0 deletions constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ const (
opClose = -11
opSetAuth = 100
opSetWatches = 101
opSetSasl = 102
opError = -1
// Not in protocol, used internally
opWatcherEvent = -2
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module github.com/go-zookeeper/zk
module github.com/pingkui/zk

go 1.13
158 changes: 158 additions & 0 deletions sasl.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package zk

import (
"crypto/md5"
"crypto/rand"
"errors"
"fmt"
"math/big"
"strings"
)

// Handle the SASL authentification.
const (
zkSaslMd5Uri = "zookeeper/zk-sasl-md5"
zkSaslAuthQop = "auth"
zkSaslAuthIntQop = "auth-int"
zkSaslAuthConfQop = "auth-conf"
)

type setSaslResponse struct {
Nonce string
Realm string
Charset string
Algorithm string
RspAuth string
}

func getHexMd5(s string) string {
bs := []byte(s)
hash := ""
sum := md5.Sum(bs)
for _, b := range sum {
hash += fmt.Sprintf("%02x", b)
}
return hash
}

func getMd5(s string) string {
bs := []byte(s)
sum := md5.Sum(bs)
return string(sum[:])
}

func doubleQuote(s string) string {
return `"` + s + `"`
}

func rmDoubleQuote(s string) string {
leng := len(s)
return s[1 : leng-1]
}

func (r setSaslResponse) getUserPassword(auth []byte) (string, string) {
userPassword := string(auth)

split := strings.SplitN(userPassword, ":", 2)

return split[0], split[1]
}

func (r setSaslResponse) genA1(user, password, cnonce string) string {
hexStr := fmt.Sprintf("%s:%s:%s", user, r.Realm, password)
hash := getMd5(hexStr)
keyHash := fmt.Sprintf("%s:%s:%s", hash, r.Nonce, cnonce)
return getHexMd5(keyHash)
}

func (r setSaslResponse) genChallenge(user, password, cnonce, qop string, nc int) string {

rawA2 := fmt.Sprintf("%s:%s", "AUTHENTICATE", zkSaslMd5Uri)
a2 := getHexMd5(rawA2)

a1 := r.genA1(user, password, cnonce)

rv := fmt.Sprintf("%s:%s:%08x:%s:%s:%s", a1, r.Nonce, nc, cnonce, qop, a2)

return getHexMd5(rv)
}

// GenSaslChallenge refers to RFC2831 to generate a md5-digest challenge.
func (r setSaslResponse) GenSaslChallenge(auth []byte, cnonce string) (string, error) {

user, password := r.getUserPassword(auth)
if user == "" || password == "" {
return "", errors.New("found invalid user&password")
}

ch := make(map[string]string, 20)

ch["digest-uri"] = doubleQuote(zkSaslMd5Uri)

// Only "auth" qop supports so far.
qop := zkSaslAuthQop
ch["qop"] = qop

nc := 1
ch["nc"] = fmt.Sprintf("%08x", nc)

ch["realm"] = doubleQuote(r.Realm)
ch["username"] = doubleQuote(user)

// for unittest.
if cnonce == "" {
n, err := rand.Int(rand.Reader, big.NewInt(65535))
if err != nil {
return "", err
}
cnonce = fmt.Sprintf("%s", n)
}
ch["cnonce"] = doubleQuote(cnonce)
ch["nonce"] = doubleQuote(r.Nonce)

ch["response"] = r.genChallenge(user, password, cnonce, qop, nc)

items := make([]string, 0, len(ch))

for k, v := range ch {
items = append(items, fmt.Sprintf("%s=%s", k, v))
}

return strings.Join(items, ","), nil
}

// Decode decodes a md5-digest ZK SASL response.
func (r *setSaslResponse) Decode(buf []byte) (int, error) {

// Discard the first 4 bytes, they are not used here.
// According to RFC, the payload is inform of k1=v,k2=v, some of the values maybe enclosure with double quote(").
payload := string(buf[4:])

splitPayload := strings.Split(payload, ",")

if len(splitPayload) == 0 {
return 0, errors.New("invalid sasl payload")
}

r.Nonce = ""
r.Realm = ""
r.RspAuth = ""

for _, item := range splitPayload {
kv := strings.SplitN(item, "=", 2)
if len(kv) != 2 {
return 0, errors.New("invalid sasl payload format")
}

key := strings.ToLower(kv[0])
if key == "nonce" {
r.Nonce = rmDoubleQuote(kv[1])
} else if key == "realm" {
r.Realm = rmDoubleQuote(kv[1])
} else if key == "rspauth" {
r.RspAuth = kv[1]
}
}

return len(buf), nil
}
Loading