Skip to content

Commit

Permalink
tsh Profile SSH certs fix (#6214)
Browse files Browse the repository at this point in the history
  • Loading branch information
Joerger authored and awly committed Mar 30, 2021
1 parent f3af7a5 commit aa42cd6
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 47 deletions.
42 changes: 13 additions & 29 deletions api/client/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/utils/sshutils"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -134,39 +133,24 @@ func TestLoadProfile(t *testing.T) {

// Write identity file to disk.
dir := t.TempDir()
name := "proxy"
name := "proxy.example.com"
p := &Profile{
WebProxyAddr: "proxy:3080",
WebProxyAddr: "proxy.example.com:3080",
SiteName: "example.com",
Username: "testUser",
Dir: dir,
}

// Save profile to a file.
err := p.SaveToDir(dir, true)
require.NoError(t, err)

// Write keys to disk.
keyDir := filepath.Join(dir, constants.SessionKeyDir)
err = os.MkdirAll(keyDir, 0700)
require.NoError(t, err)
userKeyDir := filepath.Join(keyDir, p.Name())
os.MkdirAll(userKeyDir, 0700)
require.NoError(t, err)
keyPath := filepath.Join(userKeyDir, p.Username)
err = ioutil.WriteFile(keyPath, []byte(keyPEM), 0600)
require.NoError(t, err)
tlsCertPath := filepath.Join(userKeyDir, p.Username+constants.FileExtTLSCert)
err = ioutil.WriteFile(tlsCertPath, []byte(tlsCert), 0600)
require.NoError(t, err)
tlsCasPath := filepath.Join(userKeyDir, constants.FileNameTLSCerts)
err = ioutil.WriteFile(tlsCasPath, []byte(tlsCACert), 0600)
require.NoError(t, err)
sshCertPath := filepath.Join(userKeyDir, p.Username+constants.FileExtSSHCert)
err = ioutil.WriteFile(sshCertPath, []byte(sshCert), 0600)
require.NoError(t, err)
sshCasPath := filepath.Join(dir, constants.FileNameKnownHosts)
err = ioutil.WriteFile(sshCasPath, []byte(sshCACert), 0600)
require.NoError(t, err)
// Save profile and keys to disk.
require.NoError(t, p.SaveToDir(dir, true))
require.NoError(t, os.MkdirAll(p.keyDir(), 0700))
require.NoError(t, os.MkdirAll(p.userKeyDir(), 0700))
require.NoError(t, os.MkdirAll(p.sshDir(), 0700))
require.NoError(t, ioutil.WriteFile(p.keyPath(), []byte(keyPEM), 0600))
require.NoError(t, ioutil.WriteFile(p.tlsCertPath(), []byte(tlsCert), 0600))
require.NoError(t, ioutil.WriteFile(p.tlsCasPath(), []byte(tlsCACert), 0600))
require.NoError(t, ioutil.WriteFile(p.sshCertPath(), []byte(sshCert), 0600))
require.NoError(t, ioutil.WriteFile(p.sshCasPath(), []byte(sshCACert), 0600))

// Load profile from disk.
creds := LoadProfile(dir, name)
Expand Down
53 changes: 40 additions & 13 deletions api/client/profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,23 +93,18 @@ func (p *Profile) Name() string {

// TLSConfig returns the profile's associated TLSConfig.
func (p *Profile) TLSConfig() (*tls.Config, error) {
credsPath := filepath.Join(p.Dir, constants.SessionKeyDir, p.Name())

certPath := filepath.Join(credsPath, p.Username+constants.FileExtTLSCert)
keyPath := filepath.Join(credsPath, p.Username)
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
cert, err := tls.LoadX509KeyPair(p.tlsCertPath(), p.keyPath())
if err != nil {
return nil, trace.Wrap(err)
}

certsPath := filepath.Join(credsPath, constants.FileNameTLSCerts)
certs, err := ioutil.ReadFile(certsPath)
caCerts, err := ioutil.ReadFile(p.tlsCasPath())
if err != nil {
return nil, trace.Wrap(err)
}

pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(certs) {
if !pool.AppendCertsFromPEM(caCerts) {
return nil, trace.BadParameter("invalid CA cert PEM")
}

Expand All @@ -121,22 +116,22 @@ func (p *Profile) TLSConfig() (*tls.Config, error) {

// SSHClientConfig returns the profile's associated SSHClientConfig.
func (p *Profile) SSHClientConfig() (*ssh.ClientConfig, error) {
credsPath := filepath.Join(p.Dir, constants.SessionKeyDir, p.Name())
cert, err := ioutil.ReadFile(filepath.Join(credsPath, p.Username+constants.FileExtSSHCert))
cert, err := ioutil.ReadFile(p.sshCertPath())
if err != nil {
return nil, trace.Wrap(err)
}
key, err := ioutil.ReadFile(filepath.Join(credsPath, p.Username))

key, err := ioutil.ReadFile(p.keyPath())
if err != nil {
return nil, trace.Wrap(err)
}

knownHosts, err := ioutil.ReadFile(filepath.Join(p.Dir, constants.FileNameKnownHosts))
caCerts, err := ioutil.ReadFile(p.sshCasPath())
if err != nil {
return nil, trace.Wrap(err)
}

ssh, err := sshutils.SSHClientConfig(cert, key, [][]byte{knownHosts})
ssh, err := sshutils.SSHClientConfig(cert, key, [][]byte{caCerts})
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -281,3 +276,35 @@ func (p *Profile) saveToFile(filepath string) error {
}
return nil
}

func (p *Profile) keyDir() string {
return filepath.Join(p.Dir, constants.SessionKeyDir)
}

func (p *Profile) userKeyDir() string {
return filepath.Join(p.keyDir(), p.Name())
}

func (p *Profile) keyPath() string {
return filepath.Join(p.userKeyDir(), p.Username)
}

func (p *Profile) tlsCertPath() string {
return filepath.Join(p.userKeyDir(), p.Username+constants.FileExtTLSCert)
}

func (p *Profile) tlsCasPath() string {
return filepath.Join(p.userKeyDir(), constants.FileNameTLSCerts)
}

func (p *Profile) sshDir() string {
return filepath.Join(p.userKeyDir(), p.Username+constants.SSHDirSuffix)
}

func (p *Profile) sshCertPath() string {
return filepath.Join(p.sshDir(), p.SiteName+constants.FileExtSSHCert)
}

func (p *Profile) sshCasPath() string {
return filepath.Join(p.Dir, constants.FileNameKnownHosts)
}
6 changes: 5 additions & 1 deletion api/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ const (
const (
// SessionKeyDir is the sub-directory where session keys are stored (.tsh/keys).
SessionKeyDir = "keys"
// SSHDirSuffix is the suffix of the sub-directory where
// ssh keys are stored (.tsh/keys/profilename/username-ssh).
SSHDirSuffix = "-ssh"
// FileNameKnownHosts is a file that stores known hosts.
FileNameKnownHosts = "known_hosts"
// FileExtTLSCert is the filename extension/suffix of TLS certs
Expand All @@ -140,7 +143,8 @@ const (
// FileNameTLSCerts is the filename of Cert Authorities stored in a
// profile (./tsh/keys/profilename/certs.pem).
FileNameTLSCerts = "certs.pem"
// FileExtCert is a file extension used for SSH Certificate files.
// FileExtCert is a file extension used for SSH Certificate files
// (.tsh/keys/profilename/username-ssh/clustername-cert.pub).
FileExtSSHCert = "-cert.pub"
// FileExtPub is a file extension used for SSH Certificate Authorities
// stored in a profile (./tsh/keys/profilename/username.pub).
Expand Down
2 changes: 1 addition & 1 deletion lib/client/keyagent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (s *KeyAgentTestSuite) TestAddKey(c *check.C) {
s.username, // private key
s.username + constants.FileExtPub, // public key
s.username + constants.FileExtTLSCert, // Teleport TLS certificate
filepath.Join(s.username+sshDirSuffix, s.key.ClusterName+constants.FileExtSSHCert), // SSH certificate
filepath.Join(s.username+constants.SSHDirSuffix, s.key.ClusterName+constants.FileExtSSHCert), // SSH certificate
}
for _, file := range expectedFiles {
_, err := os.Stat(filepath.Join(s.keyDir, "keys", s.hostname, file))
Expand Down
5 changes: 2 additions & 3 deletions lib/client/keystore.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ import (
)

const (
sshDirSuffix = "-ssh"
kubeDirSuffix = "-kube"
dbDirSuffix = "-db"
appDirSuffix = "-app"
Expand Down Expand Up @@ -177,7 +176,7 @@ func (fs *FSLocalKeyStore) AddKey(key *Key) error {
}

// Store per-cluster key data.
if err := fs.writeBytes(key.Cert, inProxyHostDir(key.Username+sshDirSuffix, key.ClusterName+constants.FileExtSSHCert)); err != nil {
if err := fs.writeBytes(key.Cert, inProxyHostDir(key.Username+constants.SSHDirSuffix, key.ClusterName+constants.FileExtSSHCert)); err != nil {
return trace.Wrap(err)
}
// TODO(awly): unit test this.
Expand Down Expand Up @@ -385,7 +384,7 @@ var WithAllCerts = []CertOption{WithSSHCerts{}, WithKubeCerts{}, WithDBCerts{},
type WithSSHCerts struct{}

func (o WithSSHCerts) relativeCertPath(idx KeyIndex) string {
components := []string{idx.ProxyHost, idx.Username + sshDirSuffix}
components := []string{idx.ProxyHost, idx.Username + constants.SSHDirSuffix}
if idx.ClusterName != "" {
components = append(components, idx.ClusterName+constants.FileExtSSHCert)
}
Expand Down

0 comments on commit aa42cd6

Please sign in to comment.