diff --git a/conn.go b/conn.go index db6d542d..9bb0c8d7 100644 --- a/conn.go +++ b/conn.go @@ -394,6 +394,40 @@ func (c *Conn) connect() error { } } +// sendRequestEx sends request directly, and handles the closed or quit scenarios. +func (c *Conn) sendRequestEx( + ctx context.Context, + opcode int32, + req interface{}, + res interface{}, + recvFunc func(*request, *responseHeader, error)) (bool, error) { + + resChan, err := c.sendRequest(opcode, req, res, recvFunc) + + if err != nil { + return true, fmt.Errorf("failed to send auth request: %v", err) + } + + var resp response + + select { + case resp = <-resChan: + case <-c.closeChan: + c.logger.Printf("recv routine closed") + return false, nil + case <-c.shouldQuit: + c.logger.Printf("should quit") + return false, nil + case <-ctx.Done(): + return false, ctx.Err() + } + + if resp.err != nil { + return true, fmt.Errorf("failed for op: %d, error: %v", opcode, resp.err) + } + return true, nil +} + func (c *Conn) sendRequest( opcode int32, req interface{}, @@ -932,8 +966,34 @@ func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc } } +func (c *Conn) doAddSaslAuth(auth []byte) (int64, error) { + // step 1 Ask for server informations. + resp := setSaslResponse{} + + zxid, err := c.request(opSetSasl, &setSaslRequest{}, &resp, nil) + if err != nil { + return zxid, err + } + + challenge, err := resp.GenSaslChallenge(auth, "") + + if err != nil { + return 0, err + } + + // step 2 Do the authentication. + return c.request(opSetSasl, &setSaslRequest{challenge}, &resp, nil) +} + +// AddAuth adds an auth specified by and , supported schemes +// includes "digest", "sasl", usually comes as "user:pasword". func (c *Conn) AddAuth(scheme string, auth []byte) error { - _, err := c.request(opSetAuth, &setAuthRequest{Type: 0, Scheme: scheme, Auth: auth}, &setAuthResponse{}, nil) + var err error + if scheme == "sasl" { + _, err = c.doAddSaslAuth(auth) + } else { + _, err = c.request(opSetAuth, &setAuthRequest{Type: 0, Scheme: scheme, Auth: auth}, &setAuthResponse{}, nil) + } if err != nil { return err @@ -1299,6 +1359,24 @@ func (c *Conn) Server() string { return c.server } +// FIXME(linsite) unify it with doAddSasl. +// resendZkSasl resends SASL auth, when the 1st return value is true indicates the connection is invalid. +func resendZkSasl(ctx context.Context, c *Conn, auth []byte) (bool, error) { + resp := setSaslResponse{} + shouldContinue, err := c.sendRequestEx(ctx, opSetSasl, &setSaslRequest{}, &resp, nil) + if err != nil { + return shouldContinue, err + } + + challenge, err := resp.GenSaslChallenge(auth, "") + + if err != nil { + return true, err + } + + return c.sendRequestEx(ctx, opSetSasl, &setSaslRequest{challenge}, &resp, nil) +} + func resendZkAuth(ctx context.Context, c *Conn) error { shouldCancel := func() bool { select { @@ -1318,6 +1396,9 @@ func resendZkAuth(ctx context.Context, c *Conn) error { c.logger.Printf("re-submitting `%d` credentials after reconnect", len(c.creds)) } + var shouldContinue bool + var err error + for _, cred := range c.creds { // return early before attempting to send request. if shouldCancel() { @@ -1325,33 +1406,26 @@ func resendZkAuth(ctx context.Context, c *Conn) error { } // do not use the public API for auth since it depends on the send/recv loops // that are waiting for this to return - resChan, err := c.sendRequest( - opSetAuth, - &setAuthRequest{Type: 0, - Scheme: cred.scheme, - Auth: cred.auth, - }, - &setAuthResponse{}, - nil, /* recvFunc*/ - ) - if err != nil { - return fmt.Errorf("failed to send auth request: %v", err) - } - var res response - select { - case res = <-resChan: - case <-c.closeChan: - c.logger.Printf("recv closed, cancel re-submitting credentials") - return nil - case <-c.shouldQuit: - c.logger.Printf("should quit, cancel re-submitting credentials") - return nil - case <-ctx.Done(): - return ctx.Err() + if cred.scheme == "sasl" { + shouldContinue, err = resendZkSasl(ctx, c, cred.auth) + } else { + shouldContinue, err = c.sendRequestEx( + ctx, + opSetAuth, + &setAuthRequest{Type: 0, + Scheme: cred.scheme, + Auth: cred.auth, + }, + &setAuthResponse{}, + nil, /* recvFunc*/ + ) } - if res.err != nil { - return fmt.Errorf("failed conneciton setAuth request: %v", res.err) + if err != nil { + if shouldContinue { + continue + } + return err } } diff --git a/constants.go b/constants.go index d914301f..49108a8b 100644 --- a/constants.go +++ b/constants.go @@ -32,6 +32,7 @@ const ( opClose = -11 opSetAuth = 100 opSetWatches = 101 + opSetSasl = 102 opError = -1 // Not in protocol, used internally opWatcherEvent = -2 diff --git a/go.mod b/go.mod index a2662730..5ebd7c2e 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ -module github.com/go-zookeeper/zk +module github.com/pingkui/zk go 1.13 diff --git a/sasl.go b/sasl.go new file mode 100644 index 00000000..2cd33c21 --- /dev/null +++ b/sasl.go @@ -0,0 +1,158 @@ +package zk + +import ( + "crypto/md5" + "crypto/rand" + "errors" + "fmt" + "math/big" + "strings" +) + +// Handle the SASL authentification. +const ( + zkSaslMd5Uri = "zookeeper/zk-sasl-md5" + zkSaslAuthQop = "auth" + zkSaslAuthIntQop = "auth-int" + zkSaslAuthConfQop = "auth-conf" +) + +type setSaslResponse struct { + Nonce string + Realm string + Charset string + Algorithm string + RspAuth string +} + +func getHexMd5(s string) string { + bs := []byte(s) + hash := "" + sum := md5.Sum(bs) + for _, b := range sum { + hash += fmt.Sprintf("%02x", b) + } + return hash +} + +func getMd5(s string) string { + bs := []byte(s) + sum := md5.Sum(bs) + return string(sum[:]) +} + +func doubleQuote(s string) string { + return `"` + s + `"` +} + +func rmDoubleQuote(s string) string { + leng := len(s) + return s[1 : leng-1] +} + +func (r setSaslResponse) getUserPassword(auth []byte) (string, string) { + userPassword := string(auth) + + split := strings.SplitN(userPassword, ":", 2) + + return split[0], split[1] +} + +func (r setSaslResponse) genA1(user, password, cnonce string) string { + hexStr := fmt.Sprintf("%s:%s:%s", user, r.Realm, password) + hash := getMd5(hexStr) + keyHash := fmt.Sprintf("%s:%s:%s", hash, r.Nonce, cnonce) + return getHexMd5(keyHash) +} + +func (r setSaslResponse) genChallenge(user, password, cnonce, qop string, nc int) string { + + rawA2 := fmt.Sprintf("%s:%s", "AUTHENTICATE", zkSaslMd5Uri) + a2 := getHexMd5(rawA2) + + a1 := r.genA1(user, password, cnonce) + + rv := fmt.Sprintf("%s:%s:%08x:%s:%s:%s", a1, r.Nonce, nc, cnonce, qop, a2) + + return getHexMd5(rv) +} + +// GenSaslChallenge refers to RFC2831 to generate a md5-digest challenge. +func (r setSaslResponse) GenSaslChallenge(auth []byte, cnonce string) (string, error) { + + user, password := r.getUserPassword(auth) + if user == "" || password == "" { + return "", errors.New("found invalid user&password") + } + + ch := make(map[string]string, 20) + + ch["digest-uri"] = doubleQuote(zkSaslMd5Uri) + + // Only "auth" qop supports so far. + qop := zkSaslAuthQop + ch["qop"] = qop + + nc := 1 + ch["nc"] = fmt.Sprintf("%08x", nc) + + ch["realm"] = doubleQuote(r.Realm) + ch["username"] = doubleQuote(user) + + // for unittest. + if cnonce == "" { + n, err := rand.Int(rand.Reader, big.NewInt(65535)) + if err != nil { + return "", err + } + cnonce = fmt.Sprintf("%s", n) + } + ch["cnonce"] = doubleQuote(cnonce) + ch["nonce"] = doubleQuote(r.Nonce) + + ch["response"] = r.genChallenge(user, password, cnonce, qop, nc) + + items := make([]string, 0, len(ch)) + + for k, v := range ch { + items = append(items, fmt.Sprintf("%s=%s", k, v)) + } + + return strings.Join(items, ","), nil +} + +// Decode decodes a md5-digest ZK SASL response. +func (r *setSaslResponse) Decode(buf []byte) (int, error) { + + // Discard the first 4 bytes, they are not used here. + // According to RFC, the payload is inform of k1=v,k2=v, some of the values maybe enclosure with double quote("). + payload := string(buf[4:]) + + splitPayload := strings.Split(payload, ",") + + if len(splitPayload) == 0 { + return 0, errors.New("invalid sasl payload") + } + + r.Nonce = "" + r.Realm = "" + r.RspAuth = "" + + for _, item := range splitPayload { + kv := strings.SplitN(item, "=", 2) + if len(kv) != 2 { + return 0, errors.New("invalid sasl payload format") + } + + key := strings.ToLower(kv[0]) + if key == "nonce" { + r.Nonce = rmDoubleQuote(kv[1]) + } else if key == "realm" { + r.Realm = rmDoubleQuote(kv[1]) + } else if key == "rspauth" { + r.RspAuth = kv[1] + } + } + + return len(buf), nil +} diff --git a/sasl_test.go b/sasl_test.go new file mode 100644 index 00000000..4a5ffd8d --- /dev/null +++ b/sasl_test.go @@ -0,0 +1,68 @@ +package zk + +import ( + "strings" + "testing" +) + +func TestDecode(t *testing.T) { + + // Response is expected. + buf := `0000nonce="1x1",realm="2x2",rspauth="3x3"` + resp := &setSaslResponse{} + + length, err := resp.Decode([]byte(buf)) + + if err != nil || length != len(buf) { + t.Errorf("failed to check Decode, %v", resp) + } + + if resp.Nonce != "1x1" || resp.Realm != "2x2" { + t.Errorf("failed to check Decode, %v", resp) + } + + if resp.RspAuth != `"3x3"` { + t.Errorf("failed to check Decode, %v", resp) + } + + // Response is not expected. + buf = `0000nonce"1x1",realm="2x2",rspauth="3x3"` + resp = &setSaslResponse{} + + _, err = resp.Decode([]byte(buf)) + + if err == nil { + t.Errorf("failed to check abnormal Decode, %v", resp) + } + +} + +func TestGenA1(t *testing.T) { + resp := setSaslResponse{} + resp.Realm = "test" + resp.Nonce = "1111" + + hash := resp.genA1("super", "password", "1111") + + if hash == "" { + t.Errorf("failed to genA1, %v", resp) + } +} + +func TestGenSaslChallenge(t *testing.T) { + resp := setSaslResponse{} + resp.Realm = "zk-sasl-md5" + resp.Nonce = "qWkHmx+rW9vYQNysvUOCA3gWLks3u9cL5rc9JJFi" + + auth := "super:admin" + hash, err := resp.GenSaslChallenge([]byte(auth), "140741146289") + + if hash == "" || err != nil { + t.Errorf("failed to genA1, %v, error: %v", resp, err) + } + + expect := "08125d12f8b89ca7dd8b5028b5cd7c3b" + if !strings.Contains(hash, expect) { + t.Errorf("failed to gen hash %s, expect %s.", hash, expect) + } +} diff --git a/server_help_test.go b/server_help_test.go index a7a0c138..163692cb 100644 --- a/server_help_test.go +++ b/server_help_test.go @@ -107,6 +107,8 @@ func StartTestCluster(t *testing.T, size int, stdout, stderr io.Writer) (*TestCl return nil, err } + cfg.AuthProvider = "org.apache.zookeeper.server.auth.SASLAuthenticationProvider" + cluster.Servers = append(cluster.Servers, TestServer{ Path: srvPath, Port: cfg.ClientPort, diff --git a/server_java_test.go b/server_java_test.go index dcada4a9..e3542d90 100644 --- a/server_java_test.go +++ b/server_java_test.go @@ -93,6 +93,7 @@ type ServerConfig struct { AutoPurgeSnapRetainCount int // Number of snapshots to retain in dataDir AutoPurgePurgeInterval int // Purge task internal in hours (0 to disable auto purge) Servers []ServerConfigServer + AuthProvider string } func (sc ServerConfig) Marshall(w io.Writer) error { @@ -131,6 +132,10 @@ func (sc ServerConfig) Marshall(w io.Writer) error { fmt.Fprintln(w, "reconfigEnabled=true") fmt.Fprintln(w, "4lw.commands.whitelist=*") + if sc.AuthProvider != "" { + fmt.Fprintf(w, "authProvider.1=%s\n", sc.AuthProvider) + } + if len(sc.Servers) < 2 { // if we dont have more than 2 servers we just dont specify server list to start in standalone mode // see https://zookeeper.apache.org/doc/current/zookeeperStarted.html#sc_InstallingSingleMode for more details. @@ -148,6 +153,7 @@ func (sc ServerConfig) Marshall(w io.Writer) error { } fmt.Fprintf(w, "server.%d=%s:%d:%d\n", srv.ID, srv.Host, srv.PeerPort, srv.LeaderElectionPort) } + return nil } diff --git a/structs.go b/structs.go index e41d8c52..e4dfc560 100644 --- a/structs.go +++ b/structs.go @@ -250,10 +250,6 @@ type setSaslRequest struct { Token string } -type setSaslResponse struct { - Token string -} - type setWatchesRequest struct { RelativeZxid int64 DataWatches []string @@ -633,6 +629,8 @@ func requestStructForOp(op int32) interface{} { return &multiRequest{} case opReconfig: return &reconfigRequest{} + case opSetSasl: + return &setSaslResponse{} } return nil } diff --git a/util.go b/util.go index 5a92b66b..c8f75899 100644 --- a/util.go +++ b/util.go @@ -24,6 +24,13 @@ func WorldACL(perms int32) []ACL { return []ACL{{perms, "world", "anyone"}} } +// SaslACL produces an ACL list containing a single ACL which uses the +// provided permissions, with the scheme "sasl", and ID , which is used +// by Zookeeper to represent a SASL authenticated user. +func SaslACL(user string, perms int32) []ACL { + return []ACL{{perms, "sasl", user}} +} + func DigestACL(perms int32, user, password string) []ACL { userPass := []byte(fmt.Sprintf("%s:%s", user, password)) h := sha1.New() diff --git a/zk_test.go b/zk_test.go index 9129c766..65011e09 100644 --- a/zk_test.go +++ b/zk_test.go @@ -604,6 +604,118 @@ func TestAuth(t *testing.T) { } } +func TestSasl(t *testing.T) { + tmpPath, err := ioutil.TempDir("", "gozk") + requireNoError(t, err, "failed to create tmp dir for test server setup") + defer os.RemoveAll(tmpPath) + + startPort := int(rand.Int31n(6000) + 10000) + + srvPath := filepath.Join(tmpPath, fmt.Sprintf("srv1")) + if err := os.Mkdir(srvPath, 0700); err != nil { + requireNoError(t, err, "failed to make server path") + } + testSrvConfig := ServerConfigServer{ + ID: 1, + Host: "127.0.0.1", + PeerPort: startPort + 1, + LeaderElectionPort: startPort + 2, + } + cfg := ServerConfig{ + ClientPort: startPort, + DataDir: srvPath, + Servers: []ServerConfigServer{testSrvConfig}, + } + + cfgPath := filepath.Join(srvPath, _testConfigName) + fi, err := os.Create(cfgPath) + requireNoError(t, err) + + user, password := "admin", "super" + jaasCfgTpl := `Server { + org.apache.zookeeper.server.auth.DigestLoginModule required + user_%s="%s"; + }; + ` + jaasCfg := fmt.Sprintf(jaasCfgTpl, user, password) + jaasPath := filepath.Join(tmpPath, "jaas.conf") + err = ioutil.WriteFile(jaasPath, []byte(jaasCfg), 0644) + requireNoError(t, err) + + cfg.AuthProvider = "org.apache.zookeeper.server.auth.SASLAuthenticationProvider" + + requireNoError(t, cfg.Marshall(fi)) + fi.Close() + + fi, err = os.Create(filepath.Join(srvPath, _testMyIDFileName)) + requireNoError(t, err) + + _, err = fmt.Fprintln(fi, "1") + fi.Close() + requireNoError(t, err) + + testServer, err := NewIntegrationTestServer(t, cfgPath, nil, nil) + requireNoError(t, err) + authEnv := fmt.Sprintf(" -Djava.security.auth.login.config=%s", jaasPath) + testServer.cmdEnv[0] += authEnv + err = testServer.Start() + requireNoError(t, err) + + defer testServer.Stop() + + waitCtx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + servers := []string{fmt.Sprintf("127.0.0.1:%d", startPort)} + conn, eChan, err := Connect(servers, time.Second*2) + + requireNoError(t, err) + waitForSession(waitCtx, eChan) + defer conn.Close() + + // Wait for connection to be ready. + for i := 0; i < 20; i++ { + _, _, err = conn.Children("/") + if err == nil { + break + } + time.Sleep(time.Second * 1) + } + + requireNoError(t, err) + + // AddAuth should work. + authData := fmt.Sprintf("%s:%s", user, password) + err = conn.AddAuth("sasl", []byte(authData)) + requireNoError(t, err) + + data := []byte("sasl") + _, err = conn.Create("/sasl", data, 0, SaslACL(user, PermAll)) + requireNoError(t, err) + + conn2, _, err := Connect(servers, time.Second*2) + defer conn2.Close() + _, err = conn2.Create("/sasl/test", data, 0, WorldACL(PermAll)) + + // Expect it to fail. + if err == nil { + t.Errorf("Auth doesn't work.") + } + + // Check resend work. + obj := authCreds{ + scheme: "sasl", + auth: []byte(authData), + } + conn2.creds = append(conn2.creds, obj) + + err = resendZkAuth(waitCtx, conn2) + requireNoError(t, err) + + _, err = conn2.Create("/sasl/test", data, 0, WorldACL(PermAll)) + requireNoError(t, err) +} + func TestChildren(t *testing.T) { ts, err := StartTestCluster(t, 1, nil, logWriter{t: t, p: "[ZKERR] "}) if err != nil {