diff --git a/cert.go b/cert.go deleted file mode 100644 index bbd29c6d4..000000000 --- a/cert.go +++ /dev/null @@ -1,163 +0,0 @@ -package nebula - -import ( - "errors" - "fmt" - "io/ioutil" - "strings" - "time" - - "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/config" -) - -type CertState struct { - certificate *cert.NebulaCertificate - rawCertificate []byte - rawCertificateNoKey []byte - publicKey []byte - privateKey []byte -} - -func NewCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) { - // Marshal the certificate to ensure it is valid - rawCertificate, err := certificate.Marshal() - if err != nil { - return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err) - } - - publicKey := certificate.Details.PublicKey - cs := &CertState{ - rawCertificate: rawCertificate, - certificate: certificate, // PublicKey has been set to nil above - privateKey: privateKey, - publicKey: publicKey, - } - - cs.certificate.Details.PublicKey = nil - rawCertNoKey, err := cs.certificate.Marshal() - if err != nil { - return nil, fmt.Errorf("error marshalling certificate no key: %s", err) - } - cs.rawCertificateNoKey = rawCertNoKey - // put public key back - cs.certificate.Details.PublicKey = cs.publicKey - return cs, nil -} - -func NewCertStateFromConfig(c *config.C) (*CertState, error) { - var pemPrivateKey []byte - var err error - - privPathOrPEM := c.GetString("pki.key", "") - - if privPathOrPEM == "" { - return nil, errors.New("no pki.key path or PEM data provided") - } - - if strings.Contains(privPathOrPEM, "-----BEGIN") { - pemPrivateKey = []byte(privPathOrPEM) - privPathOrPEM = "" - } else { - pemPrivateKey, err = ioutil.ReadFile(privPathOrPEM) - if err != nil { - return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) - } - } - - rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey) - if err != nil { - return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) - } - - var rawCert []byte - - pubPathOrPEM := c.GetString("pki.cert", "") - - if pubPathOrPEM == "" { - return nil, errors.New("no pki.cert path or PEM data provided") - } - - if strings.Contains(pubPathOrPEM, "-----BEGIN") { - rawCert = []byte(pubPathOrPEM) - pubPathOrPEM = "" - } else { - rawCert, err = ioutil.ReadFile(pubPathOrPEM) - if err != nil { - return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err) - } - } - - nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert) - if err != nil { - return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err) - } - - if nebulaCert.Expired(time.Now()) { - return nil, fmt.Errorf("nebula certificate for this host is expired") - } - - if len(nebulaCert.Details.Ips) == 0 { - return nil, fmt.Errorf("no IPs encoded in certificate") - } - - if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil { - return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") - } - - return NewCertState(nebulaCert, rawKey) -} - -func loadCAFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) { - var rawCA []byte - var err error - - caPathOrPEM := c.GetString("pki.ca", "") - if caPathOrPEM == "" { - return nil, errors.New("no pki.ca path or PEM data provided") - } - - if strings.Contains(caPathOrPEM, "-----BEGIN") { - rawCA = []byte(caPathOrPEM) - - } else { - rawCA, err = ioutil.ReadFile(caPathOrPEM) - if err != nil { - return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err) - } - } - - CAs, err := cert.NewCAPoolFromBytes(rawCA) - if errors.Is(err, cert.ErrExpired) { - var expired int - for _, cert := range CAs.CAs { - if cert.Expired(time.Now()) { - expired++ - l.WithField("cert", cert).Warn("expired certificate present in CA pool") - } - } - - if expired >= len(CAs.CAs) { - return nil, errors.New("no valid CA certificates present") - } - - } else if err != nil { - return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err) - } - - for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) { - l.WithField("fingerprint", fp).Info("Blocklisting cert") - CAs.BlocklistFingerprint(fp) - } - - // Support deprecated config for at least one minor release to allow for migrations - //TODO: remove in 2022 or later - for _, fp := range c.GetStringSlice("pki.blacklist", []string{}) { - l.WithField("fingerprint", fp).Info("Blocklisting cert") - l.Warn("pki.blacklist is deprecated and will not be supported in a future release. Please migrate your config to use pki.blocklist") - CAs.BlocklistFingerprint(fp) - } - - return CAs, nil -} diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index c1de26722..5616cd4b9 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -59,14 +59,15 @@ func main() { } ctrl, err := nebula.Main(c, *configTest, Build, l, nil) - - switch v := err.(type) { - case util.ContextualError: - v.Log(l) - os.Exit(1) - case error: - l.WithError(err).Error("Failed to start") - os.Exit(1) + if err != nil { + switch v := err.(type) { + case *util.ContextualError: + v.Log(l) + os.Exit(1) + case error: + l.WithError(err).Error("Failed to start") + os.Exit(1) + } } if !*configTest { diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index 9461035b1..d59ccd347 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -53,14 +53,15 @@ func main() { } ctrl, err := nebula.Main(c, *configTest, Build, l, nil) - - switch v := err.(type) { - case util.ContextualError: - v.Log(l) - os.Exit(1) - case error: - l.WithError(err).Error("Failed to start") - os.Exit(1) + if err != nil { + switch v := err.(type) { + case *util.ContextualError: + v.Log(l) + os.Exit(1) + case error: + l.WithError(err).Error("Failed to start") + os.Exit(1) + } } if !*configTest { diff --git a/connection_manager.go b/connection_manager.go index 528cf1b66..62a8dd234 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -405,8 +405,8 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { return false } - certState := n.intf.certState.Load() - return bytes.Equal(current.ConnectionState.certState.certificate.Signature, certState.certificate.Signature) + certState := n.intf.pki.GetCertState() + return bytes.Equal(current.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature) } func (n *connectionManager) swapPrimary(current, primary *HostInfo) { @@ -427,7 +427,7 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn return false } - valid, err := remoteCert.VerifyWithCache(now, n.intf.caPool) + valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool()) if valid { return false } @@ -464,8 +464,8 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) { } func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { - certState := n.intf.certState.Load() - if bytes.Equal(hostinfo.ConnectionState.certState.certificate.Signature, certState.certificate.Signature) { + certState := n.intf.pki.GetCertState() + if bytes.Equal(hostinfo.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature) { return } diff --git a/connection_manager_test.go b/connection_manager_test.go index 642e0554c..a489bf2bc 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -44,10 +44,10 @@ func Test_NewConnectionManagerTest(t *testing.T) { // Very incomplete mock objects hostMap := NewHostMap(l, vpncidr, preferredRanges) cs := &CertState{ - rawCertificate: []byte{}, - privateKey: []byte{}, - certificate: &cert.NebulaCertificate{}, - rawCertificateNoKey: []byte{}, + RawCertificate: []byte{}, + PrivateKey: []byte{}, + Certificate: &cert.NebulaCertificate{}, + RawCertificateNoKey: []byte{}, } lh := newTestLighthouse() @@ -57,10 +57,11 @@ func Test_NewConnectionManagerTest(t *testing.T) { outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, + pki: &PKI{}, handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), l: l, } - ifce.certState.Store(cs) + ifce.pki.cs.Store(cs) // Create manager ctx, cancel := context.WithCancel(context.Background()) @@ -123,10 +124,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) { // Very incomplete mock objects hostMap := NewHostMap(l, vpncidr, preferredRanges) cs := &CertState{ - rawCertificate: []byte{}, - privateKey: []byte{}, - certificate: &cert.NebulaCertificate{}, - rawCertificateNoKey: []byte{}, + RawCertificate: []byte{}, + PrivateKey: []byte{}, + Certificate: &cert.NebulaCertificate{}, + RawCertificateNoKey: []byte{}, } lh := newTestLighthouse() @@ -136,10 +137,11 @@ func Test_NewConnectionManagerTest2(t *testing.T) { outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, + pki: &PKI{}, handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), l: l, } - ifce.certState.Store(cs) + ifce.pki.cs.Store(cs) // Create manager ctx, cancel := context.WithCancel(context.Background()) @@ -242,10 +244,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { peerCert.Sign(cert.Curve_CURVE25519, privCA) cs := &CertState{ - rawCertificate: []byte{}, - privateKey: []byte{}, - certificate: &cert.NebulaCertificate{}, - rawCertificateNoKey: []byte{}, + RawCertificate: []byte{}, + PrivateKey: []byte{}, + Certificate: &cert.NebulaCertificate{}, + RawCertificateNoKey: []byte{}, } lh := newTestLighthouse() @@ -258,9 +260,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), l: l, disconnectInvalid: true, - caPool: ncp, + pki: &PKI{}, } - ifce.certState.Store(cs) + ifce.pki.cs.Store(cs) + ifce.pki.caPool.Store(ncp) // Create manager ctx, cancel := context.WithCancel(context.Background()) diff --git a/connection_state.go b/connection_state.go index ab818c97d..163e4bc74 100644 --- a/connection_state.go +++ b/connection_state.go @@ -30,15 +30,15 @@ type ConnectionState struct { func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { var dhFunc noise.DHFunc - curCertState := f.certState.Load() + curCertState := f.pki.GetCertState() - switch curCertState.certificate.Details.Curve { + switch curCertState.Certificate.Details.Curve { case cert.Curve_CURVE25519: dhFunc = noise.DH25519 case cert.Curve_P256: dhFunc = noiseutil.DHP256 default: - l.Errorf("invalid curve: %s", curCertState.certificate.Details.Curve) + l.Errorf("invalid curve: %s", curCertState.Certificate.Details.Curve) return nil } cs := noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) @@ -46,7 +46,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256) } - static := noise.DHKey{Private: curCertState.privateKey, Public: curCertState.publicKey} + static := noise.DHKey{Private: curCertState.PrivateKey, Public: curCertState.PublicKey} b := NewBits(ReplayWindow) // Clear out bit 0, we never transmit it and we don't want it showing as packet loss diff --git a/control_tester.go b/control_tester.go index dd1a77418..a26c8bb23 100644 --- a/control_tester.go +++ b/control_tester.go @@ -161,7 +161,7 @@ func (c *Control) GetHostmap() *HostMap { } func (c *Control) GetCert() *cert.NebulaCertificate { - return c.f.certState.Load().certificate + return c.f.pki.GetCertState().Certificate } func (c *Control) ReHandshake(vpnIp iputil.VpnIp) { diff --git a/handshake_ix.go b/handshake_ix.go index 70263b96a..52efdf5e6 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -33,7 +33,7 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) { hsProto := &NebulaHandshakeDetails{ InitiatorIndex: hostinfo.localIndexId, Time: uint64(time.Now().UnixNano()), - Cert: ci.certState.rawCertificateNoKey, + Cert: ci.certState.RawCertificateNoKey, } hsBytes := []byte{} @@ -91,7 +91,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by return } - remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool) + remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) if err != nil { f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). @@ -155,7 +155,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by Info("Handshake message received") hs.Details.ResponderIndex = myIndex - hs.Details.Cert = ci.certState.rawCertificateNoKey + hs.Details.Cert = ci.certState.RawCertificateNoKey // Update the time in case their clock is way off from ours hs.Details.Time = uint64(time.Now().UnixNano()) @@ -399,7 +399,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H return true } - remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool) + remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) if err != nil { f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). diff --git a/inside.go b/inside.go index 0d4392666..0fac833a6 100644 --- a/inside.go +++ b/inside.go @@ -69,7 +69,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet ci.queueLock.Unlock() } - dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache) + dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, packet, nb, out, q) @@ -183,7 +183,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp } // check if packet is in outbound fw rules - dropReason := f.firewall.Drop(p, *fp, false, hostinfo, f.caPool, nil) + dropReason := f.firewall.Drop(p, *fp, false, hostinfo, f.pki.GetCAPool(), nil) if dropReason != nil { if f.l.Level >= logrus.DebugLevel { f.l.WithField("fwPacket", fp). diff --git a/interface.go b/interface.go index 771aed0e6..fbf610a9b 100644 --- a/interface.go +++ b/interface.go @@ -13,7 +13,6 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -28,7 +27,7 @@ type InterfaceConfig struct { HostMap *HostMap Outside udp.Conn Inside overlay.Device - certState *CertState + pki *PKI Cipher string Firewall *Firewall ServeDns bool @@ -41,7 +40,6 @@ type InterfaceConfig struct { routines int MessageMetrics *MessageMetrics version string - caPool *cert.NebulaCAPool disconnectInvalid bool relayManager *relayManager punchy *Punchy @@ -58,7 +56,7 @@ type Interface struct { hostMap *HostMap outside udp.Conn inside overlay.Device - certState atomic.Pointer[CertState] + pki *PKI cipher string firewall *Firewall connectionManager *connectionManager @@ -71,7 +69,6 @@ type Interface struct { dropLocalBroadcast bool dropMulticast bool routines int - caPool *cert.NebulaCAPool disconnectInvalid bool closed atomic.Bool relayManager *relayManager @@ -152,15 +149,17 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { if c.Inside == nil { return nil, errors.New("no inside interface (tun)") } - if c.certState == nil { + if c.pki == nil { return nil, errors.New("no certificate state") } if c.Firewall == nil { return nil, errors.New("no firewall rules") } - myVpnIp := iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].IP) + certificate := c.pki.GetCertState().Certificate + myVpnIp := iputil.Ip2VpnIp(certificate.Details.Ips[0].IP) ifce := &Interface{ + pki: c.pki, hostMap: c.HostMap, outside: c.Outside, inside: c.Inside, @@ -170,14 +169,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { handshakeManager: c.HandshakeManager, createTime: time.Now(), lightHouse: c.lightHouse, - localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].Mask), + localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(certificate.Details.Ips[0].Mask), dropLocalBroadcast: c.DropLocalBroadcast, dropMulticast: c.DropMulticast, routines: c.routines, version: c.version, writers: make([]udp.Conn, c.routines), readers: make([]io.ReadWriteCloser, c.routines), - caPool: c.caPool, disconnectInvalid: c.disconnectInvalid, myVpnIp: myVpnIp, relayManager: c.relayManager, @@ -198,7 +196,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryWait.Store(int64(c.reQueryWait)) - ifce.certState.Store(c.certState) ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy) return ifce, nil @@ -295,8 +292,6 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { } func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { - c.RegisterReloadCallback(f.reloadCA) - c.RegisterReloadCallback(f.reloadCertKey) c.RegisterReloadCallback(f.reloadFirewall) c.RegisterReloadCallback(f.reloadSendRecvError) c.RegisterReloadCallback(f.reloadMisc) @@ -305,40 +300,6 @@ func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { } } -func (f *Interface) reloadCA(c *config.C) { - // reload and check regardless - // todo: need mutex? - newCAs, err := loadCAFromConfig(f.l, c) - if err != nil { - f.l.WithError(err).Error("Could not refresh trusted CA certificates") - return - } - - f.caPool = newCAs - f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed") -} - -func (f *Interface) reloadCertKey(c *config.C) { - // reload and check in all cases - cs, err := NewCertStateFromConfig(c) - if err != nil { - f.l.WithError(err).Error("Could not refresh client cert") - return - } - - // did IP in cert change? if so, don't set - currentCert := f.certState.Load().certificate - oldIPs := currentCert.Details.Ips - newIPs := cs.certificate.Details.Ips - if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() { - f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old") - return - } - - f.certState.Store(cs) - f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk") -} - func (f *Interface) reloadFirewall(c *config.C) { //TODO: need to trigger/detect if the certificate changed too if c.HasChanged("firewall") == false { @@ -346,7 +307,7 @@ func (f *Interface) reloadFirewall(c *config.C) { return } - fw, err := NewFirewallFromConfig(f.l, f.certState.Load().certificate, c) + fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c) if err != nil { f.l.WithError(err).Error("Error while creating firewall during reload") return @@ -438,7 +399,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { f.firewall.EmitStats() f.handshakeManager.EmitStats() udpStats() - certExpirationGauge.Update(int64(f.certState.Load().certificate.Details.NotAfter.Sub(time.Now()) / time.Second)) + certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.Details.NotAfter.Sub(time.Now()) / time.Second)) } } } diff --git a/lighthouse.go b/lighthouse.go index 6c46663c9..9b3b837e0 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -132,7 +132,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, c.RegisterReloadCallback(func(c *config.C) { err := h.reload(c, false) switch v := err.(type) { - case util.ContextualError: + case *util.ContextualError: v.Log(l) case error: l.WithError(err).Error("failed to reload lighthouse") diff --git a/main.go b/main.go index d050db226..22a5edab8 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,6 @@ package nebula import ( "context" "encoding/binary" - "errors" "fmt" "net" "time" @@ -56,28 +55,21 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } }) - caPool, err := loadCAFromConfig(l, c) + pki, err := NewPKIFromConfig(l, c) if err != nil { - //The errors coming out of loadCA are already nicely formatted - return nil, util.NewContextualError("Failed to load ca from config", nil, err) - } - l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") - - cs, err := NewCertStateFromConfig(c) - if err != nil { - //The errors coming out of NewCertStateFromConfig are already nicely formatted - return nil, util.NewContextualError("Failed to load certificate from config", nil, err) + //The errors coming out of NewPKIFromConfig are already nicely formatted + return nil, err } - l.WithField("cert", cs.certificate).Debug("Client nebula certificate") - fw, err := NewFirewallFromConfig(l, cs.certificate, c) + certificate := pki.GetCertState().Certificate + fw, err := NewFirewallFromConfig(l, certificate, c) if err != nil { return nil, util.NewContextualError("Error while loading firewall rules", nil, err) } l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started") // TODO: make sure mask is 4 bytes - tunCidr := cs.certificate.Details.Ips[0] + tunCidr := certificate.Details.Ips[0] ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) wireSSHReload(l, ssh, c) @@ -222,11 +214,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg punchy := NewPunchyFromConfig(l, c) lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy) - switch { - case errors.As(err, &util.ContextualError{}): - return nil, err - case err != nil: - return nil, util.NewContextualError("Failed to initialize lighthouse handler", nil, err) + if err != nil { + switch v := err.(type) { + case *util.ContextualError: + return nil, err + case error: + return nil, util.NewContextualError("Failed to initialize lighthouse handler", nil, v) + } } var messageMetrics *MessageMetrics @@ -266,7 +260,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg HostMap: hostMap, Inside: tun, Outside: udpConns[0], - certState: cs, + pki: pki, Cipher: c.GetString("cipher", "aes"), Firewall: fw, ServeDns: serveDns, @@ -282,7 +276,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg routines: routines, MessageMetrics: messageMetrics, version: buildVersion, - caPool: caPool, disconnectInvalid: c.GetBool("pki.disconnect_invalid", false), relayManager: NewRelayManager(ctx, l, hostMap, c), punchy: punchy, diff --git a/outside.go b/outside.go index 19a980bfa..a9dcdc860 100644 --- a/outside.go +++ b/outside.go @@ -404,7 +404,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return false } - dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.caPool, localCache) + dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) if dropReason != nil { f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, out, q) if f.l.Level >= logrus.DebugLevel { diff --git a/pki.go b/pki.go new file mode 100644 index 000000000..43a88bb2f --- /dev/null +++ b/pki.go @@ -0,0 +1,250 @@ +package nebula + +import ( + "errors" + "fmt" + "os" + "strings" + "sync/atomic" + "time" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" +) + +type PKI struct { + cs atomic.Pointer[CertState] + caPool atomic.Pointer[cert.NebulaCAPool] + l *logrus.Logger +} + +type CertState struct { + Certificate *cert.NebulaCertificate + RawCertificate []byte + RawCertificateNoKey []byte + PublicKey []byte + PrivateKey []byte +} + +func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { + pki := &PKI{l: l} + err := pki.reload(c, true) + if err != nil { + l.WithError(err).Error("WHAT THE HELL") + return nil, err + } + + c.RegisterReloadCallback(func(c *config.C) { + cErr := pki.reload(c, false) + if cErr != nil { + cErr.Log(l) + } + }) + + l.Error("SO NO ERROR THEN") + return pki, nil +} + +func (p *PKI) GetCertState() *CertState { + return p.cs.Load() +} + +func (p *PKI) GetCAPool() *cert.NebulaCAPool { + return p.caPool.Load() +} + +func (p *PKI) reload(c *config.C, initial bool) *util.ContextualError { + err := p.reloadCert(c, initial) + if err != nil { + if initial { + return err + } + err.Log(p.l) + } + + err = p.reloadCAPool(c) + if err != nil { + if initial { + return err + } + err.Log(p.l) + } + + return nil +} + +func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError { + cs, err := newCertStateFromConfig(c) + if err != nil { + return util.NewContextualError("Could not load client cert", nil, err) + } + + if !initial { + // did IP in cert change? if so, don't set + currentCert := p.cs.Load().Certificate + oldIPs := currentCert.Details.Ips + newIPs := cs.Certificate.Details.Ips + if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() { + return util.NewContextualError( + "IP in new cert was different from old", + m{"new_ip": newIPs[0], "old_ip": oldIPs[0]}, + nil, + ) + } + } + + p.cs.Store(cs) + if initial { + p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate") + } else { + p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk") + } + return nil +} + +func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError { + caPool, err := loadCAPoolFromConfig(p.l, c) + if err != nil { + return util.NewContextualError("Failed to load ca from config", nil, err) + } + + p.caPool.Store(caPool) + p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") + return nil +} + +func newCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) { + // Marshal the certificate to ensure it is valid + rawCertificate, err := certificate.Marshal() + if err != nil { + return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err) + } + + publicKey := certificate.Details.PublicKey + cs := &CertState{ + RawCertificate: rawCertificate, + Certificate: certificate, // PublicKey has been set to nil above + PrivateKey: privateKey, + PublicKey: publicKey, + } + + cs.Certificate.Details.PublicKey = nil + rawCertNoKey, err := cs.Certificate.Marshal() + if err != nil { + return nil, fmt.Errorf("error marshalling certificate no key: %s", err) + } + cs.RawCertificateNoKey = rawCertNoKey + // put public key back + cs.Certificate.Details.PublicKey = cs.PublicKey + return cs, nil +} + +func newCertStateFromConfig(c *config.C) (*CertState, error) { + var pemPrivateKey []byte + var err error + + privPathOrPEM := c.GetString("pki.key", "") + if privPathOrPEM == "" { + return nil, errors.New("no pki.key path or PEM data provided") + } + + if strings.Contains(privPathOrPEM, "-----BEGIN") { + pemPrivateKey = []byte(privPathOrPEM) + privPathOrPEM = "" + + } else { + pemPrivateKey, err = os.ReadFile(privPathOrPEM) + if err != nil { + return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) + } + } + + rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey) + if err != nil { + return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + } + + var rawCert []byte + + pubPathOrPEM := c.GetString("pki.cert", "") + if pubPathOrPEM == "" { + return nil, errors.New("no pki.cert path or PEM data provided") + } + + if strings.Contains(pubPathOrPEM, "-----BEGIN") { + rawCert = []byte(pubPathOrPEM) + pubPathOrPEM = "" + + } else { + rawCert, err = os.ReadFile(pubPathOrPEM) + if err != nil { + return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err) + } + } + + nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert) + if err != nil { + return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err) + } + + if nebulaCert.Expired(time.Now()) { + return nil, fmt.Errorf("nebula certificate for this host is expired") + } + + if len(nebulaCert.Details.Ips) == 0 { + return nil, fmt.Errorf("no IPs encoded in certificate") + } + + if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil { + return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") + } + + return newCertState(nebulaCert, rawKey) +} + +func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) { + var rawCA []byte + var err error + + caPathOrPEM := c.GetString("pki.ca", "") + if caPathOrPEM == "" { + return nil, errors.New("no pki.ca path or PEM data provided") + } + + if strings.Contains(caPathOrPEM, "-----BEGIN") { + rawCA = []byte(caPathOrPEM) + + } else { + rawCA, err = os.ReadFile(caPathOrPEM) + if err != nil { + return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err) + } + } + + caPool, err := cert.NewCAPoolFromBytes(rawCA) + if errors.Is(err, cert.ErrExpired) { + var expired int + for _, crt := range caPool.CAs { + if crt.Expired(time.Now()) { + expired++ + l.WithField("cert", crt).Warn("expired certificate present in CA pool") + } + } + + if expired >= len(caPool.CAs) { + return nil, errors.New("no valid CA certificates present") + } + + } else if err != nil { + return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err) + } + + for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) { + l.WithField("fingerprint", fp).Info("Blocklisting cert") + caPool.BlocklistFingerprint(fp) + } + + return caPool, nil +} diff --git a/ssh.go b/ssh.go index 0f624dbe5..44286c89a 100644 --- a/ssh.go +++ b/ssh.go @@ -754,7 +754,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit return nil } - cert := ifce.certState.Load().certificate + cert := ifce.pki.GetCertState().Certificate if len(a) > 0 { parsedIp := net.ParseIP(a[0]) if parsedIp == nil { diff --git a/util/error.go b/util/error.go index 7f9bc4792..53322d02b 100644 --- a/util/error.go +++ b/util/error.go @@ -12,18 +12,18 @@ type ContextualError struct { Context string } -func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError { - return ContextualError{Context: msg, Fields: fields, RealError: realError} +func NewContextualError(msg string, fields map[string]interface{}, realError error) *ContextualError { + return &ContextualError{Context: msg, Fields: fields, RealError: realError} } -func (ce ContextualError) Error() string { +func (ce *ContextualError) Error() string { if ce.RealError == nil { return ce.Context } return ce.RealError.Error() } -func (ce ContextualError) Unwrap() error { +func (ce *ContextualError) Unwrap() error { if ce.RealError == nil { return errors.New(ce.Context) }