From 26331fb53943525fa3ca99fed47551451ca8f621 Mon Sep 17 00:00:00 2001 From: klizhentas Date: Thu, 3 Mar 2016 18:03:25 -0800 Subject: [PATCH] fixing bugs, refs #180 --- lib/auth/api_with_roles.go | 4 +- lib/auth/tun.go | 113 ++++++++++++++++++------------------ lib/hangout/hangout_test.go | 2 +- lib/reversetunnel/srv.go | 51 +++++++++++----- lib/services/web.go | 1 - lib/srv/sess.go | 3 +- lib/srv/srv_test.go | 43 +------------- lib/srv/term.go | 1 + lib/utils/hangout.go | 9 +-- lib/web/connect.go | 10 +++- lib/web/sessions.go | 13 ++--- 11 files changed, 117 insertions(+), 133 deletions(-) diff --git a/lib/auth/api_with_roles.go b/lib/auth/api_with_roles.go index db443edef80da..a51341ec28d4a 100644 --- a/lib/auth/api_with_roles.go +++ b/lib/auth/api_with_roles.go @@ -176,12 +176,12 @@ func (socket *fakeSocket) CreateBridge(remoteAddr net.Addr, sshChan ssh.Channel) // Accept() will unblock this select case socket.connections <- connection: } - + log.Debugf("SocketOverSSH.Handle(from=%v) is accepted", remoteAddr) // wait for the connection to close: select { case <-connection.closed: } - log.Debugf("SocketOverSSH.Handle(from=%v) is done", remoteAddr) + log.Debugf("SocketOverSSH.Handle(from=%v) is closed", remoteAddr) return nil } diff --git a/lib/auth/tun.go b/lib/auth/tun.go index a09a671a22a47..08cf8d773dae5 100644 --- a/lib/auth/tun.go +++ b/lib/auth/tun.go @@ -19,6 +19,7 @@ package auth import ( "encoding/json" "fmt" + "io" "net" "net/http" "os" @@ -364,7 +365,7 @@ func (s *AuthTunnel) passwordAuth( switch ab.Type { case AuthWebPassword: if err := s.authServer.CheckPassword(conn.User(), ab.Pass, ab.HotpToken); err != nil { - log.Errorf("Password auth error: %v", err) + log.Warningf("password auth error: %v", err) return nil, trace.Wrap(err) } perms := &ssh.Permissions{ @@ -510,7 +511,7 @@ func NewTunClient(addr utils.NetAddr, user string, auth []ssh.AuthMethod) (*TunC return tc, nil } -func (c *TunClient) GetAgent() (agent.Agent, error) { +func (c *TunClient) GetAgent() (AgentCloser, error) { return c.dialer.GetAgent() } @@ -525,63 +526,60 @@ func (c *TunClient) GetDialer() AccessPointDialer { } } +type AgentCloser interface { + io.Closer + agent.Agent +} + +type tunAgent struct { + agent.Agent + client *ssh.Client +} + +func (ta *tunAgent) Close() error { + log.Infof("tunAgent.Close") + return ta.client.Close() +} + type TunDialer struct { sync.Mutex auth []ssh.AuthMethod user string - tun *ssh.Client addr utils.NetAddr } func (t *TunDialer) Close() error { - if t.tun != nil { - return t.tun.Close() - } return nil } -func (t *TunDialer) GetAgent() (agent.Agent, error) { - _, err := t.getClient(false) // we need an established connection first +func (t *TunDialer) GetAgent() (AgentCloser, error) { + client, err := t.getClient() // we need an established connection first if err != nil { - return nil, trace.Wrap(err) + return nil, trace.Wrap( + teleport.ConnectionProblem("failed to connect to remote API", err)) } - ch, _, err := t.tun.OpenChannel(ReqWebSessionAgent, nil) + ch, _, err := client.OpenChannel(ReqWebSessionAgent, nil) if err != nil { - // reconnecting and trying again - _, err := t.getClient(true) - if err != nil { - return nil, trace.Wrap(err) - } - ch, _, err = t.tun.OpenChannel(ReqWebSessionAgent, nil) - if err != nil { - return nil, trace.Wrap(err) - } + return nil, trace.Wrap( + teleport.ConnectionProblem("failed to connect to remote API", err)) } - log.Infof("opened agent channel") - return agent.NewClient(ch), nil + agentCloser := &tunAgent{client: client} + agentCloser.Agent = agent.NewClient(ch) + return agentCloser, nil } -func (t *TunDialer) getClient(reset bool) (*ssh.Client, error) { - t.Lock() - defer t.Unlock() - if t.tun != nil { - if !reset { - return t.tun, nil - } - go t.tun.Close() - t.tun = nil - } +func (t *TunDialer) getClient() (*ssh.Client, error) { config := &ssh.ClientConfig{ User: t.user, Auth: t.auth, } client, err := ssh.Dial(t.addr.AddrNetwork, t.addr.Addr, config) + log.Infof("TunDialer.getClient(%v)", t.addr.String()) if err != nil { log.Infof("TunDialer could not ssh.Dial: %v", err) return nil, trace.Wrap(err) } - t.tun = client - return t.tun, nil + return client, nil } const ( @@ -592,32 +590,33 @@ const ( DialerPeriodBetweenAttempts = time.Second ) +type tunConn struct { + net.Conn + client *ssh.Client +} + +func (c *tunConn) Close() error { + log.Infof("tunConn: close!") + err := c.Conn.Close() + err = c.client.Close() + return trace.Wrap(err) +} + func (t *TunDialer) Dial(network, address string) (net.Conn, error) { - var client *ssh.Client - var err error - for i := 0; i < DialerRetryAttempts; i++ { - if i == 0 { - client, err = t.getClient(false) - if err != nil { - log.Infof("TunDialer failed to get client: %v", err) - continue - } - } else { - time.Sleep(DialerPeriodBetweenAttempts) - client, err = t.getClient(true) - if err != nil { - log.Infof("TunDialer failed to get client: %v", err) - continue - } - } - conn, err := client.Dial(network, address) - if err == nil { - return conn, nil - } - log.Infof("TunDialer connection issue (%v), reconnect", err) + log.Infof("TunDialer.Dial(%v, %v)", network, address) + client, err := t.getClient() + if err != nil { + return nil, trace.Wrap( + teleport.ConnectionProblem("failed to connect to remote API", err)) } - return nil, trace.Wrap( - teleport.ConnectionProblem("failed to connect to remote API", err)) + conn, err := client.Dial(network, address) + if err != nil { + return nil, trace.Wrap( + teleport.ConnectionProblem("failed to connect to remote API", err)) + } + tc := &tunConn{client: client} + tc.Conn = conn + return tc, nil } func NewClientFromSSHClient(sshClient *ssh.Client) (*Client, error) { diff --git a/lib/hangout/hangout_test.go b/lib/hangout/hangout_test.go index 289a8af8431fa..f1a4860d1767c 100644 --- a/lib/hangout/hangout_test.go +++ b/lib/hangout/hangout_test.go @@ -71,7 +71,7 @@ type HangoutsSuite struct { var _ = Suite(&HangoutsSuite{}) func (s *HangoutsSuite) SetUpSuite(c *C) { - utils.InitLoggerDebug() + utils.InitLoggerCLI() client.KeysDir = c.MkDir() s.dir = c.MkDir() diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index b7b894ad34a35..c091c4d60cd71 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -362,6 +362,10 @@ func (s *server) upsertRegularSite(conn net.Conn, sshConn *ssh.ServerConn) (*tun return nil, trace.Wrap(teleport.BadParameter( "authDomain", fmt.Sprintf("'%v' is a bad domain name", domainName))) } + + s.Lock() + defer s.Unlock() + var site *tunnelSite for _, st := range s.tunnelSites { if st.domainName == domainName { @@ -371,9 +375,6 @@ func (s *server) upsertRegularSite(conn net.Conn, sshConn *ssh.ServerConn) (*tun } log.Infof("found authority domain: %v", domainName) - s.Lock() - defer s.Unlock() - var err error if site != nil { if err := site.addConn(conn, sshConn); err != nil { @@ -392,16 +393,24 @@ func (s *server) upsertRegularSite(conn net.Conn, sshConn *ssh.ServerConn) (*tun return site, nil } -func (s *server) upsertHangoutSite(conn net.Conn, sshConn ssh.Conn) (*tunnelSite, error) { +func (s *server) tryInsertHangoutSite(hangoutID string, remoteSite *tunnelSite) error { s.Lock() defer s.Unlock() - hangoutID := sshConn.User() for _, st := range s.tunnelSites { if st.domainName == hangoutID { - return nil, trace.Errorf("Hangout ID is already used") + return trace.Wrap( + teleport.BadParameter("hangoutID", + fmt.Sprintf("%v hangout id is already used", hangoutID))) } } + s.tunnelSites = append(s.tunnelSites, remoteSite) + return nil + +} + +func (s *server) upsertHangoutSite(conn net.Conn, sshConn ssh.Conn) (*tunnelSite, error) { + hangoutID := sshConn.User() site, err := newRemoteSite(s, hangoutID) if err != nil { @@ -441,26 +450,36 @@ func (s *server) upsertHangoutSite(conn net.Conn, sshConn ssh.Conn) (*tunnelSite } // receiving hangoutInfo using sessions just as storage - sess, err := clt.GetSessions() + sessions, err := clt.GetSessions() if err != nil { return nil, trace.Wrap(err) } - if len(sess) != 1 { - return nil, trace.Wrap(&teleport.NotFoundError{ - Message: fmt.Sprintf("hangout %v not found", hangoutID), - }) + var hangoutInfo *utils.HangoutInfo + for _, sess := range sessions { + info, err := utils.UnmarshalHangoutInfo(sess.ID) + if err != nil { + log.Infof("failed to unmarshal hangout info: %v", err) + } + if info.HangoutID == hangoutID { + hangoutInfo = info + break + } } - hangoutInfo, err := utils.UnmarshalHangoutInfo(sess[0].ID) - if err != nil { - return nil, err + if hangoutInfo == nil { + return nil, trace.Wrap(teleport.NotFound( + fmt.Sprintf("hangout %v not found", hangoutID))) } - site.domainName = hangoutInfo.HangoutID + // TODO(klizhentas) refactor this + site.domainName = hangoutInfo.HangoutID site.hangoutInfo.OSUser = hangoutInfo.OSUser site.hangoutInfo.AuthPort = hangoutInfo.AuthPort site.hangoutInfo.NodePort = hangoutInfo.NodePort - s.tunnelSites = append(s.tunnelSites, site) + if err := s.tryInsertHangoutSite(hangoutID, site); err != nil { + defer conn.Close() + return nil, trace.Wrap(err) + } return site, nil } diff --git a/lib/services/web.go b/lib/services/web.go index 76056bae3d108..0eb6d36e4d0cb 100644 --- a/lib/services/web.go +++ b/lib/services/web.go @@ -355,7 +355,6 @@ func (s *WebService) CheckPassword(user string, password []byte, hotpToken strin if err != nil { return trace.Wrap(err) } - if !otp.Scan(hotpToken, 4) { return &teleport.BadParameterError{Err: "tokens do not match", Param: "token"} } diff --git a/lib/srv/sess.go b/lib/srv/sess.go index ea6f53e888ec9..a1fd9516064d9 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -204,7 +204,6 @@ func newSession(id string, r *sessionRegistry, context *ctx) (*session, error) { login: context.login, closeC: make(chan bool), } - go sess.pollAndSyncTerm() return sess, nil } @@ -282,6 +281,7 @@ func (s *session) start(sconn *ssh.ServerConn, ch ssh.Channel, ctx *ctx) error { return trace.Wrap(err) } } + go s.pollAndSyncTerm() cmd := exec.Command(s.registry.srv.shell) // TODO(klizhentas) figure out linux user policy for launching shells, // what user and environment should we use to execute the shell? the simplest @@ -377,6 +377,7 @@ func (s *session) syncTerm(sessionServer rsession.Service) error { log.Infof("syncTerm: no session") return trace.Wrap(err) } + log.Infof("syncTerm: term: %v", s.term) winSize, err := s.term.getWinsize() if err != nil { log.Infof("syncTerm: no terminal") diff --git a/lib/srv/srv_test.go b/lib/srv/srv_test.go index 90d7a47ff238a..d229ffaa9e139 100644 --- a/lib/srv/srv_test.go +++ b/lib/srv/srv_test.go @@ -44,7 +44,6 @@ import ( "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" - "github.com/gokyle/hotp" "github.com/gravitational/trace" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" @@ -401,20 +400,8 @@ func (s *SrvSuite) TestProxy(c *C) { c.Assert(err, IsNil) c.Assert(tsrv.Start(), IsNil) - user := "user1" - pass := []byte("utndkrn") - - hotpURL, _, err := s.a.UpsertPassword(user, pass) - c.Assert(err, IsNil) - otp, _, err := hotp.FromURL(hotpURL) - c.Assert(err, IsNil) - otp.Increment() - - authMethod, err := auth.NewWebPasswordAuth(user, pass, otp.OTP()) - c.Assert(err, IsNil) - tunClt, err := auth.NewTunClient( - utils.NetAddr{AddrNetwork: "tcp", Addr: tsrv.Addr()}, user, authMethod) + utils.NetAddr{AddrNetwork: "tcp", Addr: tsrv.Addr()}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)}) c.Assert(err, IsNil) defer tunClt.Close() @@ -606,20 +593,8 @@ func (s *SrvSuite) TestProxyRoundRobin(c *C) { c.Assert(err, IsNil) c.Assert(tsrv.Start(), IsNil) - user := "user1" - pass := []byte("utndkrn") - - hotpURL, _, err := s.a.UpsertPassword(user, pass) - c.Assert(err, IsNil) - otp, _, err := hotp.FromURL(hotpURL) - c.Assert(err, IsNil) - otp.Increment() - - authMethod, err := auth.NewWebPasswordAuth(user, pass, otp.OTP()) - c.Assert(err, IsNil) - tunClt, err := auth.NewTunClient( - utils.NetAddr{AddrNetwork: "tcp", Addr: tsrv.Addr()}, user, authMethod) + utils.NetAddr{AddrNetwork: "tcp", Addr: tsrv.Addr()}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)}) c.Assert(err, IsNil) defer tunClt.Close() @@ -716,20 +691,8 @@ func (s *SrvSuite) TestProxyDirectAccess(c *C) { c.Assert(err, IsNil) c.Assert(tsrv.Start(), IsNil) - user := "user1" - pass := []byte("utndkrn") - - hotpURL, _, err := s.a.UpsertPassword(user, pass) - c.Assert(err, IsNil) - otp, _, err := hotp.FromURL(hotpURL) - c.Assert(err, IsNil) - otp.Increment() - - authMethod, err := auth.NewWebPasswordAuth(user, pass, otp.OTP()) - c.Assert(err, IsNil) - tunClt, err := auth.NewTunClient( - utils.NetAddr{AddrNetwork: "tcp", Addr: tsrv.Addr()}, user, authMethod) + utils.NetAddr{AddrNetwork: "tcp", Addr: tsrv.Addr()}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)}) c.Assert(err, IsNil) defer tunClt.Close() diff --git a/lib/srv/term.go b/lib/srv/term.go index 7f74e4d8ccc84..12c6d270f255b 100644 --- a/lib/srv/term.go +++ b/lib/srv/term.go @@ -79,6 +79,7 @@ func requestPTY(req *ssh.Request) (*terminal, *rsession.TerminalParams, error) { } func (t *terminal) getWinsize() (*term.Winsize, error) { + log.Infof("pty: %v", t.pty) ws, err := term.GetWinsize(t.pty.Fd()) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/utils/hangout.go b/lib/utils/hangout.go index 17e1639e99709..1b3f5abcffb99 100644 --- a/lib/utils/hangout.go +++ b/lib/utils/hangout.go @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + package utils import ( @@ -23,10 +24,10 @@ import ( ) type HangoutInfo struct { - AuthPort string - NodePort string - HangoutID string - OSUser string + AuthPort string `json:"auth_port"` + NodePort string `json:"node_port"` + HangoutID string `json:"hangout_id"` + OSUser string `json:"os_user"` } func MarshalHangoutInfo(h *HangoutInfo) (string, error) { diff --git a/lib/web/connect.go b/lib/web/connect.go index 5f8242d837fa6..161cd4039d9cf 100644 --- a/lib/web/connect.go +++ b/lib/web/connect.go @@ -107,11 +107,17 @@ func (w *connectHandler) resizePTYWindow(params session.TerminalParams) error { } func (w *connectHandler) connectUpstream() (*sshutils.Upstream, error) { - methods, err := w.ctx.GetAuthMethods() + agent, err := w.ctx.GetAgent() if err != nil { return nil, trace.Wrap(err) } - client, err := w.site.ConnectToServer(w.req.Addr, w.req.Login, methods) + defer agent.Close() + signers, err := agent.Signers() + if err != nil { + return nil, trace.Wrap(err) + } + client, err := w.site.ConnectToServer( + w.req.Addr, w.req.Login, []ssh.AuthMethod{ssh.PublicKeys(signers...)}) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/web/sessions.go b/lib/web/sessions.go index ed12a24d8b734..67962fe44c42d 100644 --- a/lib/web/sessions.go +++ b/lib/web/sessions.go @@ -31,7 +31,6 @@ import ( log "github.com/Sirupsen/logrus" "github.com/gravitational/trace" "github.com/mailgun/ttlmap" - "golang.org/x/crypto/ssh" ) // sessionContext is a context associated with users' @@ -116,18 +115,14 @@ func (c *sessionContext) CreateWebSession() (*auth.Session, error) { return sess, nil } -// GetAuthMethods returns authentication methods (credentials) that proxy -// can use to connect to servers -func (c *sessionContext) GetAuthMethods() ([]ssh.AuthMethod, error) { +// GetAgent returns agent that can we used to answer challenges +// for the web to ssh connection +func (c *sessionContext) GetAgent() (auth.AgentCloser, error) { a, err := c.clt.GetAgent() if err != nil { return nil, trace.Wrap(err) } - signers, err := a.Signers() - if err != nil { - return nil, trace.Wrap(err) - } - return []ssh.AuthMethod{ssh.PublicKeys(signers...)}, nil + return a, nil } // Close cleans up connections associated with requests