forked from influxdata/telegraf
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This solves issue with collecting user stats from mariadb 10 See influxdata#2910 (Note - the person claiming 1.8 didn't work for them is using mariadb 5... we are using mariadb 10 which is explicitly handled in the new plugin)
- Loading branch information
1 parent
378702f
commit a303e6b
Showing
14 changed files
with
1,276 additions
and
469 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
package tls | ||
|
||
import ( | ||
"crypto/tls" | ||
"crypto/x509" | ||
"fmt" | ||
"io/ioutil" | ||
) | ||
|
||
// ClientConfig represents the standard client TLS config. | ||
type ClientConfig struct { | ||
TLSCA string `toml:"tls_ca"` | ||
TLSCert string `toml:"tls_cert"` | ||
TLSKey string `toml:"tls_key"` | ||
InsecureSkipVerify bool `toml:"insecure_skip_verify"` | ||
|
||
// Deprecated in 1.7; use TLS variables above | ||
SSLCA string `toml:"ssl_ca"` | ||
SSLCert string `toml:"ssl_cert"` | ||
SSLKey string `toml:"ssl_key"` | ||
} | ||
|
||
// ServerConfig represents the standard server TLS config. | ||
type ServerConfig struct { | ||
TLSCert string `toml:"tls_cert"` | ||
TLSKey string `toml:"tls_key"` | ||
TLSAllowedCACerts []string `toml:"tls_allowed_cacerts"` | ||
} | ||
|
||
// TLSConfig returns a tls.Config, may be nil without error if TLS is not | ||
// configured. | ||
func (c *ClientConfig) TLSConfig() (*tls.Config, error) { | ||
// Support deprecated variable names | ||
if c.TLSCA == "" && c.SSLCA != "" { | ||
c.TLSCA = c.SSLCA | ||
} | ||
if c.TLSCert == "" && c.SSLCert != "" { | ||
c.TLSCert = c.SSLCert | ||
} | ||
if c.TLSKey == "" && c.SSLKey != "" { | ||
c.TLSKey = c.SSLKey | ||
} | ||
|
||
// TODO: return default tls.Config; plugins should not call if they don't | ||
// want TLS, this will require using another option to determine. In the | ||
// case of an HTTP plugin, you could use `https`. Other plugins may need | ||
// the dedicated option `TLSEnable`. | ||
if c.TLSCA == "" && c.TLSKey == "" && c.TLSCert == "" && !c.InsecureSkipVerify { | ||
return nil, nil | ||
} | ||
|
||
tlsConfig := &tls.Config{ | ||
InsecureSkipVerify: c.InsecureSkipVerify, | ||
Renegotiation: tls.RenegotiateNever, | ||
} | ||
|
||
if c.TLSCA != "" { | ||
pool, err := makeCertPool([]string{c.TLSCA}) | ||
if err != nil { | ||
return nil, err | ||
} | ||
tlsConfig.RootCAs = pool | ||
} | ||
|
||
if c.TLSCert != "" && c.TLSKey != "" { | ||
err := loadCertificate(tlsConfig, c.TLSCert, c.TLSKey) | ||
if err != nil { | ||
return nil, err | ||
} | ||
} | ||
|
||
return tlsConfig, nil | ||
} | ||
|
||
// TLSConfig returns a tls.Config, may be nil without error if TLS is not | ||
// configured. | ||
func (c *ServerConfig) TLSConfig() (*tls.Config, error) { | ||
if c.TLSCert == "" && c.TLSKey == "" && len(c.TLSAllowedCACerts) == 0 { | ||
return nil, nil | ||
} | ||
|
||
tlsConfig := &tls.Config{} | ||
|
||
if len(c.TLSAllowedCACerts) != 0 { | ||
pool, err := makeCertPool(c.TLSAllowedCACerts) | ||
if err != nil { | ||
return nil, err | ||
} | ||
tlsConfig.ClientCAs = pool | ||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert | ||
} | ||
|
||
if c.TLSCert != "" && c.TLSKey != "" { | ||
err := loadCertificate(tlsConfig, c.TLSCert, c.TLSKey) | ||
if err != nil { | ||
return nil, err | ||
} | ||
} | ||
|
||
return tlsConfig, nil | ||
} | ||
|
||
func makeCertPool(certFiles []string) (*x509.CertPool, error) { | ||
pool := x509.NewCertPool() | ||
for _, certFile := range certFiles { | ||
pem, err := ioutil.ReadFile(certFile) | ||
if err != nil { | ||
return nil, fmt.Errorf( | ||
"could not read certificate %q: %v", certFile, err) | ||
} | ||
ok := pool.AppendCertsFromPEM(pem) | ||
if !ok { | ||
return nil, fmt.Errorf( | ||
"could not parse any PEM certificates %q: %v", certFile, err) | ||
} | ||
} | ||
return pool, nil | ||
} | ||
|
||
func loadCertificate(config *tls.Config, certFile, keyFile string) error { | ||
cert, err := tls.LoadX509KeyPair(certFile, keyFile) | ||
if err != nil { | ||
return fmt.Errorf( | ||
"could not load keypair %s:%s: %v", certFile, keyFile, err) | ||
} | ||
|
||
config.Certificates = []tls.Certificate{cert} | ||
config.BuildNameToCertificate() | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
package tls_test | ||
|
||
import ( | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
"time" | ||
|
||
"github.com/influxdata/telegraf/internal/tls" | ||
"github.com/influxdata/telegraf/testutil" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
var pki = testutil.NewPKI("../../testutil/pki") | ||
|
||
func TestClientConfig(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
client tls.ClientConfig | ||
expNil bool | ||
expErr bool | ||
}{ | ||
{ | ||
name: "unset", | ||
client: tls.ClientConfig{}, | ||
expNil: true, | ||
}, | ||
{ | ||
name: "success", | ||
client: tls.ClientConfig{ | ||
TLSCA: pki.CACertPath(), | ||
TLSCert: pki.ClientCertPath(), | ||
TLSKey: pki.ClientKeyPath(), | ||
}, | ||
}, | ||
{ | ||
name: "invalid ca", | ||
client: tls.ClientConfig{ | ||
TLSCA: pki.ClientKeyPath(), | ||
TLSCert: pki.ClientCertPath(), | ||
TLSKey: pki.ClientKeyPath(), | ||
}, | ||
expNil: true, | ||
expErr: true, | ||
}, | ||
{ | ||
name: "missing ca is okay", | ||
client: tls.ClientConfig{ | ||
TLSCert: pki.ClientCertPath(), | ||
TLSKey: pki.ClientKeyPath(), | ||
}, | ||
}, | ||
{ | ||
name: "invalid cert", | ||
client: tls.ClientConfig{ | ||
TLSCA: pki.CACertPath(), | ||
TLSCert: pki.ClientKeyPath(), | ||
TLSKey: pki.ClientKeyPath(), | ||
}, | ||
expNil: true, | ||
expErr: true, | ||
}, | ||
{ | ||
name: "missing cert skips client keypair", | ||
client: tls.ClientConfig{ | ||
TLSCA: pki.CACertPath(), | ||
TLSKey: pki.ClientKeyPath(), | ||
}, | ||
expNil: false, | ||
expErr: false, | ||
}, | ||
{ | ||
name: "missing key skips client keypair", | ||
client: tls.ClientConfig{ | ||
TLSCA: pki.CACertPath(), | ||
TLSCert: pki.ClientCertPath(), | ||
}, | ||
expNil: false, | ||
expErr: false, | ||
}, | ||
{ | ||
name: "support deprecated ssl field names", | ||
client: tls.ClientConfig{ | ||
SSLCA: pki.CACertPath(), | ||
SSLCert: pki.ClientCertPath(), | ||
SSLKey: pki.ClientKeyPath(), | ||
}, | ||
}, | ||
} | ||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
tlsConfig, err := tt.client.TLSConfig() | ||
if !tt.expNil { | ||
require.NotNil(t, tlsConfig) | ||
} else { | ||
require.Nil(t, tlsConfig) | ||
} | ||
|
||
if !tt.expErr { | ||
require.NoError(t, err) | ||
} else { | ||
require.Error(t, err) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestServerConfig(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
server tls.ServerConfig | ||
expNil bool | ||
expErr bool | ||
}{ | ||
{ | ||
name: "unset", | ||
server: tls.ServerConfig{}, | ||
expNil: true, | ||
}, | ||
{ | ||
name: "success", | ||
server: tls.ServerConfig{ | ||
TLSCert: pki.ServerCertPath(), | ||
TLSKey: pki.ServerKeyPath(), | ||
TLSAllowedCACerts: []string{pki.CACertPath()}, | ||
}, | ||
}, | ||
{ | ||
name: "invalid ca", | ||
server: tls.ServerConfig{ | ||
TLSCert: pki.ServerCertPath(), | ||
TLSKey: pki.ServerKeyPath(), | ||
TLSAllowedCACerts: []string{pki.ServerKeyPath()}, | ||
}, | ||
expNil: true, | ||
expErr: true, | ||
}, | ||
{ | ||
name: "missing allowed ca is okay", | ||
server: tls.ServerConfig{ | ||
TLSCert: pki.ServerCertPath(), | ||
TLSKey: pki.ServerKeyPath(), | ||
}, | ||
expNil: true, | ||
expErr: true, | ||
}, | ||
{ | ||
name: "invalid cert", | ||
server: tls.ServerConfig{ | ||
TLSCert: pki.ServerKeyPath(), | ||
TLSKey: pki.ServerKeyPath(), | ||
TLSAllowedCACerts: []string{pki.CACertPath()}, | ||
}, | ||
expNil: true, | ||
expErr: true, | ||
}, | ||
{ | ||
name: "missing cert", | ||
server: tls.ServerConfig{ | ||
TLSKey: pki.ServerKeyPath(), | ||
TLSAllowedCACerts: []string{pki.CACertPath()}, | ||
}, | ||
expNil: true, | ||
expErr: true, | ||
}, | ||
{ | ||
name: "missing key", | ||
server: tls.ServerConfig{ | ||
TLSCert: pki.ServerCertPath(), | ||
TLSAllowedCACerts: []string{pki.CACertPath()}, | ||
}, | ||
expNil: true, | ||
expErr: true, | ||
}, | ||
} | ||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
tlsConfig, err := tt.server.TLSConfig() | ||
if !tt.expNil { | ||
require.NotNil(t, tlsConfig) | ||
} | ||
if !tt.expErr { | ||
require.NoError(t, err) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestConnect(t *testing.T) { | ||
clientConfig := tls.ClientConfig{ | ||
TLSCA: pki.CACertPath(), | ||
TLSCert: pki.ClientCertPath(), | ||
TLSKey: pki.ClientKeyPath(), | ||
} | ||
|
||
serverConfig := tls.ServerConfig{ | ||
TLSCert: pki.ServerCertPath(), | ||
TLSKey: pki.ServerKeyPath(), | ||
TLSAllowedCACerts: []string{pki.CACertPath()}, | ||
} | ||
|
||
serverTLSConfig, err := serverConfig.TLSConfig() | ||
require.NoError(t, err) | ||
|
||
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusOK) | ||
})) | ||
ts.TLS = serverTLSConfig | ||
|
||
ts.StartTLS() | ||
defer ts.Close() | ||
|
||
clientTLSConfig, err := clientConfig.TLSConfig() | ||
require.NoError(t, err) | ||
|
||
client := http.Client{ | ||
Transport: &http.Transport{ | ||
TLSClientConfig: clientTLSConfig, | ||
}, | ||
Timeout: 10 * time.Second, | ||
} | ||
|
||
resp, err := client.Get(ts.URL) | ||
require.NoError(t, err) | ||
require.Equal(t, 200, resp.StatusCode) | ||
} |
Oops, something went wrong.