diff --git a/command/command.go b/command/command.go index b759fb9..40b88ee 100644 --- a/command/command.go +++ b/command/command.go @@ -379,11 +379,12 @@ func (cmd *BaseCommand) RemoveServo(servo Servo) error { // Servo represents a deployed Servo assembly running somewhere type Servo struct { - Name string - User string - Host string - Port string - Path string + Name string + User string + Host string + Port string + Path string + Bastion string } func (s Servo) HostAndPort() string { @@ -420,3 +421,13 @@ func (s Servo) URL() string { } return fmt.Sprintf("ssh://%s@%s:%s", s.User, s.DisplayHost(), pathComponent) } + +func (s Servo) BastionComponents() (string, string) { + components := strings.Split(s.Bastion, "@") + user := components[0] + host := components[1] + if !strings.Contains(host, ":") { + host = host + ":22" + } + return user, host +} diff --git a/command/servo.go b/command/servo.go index 8a15a01..7612d74 100644 --- a/command/servo.go +++ b/command/servo.go @@ -420,15 +420,48 @@ func (servoCmd *servoCommand) runInSSHSession(ctx context.Context, name string, HostKeyCallback: hostKeyCallback, } - // Connect to host - client, err := ssh.Dial("tcp", servo.HostAndPort(), config) - if err != nil { - log.Fatal(err) + // Support bastion hosts via redialing + var sshClient *ssh.Client + if servo.Bastion != "" { + user, host := servo.BastionComponents() + bastionConfig := &ssh.ClientConfig{ + User: user, + Auth: []ssh.AuthMethod{ + servoCmd.sshAgent(), + }, + HostKeyCallback: hostKeyCallback, + } + + // Dial the bastion host + bastionClient, err := ssh.Dial("tcp", host, bastionConfig) + if err != nil { + log.Fatal(err) + } + + // Establish a new connection thrrough the bastion + conn, err := bastionClient.Dial("tcp", servo.HostAndPort()) + if err != nil { + log.Fatal(err) + } + + // Build a new SSH connection on top of the bastion connection + ncc, chans, reqs, err := ssh.NewClientConn(conn, servo.HostAndPort(), config) + if err != nil { + log.Fatal(err) + } + + // Now connection a client on top of it + sshClient = ssh.NewClient(ncc, chans, reqs) + } else { + sshClient, err = ssh.Dial("tcp", servo.HostAndPort(), config) + if err != nil { + log.Fatal(err) + } } - defer client.Close() + defer sshClient.Close() // Create sesssion - session, err := client.NewSession() + session, err := sshClient.NewSession() if err != nil { log.Fatal("Failed to create session: ", err) } @@ -436,7 +469,7 @@ func (servoCmd *servoCommand) runInSSHSession(ctx context.Context, name string, go func() { <-ctx.Done() - client.Close() + sshClient.Close() }() return runIt(ctx, *servo, session)