diff --git a/api/client/credentials_test.go b/api/client/credentials_test.go index 3923d72b6d030..520ab48455a77 100644 --- a/api/client/credentials_test.go +++ b/api/client/credentials_test.go @@ -26,7 +26,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/utils/sshutils" "github.com/stretchr/testify/require" @@ -134,39 +133,24 @@ func TestLoadProfile(t *testing.T) { // Write identity file to disk. dir := t.TempDir() - name := "proxy" + name := "proxy.example.com" p := &Profile{ - WebProxyAddr: "proxy:3080", + WebProxyAddr: "proxy.example.com:3080", + SiteName: "example.com", Username: "testUser", Dir: dir, } - // Save profile to a file. - err := p.SaveToDir(dir, true) - require.NoError(t, err) - - // Write keys to disk. - keyDir := filepath.Join(dir, constants.SessionKeyDir) - err = os.MkdirAll(keyDir, 0700) - require.NoError(t, err) - userKeyDir := filepath.Join(keyDir, p.Name()) - os.MkdirAll(userKeyDir, 0700) - require.NoError(t, err) - keyPath := filepath.Join(userKeyDir, p.Username) - err = ioutil.WriteFile(keyPath, []byte(keyPEM), 0600) - require.NoError(t, err) - tlsCertPath := filepath.Join(userKeyDir, p.Username+constants.FileExtTLSCert) - err = ioutil.WriteFile(tlsCertPath, []byte(tlsCert), 0600) - require.NoError(t, err) - tlsCasPath := filepath.Join(userKeyDir, constants.FileNameTLSCerts) - err = ioutil.WriteFile(tlsCasPath, []byte(tlsCACert), 0600) - require.NoError(t, err) - sshCertPath := filepath.Join(userKeyDir, p.Username+constants.FileExtSSHCert) - err = ioutil.WriteFile(sshCertPath, []byte(sshCert), 0600) - require.NoError(t, err) - sshCasPath := filepath.Join(dir, constants.FileNameKnownHosts) - err = ioutil.WriteFile(sshCasPath, []byte(sshCACert), 0600) - require.NoError(t, err) + // Save profile and keys to disk. + require.NoError(t, p.SaveToDir(dir, true)) + require.NoError(t, os.MkdirAll(p.keyDir(), 0700)) + require.NoError(t, os.MkdirAll(p.userKeyDir(), 0700)) + require.NoError(t, os.MkdirAll(p.sshDir(), 0700)) + require.NoError(t, ioutil.WriteFile(p.keyPath(), []byte(keyPEM), 0600)) + require.NoError(t, ioutil.WriteFile(p.tlsCertPath(), []byte(tlsCert), 0600)) + require.NoError(t, ioutil.WriteFile(p.tlsCasPath(), []byte(tlsCACert), 0600)) + require.NoError(t, ioutil.WriteFile(p.sshCertPath(), []byte(sshCert), 0600)) + require.NoError(t, ioutil.WriteFile(p.sshCasPath(), []byte(sshCACert), 0600)) // Load profile from disk. creds := LoadProfile(dir, name) diff --git a/api/client/profile.go b/api/client/profile.go index 1e1e0995fbeae..52faaa9826a34 100644 --- a/api/client/profile.go +++ b/api/client/profile.go @@ -93,23 +93,18 @@ func (p *Profile) Name() string { // TLSConfig returns the profile's associated TLSConfig. func (p *Profile) TLSConfig() (*tls.Config, error) { - credsPath := filepath.Join(p.Dir, constants.SessionKeyDir, p.Name()) - - certPath := filepath.Join(credsPath, p.Username+constants.FileExtTLSCert) - keyPath := filepath.Join(credsPath, p.Username) - cert, err := tls.LoadX509KeyPair(certPath, keyPath) + cert, err := tls.LoadX509KeyPair(p.tlsCertPath(), p.keyPath()) if err != nil { return nil, trace.Wrap(err) } - certsPath := filepath.Join(credsPath, constants.FileNameTLSCerts) - certs, err := ioutil.ReadFile(certsPath) + caCerts, err := ioutil.ReadFile(p.tlsCasPath()) if err != nil { return nil, trace.Wrap(err) } pool := x509.NewCertPool() - if !pool.AppendCertsFromPEM(certs) { + if !pool.AppendCertsFromPEM(caCerts) { return nil, trace.BadParameter("invalid CA cert PEM") } @@ -121,22 +116,22 @@ func (p *Profile) TLSConfig() (*tls.Config, error) { // SSHClientConfig returns the profile's associated SSHClientConfig. func (p *Profile) SSHClientConfig() (*ssh.ClientConfig, error) { - credsPath := filepath.Join(p.Dir, constants.SessionKeyDir, p.Name()) - cert, err := ioutil.ReadFile(filepath.Join(credsPath, p.Username+constants.FileExtSSHCert)) + cert, err := ioutil.ReadFile(p.sshCertPath()) if err != nil { return nil, trace.Wrap(err) } - key, err := ioutil.ReadFile(filepath.Join(credsPath, p.Username)) + + key, err := ioutil.ReadFile(p.keyPath()) if err != nil { return nil, trace.Wrap(err) } - knownHosts, err := ioutil.ReadFile(filepath.Join(p.Dir, constants.FileNameKnownHosts)) + caCerts, err := ioutil.ReadFile(p.sshCasPath()) if err != nil { return nil, trace.Wrap(err) } - ssh, err := sshutils.SSHClientConfig(cert, key, [][]byte{knownHosts}) + ssh, err := sshutils.SSHClientConfig(cert, key, [][]byte{caCerts}) if err != nil { return nil, trace.Wrap(err) } @@ -281,3 +276,35 @@ func (p *Profile) saveToFile(filepath string) error { } return nil } + +func (p *Profile) keyDir() string { + return filepath.Join(p.Dir, constants.SessionKeyDir) +} + +func (p *Profile) userKeyDir() string { + return filepath.Join(p.keyDir(), p.Name()) +} + +func (p *Profile) keyPath() string { + return filepath.Join(p.userKeyDir(), p.Username) +} + +func (p *Profile) tlsCertPath() string { + return filepath.Join(p.userKeyDir(), p.Username+constants.FileExtTLSCert) +} + +func (p *Profile) tlsCasPath() string { + return filepath.Join(p.userKeyDir(), constants.FileNameTLSCerts) +} + +func (p *Profile) sshDir() string { + return filepath.Join(p.userKeyDir(), p.Username+constants.SSHDirSuffix) +} + +func (p *Profile) sshCertPath() string { + return filepath.Join(p.sshDir(), p.SiteName+constants.FileExtSSHCert) +} + +func (p *Profile) sshCasPath() string { + return filepath.Join(p.Dir, constants.FileNameKnownHosts) +} diff --git a/api/constants/constants.go b/api/constants/constants.go index a8699829bb5eb..d13335ff45b12 100644 --- a/api/constants/constants.go +++ b/api/constants/constants.go @@ -132,6 +132,9 @@ const ( const ( // SessionKeyDir is the sub-directory where session keys are stored (.tsh/keys). SessionKeyDir = "keys" + // SSHDirSuffix is the suffix of the sub-directory where + // ssh keys are stored (.tsh/keys/profilename/username-ssh). + SSHDirSuffix = "-ssh" // FileNameKnownHosts is a file that stores known hosts. FileNameKnownHosts = "known_hosts" // FileExtTLSCert is the filename extension/suffix of TLS certs @@ -140,7 +143,8 @@ const ( // FileNameTLSCerts is the filename of Cert Authorities stored in a // profile (./tsh/keys/profilename/certs.pem). FileNameTLSCerts = "certs.pem" - // FileExtCert is a file extension used for SSH Certificate files. + // FileExtCert is a file extension used for SSH Certificate files + // (.tsh/keys/profilename/username-ssh/clustername-cert.pub). FileExtSSHCert = "-cert.pub" // FileExtPub is a file extension used for SSH Certificate Authorities // stored in a profile (./tsh/keys/profilename/username.pub). diff --git a/lib/client/keyagent_test.go b/lib/client/keyagent_test.go index 8a60aa7caf961..133419a367b6d 100644 --- a/lib/client/keyagent_test.go +++ b/lib/client/keyagent_test.go @@ -114,7 +114,7 @@ func (s *KeyAgentTestSuite) TestAddKey(c *check.C) { s.username, // private key s.username + constants.FileExtPub, // public key s.username + constants.FileExtTLSCert, // Teleport TLS certificate - filepath.Join(s.username+sshDirSuffix, s.key.ClusterName+constants.FileExtSSHCert), // SSH certificate + filepath.Join(s.username+constants.SSHDirSuffix, s.key.ClusterName+constants.FileExtSSHCert), // SSH certificate } for _, file := range expectedFiles { _, err := os.Stat(filepath.Join(s.keyDir, "keys", s.hostname, file)) diff --git a/lib/client/keystore.go b/lib/client/keystore.go index 4f4ee3fa9aebf..9adc4b194523b 100644 --- a/lib/client/keystore.go +++ b/lib/client/keystore.go @@ -41,7 +41,6 @@ import ( ) const ( - sshDirSuffix = "-ssh" kubeDirSuffix = "-kube" dbDirSuffix = "-db" appDirSuffix = "-app" @@ -177,7 +176,7 @@ func (fs *FSLocalKeyStore) AddKey(key *Key) error { } // Store per-cluster key data. - if err := fs.writeBytes(key.Cert, inProxyHostDir(key.Username+sshDirSuffix, key.ClusterName+constants.FileExtSSHCert)); err != nil { + if err := fs.writeBytes(key.Cert, inProxyHostDir(key.Username+constants.SSHDirSuffix, key.ClusterName+constants.FileExtSSHCert)); err != nil { return trace.Wrap(err) } // TODO(awly): unit test this. @@ -385,7 +384,7 @@ var WithAllCerts = []CertOption{WithSSHCerts{}, WithKubeCerts{}, WithDBCerts{}, type WithSSHCerts struct{} func (o WithSSHCerts) relativeCertPath(idx KeyIndex) string { - components := []string{idx.ProxyHost, idx.Username + sshDirSuffix} + components := []string{idx.ProxyHost, idx.Username + constants.SSHDirSuffix} if idx.ClusterName != "" { components = append(components, idx.ClusterName+constants.FileExtSSHCert) }