diff --git a/.gitignore b/.gitignore index 0ba8193185..3b85ecbb26 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ ignored/ +tailscale/ # Binaries for programs and plugins *.exe diff --git a/.golangci.yaml b/.golangci.yaml index cf20c0268a..65a88511f9 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -10,6 +10,8 @@ issues: linters: enable-all: true disable: + - depguard + - exhaustivestruct - revive - lll diff --git a/Dockerfile b/Dockerfile index b1ec3317e9..0e6774d72b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Builder image -FROM docker.io/golang:1.20-bullseye AS build +FROM docker.io/golang:1.21-bookworm AS build ARG VERSION=dev ENV GOPATH /go WORKDIR /go/src/headscale @@ -14,7 +14,7 @@ RUN strip /go/bin/headscale RUN test -e /go/bin/headscale # Production image -FROM docker.io/debian:bullseye-slim +FROM docker.io/debian:bookworm-slim RUN apt-get update \ && apt-get install -y ca-certificates \ diff --git a/Dockerfile.debug b/Dockerfile.debug index 7cd609cfe2..eb8a79c470 100644 --- a/Dockerfile.debug +++ b/Dockerfile.debug @@ -18,6 +18,10 @@ FROM docker.io/golang:1.20.0-bullseye COPY --from=build /go/bin/headscale /bin/headscale ENV TZ UTC +RUN apt-get update \ + && apt-get install --no-install-recommends --yes less jq \ + && rm -rf /var/lib/apt/lists/* \ + && apt-get clean RUN mkdir -p /var/run/headscale # Need to reset the entrypoint or everything will run as a busybox script diff --git a/cmd/headscale/cli/root.go b/cmd/headscale/cli/root.go index f70945e691..6c788aa373 100644 --- a/cmd/headscale/cli/root.go +++ b/cmd/headscale/cli/root.go @@ -51,7 +51,7 @@ func initConfig() { cfg, err := types.GetHeadscaleConfig() if err != nil { - log.Fatal().Caller().Err(err) + log.Fatal().Caller().Err(err).Msg("Failed to get headscale configuration") } machineOutput := HasMachineOutputFlag() diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index baaf2094cf..a193d17dfb 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -154,17 +154,17 @@ func SuccessOutput(result interface{}, override string, outputFormat string) { case "json": jsonBytes, err = json.MarshalIndent(result, "", "\t") if err != nil { - log.Fatal().Err(err) + log.Fatal().Err(err).Msg("failed to unmarshal output") } case "json-line": jsonBytes, err = json.Marshal(result) if err != nil { - log.Fatal().Err(err) + log.Fatal().Err(err).Msg("failed to unmarshal output") } case "yaml": jsonBytes, err = yaml.Marshal(result) if err != nil { - log.Fatal().Err(err) + log.Fatal().Err(err).Msg("failed to unmarshal output") } default: //nolint diff --git a/flake.lock b/flake.lock index 505f2aecad..f9394e40eb 100644 --- a/flake.lock +++ b/flake.lock @@ -20,11 +20,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1693355128, - "narHash": "sha256-+ZoAny3ZxLcfMaUoLVgL9Ywb/57wP+EtsdNGuXUJrwg=", + "lastModified": 1693844670, + "narHash": "sha256-t69F2nBB8DNQUWHD809oJZJVE+23XBrth4QZuVd6IE0=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "a63a64b593dcf2fe05f7c5d666eb395950f36bc9", + "rev": "3c15feef7770eb5500a4b8792623e2d6f598c9c1", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 49df6ffa80..f4d2358f40 100644 --- a/flake.nix +++ b/flake.nix @@ -33,7 +33,7 @@ # When updating go.mod or go.sum, a new sha will need to be calculated, # update this if you have a mismatch after doing a change to thos files. - vendorSha256 = "sha256-9Hol8w8HB28AlulshMYYQwOgvGzR47qxzyPrB8G0XSQ="; + vendorSha256 = "sha256-dNE5wgR3oWXlYzPNXp0v/GGwY0/hvhOB5JWCb5EIbg8="; ldflags = ["-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}"]; }; diff --git a/go.mod b/go.mod index 21070015cf..2a37c5059f 100644 --- a/go.mod +++ b/go.mod @@ -61,7 +61,7 @@ require ( github.com/containerd/console v1.0.3 // indirect github.com/containerd/continuity v0.3.0 // indirect github.com/docker/cli v23.0.5+incompatible // indirect - github.com/docker/docker v23.0.5+incompatible // indirect + github.com/docker/docker v24.0.4+incompatible // indirect github.com/docker/go-connections v0.4.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect diff --git a/go.sum b/go.sum index e5aaa15cdd..15b7018c03 100644 --- a/go.sum +++ b/go.sum @@ -109,8 +109,8 @@ github.com/deckarep/golang-set/v2 v2.3.0 h1:qs18EKUfHm2X9fA50Mr/M5hccg2tNnVqsiBI github.com/deckarep/golang-set/v2 v2.3.0/go.mod h1:VAky9rY/yGXJOLEDv3OMci+7wtDpOF4IN+y82NBOac4= github.com/docker/cli v23.0.5+incompatible h1:ufWmAOuD3Vmr7JP2G5K3cyuNC4YZWiAsuDEvFVVDafE= github.com/docker/cli v23.0.5+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= -github.com/docker/docker v23.0.5+incompatible h1:DaxtlTJjFSnLOXVNUBU1+6kXGz2lpDoEAH6QoxaSg8k= -github.com/docker/docker v23.0.5+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/docker v24.0.4+incompatible h1:s/LVDftw9hjblvqIeTiGYXBCD95nOEEl7qRsRrIOuQI= +github.com/docker/docker v24.0.4+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKohAFqRJQ= github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= diff --git a/hscontrol/app.go b/hscontrol/app.go index c654d4a647..630339c136 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -8,9 +8,10 @@ import ( "io" "net" "net/http" + _ "net/http/pprof" //nolint "os" "os/signal" - "sort" + "runtime" "strconv" "strings" "sync" @@ -20,19 +21,19 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/mux" grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware" - "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + grpcRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/juanfont/headscale" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/derp" derpServer "github.com/juanfont/headscale/hscontrol/derp/server" + "github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/patrickmn/go-cache" zerolog "github.com/philip-bui/grpc-zerolog" "github.com/prometheus/client_golang/prometheus/promhttp" - "github.com/puzpuzpuz/xsync/v2" zl "github.com/rs/zerolog" "github.com/rs/zerolog/log" "golang.org/x/crypto/acme" @@ -84,7 +85,7 @@ type Headscale struct { ACLPolicy *policy.ACLPolicy - lastStateChange *xsync.MapOf[string, time.Time] + nodeNotifier *notifier.Notifier oidcProvider *oidc.Provider oauth2Config *oauth2.Config @@ -93,12 +94,13 @@ type Headscale struct { shutdownChan chan struct{} pollNetMapStreamWG sync.WaitGroup - - stateUpdateChan chan struct{} - cancelStateUpdateChan chan struct{} } func NewHeadscale(cfg *types.Config) (*Headscale, error) { + if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile { + runtime.SetBlockProfileRate(1) + } + privateKey, err := readOrCreatePrivateKey(cfg.PrivateKeyPath) if err != nil { return nil, fmt.Errorf("failed to read or create private key: %w", err) @@ -158,19 +160,14 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { noisePrivateKey: noisePrivateKey, registrationCache: registrationCache, pollNetMapStreamWG: sync.WaitGroup{}, - lastStateChange: xsync.NewMapOf[time.Time](), - - stateUpdateChan: make(chan struct{}), - cancelStateUpdateChan: make(chan struct{}), + nodeNotifier: notifier.NewNotifier(), } - go app.watchStateChannel() - database, err := db.NewHeadscaleDatabase( cfg.DBtype, dbString, app.dbDebug, - app.stateUpdateChan, + app.nodeNotifier, cfg.IPPrefixes, cfg.BaseDomain) if err != nil { @@ -203,7 +200,11 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { if cfg.DERP.ServerEnabled { // TODO(kradalby): replace this key with a dedicated DERP key. - embeddedDERPServer, err := derpServer.NewDERPServer(cfg.ServerURL, key.NodePrivate(*privateKey), &cfg.DERP) + embeddedDERPServer, err := derpServer.NewDERPServer( + cfg.ServerURL, + key.NodePrivate(*privateKey), + &cfg.DERP, + ) if err != nil { return nil, err } @@ -230,10 +231,14 @@ func (h *Headscale) expireEphemeralNodes(milliSeconds int64) { // expireExpiredMachines expires machines that have an explicit expiry set // after that expiry time has passed. -func (h *Headscale) expireExpiredMachines(milliSeconds int64) { - ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) +func (h *Headscale) expireExpiredMachines(intervalMs int64) { + interval := time.Duration(intervalMs) * time.Millisecond + ticker := time.NewTicker(interval) + + lastCheck := time.Unix(0, 0) + for range ticker.C { - h.db.ExpireExpiredMachines(h.getLastStateChange()) + lastCheck = h.db.ExpireExpiredMachines(lastCheck) } } @@ -258,7 +263,10 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) { h.DERPMap.Regions[region.RegionID] = ®ion } - h.setLastStateChangeToNow() + h.nodeNotifier.NotifyAll(types.StateUpdate{ + Type: types.StateDERPUpdated, + DERPMap: *h.DERPMap, + }) } } } @@ -433,8 +441,9 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error { return os.Remove(h.cfg.UnixSocket) } -func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *mux.Router { +func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { router := mux.NewRouter() + router.PathPrefix("/debug/pprof/").Handler(http.DefaultServeMux) router.HandleFunc(ts2021UpgradePath, h.NoiseUpgradeHandler).Methods(http.MethodPost) @@ -541,7 +550,7 @@ func (h *Headscale) Serve() error { return fmt.Errorf("failed change permission of gRPC socket: %w", err) } - grpcGatewayMux := runtime.NewServeMux() + grpcGatewayMux := grpcRuntime.NewServeMux() // Make the grpc-gateway connect to grpc over socket grpcGatewayConn, err := grpc.Dial( @@ -722,7 +731,9 @@ func (h *Headscale) Serve() error { Str("path", aclPath). Msg("ACL policy successfully reloaded, notifying nodes of change") - h.setLastStateChangeToNow() + h.nodeNotifier.NotifyAll(types.StateUpdate{ + Type: types.StateFullUpdate, + }) } default: @@ -760,10 +771,6 @@ func (h *Headscale) Serve() error { // Stop listening (and unlink the socket if unix type): socketListener.Close() - <-h.cancelStateUpdateChan - close(h.stateUpdateChan) - close(h.cancelStateUpdateChan) - // Close db connections err = h.db.Close() if err != nil { @@ -775,6 +782,8 @@ func (h *Headscale) Serve() error { // And we're done: cancel() + + return } } } @@ -859,73 +868,6 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { } } -// TODO(kradalby): baby steps, make this more robust. -func (h *Headscale) watchStateChannel() { - for { - select { - case <-h.stateUpdateChan: - h.setLastStateChangeToNow() - - case <-h.cancelStateUpdateChan: - return - } - } -} - -func (h *Headscale) setLastStateChangeToNow() { - var err error - - now := time.Now().UTC() - - users, err := h.db.ListUsers() - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("failed to fetch all users, failing to update last changed state.") - } - - for _, user := range users { - lastStateUpdate.WithLabelValues(user.Name, "headscale").Set(float64(now.Unix())) - if h.lastStateChange == nil { - h.lastStateChange = xsync.NewMapOf[time.Time]() - } - h.lastStateChange.Store(user.Name, now) - } -} - -func (h *Headscale) getLastStateChange(users ...types.User) time.Time { - times := []time.Time{} - - // getLastStateChange takes a list of users as a "filter", if no users - // are past, then use the entier list of users and look for the last update - if len(users) > 0 { - for _, user := range users { - if lastChange, ok := h.lastStateChange.Load(user.Name); ok { - times = append(times, lastChange) - } - } - } else { - h.lastStateChange.Range(func(key string, value time.Time) bool { - times = append(times, value) - - return true - }) - } - - sort.Slice(times, func(i, j int) bool { - return times[i].After(times[j]) - }) - - log.Trace().Msgf("Latest times %#v", times) - - if len(times) == 0 { - return time.Now().UTC() - } else { - return times[0] - } -} - func notFoundHandler( writer http.ResponseWriter, req *http.Request, diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 43dfd2b08e..e9121203c5 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -86,7 +86,8 @@ func (h *Headscale) handleRegister( Caller(). Str("func", "RegistrationHandler"). Str("hostinfo.name", registerRequest.Hostinfo.Hostname). - Err(err) + Err(err). + Msg("Failed to generate given name for node") return } @@ -309,7 +310,7 @@ func (h *Headscale) handleAuthKey( machine.NodeKey = nodeKey machine.AuthKeyID = uint(pak.ID) - err := h.db.RefreshMachine(machine, registerRequest.Expiry) + err := h.db.MachineSetExpiry(machine, registerRequest.Expiry) if err != nil { log.Error(). Caller(). @@ -348,7 +349,8 @@ func (h *Headscale) handleAuthKey( Bool("noise", isNoise). Str("func", "RegistrationHandler"). Str("hostinfo.name", registerRequest.Hostinfo.Hostname). - Err(err) + Err(err). + Msg("Failed to generate given name for node") return } @@ -510,7 +512,8 @@ func (h *Headscale) handleMachineLogOut( Str("machine", machine.Hostname). Msg("Client requested logout") - err := h.db.ExpireMachine(&machine) + now := time.Now() + err := h.db.MachineSetExpiry(&machine, now) if err != nil { log.Error(). Caller(). @@ -552,7 +555,7 @@ func (h *Headscale) handleMachineLogOut( } if machine.IsEphemeral() { - err = h.db.HardDeleteMachine(&machine) + err = h.db.DeleteMachine(&machine) if err != nil { log.Error(). Err(err). diff --git a/hscontrol/auth_noise.go b/hscontrol/auth_noise.go index 7f6a7fd560..54bea1c515 100644 --- a/hscontrol/auth_noise.go +++ b/hscontrol/auth_noise.go @@ -23,6 +23,7 @@ func (ns *noiseServer) NoiseRegistrationHandler( log.Trace(). Any("headers", req.Header). + Caller(). Msg("Headers") body, _ := io.ReadAll(req.Body) diff --git a/hscontrol/db/addresses_test.go b/hscontrol/db/addresses_test.go index 12891480bb..888dda369b 100644 --- a/hscontrol/db/addresses_test.go +++ b/hscontrol/db/addresses_test.go @@ -63,8 +63,6 @@ func (s *Suite) TestGetUsedIps(c *check.C) { c.Assert(len(machine1.IPAddresses), check.Equals, 1) c.Assert(machine1.IPAddresses[0], check.Equals, expected) - - c.Assert(channelUpdates, check.Equals, int32(0)) } func (s *Suite) TestGetMultiIp(c *check.C) { @@ -153,8 +151,6 @@ func (s *Suite) TestGetMultiIp(c *check.C) { c.Assert(len(nextIP2), check.Equals, 1) c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String()) - - c.Assert(channelUpdates, check.Equals, int32(0)) } func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { @@ -192,6 +188,4 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { c.Assert(len(ips2), check.Equals, 1) c.Assert(ips2[0].String(), check.Equals, expected.String()) - - c.Assert(channelUpdates, check.Equals, int32(0)) } diff --git a/hscontrol/db/api_key.go b/hscontrol/db/api_key.go index 4e4030ebfe..bc8dc2bbde 100644 --- a/hscontrol/db/api_key.go +++ b/hscontrol/db/api_key.go @@ -22,6 +22,9 @@ var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey") func (hsdb *HSDatabase) CreateAPIKey( expiration *time.Time, ) (string, *types.APIKey, error) { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength) if err != nil { return "", nil, err @@ -55,6 +58,9 @@ func (hsdb *HSDatabase) CreateAPIKey( // ListAPIKeys returns the list of ApiKeys for a user. func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + keys := []types.APIKey{} if err := hsdb.db.Find(&keys).Error; err != nil { return nil, err @@ -65,6 +71,9 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { // GetAPIKey returns a ApiKey for a given key. func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + key := types.APIKey{} if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil { return nil, result.Error @@ -75,6 +84,9 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { // GetAPIKeyByID returns a ApiKey for a given id. func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + key := types.APIKey{} if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil { return nil, result.Error @@ -86,6 +98,9 @@ func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { // DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey // does not exist. func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + if result := hsdb.db.Unscoped().Delete(key); result.Error != nil { return result.Error } @@ -95,6 +110,9 @@ func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { // ExpireAPIKey marks a ApiKey as expired. func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil { return err } @@ -103,6 +121,9 @@ func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { } func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + prefix, hash, found := strings.Cut(keyStr, ".") if !found { return false, ErrAPIKeyFailedToParse diff --git a/hscontrol/db/api_key_test.go b/hscontrol/db/api_key_test.go index 0fc42c5a50..c0b4e98845 100644 --- a/hscontrol/db/api_key_test.go +++ b/hscontrol/db/api_key_test.go @@ -22,8 +22,6 @@ func (*Suite) TestCreateAPIKey(c *check.C) { keys, err := db.ListAPIKeys() c.Assert(err, check.IsNil) c.Assert(len(keys), check.Equals, 1) - - c.Assert(channelUpdates, check.Equals, int32(0)) } func (*Suite) TestAPIKeyDoesNotExist(c *check.C) { @@ -41,8 +39,6 @@ func (*Suite) TestValidateAPIKeyOk(c *check.C) { valid, err := db.ValidateAPIKey(apiKeyStr) c.Assert(err, check.IsNil) c.Assert(valid, check.Equals, true) - - c.Assert(channelUpdates, check.Equals, int32(0)) } func (*Suite) TestValidateAPIKeyNotOk(c *check.C) { @@ -71,8 +67,6 @@ func (*Suite) TestValidateAPIKeyNotOk(c *check.C) { validWithErr, err := db.ValidateAPIKey("produceerrorkey") c.Assert(err, check.NotNil) c.Assert(validWithErr, check.Equals, false) - - c.Assert(channelUpdates, check.Equals, int32(0)) } func (*Suite) TestExpireAPIKey(c *check.C) { @@ -92,6 +86,4 @@ func (*Suite) TestExpireAPIKey(c *check.C) { notValid, err := db.ValidateAPIKey(apiKeyStr) c.Assert(err, check.IsNil) c.Assert(notValid, check.Equals, false) - - c.Assert(channelUpdates, check.Equals, int32(0)) } diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 5cff786878..b309f593cc 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -9,6 +9,7 @@ import ( "time" "github.com/glebarez/sqlite" + "github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -36,8 +37,10 @@ type KV struct { } type HSDatabase struct { - db *gorm.DB - notifyStateChan chan<- struct{} + db *gorm.DB + notifier *notifier.Notifier + + mu sync.RWMutex ipAllocationMutex sync.Mutex @@ -50,7 +53,7 @@ type HSDatabase struct { func NewHeadscaleDatabase( dbType, connectionAddr string, debug bool, - notifyStateChan chan<- struct{}, + notifier *notifier.Notifier, ipPrefixes []netip.Prefix, baseDomain string, ) (*HSDatabase, error) { @@ -60,8 +63,8 @@ func NewHeadscaleDatabase( } db := HSDatabase{ - db: dbConn, - notifyStateChan: notifyStateChan, + db: dbConn, + notifier: notifier, ipPrefixes: ipPrefixes, baseDomain: baseDomain, @@ -211,7 +214,7 @@ func NewHeadscaleDatabase( Msg("Failed to normalize machine hostname in DB migration") } - err = db.RenameMachine(&machines[item], normalizedHostname) + err = db.RenameMachine(machines[item], normalizedHostname) if err != nil { log.Error(). Caller(). @@ -297,10 +300,6 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { ) } -func (hsdb *HSDatabase) notifyStateChange() { - hsdb.notifyStateChan <- struct{}{} -} - // getValue returns the value for the given key in KV. func (hsdb *HSDatabase) getValue(key string) (string, error) { var row KV diff --git a/hscontrol/db/machine.go b/hscontrol/db/machine.go index f2139abbf1..c079f67779 100644 --- a/hscontrol/db/machine.go +++ b/hscontrol/db/machine.go @@ -13,6 +13,7 @@ import ( "github.com/patrickmn/go-cache" "github.com/rs/zerolog/log" "gorm.io/gorm" + "tailscale.com/tailcfg" "tailscale.com/types/key" ) @@ -35,6 +36,13 @@ var ( // ListPeers returns all peers of machine, regardless of any Policy or if the node is expired. func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.listPeers(machine) +} + +func (hsdb *HSDatabase) listPeers(machine *types.Machine) (types.Machines, error) { log.Trace(). Caller(). Str("machine", machine.Hostname). @@ -62,6 +70,13 @@ func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error } func (hsdb *HSDatabase) ListMachines() ([]types.Machine, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.listMachines() +} + +func (hsdb *HSDatabase) listMachines() ([]types.Machine, error) { machines := []types.Machine{} if err := hsdb.db. Preload("AuthKey"). @@ -76,6 +91,13 @@ func (hsdb *HSDatabase) ListMachines() ([]types.Machine, error) { } func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) (types.Machines, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.listMachinesByGivenName(givenName) +} + +func (hsdb *HSDatabase) listMachinesByGivenName(givenName string) (types.Machines, error) { machines := types.Machines{} if err := hsdb.db. Preload("AuthKey"). @@ -91,6 +113,9 @@ func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) (types.Machine // GetMachine finds a Machine by name and user and returns the Machine struct. func (hsdb *HSDatabase) GetMachine(user string, name string) (*types.Machine, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + machines, err := hsdb.ListMachinesByUser(user) if err != nil { return nil, err @@ -98,7 +123,7 @@ func (hsdb *HSDatabase) GetMachine(user string, name string) (*types.Machine, er for _, m := range machines { if m.Hostname == name { - return &m, nil + return m, nil } } @@ -110,15 +135,17 @@ func (hsdb *HSDatabase) GetMachineByGivenName( user string, givenName string, ) (*types.Machine, error) { - machines, err := hsdb.ListMachinesByUser(user) - if err != nil { - return nil, err - } + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() - for _, m := range machines { - if m.GivenName == givenName { - return &m, nil - } + machine := types.Machine{} + if err := hsdb.db. + Preload("AuthKey"). + Preload("AuthKey.User"). + Preload("User"). + Preload("Routes"). + Where("given_name = ?", givenName).First(&machine).Error; err != nil { + return nil, err } return nil, ErrMachineNotFound @@ -126,6 +153,9 @@ func (hsdb *HSDatabase) GetMachineByGivenName( // GetMachineByID finds a Machine by ID and returns the Machine struct. func (hsdb *HSDatabase) GetMachineByID(id uint64) (*types.Machine, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + mach := types.Machine{} if result := hsdb.db. Preload("AuthKey"). @@ -143,6 +173,9 @@ func (hsdb *HSDatabase) GetMachineByID(id uint64) (*types.Machine, error) { func (hsdb *HSDatabase) GetMachineByMachineKey( machineKey key.MachinePublic, ) (*types.Machine, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + mach := types.Machine{} if result := hsdb.db. Preload("AuthKey"). @@ -160,6 +193,9 @@ func (hsdb *HSDatabase) GetMachineByMachineKey( func (hsdb *HSDatabase) GetMachineByNodeKey( nodeKey key.NodePublic, ) (*types.Machine, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + machine := types.Machine{} if result := hsdb.db. Preload("AuthKey"). @@ -178,6 +214,9 @@ func (hsdb *HSDatabase) GetMachineByNodeKey( func (hsdb *HSDatabase) GetMachineByAnyKey( machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic, ) (*types.Machine, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + machine := types.Machine{} if result := hsdb.db. Preload("AuthKey"). @@ -194,10 +233,10 @@ func (hsdb *HSDatabase) GetMachineByAnyKey( return &machine, nil } -// TODO(kradalby): rename this, it sounds like a mix of getting and setting to db -// UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database -// and updates it with the latest data from the database. -func (hsdb *HSDatabase) UpdateMachineFromDatabase(machine *types.Machine) error { +func (hsdb *HSDatabase) MachineReloadFromDatabase(machine *types.Machine) error { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + if result := hsdb.db.Find(machine).First(&machine); result.Error != nil { return result.Error } @@ -210,33 +249,26 @@ func (hsdb *HSDatabase) SetTags( machine *types.Machine, tags []string, ) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + newTags := []string{} for _, tag := range tags { if !util.StringOrPrefixListContains(newTags, tag) { newTags = append(newTags, tag) } } - machine.ForcedTags = newTags - - hsdb.notifyStateChange() - if err := hsdb.db.Save(machine).Error; err != nil { + if err := hsdb.db.Model(machine).Updates(types.Machine{ + ForcedTags: newTags, + }).Error; err != nil { return fmt.Errorf("failed to update tags for machine in the database: %w", err) } - return nil -} - -// ExpireMachine takes a Machine struct and sets the expire field to now. -func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error { - now := time.Now() - machine.Expiry = &now - - hsdb.notifyStateChange() - - if err := hsdb.db.Save(machine).Error; err != nil { - return fmt.Errorf("failed to expire machine in the database: %w", err) - } + hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ + Type: types.StatePeerChanged, + Changed: types.Machines{machine}, + }, machine.MachineKey) return nil } @@ -244,6 +276,9 @@ func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error { // RenameMachine takes a Machine struct and a new GivenName for the machines // and renames it. func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + err := util.CheckForFQDNRules( newName, ) @@ -253,100 +288,90 @@ func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) er Str("func", "RenameMachine"). Str("machine", machine.Hostname). Str("newName", newName). - Err(err) + Err(err). + Msg("failed to rename machine") return err } machine.GivenName = newName - hsdb.notifyStateChange() - - if err := hsdb.db.Save(machine).Error; err != nil { + if err := hsdb.db.Model(machine).Updates(types.Machine{ + GivenName: newName, + }).Error; err != nil { return fmt.Errorf("failed to rename machine in the database: %w", err) } + hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ + Type: types.StatePeerChanged, + Changed: types.Machines{machine}, + }, machine.MachineKey) + return nil } -// RefreshMachine takes a Machine struct and a new expiry time. -func (hsdb *HSDatabase) RefreshMachine(machine *types.Machine, expiry time.Time) error { - now := time.Now() +// MachineSetExpiry takes a Machine struct and a new expiry time. +func (hsdb *HSDatabase) MachineSetExpiry(machine *types.Machine, expiry time.Time) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() - machine.LastSuccessfulUpdate = &now - machine.Expiry = &expiry - - hsdb.notifyStateChange() + return hsdb.machineSetExpiry(machine, expiry) +} - if err := hsdb.db.Save(machine).Error; err != nil { +func (hsdb *HSDatabase) machineSetExpiry(machine *types.Machine, expiry time.Time) error { + if err := hsdb.db.Model(machine).Updates(types.Machine{ + Expiry: &expiry, + }).Error; err != nil { return fmt.Errorf( "failed to refresh machine (update expiration) in the database: %w", err, ) } + hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ + Type: types.StatePeerChanged, + Changed: types.Machines{machine}, + }, machine.MachineKey) + return nil } -// DeleteMachine softs deletes a Machine from the database. +// DeleteMachine deletes a Machine from the database. func (hsdb *HSDatabase) DeleteMachine(machine *types.Machine) error { - err := hsdb.DeleteMachineRoutes(machine) - if err != nil { - return err - } + hsdb.mu.Lock() + defer hsdb.mu.Unlock() - if err := hsdb.db.Delete(&machine).Error; err != nil { - return err - } - - return nil + return hsdb.deleteMachine(machine) } -func (hsdb *HSDatabase) TouchMachine(machine *types.Machine) error { - return hsdb.db.Updates(types.Machine{ - ID: machine.ID, - LastSeen: machine.LastSeen, - LastSuccessfulUpdate: machine.LastSuccessfulUpdate, - }).Error -} - -// HardDeleteMachine hard deletes a Machine from the database. -func (hsdb *HSDatabase) HardDeleteMachine(machine *types.Machine) error { - err := hsdb.DeleteMachineRoutes(machine) +func (hsdb *HSDatabase) deleteMachine(machine *types.Machine) error { + err := hsdb.deleteMachineRoutes(machine) if err != nil { return err } + // Unscoped causes the machine to be fully removed from the database. if err := hsdb.db.Unscoped().Delete(&machine).Error; err != nil { return err } + hsdb.notifier.NotifyAll(types.StateUpdate{ + Type: types.StatePeerRemoved, + Removed: []tailcfg.NodeID{tailcfg.NodeID(machine.ID)}, + }) + return nil } -func (hsdb *HSDatabase) IsOutdated(machine *types.Machine, lastChange time.Time) bool { - if err := hsdb.UpdateMachineFromDatabase(machine); err != nil { - // It does not seem meaningful to propagate this error as the end result - // will have to be that the machine has to be considered outdated. - return true - } - - // Get the last update from all headscale users to compare with our nodes - // last update. - // TODO(kradalby): Only request updates from users where we can talk to nodes - // This would mostly be for a bit of performance, and can be calculated based on - // ACLs. - lastUpdate := machine.CreatedAt - if machine.LastSuccessfulUpdate != nil { - lastUpdate = *machine.LastSuccessfulUpdate - } - log.Trace(). - Caller(). - Str("machine", machine.Hostname). - Time("last_successful_update", lastChange). - Time("last_state_change", lastUpdate). - Msgf("Checking if %s is missing updates", machine.Hostname) - - return lastUpdate.Before(lastChange) +// UpdateLastSeen sets a machine's last seen field indicating that we +// have recently communicating with this machine. +// This is mostly used to indicate if a machine is online and is not +// extremely important to make sure is fully correct and to avoid +// holding up the hot path, does not contain any locks and isnt +// concurrency safe. But that should be ok. +func (hsdb *HSDatabase) UpdateLastSeen(machine *types.Machine) error { + return hsdb.db.Model(machine).Updates(types.Machine{ + LastSeen: machine.LastSeen, + }).Error } func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( @@ -356,6 +381,9 @@ func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( machineExpiry *time.Time, registrationMethod string, ) (*types.Machine, error) { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + nodeKey := key.NodePublic{} err := nodeKey.UnmarshalText([]byte(nodeKeyStr)) if err != nil { @@ -371,7 +399,7 @@ func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( if machineInterface, ok := cache.Get(util.NodePublicKeyStripPrefix(nodeKey)); ok { if registrationMachine, ok := machineInterface.(types.Machine); ok { - user, err := hsdb.GetUser(userName) + user, err := hsdb.getUser(userName) if err != nil { return nil, fmt.Errorf( "failed to find user in register machine from auth callback, %w", @@ -392,7 +420,7 @@ func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( registrationMachine.Expiry = machineExpiry } - machine, err := hsdb.RegisterMachine( + machine, err := hsdb.registerMachine( registrationMachine, ) @@ -410,8 +438,14 @@ func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( } // RegisterMachine is executed from the CLI to register a new Machine using its MachineKey. -func (hsdb *HSDatabase) RegisterMachine(machine types.Machine, -) (*types.Machine, error) { +func (hsdb *HSDatabase) RegisterMachine(machine types.Machine) (*types.Machine, error) { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + return hsdb.registerMachine(machine) +} + +func (hsdb *HSDatabase) registerMachine(machine types.Machine) (*types.Machine, error) { log.Debug(). Str("machine", machine.Hostname). Str("machine_key", machine.MachineKey). @@ -469,9 +503,12 @@ func (hsdb *HSDatabase) RegisterMachine(machine types.Machine, // MachineSetNodeKey sets the node key of a machine and saves it to the database. func (hsdb *HSDatabase) MachineSetNodeKey(machine *types.Machine, nodeKey key.NodePublic) error { - machine.NodeKey = util.NodePublicKeyStripPrefix(nodeKey) + hsdb.mu.Lock() + defer hsdb.mu.Unlock() - if err := hsdb.db.Save(machine).Error; err != nil { + if err := hsdb.db.Model(machine).Updates(types.Machine{ + NodeKey: util.NodePublicKeyStripPrefix(nodeKey), + }).Error; err != nil { return err } @@ -481,11 +518,14 @@ func (hsdb *HSDatabase) MachineSetNodeKey(machine *types.Machine, nodeKey key.No // MachineSetMachineKey sets the machine key of a machine and saves it to the database. func (hsdb *HSDatabase) MachineSetMachineKey( machine *types.Machine, - nodeKey key.MachinePublic, + machineKey key.MachinePublic, ) error { - machine.MachineKey = util.MachinePublicKeyStripPrefix(nodeKey) + hsdb.mu.Lock() + defer hsdb.mu.Unlock() - if err := hsdb.db.Save(machine).Error; err != nil { + if err := hsdb.db.Model(machine).Updates(types.Machine{ + MachineKey: util.MachinePublicKeyStripPrefix(machineKey), + }).Error; err != nil { return err } @@ -495,6 +535,9 @@ func (hsdb *HSDatabase) MachineSetMachineKey( // MachineSave saves a machine object to the database, prefer to use a specific save method rather // than this. It is intended to be used when we are changing or. func (hsdb *HSDatabase) MachineSave(machine *types.Machine) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + if err := hsdb.db.Save(machine).Error; err != nil { return err } @@ -504,6 +547,13 @@ func (hsdb *HSDatabase) MachineSave(machine *types.Machine) error { // GetAdvertisedRoutes returns the routes that are be advertised by the given machine. func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *types.Machine) ([]netip.Prefix, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.getAdvertisedRoutes(machine) +} + +func (hsdb *HSDatabase) getAdvertisedRoutes(machine *types.Machine) ([]netip.Prefix, error) { routes := types.Routes{} err := hsdb.db. @@ -529,6 +579,13 @@ func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *types.Machine) ([]netip.Pre // GetEnabledRoutes returns the routes that are enabled for the machine. func (hsdb *HSDatabase) GetEnabledRoutes(machine *types.Machine) ([]netip.Prefix, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.getEnabledRoutes(machine) +} + +func (hsdb *HSDatabase) getEnabledRoutes(machine *types.Machine) ([]netip.Prefix, error) { routes := types.Routes{} err := hsdb.db. @@ -554,12 +611,15 @@ func (hsdb *HSDatabase) GetEnabledRoutes(machine *types.Machine) ([]netip.Prefix } func (hsdb *HSDatabase) IsRoutesEnabled(machine *types.Machine, routeStr string) bool { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + route, err := netip.ParsePrefix(routeStr) if err != nil { return false } - enabledRoutes, err := hsdb.GetEnabledRoutes(machine) + enabledRoutes, err := hsdb.getEnabledRoutes(machine) if err != nil { log.Error().Err(err).Msg("Could not get enabled routes") @@ -575,6 +635,30 @@ func (hsdb *HSDatabase) IsRoutesEnabled(machine *types.Machine, routeStr string) return false } +func OnlineMachineMap(peers types.Machines) map[tailcfg.NodeID]bool { + ret := make(map[tailcfg.NodeID]bool) + + for _, peer := range peers { + ret[tailcfg.NodeID(peer.ID)] = peer.IsOnline() + } + + return ret +} + +func (hsdb *HSDatabase) ListOnlineMachines( + machine *types.Machine, +) (map[tailcfg.NodeID]bool, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + peers, err := hsdb.listPeers(machine) + if err != nil { + return nil, err + } + + return OnlineMachineMap(peers), nil +} + // enableRoutes enables new routes based on a list of new routes. func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string) error { newRoutes := make([]netip.Prefix, len(routeStrs)) @@ -587,7 +671,7 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string newRoutes[index] = route } - advertisedRoutes, err := hsdb.GetAdvertisedRoutes(machine) + advertisedRoutes, err := hsdb.getAdvertisedRoutes(machine) if err != nil { return err } @@ -626,12 +710,15 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string } } - hsdb.notifyStateChange() + hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ + Type: types.StatePeerChanged, + Changed: types.Machines{machine}, + }, machine.MachineKey) return nil } -func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool) (string, error) { +func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper( suppliedName, ) @@ -658,20 +745,23 @@ func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool } func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string) (string, error) { - givenName, err := hsdb.generateGivenName(suppliedName, false) + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + givenName, err := generateGivenName(suppliedName, false) if err != nil { return "", err } // Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/ - machines, err := hsdb.ListMachinesByGivenName(givenName) + machines, err := hsdb.listMachinesByGivenName(givenName) if err != nil { return "", err } for _, machine := range machines { if machine.MachineKey != machineKey && machine.GivenName == givenName { - postfixedName, err := hsdb.generateGivenName(suppliedName, true) + postfixedName, err := generateGivenName(suppliedName, true) if err != nil { return "", err } @@ -684,7 +774,10 @@ func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string } func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Duration) { - users, err := hsdb.ListUsers() + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + users, err := hsdb.listUsers() if err != nil { log.Error().Err(err).Msg("Error listing users") @@ -692,7 +785,7 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati } for _, user := range users { - machines, err := hsdb.ListMachinesByUser(user.Name) + machines, err := hsdb.listMachinesByUser(user.Name) if err != nil { log.Error(). Err(err). @@ -702,17 +795,18 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati return } - expiredFound := false + expired := make([]tailcfg.NodeID, 0) for idx, machine := range machines { if machine.IsEphemeral() && machine.LastSeen != nil && time.Now(). After(machine.LastSeen.Add(inactivityThreshhold)) { - expiredFound = true + expired = append(expired, tailcfg.NodeID(machine.ID)) + log.Info(). Str("machine", machine.Hostname). Msg("Ephemeral client removed from database") - err = hsdb.HardDeleteMachine(&machines[idx]) + err = hsdb.deleteMachine(machines[idx]) if err != nil { log.Error(). Err(err). @@ -722,38 +816,50 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati } } - if expiredFound { - hsdb.notifyStateChange() + if len(expired) > 0 { + hsdb.notifier.NotifyAll(types.StateUpdate{ + Type: types.StatePeerRemoved, + Removed: expired, + }) } } } -func (hsdb *HSDatabase) ExpireExpiredMachines(lastChange time.Time) { - users, err := hsdb.ListUsers() +func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + // use the time of the start of the function to ensure we + // dont miss some machines by returning it _after_ we have + // checked everything. + started := time.Now() + + users, err := hsdb.listUsers() if err != nil { log.Error().Err(err).Msg("Error listing users") - return + return time.Unix(0, 0) } for _, user := range users { - machines, err := hsdb.ListMachinesByUser(user.Name) + machines, err := hsdb.listMachinesByUser(user.Name) if err != nil { log.Error(). Err(err). Str("user", user.Name). Msg("Error listing machines in user") - return + return time.Unix(0, 0) } - expiredFound := false + expired := make([]tailcfg.NodeID, 0) for index, machine := range machines { if machine.IsExpired() && - machine.Expiry.After(lastChange) { - expiredFound = true + machine.Expiry.After(lastCheck) { + expired = append(expired, tailcfg.NodeID(machine.ID)) - err := hsdb.ExpireMachine(&machines[index]) + now := time.Now() + err := hsdb.machineSetExpiry(machines[index], now) if err != nil { log.Error(). Err(err). @@ -769,8 +875,13 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastChange time.Time) { } } - if expiredFound { - hsdb.notifyStateChange() + if len(expired) > 0 { + hsdb.notifier.NotifyAll(types.StateUpdate{ + Type: types.StatePeerRemoved, + Removed: expired, + }) } } + + return started } diff --git a/hscontrol/db/machine_test.go b/hscontrol/db/machine_test.go index 2786a0dcac..0115ee901e 100644 --- a/hscontrol/db/machine_test.go +++ b/hscontrol/db/machine_test.go @@ -39,8 +39,6 @@ func (s *Suite) TestGetMachine(c *check.C) { _, err = db.GetMachine("test", "testmachine") c.Assert(err, check.IsNil) - - c.Assert(channelUpdates, check.Equals, int32(0)) } func (s *Suite) TestGetMachineByID(c *check.C) { @@ -67,8 +65,6 @@ func (s *Suite) TestGetMachineByID(c *check.C) { _, err = db.GetMachineByID(0) c.Assert(err, check.IsNil) - - c.Assert(channelUpdates, check.Equals, int32(0)) } func (s *Suite) TestGetMachineByNodeKey(c *check.C) { @@ -98,8 +94,6 @@ func (s *Suite) TestGetMachineByNodeKey(c *check.C) { _, err = db.GetMachineByNodeKey(nodeKey.Public()) c.Assert(err, check.IsNil) - - c.Assert(channelUpdates, check.Equals, int32(0)) } func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) { @@ -131,32 +125,6 @@ func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) { _, err = db.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) c.Assert(err, check.IsNil) - - c.Assert(channelUpdates, check.Equals, int32(0)) -} - -func (s *Suite) TestDeleteMachine(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - machine := types.Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(1), - } - db.db.Save(&machine) - - err = db.DeleteMachine(&machine) - c.Assert(err, check.IsNil) - - _, err = db.GetMachine(user.Name, "testmachine") - c.Assert(err, check.NotNil) - - c.Assert(channelUpdates, check.Equals, int32(0)) } func (s *Suite) TestHardDeleteMachine(c *check.C) { @@ -174,13 +142,11 @@ func (s *Suite) TestHardDeleteMachine(c *check.C) { } db.db.Save(&machine) - err = db.HardDeleteMachine(&machine) + err = db.DeleteMachine(&machine) c.Assert(err, check.IsNil) _, err = db.GetMachine(user.Name, "testmachine3") c.Assert(err, check.NotNil) - - c.Assert(channelUpdates, check.Equals, int32(0)) } func (s *Suite) TestListPeers(c *check.C) { @@ -217,8 +183,6 @@ func (s *Suite) TestListPeers(c *check.C) { c.Assert(peersOfMachine0[0].Hostname, check.Equals, "testmachine2") c.Assert(peersOfMachine0[5].Hostname, check.Equals, "testmachine7") c.Assert(peersOfMachine0[8].Hostname, check.Equals, "testmachine10") - - c.Assert(channelUpdates, check.Equals, int32(0)) } func (s *Suite) TestGetACLFilteredPeers(c *check.C) { @@ -312,8 +276,6 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { c.Assert(peersOfAdminMachine[0].Hostname, check.Equals, "testmachine2") c.Assert(peersOfAdminMachine[2].Hostname, check.Equals, "testmachine4") c.Assert(peersOfAdminMachine[5].Hostname, check.Equals, "testmachine7") - - c.Assert(channelUpdates, check.Equals, int32(0)) } func (s *Suite) TestExpireMachine(c *check.C) { @@ -345,12 +307,11 @@ func (s *Suite) TestExpireMachine(c *check.C) { c.Assert(machineFromDB.IsExpired(), check.Equals, false) - err = db.ExpireMachine(machineFromDB) + now := time.Now() + err = db.MachineSetExpiry(machineFromDB, now) c.Assert(err, check.IsNil) c.Assert(machineFromDB.IsExpired(), check.Equals, true) - - c.Assert(channelUpdates, check.Equals, int32(1)) } func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) { @@ -372,8 +333,6 @@ func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) { for i := range deserialized { c.Assert(deserialized[i], check.Equals, input[i]) } - - c.Assert(channelUpdates, check.Equals, int32(0)) } func (s *Suite) TestGenerateGivenName(c *check.C) { @@ -418,8 +377,6 @@ func (s *Suite) TestGenerateGivenName(c *check.C) { comment = check.Commentf("Unique users, unique machines, same hostname, conflict") c.Assert(err, check.IsNil, comment) c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment) - - c.Assert(channelUpdates, check.Equals, int32(0)) } func (s *Suite) TestSetTags(c *check.C) { @@ -463,8 +420,6 @@ func (s *Suite) TestSetTags(c *check.C) { check.DeepEquals, types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}), ) - - c.Assert(channelUpdates, check.Equals, int32(2)) } func TestHeadscale_generateGivenName(t *testing.T) { @@ -474,14 +429,12 @@ func TestHeadscale_generateGivenName(t *testing.T) { } tests := []struct { name string - db *HSDatabase args args want *regexp.Regexp wantErr bool }{ { name: "simple machine name generation", - db: &HSDatabase{}, args: args{ suppliedName: "testmachine", randomSuffix: false, @@ -491,7 +444,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 53 chars", - db: &HSDatabase{}, args: args{ suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine", randomSuffix: false, @@ -501,7 +453,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 63 chars", - db: &HSDatabase{}, args: args{ suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", randomSuffix: false, @@ -511,7 +462,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 64 chars", - db: &HSDatabase{}, args: args{ suppliedName: "machineeee123456789012345678901234567890123456789012345678901234", randomSuffix: false, @@ -521,7 +471,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 73 chars", - db: &HSDatabase{}, args: args{ suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123", randomSuffix: false, @@ -531,7 +480,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with random suffix", - db: &HSDatabase{}, args: args{ suppliedName: "test", randomSuffix: true, @@ -541,7 +489,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 63 chars with random suffix", - db: &HSDatabase{}, args: args{ suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", randomSuffix: true, @@ -552,7 +499,7 @@ func TestHeadscale_generateGivenName(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.db.generateGivenName(tt.args.suppliedName, tt.args.randomSuffix) + got, err := generateGivenName(tt.args.suppliedName, tt.args.randomSuffix) if (err != nil) != tt.wantErr { t.Errorf( "Headscale.GenerateGivenName() error = %v, wantErr %v", @@ -643,7 +590,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { db.db.Save(&machine) - err = db.ProcessMachineRoutes(&machine) + err = db.SaveMachineRoutes(&machine) c.Assert(err, check.IsNil) machine0ByID, err := db.GetMachineByID(0) @@ -655,6 +602,4 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { enabledRoutes, err := db.GetEnabledRoutes(machine0ByID) c.Assert(err, check.IsNil) c.Assert(enabledRoutes, check.HasLen, 4) - - c.Assert(channelUpdates, check.Equals, int32(4)) } diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index abb79c34c2..ec7ab232f4 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -28,6 +28,10 @@ func (hsdb *HSDatabase) CreatePreAuthKey( expiration *time.Time, aclTags []string, ) (*types.PreAuthKey, error) { + // TODO(kradalby): figure out this lock + // hsdb.mu.Lock() + // defer hsdb.mu.Unlock() + user, err := hsdb.GetUser(userName) if err != nil { return nil, err @@ -92,7 +96,14 @@ func (hsdb *HSDatabase) CreatePreAuthKey( // ListPreAuthKeys returns the list of PreAuthKeys for a user. func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) { - user, err := hsdb.GetUser(userName) + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.listPreAuthKeys(userName) +} + +func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, error) { + user, err := hsdb.getUser(userName) if err != nil { return nil, err } @@ -107,6 +118,9 @@ func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, er // GetPreAuthKey returns a PreAuthKey for a given key. func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + pak, err := hsdb.ValidatePreAuthKey(key) if err != nil { return nil, err @@ -122,6 +136,13 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKe // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey // does not exist. func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + return hsdb.destroyPreAuthKey(pak) +} + +func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error { return hsdb.db.Transaction(func(db *gorm.DB) error { if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil { return result.Error @@ -137,6 +158,9 @@ func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error { // MarkExpirePreAuthKey marks a PreAuthKey as expired. func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { return err } @@ -146,6 +170,9 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { // UsePreAuthKey marks a PreAuthKey as used. func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + k.Used = true if err := hsdb.db.Save(k).Error; err != nil { return fmt.Errorf("failed to update key used status in the database: %w", err) @@ -157,6 +184,9 @@ func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error { // ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node // If returns no error and a PreAuthKey, it can be used. func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + pak := types.PreAuthKey{} if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( result.Error, @@ -174,7 +204,10 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) } machines := types.Machines{} - if err := hsdb.db.Preload("AuthKey").Where(&types.Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { + if err := hsdb.db. + Preload("AuthKey"). + Where(&types.Machine{AuthKeyID: uint(pak.ID)}). + Find(&machines).Error; err != nil { return nil, err } diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index e4a9773a03..247ebd206f 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -161,8 +161,6 @@ func (*Suite) TestEphemeralKey(c *check.C) { // The machine record should have been deleted _, err = db.GetMachine("test7", "testest") c.Assert(err, check.NotNil) - - c.Assert(channelUpdates, check.Equals, int32(1)) } func (*Suite) TestExpirePreauthKey(c *check.C) { diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index 74e0afe463..2d51d18311 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -13,6 +13,13 @@ import ( var ErrRouteIsNotAvailable = errors.New("route is not available") func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.getRoutes() +} + +func (hsdb *HSDatabase) getRoutes() (types.Routes, error) { var routes types.Routes err := hsdb.db.Preload("Machine").Find(&routes).Error if err != nil { @@ -23,6 +30,13 @@ func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) { } func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.getMachineAdvertisedRoutes(machine) +} + +func (hsdb *HSDatabase) getMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) { var routes types.Routes err := hsdb.db. Preload("Machine"). @@ -36,6 +50,13 @@ func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (type } func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.getMachineRoutes(m) +} + +func (hsdb *HSDatabase) getMachineRoutes(m *types.Machine) (types.Routes, error) { var routes types.Routes err := hsdb.db. Preload("Machine"). @@ -49,6 +70,13 @@ func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error) } func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.getRoute(id) +} + +func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) { var route types.Route err := hsdb.db.Preload("Machine").First(&route, id).Error if err != nil { @@ -59,7 +87,14 @@ func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { } func (hsdb *HSDatabase) EnableRoute(id uint64) error { - route, err := hsdb.GetRoute(id) + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + return hsdb.enableRoute(id) +} + +func (hsdb *HSDatabase) enableRoute(id uint64) error { + route, err := hsdb.getRoute(id) if err != nil { return err } @@ -79,7 +114,10 @@ func (hsdb *HSDatabase) EnableRoute(id uint64) error { } func (hsdb *HSDatabase) DisableRoute(id uint64) error { - route, err := hsdb.GetRoute(id) + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + route, err := hsdb.getRoute(id) if err != nil { return err } @@ -95,10 +133,10 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { return err } - return hsdb.HandlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } - routes, err := hsdb.GetMachineRoutes(&route.Machine) + routes, err := hsdb.getMachineRoutes(&route.Machine) if err != nil { return err } @@ -114,11 +152,14 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { } } - return hsdb.HandlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } func (hsdb *HSDatabase) DeleteRoute(id uint64) error { - route, err := hsdb.GetRoute(id) + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + route, err := hsdb.getRoute(id) if err != nil { return err } @@ -131,10 +172,10 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { return err } - return hsdb.HandlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } - routes, err := hsdb.GetMachineRoutes(&route.Machine) + routes, err := hsdb.getMachineRoutes(&route.Machine) if err != nil { return err } @@ -150,11 +191,11 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { return err } - return hsdb.HandlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } -func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error { - routes, err := hsdb.GetMachineRoutes(m) +func (hsdb *HSDatabase) deleteMachineRoutes(m *types.Machine) error { + routes, err := hsdb.getMachineRoutes(m) if err != nil { return err } @@ -165,7 +206,7 @@ func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error { } } - return hsdb.HandlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } // isUniquePrefix returns if there is another machine providing the same route already. @@ -200,11 +241,14 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, erro // getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) // Exit nodes are not considered for this, as they are never marked as Primary. -func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes, error) { +func (hsdb *HSDatabase) GetMachinePrimaryRoutes(machine *types.Machine) (types.Routes, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + var routes types.Routes err := hsdb.db. Preload("Machine"). - Where("machine_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", m.ID, true, true, true). + Where("machine_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", machine.ID, true, true, true). Find(&routes).Error if err != nil { return nil, err @@ -213,7 +257,16 @@ func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes, return routes, nil } -func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error { +// SaveMachineRoutes takes a machine and updates the database with +// the new routes. +func (hsdb *HSDatabase) SaveMachineRoutes(machine *types.Machine) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + return hsdb.saveMachineRoutes(machine) +} + +func (hsdb *HSDatabase) saveMachineRoutes(machine *types.Machine) error { currentRoutes := types.Routes{} err := hsdb.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error if err != nil { @@ -225,6 +278,12 @@ func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error { advertisedRoutes[prefix] = false } + log.Trace(). + Str("machine", machine.Hostname). + Interface("advertisedRoutes", advertisedRoutes). + Interface("currentRoutes", currentRoutes). + Msg("updating routes") + for pos, route := range currentRoutes { if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok { if !route.Advertised { @@ -264,6 +323,13 @@ func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error { } func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + return hsdb.handlePrimarySubnetFailover() +} + +func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { // first, get all the enabled routes var routes types.Routes err := hsdb.db. @@ -274,12 +340,14 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { log.Error().Err(err).Msg("error getting routes") } - routesChanged := false + changedMachines := make(types.Machines, 0) for pos, route := range routes { if route.IsExitRoute() { continue } + machine := &route.Machine + if !route.IsPrimary { _, err := hsdb.getPrimaryRoute(netip.Prefix(route.Prefix)) if hsdb.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) { @@ -295,7 +363,7 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { return err } - routesChanged = true + changedMachines = append(changedMachines, machine) continue } @@ -369,12 +437,15 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { return err } - routesChanged = true + changedMachines = append(changedMachines, machine) } } - if routesChanged { - hsdb.notifyStateChange() + if len(changedMachines) > 0 { + hsdb.notifier.NotifyAll(types.StateUpdate{ + Type: types.StatePeerChanged, + Changed: changedMachines, + }) } return nil @@ -385,11 +456,14 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( aclPolicy *policy.ACLPolicy, machine *types.Machine, ) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + if len(machine.IPAddresses) == 0 { return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs } - routes, err := hsdb.GetMachineAdvertisedRoutes(machine) + routes, err := hsdb.getMachineAdvertisedRoutes(machine) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { log.Error(). Caller(). @@ -424,7 +498,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( approvedRoutes = append(approvedRoutes, advertisedRoute) } else { // TODO(kradalby): figure out how to get this to depend on less stuff - approvedIps, err := aclPolicy.ExpandAlias(types.Machines{*machine}, approvedAlias) + approvedIps, err := aclPolicy.ExpandAlias(types.Machines{machine}, approvedAlias) if err != nil { log.Err(err). Str("alias", approvedAlias). @@ -442,7 +516,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( } for _, approvedRoute := range approvedRoutes { - err := hsdb.EnableRoute(uint64(approvedRoute.ID)) + err := hsdb.enableRoute(uint64(approvedRoute.ID)) if err != nil { log.Err(err). Str("approvedRoute", approvedRoute.String()). diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index 4e91a2cb09..4698be28fa 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -40,7 +40,7 @@ func (s *Suite) TestGetRoutes(c *check.C) { } db.db.Save(&machine) - err = db.ProcessMachineRoutes(&machine) + err = db.SaveMachineRoutes(&machine) c.Assert(err, check.IsNil) advertisedRoutes, err := db.GetAdvertisedRoutes(&machine) @@ -52,8 +52,6 @@ func (s *Suite) TestGetRoutes(c *check.C) { err = db.enableRoutes(&machine, "10.0.0.0/24") c.Assert(err, check.IsNil) - - c.Assert(channelUpdates, check.Equals, int32(0)) } func (s *Suite) TestGetEnableRoutes(c *check.C) { @@ -93,7 +91,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { } db.db.Save(&machine) - err = db.ProcessMachineRoutes(&machine) + err = db.SaveMachineRoutes(&machine) c.Assert(err, check.IsNil) availableRoutes, err := db.GetAdvertisedRoutes(&machine) @@ -129,8 +127,6 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2) - - c.Assert(channelUpdates, check.Equals, int32(3)) } func (s *Suite) TestIsUniquePrefix(c *check.C) { @@ -169,7 +165,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { } db.db.Save(&machine1) - err = db.ProcessMachineRoutes(&machine1) + err = db.SaveMachineRoutes(&machine1) c.Assert(err, check.IsNil) err = db.enableRoutes(&machine1, route.String()) @@ -194,7 +190,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { } db.db.Save(&machine2) - err = db.ProcessMachineRoutes(&machine2) + err = db.SaveMachineRoutes(&machine2) c.Assert(err, check.IsNil) err = db.enableRoutes(&machine2, route2.String()) @@ -215,8 +211,6 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { routes, err = db.GetMachinePrimaryRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 0) - - c.Assert(channelUpdates, check.Equals, int32(3)) } func (s *Suite) TestSubnetFailover(c *check.C) { @@ -258,7 +252,7 @@ func (s *Suite) TestSubnetFailover(c *check.C) { } db.db.Save(&machine1) - err = db.ProcessMachineRoutes(&machine1) + err = db.SaveMachineRoutes(&machine1) c.Assert(err, check.IsNil) err = db.enableRoutes(&machine1, prefix.String()) @@ -295,7 +289,7 @@ func (s *Suite) TestSubnetFailover(c *check.C) { } db.db.Save(&machine2) - err = db.ProcessMachineRoutes(&machine2) + err = db.SaveMachineRoutes(&machine2) c.Assert(err, check.IsNil) err = db.enableRoutes(&machine2, prefix2.String()) @@ -343,7 +337,7 @@ func (s *Suite) TestSubnetFailover(c *check.C) { err = db.db.Save(&machine2).Error c.Assert(err, check.IsNil) - err = db.ProcessMachineRoutes(&machine2) + err = db.SaveMachineRoutes(&machine2) c.Assert(err, check.IsNil) err = db.enableRoutes(&machine2, prefix.String()) @@ -359,8 +353,6 @@ func (s *Suite) TestSubnetFailover(c *check.C) { routes, err = db.GetMachinePrimaryRoutes(&machine2) c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 2) - - c.Assert(channelUpdates, check.Equals, int32(6)) } func (s *Suite) TestDeleteRoutes(c *check.C) { @@ -402,7 +394,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { } db.db.Save(&machine1) - err = db.ProcessMachineRoutes(&machine1) + err = db.SaveMachineRoutes(&machine1) c.Assert(err, check.IsNil) err = db.enableRoutes(&machine1, prefix.String()) @@ -420,6 +412,4 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { enabledRoutes1, err := db.GetEnabledRoutes(&machine1) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 1) - - c.Assert(channelUpdates, check.Equals, int32(2)) } diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index 495a936399..ff1a095ce8 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -3,9 +3,9 @@ package db import ( "net/netip" "os" - "sync/atomic" "testing" + "github.com/juanfont/headscale/hscontrol/notifier" "gopkg.in/check.v1" ) @@ -20,14 +20,9 @@ type Suite struct{} var ( tmpDir string db *HSDatabase - - // channelUpdates counts the number of times - // either of the channels was notified. - channelUpdates int32 ) func (s *Suite) SetUpTest(c *check.C) { - atomic.StoreInt32(&channelUpdates, 0) s.ResetDB(c) } @@ -35,13 +30,6 @@ func (s *Suite) TearDownTest(c *check.C) { os.RemoveAll(tmpDir) } -func notificationSink(c <-chan struct{}) { - for { - <-c - atomic.AddInt32(&channelUpdates, 1) - } -} - func (s *Suite) ResetDB(c *check.C) { if len(tmpDir) != 0 { os.RemoveAll(tmpDir) @@ -52,15 +40,11 @@ func (s *Suite) ResetDB(c *check.C) { c.Fatal(err) } - sink := make(chan struct{}) - - go notificationSink(sink) - db, err = NewHeadscaleDatabase( "sqlite3", tmpDir+"/headscale_test.db", false, - sink, + notifier.NewNotifier(), []netip.Prefix{ netip.MustParsePrefix("10.27.0.0/23"), }, diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index ce186751a1..5af4660b7c 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -18,6 +18,9 @@ var ( // CreateUser creates a new User. Returns error if could not be created // or another user already exists. func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + err := util.CheckForFQDNRules(name) if err != nil { return nil, err @@ -42,12 +45,15 @@ func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { // DestroyUser destroys a User. Returns error if the User does // not exist or if there are machines associated with it. func (hsdb *HSDatabase) DestroyUser(name string) error { - user, err := hsdb.GetUser(name) + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + user, err := hsdb.getUser(name) if err != nil { return ErrUserNotFound } - machines, err := hsdb.ListMachinesByUser(name) + machines, err := hsdb.listMachinesByUser(name) if err != nil { return err } @@ -55,12 +61,12 @@ func (hsdb *HSDatabase) DestroyUser(name string) error { return ErrUserStillHasNodes } - keys, err := hsdb.ListPreAuthKeys(name) + keys, err := hsdb.listPreAuthKeys(name) if err != nil { return err } for _, key := range keys { - err = hsdb.DestroyPreAuthKey(key) + err = hsdb.destroyPreAuthKey(key) if err != nil { return err } @@ -76,8 +82,11 @@ func (hsdb *HSDatabase) DestroyUser(name string) error { // RenameUser renames a User. Returns error if the User does // not exist or if another User exists with the new name. func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + var err error - oldUser, err := hsdb.GetUser(oldName) + oldUser, err := hsdb.getUser(oldName) if err != nil { return err } @@ -85,7 +94,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { if err != nil { return err } - _, err = hsdb.GetUser(newName) + _, err = hsdb.getUser(newName) if err == nil { return ErrUserExists } @@ -104,6 +113,13 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { // GetUser fetches a user by name. func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.getUser(name) +} + +func (hsdb *HSDatabase) getUser(name string) (*types.User, error) { user := types.User{} if result := hsdb.db.First(&user, "name = ?", name); errors.Is( result.Error, @@ -117,6 +133,13 @@ func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) { // ListUsers gets all the existing users. func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.listUsers() +} + +func (hsdb *HSDatabase) listUsers() ([]types.User, error) { users := []types.User{} if err := hsdb.db.Find(&users).Error; err != nil { return nil, err @@ -127,11 +150,18 @@ func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { // ListMachinesByUser gets all the nodes in a given user. func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.listMachinesByUser(name) +} + +func (hsdb *HSDatabase) listMachinesByUser(name string) (types.Machines, error) { err := util.CheckForFQDNRules(name) if err != nil { return nil, err } - user, err := hsdb.GetUser(name) + user, err := hsdb.getUser(name) if err != nil { return nil, err } @@ -144,13 +174,16 @@ func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error) return machines, nil } -// SetMachineUser assigns a Machine to a user. -func (hsdb *HSDatabase) SetMachineUser(machine *types.Machine, username string) error { +// AssignMachineToUser assigns a Machine to a user. +func (hsdb *HSDatabase) AssignMachineToUser(machine *types.Machine, username string) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + err := util.CheckForFQDNRules(username) if err != nil { return err } - user, err := hsdb.GetUser(username) + user, err := hsdb.getUser(username) if err != nil { return err } diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index bc468b2361..97b3e6d7f6 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -114,15 +114,15 @@ func (s *Suite) TestSetMachineUser(c *check.C) { db.db.Save(&machine) c.Assert(machine.UserID, check.Equals, oldUser.ID) - err = db.SetMachineUser(&machine, newUser.Name) + err = db.AssignMachineToUser(&machine, newUser.Name) c.Assert(err, check.IsNil) c.Assert(machine.UserID, check.Equals, newUser.ID) c.Assert(machine.User.Name, check.Equals, newUser.Name) - err = db.SetMachineUser(&machine, "non-existing-user") + err = db.AssignMachineToUser(&machine, "non-existing-user") c.Assert(err, check.Equals, ErrUserNotFound) - err = db.SetMachineUser(&machine, newUser.Name) + err = db.AssignMachineToUser(&machine, newUser.Name) c.Assert(err, check.IsNil) c.Assert(machine.UserID, check.Equals, newUser.ID) c.Assert(machine.User.Name, check.Equals, newUser.Name) diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index cd25d49138..d59966b68c 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -39,7 +39,7 @@ func NewDERPServer( cfg *types.DERPConfig, ) (*DERPServer, error) { log.Trace().Caller().Msg("Creating new embedded DERP server") - server := derp.NewServer(derpKey, log.Debug().Msgf) + server := derp.NewServer(derpKey, log.Debug().Msgf) // nolint // zerolinter complains return &DERPServer{ serverURL: serverURL, diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 74950c207f..d516ab9448 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -275,8 +275,11 @@ func (api headscaleV1APIServer) ExpireMachine( return nil, err } - api.h.db.ExpireMachine( + now := time.Now() + + api.h.db.MachineSetExpiry( machine, + now, ) log.Trace(). @@ -339,7 +342,7 @@ func (api headscaleV1APIServer) ListMachines( for index, machine := range machines { m := machine.Proto() validTags, invalidTags := api.h.ACLPolicy.TagsOfMachine( - machine, + &machine, ) m.InvalidTags = invalidTags m.ValidTags = validTags @@ -358,7 +361,7 @@ func (api headscaleV1APIServer) MoveMachine( return nil, err } - err = api.h.db.SetMachineUser(machine, request.GetUser()) + err = api.h.db.AssignMachineToUser(machine, request.GetUser()) if err != nil { return nil, err } @@ -529,9 +532,8 @@ func (api headscaleV1APIServer) DebugCreateMachine( GivenName: givenName, User: *user, - Expiry: &time.Time{}, - LastSeen: &time.Time{}, - LastSuccessfulUpdate: &time.Time{}, + Expiry: &time.Time{}, + LastSeen: &time.Time{}, HostInfo: types.HostInfo(hostinfo), } diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 5d5509b9ac..0d34ebe162 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -4,9 +4,14 @@ import ( "encoding/binary" "encoding/json" "fmt" + "io/fs" "net/url" + "os" + "path" + "sort" "strings" "sync" + "sync/atomic" "time" mapset "github.com/deckarep/golang-set/v2" @@ -17,6 +22,7 @@ import ( "github.com/klauspost/compress/zstd" "github.com/rs/zerolog/log" "github.com/samber/lo" + "tailscale.com/envknob" "tailscale.com/smallzstd" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" @@ -26,11 +32,23 @@ import ( const ( nextDNSDoHPrefix = "https://dns.nextdns.io" reservedResponseHeaderSize = 4 + mapperIDLength = 8 + debugMapResponsePerm = 0o755 ) -type Mapper struct { - db *db.HSDatabase +var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH") + +// TODO: Optimise +// As this work continues, the idea is that there will be one Mapper instance +// per node, attached to the open stream between the control and client. +// This means that this can hold a state per machine and we can use that to +// improve the mapresponses sent. +// We could: +// - Keep information about the previous mapresponse so we can send a diff +// - Store hashes +// - Create a "minifier" that removes info not needed for the node +type Mapper struct { privateKey2019 *key.MachinePrivate isNoise bool @@ -41,10 +59,20 @@ type Mapper struct { dnsCfg *tailcfg.DNSConfig logtail bool randomClientPort bool + + uid string + created time.Time + seq uint64 + + // Map isnt concurrency safe, so we need to ensure + // only one func is accessing it over time. + mu sync.Mutex + peers map[uint64]*types.Machine } func NewMapper( - db *db.HSDatabase, + machine *types.Machine, + peers types.Machines, privateKey *key.MachinePrivate, isNoise bool, derpMap *tailcfg.DERPMap, @@ -53,9 +81,15 @@ func NewMapper( logtail bool, randomClientPort bool, ) *Mapper { - return &Mapper{ - db: db, + log.Debug(). + Caller(). + Bool("noise", isNoise). + Str("machine", machine.Hostname). + Msg("creating new mapper") + + uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength) + return &Mapper{ privateKey2019: privateKey, isNoise: isNoise, @@ -64,118 +98,18 @@ func NewMapper( dnsCfg: dnsCfg, logtail: logtail, randomClientPort: randomClientPort, - } -} - -// TODO: Optimise -// As this work continues, the idea is that there will be one Mapper instance -// per node, attached to the open stream between the control and client. -// This means that this can hold a state per machine and we can use that to -// improve the mapresponses sent. -// We could: -// - Keep information about the previous mapresponse so we can send a diff -// - Store hashes -// - Create a "minifier" that removes info not needed for the node - -// fullMapResponse is the internal function for generating a MapResponse -// for a machine. -func fullMapResponse( - pol *policy.ACLPolicy, - machine *types.Machine, - peers types.Machines, - - baseDomain string, - dnsCfg *tailcfg.DNSConfig, - derpMap *tailcfg.DERPMap, - logtail bool, - randomClientPort bool, -) (*tailcfg.MapResponse, error) { - tailnode, err := tailNode(*machine, pol, dnsCfg, baseDomain) - if err != nil { - return nil, err - } - - rules, sshPolicy, err := policy.GenerateFilterAndSSHRules( - pol, - machine, - peers, - ) - if err != nil { - return nil, err - } - - // Filter out peers that have expired. - peers = lo.Filter(peers, func(item types.Machine, index int) bool { - return !item.IsExpired() - }) - - // If there are filter rules present, see if there are any machines that cannot - // access eachother at all and remove them from the peers. - if len(rules) > 0 { - peers = policy.FilterMachinesByACL(machine, peers, rules) - } - - profiles := generateUserProfiles(machine, peers, baseDomain) - - dnsConfig := generateDNSConfig( - dnsCfg, - baseDomain, - *machine, - peers, - ) - - tailPeers, err := tailNodes(peers, pol, dnsCfg, baseDomain) - if err != nil { - return nil, err - } - - now := time.Now() - - resp := tailcfg.MapResponse{ - KeepAlive: false, - Node: tailnode, - - // TODO: Only send if updated - DERPMap: derpMap, - // TODO: Only send if updated - Peers: tailPeers, + uid: uid, + created: time.Now(), + seq: 0, - // TODO(kradalby): Implement: - // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L1351-L1374 - // PeersChanged - // PeersRemoved - // PeersChangedPatch - // PeerSeenChange - // OnlineChange - - // TODO: Only send if updated - DNSConfig: dnsConfig, - - // TODO: Only send if updated - Domain: baseDomain, - - // Do not instruct clients to collect services, we do not - // support or do anything with them - CollectServices: "false", - - // TODO: Only send if updated - PacketFilter: policy.ReduceFilterRules(machine, rules), - - UserProfiles: profiles, - - // TODO: Only send if updated - SSHPolicy: sshPolicy, - - ControlTime: &now, - - Debug: &tailcfg.Debug{ - DisableLogTail: !logtail, - RandomizeClientPort: randomClientPort, - }, + // TODO: populate + peers: peers.IDMap(), } +} - return &resp, nil +func (m *Mapper) String() string { + return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created) } func generateUserProfiles( @@ -211,7 +145,7 @@ func generateUserProfiles( func generateDNSConfig( base *tailcfg.DNSConfig, baseDomain string, - machine types.Machine, + machine *types.Machine, peers types.Machines, ) *tailcfg.DNSConfig { dnsConfig := base.Clone() @@ -254,7 +188,7 @@ func generateDNSConfig( // // This will produce a resolver like: // `https://dns.nextdns.io/?device_name=node-name&device_model=linux&device_ip=100.64.0.1` -func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) { +func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine *types.Machine) { for _, resolver := range resolvers { if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) { attrs := url.Values{ @@ -271,115 +205,177 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) { } } -// CreateMapResponse returns a MapResponse for the given machine. -func (m Mapper) CreateMapResponse( - mapRequest tailcfg.MapRequest, +// fullMapResponse creates a complete MapResponse for a node. +// It is a separate function to make testing easier. +func (m *Mapper) fullMapResponse( machine *types.Machine, pol *policy.ACLPolicy, -) ([]byte, error) { - peers, err := m.db.ListPeers(machine) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot fetch peers") +) (*tailcfg.MapResponse, error) { + peers := machineMapToList(m.peers) + resp, err := m.baseWithConfigMapResponse(machine, pol) + if err != nil { return nil, err } - mapResponse, err := fullMapResponse( + // TODO(kradalby): Move this into appendPeerChanges? + resp.OnlineChange = db.OnlineMachineMap(peers) + + err = appendPeerChanges( + resp, pol, machine, peers, + peers, m.baseDomain, m.dnsCfg, - m.derpMap, - m.logtail, - m.randomClientPort, ) if err != nil { return nil, err } - if m.isNoise { - return m.marshalMapResponse(mapResponse, key.MachinePublic{}, mapRequest.Compress) - } + return resp, nil +} - var machineKey key.MachinePublic - err = machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey))) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot parse client key") +// FullMapResponse returns a MapResponse for the given machine. +func (m *Mapper) FullMapResponse( + mapRequest tailcfg.MapRequest, + machine *types.Machine, + pol *policy.ACLPolicy, +) ([]byte, error) { + m.mu.Lock() + defer m.mu.Unlock() + resp, err := m.fullMapResponse(machine, pol) + if err != nil { return nil, err } - return m.marshalMapResponse(mapResponse, machineKey, mapRequest.Compress) + if m.isNoise { + return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress) + } + + return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress) } -func (m Mapper) CreateKeepAliveResponse( +// LiteMapResponse returns a MapResponse for the given machine. +// Lite means that the peers has been omitted, this is intended +// to be used to answer MapRequests with OmitPeers set to true. +func (m *Mapper) LiteMapResponse( mapRequest tailcfg.MapRequest, machine *types.Machine, + pol *policy.ACLPolicy, ) ([]byte, error) { - keepAliveResponse := tailcfg.MapResponse{ - KeepAlive: true, + resp, err := m.baseWithConfigMapResponse(machine, pol) + if err != nil { + return nil, err } if m.isNoise { - return m.marshalMapResponse( - keepAliveResponse, - key.MachinePublic{}, - mapRequest.Compress, - ) + return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress) } - var machineKey key.MachinePublic - err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey))) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot parse client key") + return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress) +} - return nil, err - } +func (m *Mapper) KeepAliveResponse( + mapRequest tailcfg.MapRequest, + machine *types.Machine, +) ([]byte, error) { + resp := m.baseMapResponse() + resp.KeepAlive = true - return m.marshalMapResponse(keepAliveResponse, machineKey, mapRequest.Compress) + return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress) } -// MarshalResponse takes an Tailscale Response, marhsal it to JSON. -// If isNoise is set, then the JSON body will be returned -// If !isNoise and privateKey2019 is set, the JSON body will be sealed in a Nacl box. -func MarshalResponse( - resp interface{}, - isNoise bool, - privateKey2019 *key.MachinePrivate, - machineKey key.MachinePublic, +func (m *Mapper) DERPMapResponse( + mapRequest tailcfg.MapRequest, + machine *types.Machine, + derpMap tailcfg.DERPMap, ) ([]byte, error) { - jsonBody, err := json.Marshal(resp) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot marshal response") + resp := m.baseMapResponse() + resp.DERPMap = &derpMap + + return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress) +} +func (m *Mapper) PeerChangedResponse( + mapRequest tailcfg.MapRequest, + machine *types.Machine, + changed types.Machines, + pol *policy.ACLPolicy, +) ([]byte, error) { + m.mu.Lock() + defer m.mu.Unlock() + + lastSeen := make(map[tailcfg.NodeID]bool) + + // Update our internal map. + for _, machine := range changed { + m.peers[machine.ID] = machine + + // We have just seen the node, let the peers update their list. + lastSeen[tailcfg.NodeID(machine.ID)] = true + } + + resp := m.baseMapResponse() + + err := appendPeerChanges( + &resp, + pol, + machine, + machineMapToList(m.peers), + changed, + m.baseDomain, + m.dnsCfg, + ) + if err != nil { return nil, err } - if !isNoise && privateKey2019 != nil { - return privateKey2019.SealTo(machineKey, jsonBody), nil + // resp.PeerSeenChange = lastSeen + + return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress) +} + +func (m *Mapper) PeerRemovedResponse( + mapRequest tailcfg.MapRequest, + machine *types.Machine, + removed []tailcfg.NodeID, +) ([]byte, error) { + m.mu.Lock() + defer m.mu.Unlock() + + // remove from our internal map + for _, id := range removed { + delete(m.peers, uint64(id)) } - return jsonBody, nil + resp := m.baseMapResponse() + resp.PeersRemoved = removed + + return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress) } -func (m Mapper) marshalMapResponse( - resp interface{}, - machineKey key.MachinePublic, +func (m *Mapper) marshalMapResponse( + mapRequest tailcfg.MapRequest, + resp *tailcfg.MapResponse, + machine *types.Machine, compression string, ) ([]byte, error) { + atomic.AddUint64(&m.seq, 1) + + var machineKey key.MachinePublic + err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey))) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Cannot parse client key") + + return nil, err + } + jsonBody, err := json.Marshal(resp) if err != nil { log.Error(). @@ -388,6 +384,41 @@ func (m Mapper) marshalMapResponse( Msg("Cannot marshal map response") } + if debugDumpMapResponsePath != "" { + data := map[string]interface{}{ + "MapRequest": mapRequest, + "MapResponse": resp, + } + + body, err := json.Marshal(data) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Cannot marshal map response") + } + + perms := fs.FileMode(debugMapResponsePerm) + mPath := path.Join(debugDumpMapResponsePath, machine.Hostname) + err = os.MkdirAll(mPath, perms) + if err != nil { + panic(err) + } + + now := time.Now().UnixNano() + + mapResponsePath := path.Join( + mPath, + fmt.Sprintf("%d-%s-%d.json", now, m.uid, atomic.LoadUint64(&m.seq)), + ) + + log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath) + err = os.WriteFile(mapResponsePath, body, perms) + if err != nil { + panic(err) + } + } + var respBody []byte if compression == util.ZstdCompression { respBody = zstdEncode(jsonBody) @@ -409,6 +440,32 @@ func (m Mapper) marshalMapResponse( return data, nil } +// MarshalResponse takes an Tailscale Response, marhsal it to JSON. +// If isNoise is set, then the JSON body will be returned +// If !isNoise and privateKey2019 is set, the JSON body will be sealed in a Nacl box. +func MarshalResponse( + resp interface{}, + isNoise bool, + privateKey2019 *key.MachinePrivate, + machineKey key.MachinePublic, +) ([]byte, error) { + jsonBody, err := json.Marshal(resp) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Cannot marshal response") + + return nil, err + } + + if !isNoise && privateKey2019 != nil { + return privateKey2019.SealTo(machineKey, jsonBody), nil + } + + return jsonBody, nil +} + func zstdEncode(in []byte) []byte { encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder) if !ok { @@ -433,3 +490,133 @@ var zstdEncoderPool = &sync.Pool{ return encoder }, } + +// baseMapResponse returns a tailcfg.MapResponse with +// KeepAlive false and ControlTime set to now. +func (m *Mapper) baseMapResponse() tailcfg.MapResponse { + now := time.Now() + + resp := tailcfg.MapResponse{ + KeepAlive: false, + ControlTime: &now, + } + + return resp +} + +// baseWithConfigMapResponse returns a tailcfg.MapResponse struct +// with the basic configuration from headscale set. +// It is used in for bigger updates, such as full and lite, not +// incremental. +func (m *Mapper) baseWithConfigMapResponse( + machine *types.Machine, + pol *policy.ACLPolicy, +) (*tailcfg.MapResponse, error) { + resp := m.baseMapResponse() + + tailnode, err := tailNode(machine, pol, m.dnsCfg, m.baseDomain) + if err != nil { + return nil, err + } + resp.Node = tailnode + + resp.DERPMap = m.derpMap + + resp.Domain = m.baseDomain + + // Do not instruct clients to collect services we do not + // support or do anything with them + resp.CollectServices = "false" + + resp.KeepAlive = false + + resp.Debug = &tailcfg.Debug{ + DisableLogTail: !m.logtail, + RandomizeClientPort: m.randomClientPort, + } + + return &resp, nil +} + +func machineMapToList(machines map[uint64]*types.Machine) types.Machines { + ret := make(types.Machines, 0) + + for _, machine := range machines { + ret = append(ret, machine) + } + + return ret +} + +func filterExpiredAndNotReady(peers types.Machines) types.Machines { + return lo.Filter(peers, func(item *types.Machine, index int) bool { + // Filter out nodes that are expired OR + // nodes that has no endpoints, this typically means they have + // registered, but are not configured. + return !item.IsExpired() || len(item.Endpoints) > 0 + }) +} + +// appendPeerChanges mutates a tailcfg.MapResponse with all the +// necessary changes when peers have changed. +func appendPeerChanges( + resp *tailcfg.MapResponse, + + pol *policy.ACLPolicy, + machine *types.Machine, + peers types.Machines, + changed types.Machines, + baseDomain string, + dnsCfg *tailcfg.DNSConfig, +) error { + fullChange := len(peers) == len(changed) + + rules, sshPolicy, err := policy.GenerateFilterAndSSHRules( + pol, + machine, + peers, + ) + if err != nil { + return err + } + + // Filter out peers that have expired. + changed = filterExpiredAndNotReady(changed) + + // If there are filter rules present, see if there are any machines that cannot + // access eachother at all and remove them from the peers. + if len(rules) > 0 { + changed = policy.FilterMachinesByACL(machine, changed, rules) + } + + profiles := generateUserProfiles(machine, changed, baseDomain) + + dnsConfig := generateDNSConfig( + dnsCfg, + baseDomain, + machine, + peers, + ) + + tailPeers, err := tailNodes(changed, pol, dnsCfg, baseDomain) + if err != nil { + return err + } + + // Peers is always returned sorted by Node.ID. + sort.SliceStable(tailPeers, func(x, y int) bool { + return tailPeers[x].ID < tailPeers[y].ID + }) + + if fullChange { + resp.Peers = tailPeers + } else { + resp.PeersChanged = tailPeers + } + resp.DNSConfig = dnsConfig + resp.PacketFilter = policy.ReduceFilterRules(machine, rules) + resp.UserProfiles = profiles + resp.SSHPolicy = sshPolicy + + return nil +} diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 0ca9633e8d..c0857f2687 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -18,8 +18,8 @@ import ( ) func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { - mach := func(hostname, username string, userid uint) types.Machine { - return types.Machine{ + mach := func(hostname, username string, userid uint) *types.Machine { + return &types.Machine{ Hostname: hostname, UserID: userid, User: types.User{ @@ -34,7 +34,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { machine2InShared1 := mach("test_get_shared_nodes_4", "user1", 1) userProfiles := generateUserProfiles( - &machineInShared1, + machineInShared1, types.Machines{ machineInShared2, machineInShared3, machine2InShared1, }, @@ -91,8 +91,8 @@ func TestDNSConfigMapResponse(t *testing.T) { for _, tt := range tests { t.Run(fmt.Sprintf("with-magicdns-%v", tt.magicDNS), func(t *testing.T) { - mach := func(hostname, username string, userid uint) types.Machine { - return types.Machine{ + mach := func(hostname, username string, userid uint) *types.Machine { + return &types.Machine{ Hostname: hostname, UserID: userid, User: types.User{ @@ -243,7 +243,7 @@ func Test_fullMapResponse(t *testing.T) { }, } - peer1 := types.Machine{ + peer1 := &types.Machine{ ID: 1, MachineKey: "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", NodeKey: "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", @@ -295,7 +295,7 @@ func Test_fullMapResponse(t *testing.T) { }, } - peer2 := types.Machine{ + peer2 := &types.Machine{ ID: 2, MachineKey: "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", NodeKey: "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", @@ -341,7 +341,7 @@ func Test_fullMapResponse(t *testing.T) { name: "no-pol-no-peers-map-response", pol: &policy.ACLPolicy{}, machine: mini, - peers: []types.Machine{}, + peers: types.Machines{}, baseDomain: "", dnsConfig: &tailcfg.DNSConfig{}, derpMap: &tailcfg.DERPMap{}, @@ -369,7 +369,7 @@ func Test_fullMapResponse(t *testing.T) { name: "no-pol-with-peer-map-response", pol: &policy.ACLPolicy{}, machine: mini, - peers: []types.Machine{ + peers: types.Machines{ peer1, }, baseDomain: "", @@ -387,6 +387,7 @@ func Test_fullMapResponse(t *testing.T) { DNSConfig: &tailcfg.DNSConfig{}, Domain: "", CollectServices: "false", + OnlineChange: map[tailcfg.NodeID]bool{tailPeer1.ID: false}, PacketFilter: []tailcfg.FilterRule{}, UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}}, SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}}, @@ -409,7 +410,7 @@ func Test_fullMapResponse(t *testing.T) { }, }, machine: mini, - peers: []types.Machine{ + peers: types.Machines{ peer1, peer2, }, @@ -428,6 +429,10 @@ func Test_fullMapResponse(t *testing.T) { DNSConfig: &tailcfg.DNSConfig{}, Domain: "", CollectServices: "false", + OnlineChange: map[tailcfg.NodeID]bool{ + tailPeer1.ID: false, + tailcfg.NodeID(peer2.ID): false, + }, PacketFilter: []tailcfg.FilterRule{ { SrcIPs: []string{"100.64.0.2/32"}, @@ -436,9 +441,11 @@ func Test_fullMapResponse(t *testing.T) { }, }, }, - UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}}, - SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}}, - ControlTime: &time.Time{}, + UserProfiles: []tailcfg.UserProfile{ + {LoginName: "mini", DisplayName: "mini"}, + }, + SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}}, + ControlTime: &time.Time{}, Debug: &tailcfg.Debug{ DisableLogTail: true, }, @@ -449,17 +456,23 @@ func Test_fullMapResponse(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := fullMapResponse( - tt.pol, + mappy := NewMapper( tt.machine, tt.peers, + nil, + false, + tt.derpMap, tt.baseDomain, tt.dnsConfig, - tt.derpMap, tt.logtail, tt.randomClientPort, ) + got, err := mappy.fullMapResponse( + tt.machine, + tt.pol, + ) + if (err != nil) != tt.wantErr { t.Errorf("fullMapResponse() error = %v, wantErr %v", err, tt.wantErr) diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index 92bd5c965b..250b26fe76 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -41,7 +41,7 @@ func tailNodes( // tailNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes // as per the expected behaviour in the official SaaS. func tailNode( - machine types.Machine, + machine *types.Machine, pol *policy.ACLPolicy, dnsConfig *tailcfg.DNSConfig, baseDomain string, diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 9874a7792d..28bc8e6585 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -45,7 +45,7 @@ func TestTailNode(t *testing.T) { tests := []struct { name string - machine types.Machine + machine *types.Machine pol *policy.ACLPolicy dnsConfig *tailcfg.DNSConfig baseDomain string @@ -54,7 +54,7 @@ func TestTailNode(t *testing.T) { }{ { name: "empty-machine", - machine: types.Machine{}, + machine: &types.Machine{}, pol: &policy.ACLPolicy{}, dnsConfig: &tailcfg.DNSConfig{}, baseDomain: "", @@ -63,7 +63,7 @@ func TestTailNode(t *testing.T) { }, { name: "minimal-machine", - machine: types.Machine{ + machine: &types.Machine{ ID: 0, MachineKey: "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", NodeKey: "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", diff --git a/hscontrol/metrics.go b/hscontrol/metrics.go index 087ce3028b..724464c4df 100644 --- a/hscontrol/metrics.go +++ b/hscontrol/metrics.go @@ -10,32 +10,16 @@ const prometheusNamespace = "headscale" var ( // This is a high cardinality metric (user x machines), we might want to make this // configurable/opt-in in the future. - lastStateUpdate = promauto.NewGaugeVec(prometheus.GaugeOpts{ - Namespace: prometheusNamespace, - Name: "last_update_seconds", - Help: "Time stamp in unix time when a machine or headscale was updated", - }, []string{"user", "machine"}) - machineRegistrations = promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: prometheusNamespace, Name: "machine_registrations_total", Help: "The total amount of registered machine attempts", }, []string{"action", "auth", "status", "user"}) - updateRequestsFromNode = promauto.NewCounterVec(prometheus.CounterOpts{ - Namespace: prometheusNamespace, - Name: "update_request_from_node_total", - Help: "The number of updates requested by a node/update function", - }, []string{"user", "machine", "state"}) updateRequestsSentToNode = promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: prometheusNamespace, Name: "update_request_sent_to_node_total", Help: "The number of calls/messages issued on a specific nodes update channel", }, []string{"user", "machine", "status"}) // TODO(kradalby): This is very debugging, we might want to remove it. - updateRequestsReceivedOnChannel = promauto.NewCounterVec(prometheus.CounterOpts{ - Namespace: prometheusNamespace, - Name: "update_request_received_on_channel_total", - Help: "The number of update requests received on an update channel", - }, []string{"user", "machine"}) ) diff --git a/hscontrol/notifier/notifier.go b/hscontrol/notifier/notifier.go new file mode 100644 index 0000000000..32f426ad72 --- /dev/null +++ b/hscontrol/notifier/notifier.go @@ -0,0 +1,80 @@ +package notifier + +import ( + "sync" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog/log" +) + +type Notifier struct { + l sync.RWMutex + nodes map[string]chan<- types.StateUpdate +} + +func NewNotifier() *Notifier { + return &Notifier{} +} + +func (n *Notifier) AddNode(machineKey string, c chan<- types.StateUpdate) { + log.Trace().Caller().Str("key", machineKey).Msg("acquiring lock to add node") + defer log.Trace().Caller().Str("key", machineKey).Msg("releasing lock to add node") + + n.l.Lock() + defer n.l.Unlock() + + if n.nodes == nil { + n.nodes = make(map[string]chan<- types.StateUpdate) + } + + n.nodes[machineKey] = c + + log.Trace(). + Str("machine_key", machineKey). + Int("open_chans", len(n.nodes)). + Msg("Added new channel") +} + +func (n *Notifier) RemoveNode(machineKey string) { + log.Trace().Caller().Str("key", machineKey).Msg("acquiring lock to remove node") + defer log.Trace().Caller().Str("key", machineKey).Msg("releasing lock to remove node") + + n.l.Lock() + defer n.l.Unlock() + + if n.nodes == nil { + return + } + + delete(n.nodes, machineKey) + + log.Trace(). + Str("machine_key", machineKey). + Int("open_chans", len(n.nodes)). + Msg("Removed channel") +} + +func (n *Notifier) NotifyAll(update types.StateUpdate) { + n.NotifyWithIgnore(update) +} + +func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string) { + log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify") + defer log.Trace(). + Caller(). + Interface("type", update.Type). + Msg("releasing lock, finished notifing") + + n.l.RLock() + defer n.l.RUnlock() + + for key, c := range n.nodes { + if util.IsStringInSlice(ignore, key) { + continue + } + + log.Trace().Caller().Str("machine", key).Strs("ignoring", ignore).Msg("sending update") + c <- update + } +} diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 663838381b..010bcb15c5 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -523,7 +523,7 @@ func (h *Headscale) validateMachineForOIDCCallback( Str("machine", machine.Hostname). Msg("machine already registered, reauthenticating") - err := h.db.RefreshMachine(machine, expiry) + err := h.db.MachineSetExpiry(machine, expiry) if err != nil { util.LogErr(err, "Failed to refresh machine") http.Error( diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index d4e2494411..90befd55eb 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -157,7 +157,7 @@ func (pol *ACLPolicy) generateFilterRules( peers types.Machines, ) ([]tailcfg.FilterRule, error) { rules := []tailcfg.FilterRule{} - machines := append(peers, *machine) + machines := append(peers, machine) for index, acl := range pol.ACLs { if acl.Action != "accept" { @@ -293,7 +293,7 @@ func (pol *ACLPolicy) generateSSHRules( for index, sshACL := range pol.SSHs { var dest netipx.IPSetBuilder for _, src := range sshACL.Destinations { - expanded, err := pol.ExpandAlias(append(peers, *machine), src) + expanded, err := pol.ExpandAlias(append(peers, machine), src) if err != nil { return nil, err } @@ -875,7 +875,7 @@ func isTag(str string) bool { // Invalid tags are tags added by a user on a node, and that user doesn't have authority to add this tag. // Valid tags are tags added by a user that is allowed in the ACL policy to add this tag. func (pol *ACLPolicy) TagsOfMachine( - machine types.Machine, + machine *types.Machine, ) ([]string, []string) { validTags := make([]string, 0) invalidTags := make([]string, 0) @@ -935,7 +935,7 @@ func FilterMachinesByACL( continue } - if machine.CanAccess(filter, &machines[index]) || peer.CanAccess(filter, machine) { + if machine.CanAccess(filter, machines[index]) || peer.CanAccess(filter, machine) { result = append(result, peer) } } diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index ab7cfb4786..9afb58a546 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -394,7 +394,7 @@ acls: netip.MustParseAddr("100.100.100.100"), }, }, types.Machines{ - types.Machine{ + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("200.200.200.200"), }, @@ -909,38 +909,38 @@ func Test_listMachinesInUser(t *testing.T) { name: "1 machine in user", args: args{ machines: types.Machines{ - {User: types.User{Name: "joe"}}, + &types.Machine{User: types.User{Name: "joe"}}, }, user: "joe", }, want: types.Machines{ - {User: types.User{Name: "joe"}}, + &types.Machine{User: types.User{Name: "joe"}}, }, }, { name: "3 machines, 2 in user", args: args{ machines: types.Machines{ - {ID: 1, User: types.User{Name: "joe"}}, - {ID: 2, User: types.User{Name: "marc"}}, - {ID: 3, User: types.User{Name: "marc"}}, + &types.Machine{ID: 1, User: types.User{Name: "joe"}}, + &types.Machine{ID: 2, User: types.User{Name: "marc"}}, + &types.Machine{ID: 3, User: types.User{Name: "marc"}}, }, user: "marc", }, want: types.Machines{ - {ID: 2, User: types.User{Name: "marc"}}, - {ID: 3, User: types.User{Name: "marc"}}, + &types.Machine{ID: 2, User: types.User{Name: "marc"}}, + &types.Machine{ID: 3, User: types.User{Name: "marc"}}, }, }, { name: "5 machines, 0 in user", args: args{ machines: types.Machines{ - {ID: 1, User: types.User{Name: "joe"}}, - {ID: 2, User: types.User{Name: "marc"}}, - {ID: 3, User: types.User{Name: "marc"}}, - {ID: 4, User: types.User{Name: "marc"}}, - {ID: 5, User: types.User{Name: "marc"}}, + &types.Machine{ID: 1, User: types.User{Name: "joe"}}, + &types.Machine{ID: 2, User: types.User{Name: "marc"}}, + &types.Machine{ID: 3, User: types.User{Name: "marc"}}, + &types.Machine{ID: 4, User: types.User{Name: "marc"}}, + &types.Machine{ID: 5, User: types.User{Name: "marc"}}, }, user: "mickael", }, @@ -998,8 +998,10 @@ func Test_expandAlias(t *testing.T) { args: args{ alias: "*", machines: types.Machines{ - {IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}}, - { + &types.Machine{ + IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, + }, + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.78.84.227"), }, @@ -1022,25 +1024,25 @@ func Test_expandAlias(t *testing.T) { args: args{ alias: "group:accountant", machines: types.Machines{ - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, User: types.User{Name: "joe"}, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, User: types.User{Name: "joe"}, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), }, User: types.User{Name: "marc"}, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, @@ -1063,25 +1065,25 @@ func Test_expandAlias(t *testing.T) { args: args{ alias: "group:hr", machines: types.Machines{ - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, User: types.User{Name: "joe"}, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, User: types.User{Name: "joe"}, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), }, User: types.User{Name: "marc"}, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, @@ -1128,7 +1130,7 @@ func Test_expandAlias(t *testing.T) { args: args{ alias: "10.0.0.1", machines: types.Machines{ - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("10.0.0.1"), }, @@ -1149,7 +1151,7 @@ func Test_expandAlias(t *testing.T) { args: args{ alias: "10.0.0.1", machines: types.Machines{ - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), @@ -1171,7 +1173,7 @@ func Test_expandAlias(t *testing.T) { args: args{ alias: "fd7a:115c:a1e0:ab12:4843:2222:6273:2222", machines: types.Machines{ - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), @@ -1240,7 +1242,7 @@ func Test_expandAlias(t *testing.T) { args: args{ alias: "tag:hr-webserver", machines: types.Machines{ - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, @@ -1251,7 +1253,7 @@ func Test_expandAlias(t *testing.T) { RequestTags: []string{"tag:hr-webserver"}, }, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, @@ -1262,13 +1264,13 @@ func Test_expandAlias(t *testing.T) { RequestTags: []string{"tag:hr-webserver"}, }, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), }, User: types.User{Name: "marc"}, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, @@ -1294,25 +1296,25 @@ func Test_expandAlias(t *testing.T) { args: args{ alias: "tag:hr-webserver", machines: types.Machines{ - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, User: types.User{Name: "joe"}, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, User: types.User{Name: "joe"}, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), }, User: types.User{Name: "marc"}, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, @@ -1331,27 +1333,27 @@ func Test_expandAlias(t *testing.T) { args: args{ alias: "tag:hr-webserver", machines: types.Machines{ - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, User: types.User{Name: "joe"}, ForcedTags: []string{"tag:hr-webserver"}, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, User: types.User{Name: "joe"}, ForcedTags: []string{"tag:hr-webserver"}, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), }, User: types.User{Name: "marc"}, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, @@ -1374,14 +1376,14 @@ func Test_expandAlias(t *testing.T) { args: args{ alias: "tag:hr-webserver", machines: types.Machines{ - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, User: types.User{Name: "joe"}, ForcedTags: []string{"tag:hr-webserver"}, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, @@ -1392,13 +1394,13 @@ func Test_expandAlias(t *testing.T) { RequestTags: []string{"tag:hr-webserver"}, }, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), }, User: types.User{Name: "marc"}, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, @@ -1419,7 +1421,7 @@ func Test_expandAlias(t *testing.T) { args: args{ alias: "joe", machines: types.Machines{ - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, @@ -1430,7 +1432,7 @@ func Test_expandAlias(t *testing.T) { RequestTags: []string{"tag:accountant-webserver"}, }, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, @@ -1441,13 +1443,13 @@ func Test_expandAlias(t *testing.T) { RequestTags: []string{"tag:accountant-webserver"}, }, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), }, User: types.User{Name: "marc"}, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, @@ -1496,7 +1498,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { TagOwners: TagOwners{"tag:accountant-webserver": []string{"joe"}}, }, nodes: types.Machines{ - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, @@ -1507,7 +1509,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { RequestTags: []string{"tag:accountant-webserver"}, }, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, @@ -1518,7 +1520,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { RequestTags: []string{"tag:accountant-webserver"}, }, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, @@ -1528,7 +1530,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { user: "joe", }, want: types.Machines{ - { + &types.Machine{ IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.4")}, User: types.User{Name: "joe"}, }, @@ -1546,7 +1548,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { }, }, nodes: types.Machines{ - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, @@ -1557,7 +1559,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { RequestTags: []string{"tag:accountant-webserver"}, }, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, @@ -1568,7 +1570,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { RequestTags: []string{"tag:accountant-webserver"}, }, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, @@ -1578,7 +1580,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { user: "joe", }, want: types.Machines{ - { + &types.Machine{ IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.4")}, User: types.User{Name: "joe"}, }, @@ -1591,7 +1593,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { TagOwners: TagOwners{"tag:accountant-webserver": []string{"joe"}}, }, nodes: types.Machines{ - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, @@ -1602,14 +1604,14 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { RequestTags: []string{"tag:accountant-webserver"}, }, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, User: types.User{Name: "joe"}, ForcedTags: []string{"tag:accountant-webserver"}, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, @@ -1619,7 +1621,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { user: "joe", }, want: types.Machines{ - { + &types.Machine{ IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.4")}, User: types.User{Name: "joe"}, }, @@ -1632,7 +1634,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { TagOwners: TagOwners{"tag:accountant-webserver": []string{"joe"}}, }, nodes: types.Machines{ - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, @@ -1643,7 +1645,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { RequestTags: []string{"tag:hr-webserver"}, }, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, @@ -1654,7 +1656,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { RequestTags: []string{"tag:hr-webserver"}, }, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, @@ -1664,7 +1666,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { user: "joe", }, want: types.Machines{ - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, @@ -1675,7 +1677,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { RequestTags: []string{"tag:hr-webserver"}, }, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, @@ -1686,7 +1688,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { RequestTags: []string{"tag:hr-webserver"}, }, }, - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.4"), }, @@ -1714,7 +1716,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { pol ACLPolicy } type args struct { - machine types.Machine + machine *types.Machine peers types.Machines } tests := []struct { @@ -1745,7 +1747,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { }, }, args: args{ - machine: types.Machine{ + machine: &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), @@ -1790,7 +1792,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { }, }, args: args{ - machine: types.Machine{ + machine: &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), @@ -1798,7 +1800,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { User: types.User{Name: "mickael"}, }, peers: types.Machines{ - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), @@ -1837,7 +1839,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.field.pol.generateFilterRules( - &tt.args.machine, + tt.args.machine, tt.args.peers, ) if (err != nil) != tt.wantErr { @@ -1857,7 +1859,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { func TestReduceFilterRules(t *testing.T) { tests := []struct { name string - machine types.Machine + machine *types.Machine peers types.Machines pol ACLPolicy want []tailcfg.FilterRule @@ -1873,7 +1875,7 @@ func TestReduceFilterRules(t *testing.T) { }, }, }, - machine: types.Machine{ + machine: &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), @@ -1881,7 +1883,7 @@ func TestReduceFilterRules(t *testing.T) { User: types.User{Name: "mickael"}, }, peers: types.Machines{ - { + &types.Machine{ IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), @@ -1896,11 +1898,11 @@ func TestReduceFilterRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { rules, _ := tt.pol.generateFilterRules( - &tt.machine, + tt.machine, tt.peers, ) - got := ReduceFilterRules(&tt.machine, rules) + got := ReduceFilterRules(tt.machine, rules) if diff := cmp.Diff(tt.want, got); diff != "" { log.Trace().Interface("got", got).Msg("result") @@ -1913,7 +1915,7 @@ func TestReduceFilterRules(t *testing.T) { func Test_getTags(t *testing.T) { type args struct { aclPolicy *ACLPolicy - machine types.Machine + machine *types.Machine } tests := []struct { name string @@ -1929,7 +1931,7 @@ func Test_getTags(t *testing.T) { "tag:valid": []string{"joe"}, }, }, - machine: types.Machine{ + machine: &types.Machine{ User: types.User{ Name: "joe", }, @@ -1949,7 +1951,7 @@ func Test_getTags(t *testing.T) { "tag:valid": []string{"joe"}, }, }, - machine: types.Machine{ + machine: &types.Machine{ User: types.User{ Name: "joe", }, @@ -1969,7 +1971,7 @@ func Test_getTags(t *testing.T) { "tag:valid": []string{"joe"}, }, }, - machine: types.Machine{ + machine: &types.Machine{ User: types.User{ Name: "joe", }, @@ -1993,7 +1995,7 @@ func Test_getTags(t *testing.T) { "tag:valid": []string{"joe"}, }, }, - machine: types.Machine{ + machine: &types.Machine{ User: types.User{ Name: "joe", }, @@ -2009,7 +2011,7 @@ func Test_getTags(t *testing.T) { name: "empty ACLPolicy should return empty tags and should not panic", args: args{ aclPolicy: &ACLPolicy{}, - machine: types.Machine{ + machine: &types.Machine{ User: types.User{ Name: "joe", }, @@ -2072,21 +2074,21 @@ func Test_getFilteredByACLPeers(t *testing.T) { name: "all hosts can talk to each other", args: args{ machines: types.Machines{ // list of all machines in the database - { + &types.Machine{ ID: 1, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, User: types.User{Name: "joe"}, }, - { + &types.Machine{ ID: 2, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, User: types.User{Name: "marc"}, }, - { + &types.Machine{ ID: 3, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), @@ -2109,12 +2111,12 @@ func Test_getFilteredByACLPeers(t *testing.T) { }, }, want: types.Machines{ - { + &types.Machine{ ID: 2, IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")}, User: types.User{Name: "marc"}, }, - { + &types.Machine{ ID: 3, IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.3")}, User: types.User{Name: "mickael"}, @@ -2125,21 +2127,21 @@ func Test_getFilteredByACLPeers(t *testing.T) { name: "One host can talk to another, but not all hosts", args: args{ machines: types.Machines{ // list of all machines in the database - { + &types.Machine{ ID: 1, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, User: types.User{Name: "joe"}, }, - { + &types.Machine{ ID: 2, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, User: types.User{Name: "marc"}, }, - { + &types.Machine{ ID: 3, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), @@ -2162,7 +2164,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { }, }, want: types.Machines{ - { + &types.Machine{ ID: 2, IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")}, User: types.User{Name: "marc"}, @@ -2173,21 +2175,21 @@ func Test_getFilteredByACLPeers(t *testing.T) { name: "host cannot directly talk to destination, but return path is authorized", args: args{ machines: types.Machines{ // list of all machines in the database - { + &types.Machine{ ID: 1, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, User: types.User{Name: "joe"}, }, - { + &types.Machine{ ID: 2, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, User: types.User{Name: "marc"}, }, - { + &types.Machine{ ID: 3, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), @@ -2210,7 +2212,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { }, }, want: types.Machines{ - { + &types.Machine{ ID: 3, IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.3")}, User: types.User{Name: "mickael"}, @@ -2221,21 +2223,21 @@ func Test_getFilteredByACLPeers(t *testing.T) { name: "rules allows all hosts to reach one destination", args: args{ machines: types.Machines{ // list of all machines in the database - { + &types.Machine{ ID: 1, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, User: types.User{Name: "joe"}, }, - { + &types.Machine{ ID: 2, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, User: types.User{Name: "marc"}, }, - { + &types.Machine{ ID: 3, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), @@ -2260,7 +2262,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { }, }, want: types.Machines{ - { + &types.Machine{ ID: 2, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), @@ -2273,21 +2275,21 @@ func Test_getFilteredByACLPeers(t *testing.T) { name: "rules allows all hosts to reach one destination, destination can reach all hosts", args: args{ machines: types.Machines{ // list of all machines in the database - { + &types.Machine{ ID: 1, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, User: types.User{Name: "joe"}, }, - { + &types.Machine{ ID: 2, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, User: types.User{Name: "marc"}, }, - { + &types.Machine{ ID: 3, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), @@ -2312,14 +2314,14 @@ func Test_getFilteredByACLPeers(t *testing.T) { }, }, want: types.Machines{ - { + &types.Machine{ ID: 1, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, User: types.User{Name: "joe"}, }, - { + &types.Machine{ ID: 3, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), @@ -2332,21 +2334,21 @@ func Test_getFilteredByACLPeers(t *testing.T) { name: "rule allows all hosts to reach all destinations", args: args{ machines: types.Machines{ // list of all machines in the database - { + &types.Machine{ ID: 1, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, User: types.User{Name: "joe"}, }, - { + &types.Machine{ ID: 2, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, User: types.User{Name: "marc"}, }, - { + &types.Machine{ ID: 3, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), @@ -2369,14 +2371,14 @@ func Test_getFilteredByACLPeers(t *testing.T) { }, }, want: types.Machines{ - { + &types.Machine{ ID: 1, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, User: types.User{Name: "joe"}, }, - { + &types.Machine{ ID: 3, IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.3")}, User: types.User{Name: "mickael"}, @@ -2387,21 +2389,21 @@ func Test_getFilteredByACLPeers(t *testing.T) { name: "without rule all communications are forbidden", args: args{ machines: types.Machines{ // list of all machines in the database - { + &types.Machine{ ID: 1, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.1"), }, User: types.User{Name: "joe"}, }, - { + &types.Machine{ ID: 2, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), }, User: types.User{Name: "marc"}, }, - { + &types.Machine{ ID: 3, IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.3"), @@ -2427,7 +2429,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { name: "issue-699-broken-star", args: args{ machines: types.Machines{ // - { + &types.Machine{ ID: 1, Hostname: "ts-head-upcrmb", IPAddresses: types.MachineAddresses{ @@ -2436,7 +2438,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { }, User: types.User{Name: "user1"}, }, - { + &types.Machine{ ID: 2, Hostname: "ts-unstable-rlwpvr", IPAddresses: types.MachineAddresses{ @@ -2445,7 +2447,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { }, User: types.User{Name: "user1"}, }, - { + &types.Machine{ ID: 3, Hostname: "ts-head-8w6paa", IPAddresses: types.MachineAddresses{ @@ -2454,7 +2456,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { }, User: types.User{Name: "user2"}, }, - { + &types.Machine{ ID: 4, Hostname: "ts-unstable-lys2ib", IPAddresses: types.MachineAddresses{ @@ -2489,7 +2491,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { }, }, want: types.Machines{ - { + &types.Machine{ ID: 1, Hostname: "ts-head-upcrmb", IPAddresses: types.MachineAddresses{ @@ -2498,7 +2500,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { }, User: types.User{Name: "user1"}, }, - { + &types.Machine{ ID: 2, Hostname: "ts-unstable-rlwpvr", IPAddresses: types.MachineAddresses{ @@ -2512,7 +2514,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { { name: "failing-edge-case-during-p3-refactor", args: args{ - machines: []types.Machine{ + machines: []*types.Machine{ { ID: 1, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, @@ -2542,7 +2544,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { User: types.User{Name: "mini"}, }, }, - want: []types.Machine{ + want: []*types.Machine{ { ID: 2, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, @@ -2554,7 +2556,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { { name: "p4-host-in-netmap-user2-dest-bug", args: args{ - machines: []types.Machine{ + machines: []*types.Machine{ { ID: 1, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, @@ -2611,7 +2613,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { User: types.User{Name: "user2"}, }, }, - want: []types.Machine{ + want: []*types.Machine{ { ID: 1, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, @@ -2635,7 +2637,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { { name: "p4-host-in-netmap-user1-dest-bug", args: args{ - machines: []types.Machine{ + machines: []*types.Machine{ { ID: 1, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, @@ -2692,7 +2694,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { User: types.User{Name: "user1"}, }, }, - want: []types.Machine{ + want: []*types.Machine{ { ID: 1, IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, @@ -2747,7 +2749,7 @@ func TestSSHRules(t *testing.T) { }, }, peers: types.Machines{ - types.Machine{ + &types.Machine{ Hostname: "testmachine2", IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")}, UserID: 0, @@ -2855,7 +2857,7 @@ func TestSSHRules(t *testing.T) { }, }, peers: types.Machines{ - types.Machine{ + &types.Machine{ Hostname: "testmachine2", IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.99.42")}, UserID: 0, @@ -2982,7 +2984,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { RequestTags: []string{"tag:test"}, } - machine := types.Machine{ + machine := &types.Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", @@ -3009,7 +3011,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRules(pol, &machine, types.Machines{}) + got, _, err := GenerateFilterAndSSHRules(pol, machine, types.Machines{}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -3037,7 +3039,7 @@ func TestInvalidTagValidUser(t *testing.T) { RequestTags: []string{"tag:foo"}, } - machine := types.Machine{ + machine := &types.Machine{ ID: 1, MachineKey: "12345", NodeKey: "bar", @@ -3063,7 +3065,7 @@ func TestInvalidTagValidUser(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRules(pol, &machine, types.Machines{}) + got, _, err := GenerateFilterAndSSHRules(pol, machine, types.Machines{}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -3091,7 +3093,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { RequestTags: []string{"tag:test"}, } - machine := types.Machine{ + machine := &types.Machine{ ID: 1, MachineKey: "12345", NodeKey: "bar", @@ -3125,7 +3127,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { // c.Assert(rules[0].DstPorts, check.HasLen, 1) // c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") - got, _, err := GenerateFilterAndSSHRules(pol, &machine, types.Machines{}) + got, _, err := GenerateFilterAndSSHRules(pol, machine, types.Machines{}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -3155,7 +3157,7 @@ func TestValidTagInvalidUser(t *testing.T) { RequestTags: []string{"tag:webapp"}, } - machine := types.Machine{ + machine := &types.Machine{ ID: 1, MachineKey: "12345", NodeKey: "bar", @@ -3175,7 +3177,7 @@ func TestValidTagInvalidUser(t *testing.T) { Hostname: "Hostname", } - machine2 := types.Machine{ + machine2 := &types.Machine{ ID: 2, MachineKey: "56789", NodeKey: "bar2", @@ -3201,7 +3203,7 @@ func TestValidTagInvalidUser(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRules(pol, &machine, types.Machines{machine2}) + got, _, err := GenerateFilterAndSSHRules(pol, machine, types.Machines{machine2}) assert.NoError(t, err) want := []tailcfg.FilterRule{ diff --git a/hscontrol/poll.go b/hscontrol/poll.go index caf522ed41..45abe56424 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -21,8 +21,42 @@ type contextKey string const machineNameContextKey = contextKey("machineName") +type UpdateNode func() + +func logPollFunc( + mapRequest tailcfg.MapRequest, + machine *types.Machine, + isNoise bool, +) (func(string), func(error, string)) { + return func(msg string) { + log.Info(). + Caller(). + Bool("noise", isNoise). + Bool("readOnly", mapRequest.ReadOnly). + Bool("omitPeers", mapRequest.OmitPeers). + Bool("stream", mapRequest.Stream). + Str("node_key", machine.NodeKey). + Str("machine", machine.Hostname). + Msg(msg) + }, + func(err error, msg string) { + log.Error(). + Caller(). + Bool("noise", isNoise). + Bool("readOnly", mapRequest.ReadOnly). + Bool("omitPeers", mapRequest.OmitPeers). + Bool("stream", mapRequest.Stream). + Str("node_key", machine.NodeKey). + Str("machine", machine.Hostname). + Err(err). + Msg(msg) + } +} + // handlePoll is the common code for the legacy and Noise protocols to // managed the poll loop. +// +//nolint:gocyclo func (h *Headscale) handlePoll( writer http.ResponseWriter, ctx context.Context, @@ -30,8 +64,112 @@ func (h *Headscale) handlePoll( mapRequest tailcfg.MapRequest, isNoise bool, ) { + logInfo, logErr := logPollFunc(mapRequest, machine, isNoise) + + // This is the mechanism where the node gives us inforamtion about its + // current configuration. + // + // If OmitPeers is true, Stream is false, and ReadOnly is false, + // then te server will let clients update their endpoints without + // breaking existing long-polling (Stream == true) connections. + // In this case, the server can omit the entire response; the client + // only checks the HTTP response status code. + if mapRequest.OmitPeers && !mapRequest.Stream && !mapRequest.ReadOnly { + log.Info(). + Caller(). + Bool("noise", isNoise). + Bool("readOnly", mapRequest.ReadOnly). + Bool("omitPeers", mapRequest.OmitPeers). + Bool("stream", mapRequest.Stream). + Str("node_key", machine.NodeKey). + Str("machine", machine.Hostname). + Strs("endpoints", machine.Endpoints). + Msg("Received endpoint update") + + now := time.Now().UTC() + machine.LastSeen = &now + machine.Hostname = mapRequest.Hostinfo.Hostname + machine.HostInfo = types.HostInfo(*mapRequest.Hostinfo) + machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) + machine.Endpoints = mapRequest.Endpoints + + if err := h.db.MachineSave(machine); err != nil { + logErr(err, "Failed to persist/update machine in the database") + http.Error(writer, "", http.StatusInternalServerError) + + return + } + + err := h.db.SaveMachineRoutes(machine) + if err != nil { + logErr(err, "Error processing machine routes") + http.Error(writer, "", http.StatusInternalServerError) + + return + } + + h.nodeNotifier.NotifyWithIgnore( + types.StateUpdate{ + Type: types.StatePeerChanged, + Changed: types.Machines{machine}, + }, + machine.MachineKey) + + writer.WriteHeader(http.StatusOK) + if f, ok := writer.(http.Flusher); ok { + f.Flush() + } + + return + + // ReadOnly is whether the client just wants to fetch the + // MapResponse, without updating their Endpoints. The + // Endpoints field will be ignored and LastSeen will not be + // updated and peers will not be notified of changes. + // + // The intended use is for clients to discover the DERP map at + // start-up before their first real endpoint update. + } else if mapRequest.OmitPeers && !mapRequest.Stream && mapRequest.ReadOnly { + h.handleLiteRequest(writer, machine, mapRequest, isNoise) + + return + } else if mapRequest.OmitPeers && mapRequest.Stream { + logErr(nil, "Ignoring request, don't know how to handle it") + + return + } + + // Handle requests not related to continouos updates immediately. + // TODO(kradalby): I am not sure if this has any function based on + // incoming requests from clients. + if mapRequest.ReadOnly && !mapRequest.Stream { + h.handleReadOnly(writer, machine, mapRequest, isNoise) + + return + } + + now := time.Now().UTC() + machine.LastSeen = &now + machine.Hostname = mapRequest.Hostinfo.Hostname + machine.HostInfo = types.HostInfo(*mapRequest.Hostinfo) + machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) + machine.Endpoints = mapRequest.Endpoints + + // When a node connects to control, list the peers it has at + // that given point, further updates are kept in memory in + // the Mapper, which lives for the duration of the polling + // session. + peers, err := h.db.ListPeers(machine) + if err != nil { + logErr(err, "Failed to list peers when opening poller") + http.Error(writer, "", http.StatusInternalServerError) + + return + } + mapp := mapper.NewMapper( - h.db, + machine, + peers, h.privateKey2019, isNoise, h.DERPMap, @@ -41,18 +179,9 @@ func (h *Headscale) handlePoll( h.cfg.RandomizeClientPort, ) - machine.Hostname = mapRequest.Hostinfo.Hostname - machine.HostInfo = types.HostInfo(*mapRequest.Hostinfo) - machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) - now := time.Now().UTC() - - err := h.db.ProcessMachineRoutes(machine) + err = h.db.SaveMachineRoutes(machine) if err != nil { - log.Error(). - Caller(). - Err(err). - Str("machine", machine.Hostname). - Msg("Error processing machine routes") + logErr(err, "Error processing machine routes") } // update ACLRules with peer informations (to update server tags if necessary) @@ -60,592 +189,220 @@ func (h *Headscale) handlePoll( // update routes with peer information err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, machine) if err != nil { - log.Error(). - Caller(). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Err(err). - Msg("Error running auto approved routes") + logErr(err, "Error running auto approved routes") } } - // From Tailscale client: - // - // ReadOnly is whether the client just wants to fetch the MapResponse, - // without updating their Endpoints. The Endpoints field will be ignored and - // LastSeen will not be updated and peers will not be notified of changes. - // - // The intended use is for clients to discover the DERP map at start-up - // before their first real endpoint update. - if !mapRequest.ReadOnly { - machine.Endpoints = mapRequest.Endpoints - machine.LastSeen = &now - } - + // TODO(kradalby): Save specific stuff, not whole object. if err := h.db.MachineSave(machine); err != nil { - log.Error(). - Str("handler", "PollNetMap"). - Bool("noise", isNoise). - Str("node_key", machine.NodeKey). - Str("machine", machine.Hostname). - Err(err). - Msg("Failed to persist/update machine in the database") + logErr(err, "Failed to persist/update machine in the database") http.Error(writer, "", http.StatusInternalServerError) return } - mapResp, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy) + logInfo("Sending initial map") + + mapResp, err := mapp.FullMapResponse(mapRequest, machine, h.ACLPolicy) if err != nil { - log.Error(). - Str("handler", "PollNetMap"). - Bool("noise", isNoise). - Str("node_key", machine.NodeKey). - Str("machine", machine.Hostname). - Err(err). - Msg("Failed to get Map response") + logErr(err, "Failed to create MapResponse") http.Error(writer, "", http.StatusInternalServerError) return } - // We update our peers if the client is not sending ReadOnly in the MapRequest - // so we don't distribute its initial request (it comes with - // empty endpoints to peers) - - // Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696 - log.Debug(). - Str("handler", "PollNetMap"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Bool("readOnly", mapRequest.ReadOnly). - Bool("omitPeers", mapRequest.OmitPeers). - Bool("stream", mapRequest.Stream). - Msg("Client map request processed") - - if mapRequest.ReadOnly { - log.Info(). - Str("handler", "PollNetMap"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Msg("Client is starting up. Probably interested in a DERP map") - - writer.Header().Set("Content-Type", "application/json; charset=utf-8") - writer.WriteHeader(http.StatusOK) - _, err := writer.Write(mapResp) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - - if f, ok := writer.(http.Flusher); ok { - f.Flush() - } + // Send the client an update to make sure we send an initial mapresponse + _, err = writer.Write(mapResp) + if err != nil { + logErr(err, "Could not write the map response") return } - // There has been an update to _any_ of the nodes that the other nodes would - // need to know about - h.setLastStateChangeToNow() - - // The request is not ReadOnly, so we need to set up channels for updating - // peers via longpoll - - // Only create update channel if it has not been created - log.Trace(). - Caller(). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Msg("Loading or creating update channel") - - const chanSize = 8 - updateChan := make(chan struct{}, chanSize) - - pollDataChan := make(chan []byte, chanSize) - defer closeChanWithLog(pollDataChan, machine.Hostname, "pollDataChan") - - keepAliveChan := make(chan []byte) - - if mapRequest.OmitPeers && !mapRequest.Stream { - log.Info(). - Str("handler", "PollNetMap"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Msg("Client sent endpoint update and is ok with a response without peer list") - writer.Header().Set("Content-Type", "application/json; charset=utf-8") - writer.WriteHeader(http.StatusOK) - _, err := writer.Write(mapResp) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - // It sounds like we should update the nodes when we have received a endpoint update - // even tho the comments in the tailscale code dont explicitly say so. - updateRequestsFromNode.WithLabelValues(machine.User.Name, machine.Hostname, "endpoint-update"). - Inc() - updateChan <- struct{}{} - - return - } else if mapRequest.OmitPeers && mapRequest.Stream { - log.Warn(). - Str("handler", "PollNetMap"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Msg("Ignoring request, don't know how to handle it") - http.Error(writer, "", http.StatusBadRequest) - + if flusher, ok := writer.(http.Flusher); ok { + flusher.Flush() + } else { return } - log.Info(). - Str("handler", "PollNetMap"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Msg("Client is ready to access the tailnet") - log.Info(). - Str("handler", "PollNetMap"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Msg("Sending initial map") - pollDataChan <- mapResp - - log.Info(). - Str("handler", "PollNetMap"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Msg("Notifying peers") - updateRequestsFromNode.WithLabelValues(machine.User.Name, machine.Hostname, "full-update"). - Inc() - updateChan <- struct{}{} - - h.pollNetMapStream( - writer, - ctx, - machine, - mapRequest, - pollDataChan, - keepAliveChan, - updateChan, - isNoise, - ) - - log.Trace(). - Str("handler", "PollNetMap"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Msg("Finished stream, closing PollNetMap session") -} - -// pollNetMapStream stream logic for /machine/map, -// ensuring we communicate updates and data to the connected clients. -func (h *Headscale) pollNetMapStream( - writer http.ResponseWriter, - ctxReq context.Context, - machine *types.Machine, - mapRequest tailcfg.MapRequest, - pollDataChan chan []byte, - keepAliveChan chan []byte, - updateChan chan struct{}, - isNoise bool, -) { - // TODO(kradalby): This is a stepping stone, mapper should be initiated once - // per client or something similar - mapp := mapper.NewMapper(h.db, - h.privateKey2019, - isNoise, - h.DERPMap, - h.cfg.BaseDomain, - h.cfg.DNSConfig, - h.cfg.LogTail.Enabled, - h.cfg.RandomizeClientPort, - ) + h.nodeNotifier.NotifyWithIgnore( + types.StateUpdate{ + Type: types.StatePeerChanged, + Changed: types.Machines{machine}, + }, + machine.MachineKey) + // Set up the client stream h.pollNetMapStreamWG.Add(1) defer h.pollNetMapStreamWG.Done() - ctx := context.WithValue(ctxReq, machineNameContextKey, machine.Hostname) + updateChan := make(chan types.StateUpdate) + defer closeChanWithLog(updateChan, machine.Hostname, "updateChan") - ctx, cancel := context.WithCancel(ctx) - defer cancel() + // Register the node's update channel + h.nodeNotifier.AddNode(machine.MachineKey, updateChan) + defer h.nodeNotifier.RemoveNode(machine.MachineKey) - go h.scheduledPollWorker( - ctx, - updateChan, - keepAliveChan, - mapRequest, - machine, - isNoise, - ) + keepAliveTicker := time.NewTicker(keepAliveInterval) - log.Trace(). - Str("handler", "pollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Msg("Waiting for data to stream...") + ctx = context.WithValue(ctx, machineNameContextKey, machine.Hostname) - log.Trace(). - Str("handler", "pollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan) + ctx, cancel := context.WithCancel(ctx) + defer cancel() for { + logInfo("Waiting for update on stream channel") select { - case data := <-pollDataChan: - log.Trace(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Str("channel", "pollData"). - Int("bytes", len(data)). - Msg("Sending data received via pollData channel") - _, err := writer.Write(data) + case <-keepAliveTicker.C: + data, err := mapp.KeepAliveResponse(mapRequest, machine) if err != nil { - log.Error(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Str("channel", "pollData"). - Err(err). - Msg("Cannot write data") + logErr(err, "Error generating the keep alive msg") return } + _, err = writer.Write(data) + if err != nil { + logErr(err, "Cannot write keep alive message") - flusher, ok := writer.(http.Flusher) - if !ok { - log.Error(). - Caller(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Str("channel", "pollData"). - Msg("Cannot cast writer to http.Flusher") - } else { - flusher.Flush() + return } + if flusher, ok := writer.(http.Flusher); ok { + flusher.Flush() + } else { + log.Error().Msg("Failed to create http flusher") - log.Trace(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Str("channel", "pollData"). - Int("bytes", len(data)). - Msg("Data from pollData channel written successfully") - // TODO(kradalby): Abstract away all the database calls, this can cause race conditions - // when an outdated machine object is kept alive, e.g. db is update from - // command line, but then overwritten. - err = h.db.UpdateMachineFromDatabase(machine) - if err != nil { - log.Error(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Str("channel", "pollData"). - Err(err). - Msg("Cannot update machine from database") - - // client has been removed from database - // since the stream opened, terminate connection. return } - now := time.Now().UTC() - machine.LastSeen = &now - - lastStateUpdate.WithLabelValues(machine.User.Name, machine.Hostname). - Set(float64(now.Unix())) - machine.LastSuccessfulUpdate = &now - err = h.db.TouchMachine(machine) - if err != nil { - log.Error(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Str("channel", "pollData"). - Err(err). - Msg("Cannot update machine LastSuccessfulUpdate") + // This goroutine is not ideal, but we have a potential issue here + // where it blocks too long and that holds up updates. + // One alternative is to split these different channels into + // goroutines, but then you might have a problem without a lock + // if a keepalive is written at the same time as an update. + go func() { + err = h.db.UpdateLastSeen(machine) + if err != nil { + logErr(err, "Cannot update machine LastSeen") - return + return + } + }() + + case update := <-updateChan: + logInfo("Received update") + now := time.Now() + + var data []byte + var err error + + switch update.Type { + case types.StatePeerChanged: + logInfo("Sending PeerChanged MapResponse") + data, err = mapp.PeerChangedResponse(mapRequest, machine, update.Changed, h.ACLPolicy) + case types.StatePeerRemoved: + logInfo("Sending PeerRemoved MapResponse") + data, err = mapp.PeerRemovedResponse(mapRequest, machine, update.Removed) + case types.StateDERPUpdated: + logInfo("Sending DERPUpdate MapResponse") + data, err = mapp.DERPMapResponse(mapRequest, machine, update.DERPMap) + case types.StateFullUpdate: + logInfo("Sending Full MapResponse") + data, err = mapp.FullMapResponse(mapRequest, machine, h.ACLPolicy) } - log.Trace(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Str("channel", "pollData"). - Int("bytes", len(data)). - Msg("Machine entry in database updated successfully after sending data") - - case data := <-keepAliveChan: - log.Trace(). - Str("handler", "PollNetMapStream"). - Str("machine", machine.Hostname). - Str("channel", "keepAlive"). - Int("bytes", len(data)). - Msg("Sending keep alive message") - _, err := writer.Write(data) if err != nil { - log.Error(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Str("channel", "keepAlive"). - Err(err). - Msg("Cannot write keep alive message") + logErr(err, "Could not get the create map update") return } - flusher, ok := writer.(http.Flusher) - if !ok { - log.Error(). - Caller(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Str("channel", "keepAlive"). - Msg("Cannot cast writer to http.Flusher") - } else { - flusher.Flush() - } - log.Trace(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Str("channel", "keepAlive"). - Int("bytes", len(data)). - Msg("Keep alive sent successfully") - // TODO(kradalby): Abstract away all the database calls, this can cause race conditions - // when an outdated machine object is kept alive, e.g. db is update from - // command line, but then overwritten. - err = h.db.UpdateMachineFromDatabase(machine) - if err != nil { - log.Error(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Str("channel", "keepAlive"). - Err(err). - Msg("Cannot update machine from database") - - // client has been removed from database - // since the stream opened, terminate connection. - return - } - now := time.Now().UTC() - machine.LastSeen = &now - err = h.db.TouchMachine(machine) + _, err = writer.Write(data) if err != nil { - log.Error(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Str("channel", "keepAlive"). - Err(err). - Msg("Cannot update machine LastSeen") + logErr(err, "Could not write the map response") + + updateRequestsSentToNode.WithLabelValues(machine.User.Name, machine.Hostname, "failed"). + Inc() return } - log.Trace(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Str("channel", "keepAlive"). - Int("bytes", len(data)). - Msg("Machine updated successfully after sending keep alive") + if flusher, ok := writer.(http.Flusher); ok { + flusher.Flush() + } else { + log.Error().Msg("Failed to create http flusher") - case <-updateChan: - log.Trace(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Str("channel", "update"). - Msg("Received a request for update") - updateRequestsReceivedOnChannel.WithLabelValues(machine.User.Name, machine.Hostname). - Inc() - - if h.db.IsOutdated(machine, h.getLastStateChange()) { - var lastUpdate time.Time - if machine.LastSuccessfulUpdate != nil { - lastUpdate = *machine.LastSuccessfulUpdate - } - log.Debug(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Time("last_successful_update", lastUpdate). - Time("last_state_change", h.getLastStateChange(machine.User)). - Msgf("There has been updates since the last successful update to %s", machine.Hostname) - data, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy) - if err != nil { - log.Error(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Str("channel", "update"). - Err(err). - Msg("Could not get the map update") + return + } - return - } - _, err = writer.Write(data) + // See comment in keepAliveTicker + go func() { + err = h.db.UpdateLastSeen(machine) if err != nil { - log.Error(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Str("channel", "update"). - Err(err). - Msg("Could not write the map response") - updateRequestsSentToNode.WithLabelValues(machine.User.Name, machine.Hostname, "failed"). - Inc() + logErr(err, "Cannot update machine LastSeen") return } + }() - flusher, ok := writer.(http.Flusher) - if !ok { - log.Error(). - Caller(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Str("channel", "update"). - Msg("Cannot cast writer to http.Flusher") - } else { - flusher.Flush() - } - - log.Trace(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Str("channel", "update"). - Msg("Updated Map has been sent") - updateRequestsSentToNode.WithLabelValues(machine.User.Name, machine.Hostname, "success"). - Inc() - - // Keep track of the last successful update, - // we sometimes end in a state were the update - // is not picked up by a client and we use this - // to determine if we should "force" an update. - // TODO(kradalby): Abstract away all the database calls, this can cause race conditions - // when an outdated machine object is kept alive, e.g. db is update from - // command line, but then overwritten. - err = h.db.UpdateMachineFromDatabase(machine) - if err != nil { - log.Error(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Str("channel", "update"). - Err(err). - Msg("Cannot update machine from database") - - // client has been removed from database - // since the stream opened, terminate connection. - return - } - now := time.Now().UTC() - - lastStateUpdate.WithLabelValues(machine.User.Name, machine.Hostname). - Set(float64(now.Unix())) - machine.LastSuccessfulUpdate = &now + log.Info(). + Caller(). + Bool("noise", isNoise). + Bool("readOnly", mapRequest.ReadOnly). + Bool("omitPeers", mapRequest.OmitPeers). + Bool("stream", mapRequest.Stream). + Str("node_key", machine.NodeKey). + Str("machine", machine.Hostname). + TimeDiff("timeSpent", time.Now(), now). + Msg("update sent") + case <-ctx.Done(): + logInfo("The client has closed the connection") - err = h.db.TouchMachine(machine) + go func() { + err = h.db.UpdateLastSeen(machine) if err != nil { - log.Error(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Str("channel", "update"). - Err(err). - Msg("Cannot update machine LastSuccessfulUpdate") + logErr(err, "Cannot update machine LastSeen") return } - } else { - var lastUpdate time.Time - if machine.LastSuccessfulUpdate != nil { - lastUpdate = *machine.LastSuccessfulUpdate - } - log.Trace(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Time("last_successful_update", lastUpdate). - Time("last_state_change", h.getLastStateChange(machine.User)). - Msgf("%s is up to date", machine.Hostname) - } - - case <-ctx.Done(): - log.Info(). - Str("handler", "PollNetMapStream"). - Str("machine", machine.Hostname). - Msg("The client has closed the connection") - // TODO: Abstract away all the database calls, this can cause race conditions - // when an outdated machine object is kept alive, e.g. db is update from - // command line, but then overwritten. - err := h.db.UpdateMachineFromDatabase(machine) - if err != nil { - log.Error(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Str("channel", "Done"). - Err(err). - Msg("Cannot update machine from database") - - // client has been removed from database - // since the stream opened, terminate connection. - return - } - now := time.Now().UTC() - machine.LastSeen = &now - err = h.db.TouchMachine(machine) - if err != nil { - log.Error(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Str("channel", "Done"). - Err(err). - Msg("Cannot update machine LastSeen") - } + }() // The connection has been closed, so we can stop polling. return case <-h.shutdownChan: - log.Info(). - Str("handler", "PollNetMapStream"). - Bool("noise", isNoise). - Str("machine", machine.Hostname). - Msg("The long-poll handler is shutting down") + logInfo("The long-poll handler is shutting down") return } } } -func (h *Headscale) scheduledPollWorker( - ctx context.Context, - updateChan chan struct{}, - keepAliveChan chan []byte, - mapRequest tailcfg.MapRequest, +func closeChanWithLog[C chan []byte | chan struct{} | chan types.StateUpdate](channel C, machine, name string) { + log.Trace(). + Str("handler", "PollNetMap"). + Str("machine", machine). + Str("channel", "Done"). + Msg(fmt.Sprintf("Closing %s channel", name)) + + close(channel) +} + +// TODO(kradalby): This might not actually be used, +// observing incoming client requests indicates it +// is not. +func (h *Headscale) handleReadOnly( + writer http.ResponseWriter, machine *types.Machine, + mapRequest tailcfg.MapRequest, isNoise bool, ) { - // TODO(kradalby): This is a stepping stone, mapper should be initiated once - // per client or something similar - mapp := mapper.NewMapper(h.db, + logInfo, logErr := logPollFunc(mapRequest, machine, isNoise) + + mapp := mapper.NewMapper( + machine, + // TODO(kradalby): It might not be acceptable to send + // an empty peer list here. + types.Machines{}, h.privateKey2019, isNoise, h.DERPMap, @@ -654,72 +411,64 @@ func (h *Headscale) scheduledPollWorker( h.cfg.LogTail.Enabled, h.cfg.RandomizeClientPort, ) + logInfo("Client is starting up. Probably interested in a DERP map") - keepAliveTicker := time.NewTicker(keepAliveInterval) - updateCheckerTicker := time.NewTicker(h.cfg.NodeUpdateCheckInterval) + mapResp, err := mapp.FullMapResponse(mapRequest, machine, h.ACLPolicy) + if err != nil { + logErr(err, "Failed to create MapResponse") + http.Error(writer, "", http.StatusInternalServerError) - defer closeChanWithLog( - updateChan, - fmt.Sprint(ctx.Value(machineNameContextKey)), - "updateChan", - ) - defer closeChanWithLog( - keepAliveChan, - fmt.Sprint(ctx.Value(machineNameContextKey)), - "keepAliveChan", - ) + return + } - for { - select { - case <-ctx.Done(): - return + writer.Header().Set("Content-Type", "application/json; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, err = writer.Write(mapResp) + if err != nil { + logErr(err, "Failed to write response") + } - case <-keepAliveTicker.C: - data, err := mapp.CreateKeepAliveResponse(mapRequest, machine) - if err != nil { - log.Error(). - Str("func", "keepAlive"). - Bool("noise", isNoise). - Err(err). - Msg("Error generating the keep alive msg") + if f, ok := writer.(http.Flusher); ok { + f.Flush() + } +} - return - } +func (h *Headscale) handleLiteRequest( + writer http.ResponseWriter, + machine *types.Machine, + mapRequest tailcfg.MapRequest, + isNoise bool, +) { + logInfo, logErr := logPollFunc(mapRequest, machine, isNoise) - log.Debug(). - Str("func", "keepAlive"). - Str("machine", machine.Hostname). - Bool("noise", isNoise). - Msg("Sending keepalive") - select { - case keepAliveChan <- data: - case <-ctx.Done(): - return - } + mapp := mapper.NewMapper( + machine, + // TODO(kradalby): It might not be acceptable to send + // an empty peer list here. + types.Machines{}, + h.privateKey2019, + isNoise, + h.DERPMap, + h.cfg.BaseDomain, + h.cfg.DNSConfig, + h.cfg.LogTail.Enabled, + h.cfg.RandomizeClientPort, + ) - case <-updateCheckerTicker.C: - log.Debug(). - Str("func", "scheduledPollWorker"). - Str("machine", machine.Hostname). - Bool("noise", isNoise). - Msg("Sending update request") - updateRequestsFromNode.WithLabelValues(machine.User.Name, machine.Hostname, "scheduled-update"). - Inc() - select { - case updateChan <- struct{}{}: - case <-ctx.Done(): - return - } - } - } -} + logInfo("Client asked for a lite update, responding without peers") -func closeChanWithLog[C chan []byte | chan struct{}](channel C, machine, name string) { - log.Trace(). - Str("handler", "PollNetMap"). - Str("machine", machine). - Str("channel", "Done"). - Msg(fmt.Sprintf("Closing %s channel", name)) + mapResp, err := mapp.LiteMapResponse(mapRequest, machine, h.ACLPolicy) + if err != nil { + logErr(err, "Failed to create MapResponse") + http.Error(writer, "", http.StatusInternalServerError) - close(channel) + return + } + + writer.Header().Set("Content-Type", "application/json; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, err = writer.Write(mapResp) + if err != nil { + logErr(err, "Failed to write response") + } } diff --git a/hscontrol/poll_legacy.go b/hscontrol/poll_legacy.go index e175df9270..90cb13201f 100644 --- a/hscontrol/poll_legacy.go +++ b/hscontrol/poll_legacy.go @@ -91,7 +91,7 @@ func (h *Headscale) PollNetMapHandler( Str("handler", "PollNetMap"). Str("id", machineKeyStr). Str("machine", machine.Hostname). - Msg("A machine is entering polling via the legacy protocol") + Msg("A machine is sending a MapRequest via legacy protocol") h.handlePoll(writer, req.Context(), machine, mapRequest, false) } diff --git a/hscontrol/poll_noise.go b/hscontrol/poll_noise.go index e512213a14..3d672f0b6b 100644 --- a/hscontrol/poll_noise.go +++ b/hscontrol/poll_noise.go @@ -31,6 +31,7 @@ func (ns *noiseServer) NoisePollNetMapHandler( log.Trace(). Any("headers", req.Header). + Caller(). Msg("Headers") body, _ := io.ReadAll(req.Body) @@ -72,7 +73,7 @@ func (ns *noiseServer) NoisePollNetMapHandler( log.Debug(). Str("handler", "NoisePollNetMap"). Str("machine", machine.Hostname). - Msg("A machine is entering polling via the Noise protocol") + Msg("A machine sending a MapRequest with Noise protocol") ns.headscale.handlePoll(writer, req.Context(), machine, mapRequest, true) } diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index 96ad1b782e..e80ef8f4cb 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -106,3 +106,32 @@ func (i StringList) Value() (driver.Value, error) { return string(bytes), err } + +type StateUpdateType int + +const ( + StateFullUpdate StateUpdateType = iota + StatePeerChanged + StatePeerRemoved + StateDERPUpdated +) + +// StateUpdate is an internal message containing information about +// a state change that has happened to the network. +type StateUpdate struct { + // The type of update + Type StateUpdateType + + // Changed must be set when Type is StatePeerChanged and + // contain the Machine IDs of machines that has changed. + Changed Machines + + // Removed must be set when Type is StatePeerRemoved and + // contain a list of the nodes that has been removed from + // the network. + Removed []tailcfg.NodeID + + // DERPMap must be set when Type is StateDERPUpdated and + // contain the new DERP Map. + DERPMap tailcfg.DERPMap +} diff --git a/hscontrol/types/machine.go b/hscontrol/types/machine.go index 4e5a940f3a..534a0d13a5 100644 --- a/hscontrol/types/machine.go +++ b/hscontrol/types/machine.go @@ -53,9 +53,8 @@ type Machine struct { AuthKeyID uint AuthKey *PreAuthKey - LastSeen *time.Time - LastSuccessfulUpdate *time.Time - Expiry *time.Time + LastSeen *time.Time + Expiry *time.Time HostInfo HostInfo Endpoints StringList @@ -68,8 +67,7 @@ type Machine struct { } type ( - Machines []Machine - MachinesP []*Machine + Machines []*Machine ) type MachineAddresses []netip.Addr @@ -247,12 +245,6 @@ func (machine *Machine) Proto() *v1.Machine { machineProto.LastSeen = timestamppb.New(*machine.LastSeen) } - if machine.LastSuccessfulUpdate != nil { - machineProto.LastSuccessfulUpdate = timestamppb.New( - *machine.LastSuccessfulUpdate, - ) - } - if machine.Expiry != nil { machineProto.Expiry = timestamppb.New(*machine.Expiry) } @@ -343,13 +335,12 @@ func (machines Machines) String() string { return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) } -// TODO(kradalby): Remove when we have generics... -func (machines MachinesP) String() string { - temp := make([]string, len(machines)) +func (machines Machines) IDMap() map[uint64]*Machine { + ret := map[uint64]*Machine{} - for index, machine := range machines { - temp[index] = machine.Hostname + for _, machine := range machines { + ret[machine.ID] = machine } - return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) + return ret } diff --git a/integration/cli_test.go b/integration/cli_test.go index 5a27d51726..ef0ed70175 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -413,14 +413,12 @@ func TestEnablingRoutes(t *testing.T) { // advertise routes using the up command for i, client := range allClients { routeStr := fmt.Sprintf("10.0.%d.0/24", i) - hostname, _ := client.FQDN() - _, _, err = client.Execute([]string{ + command := []string{ "tailscale", - "up", - fmt.Sprintf("--advertise-routes=%s", routeStr), - "-login-server", headscale.GetEndpoint(), - "--hostname", hostname, - }) + "set", + "--advertise-routes=" + routeStr, + } + _, _, err := client.Execute(command) assertNoErrf(t, "failed to advertise route: %s", err) } @@ -474,6 +472,7 @@ func TestEnablingRoutes(t *testing.T) { &enablingRoutes, ) assertNoErr(t, err) + assert.Len(t, enablingRoutes, 3) for _, route := range enablingRoutes { assert.Equal(t, route.Advertised, true) diff --git a/integration/dockertestutil/execute.go b/integration/dockertestutil/execute.go index 71afa69872..5a8e92b334 100644 --- a/integration/dockertestutil/execute.go +++ b/integration/dockertestutil/execute.go @@ -9,7 +9,7 @@ import ( "github.com/ory/dockertest/v3" ) -const dockerExecuteTimeout = time.Second * 10 +const dockerExecuteTimeout = time.Second * 30 var ( ErrDockertestCommandFailed = errors.New("dockertest command failed") diff --git a/integration/embedded_derp_test.go b/integration/embedded_derp_test.go index 24bd898ef6..669faf577c 100644 --- a/integration/embedded_derp_test.go +++ b/integration/embedded_derp_test.go @@ -197,7 +197,9 @@ func (s *EmbeddedDERPServerScenario) CreateTailscaleIsolatedNodesInUser( ) } + s.mu.Lock() user.Clients[tsClient.Hostname()] = tsClient + s.mu.Unlock() return nil }) diff --git a/integration/general_test.go b/integration/general_test.go index a3e32f71bf..4de121c73e 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -407,9 +407,8 @@ func TestResolveMagicDNS(t *testing.T) { defer scenario.Shutdown() spec := map[string]int{ - // Omit 1.16.2 (-1) because it does not have the FQDN field - "magicdns1": len(MustTestVersions) - 1, - "magicdns2": len(MustTestVersions) - 1, + "magicdns1": len(MustTestVersions), + "magicdns2": len(MustTestVersions), } err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("magicdns")) diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index e13b727354..7d1bbceb3a 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -212,6 +212,7 @@ func New( env := []string{ "HEADSCALE_PROFILING_ENABLED=1", "HEADSCALE_PROFILING_PATH=/tmp/profile", + "HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH=/tmp/mapresponses", } for key, value := range hsic.env { env = append(env, fmt.Sprintf("%s=%s", key, value)) @@ -339,6 +340,14 @@ func (t *HeadscaleInContainer) Shutdown() error { ) } + err = t.SaveMapResponses("/tmp/control") + if err != nil { + log.Printf( + "Failed to save mapresponses from control: %s", + fmt.Errorf("failed to save mapresponses from control: %w", err), + ) + } + return t.pool.Purge(t.container) } @@ -366,6 +375,24 @@ func (t *HeadscaleInContainer) SaveProfile(savePath string) error { return nil } +func (t *HeadscaleInContainer) SaveMapResponses(savePath string) error { + tarFile, err := t.FetchPath("/tmp/mapresponses") + if err != nil { + return err + } + + err = os.WriteFile( + path.Join(savePath, t.hostname+".maps.tar"), + tarFile, + os.ModePerm, + ) + if err != nil { + return err + } + + return nil +} + // Execute runs a command inside the Headscale container and returns the // result of stdout as a string. func (t *HeadscaleInContainer) Execute( diff --git a/integration/scenario.go b/integration/scenario.go index 2bf30d0f63..a7e84acbf9 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -7,7 +7,6 @@ import ( "net/netip" "os" "sync" - "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/util" @@ -21,7 +20,6 @@ import ( const ( scenarioHashLength = 6 - maxWait = 60 * time.Second ) var ( @@ -114,7 +112,7 @@ type Scenario struct { pool *dockertest.Pool network *dockertest.Network - headscaleLock sync.Mutex + mu sync.Mutex } // NewScenario creates a test Scenario which can be used to bootstraps a ControlServer with @@ -130,7 +128,7 @@ func NewScenario() (*Scenario, error) { return nil, fmt.Errorf("could not connect to docker: %w", err) } - pool.MaxWait = maxWait + pool.MaxWait = dockertestMaxWait() networkName := fmt.Sprintf("hs-%s", hash) if overrideNetworkName := os.Getenv("HEADSCALE_TEST_NETWORK_NAME"); overrideNetworkName != "" { @@ -214,8 +212,8 @@ func (s *Scenario) Users() []string { // will be return, otherwise a new instance will be created. // TODO(kradalby): make port and headscale configurable, multiple instances support? func (s *Scenario) Headscale(opts ...hsic.Option) (ControlServer, error) { - s.headscaleLock.Lock() - defer s.headscaleLock.Unlock() + s.mu.Lock() + defer s.mu.Unlock() if headscale, ok := s.controlServers.Load("headscale"); ok { return headscale, nil @@ -328,7 +326,9 @@ func (s *Scenario) CreateTailscaleNodesInUser( ) } + s.mu.Lock() user.Clients[tsClient.Hostname()] = tsClient + s.mu.Unlock() return nil }) diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index 04ad2f3a13..efe9c904a8 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -11,7 +11,6 @@ import ( "strings" "time" - "github.com/cenkalti/backoff/v4" "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/integrationutil" @@ -21,10 +20,12 @@ import ( ) const ( - tsicHashLength = 6 - defaultPingCount = 10 - dockerContextPath = "../." - headscaleCertPath = "/usr/local/share/ca-certificates/headscale.crt" + tsicHashLength = 6 + defaultPingTimeout = 300 * time.Millisecond + defaultPingCount = 10 + dockerContextPath = "../." + headscaleCertPath = "/usr/local/share/ca-certificates/headscale.crt" + dockerExecuteTimeout = 60 * time.Second ) var ( @@ -181,7 +182,7 @@ func New( withEntrypoint: []string{ "/bin/sh", "-c", - "/bin/sleep 3 ; update-ca-certificates ; tailscaled --tun=tsdev", + "/bin/sleep 3 ; update-ca-certificates ; tailscaled --tun=tsdev --verbose=10", }, } @@ -361,7 +362,7 @@ func (t *TailscaleInContainer) Login( ) } - if _, _, err := t.Execute(command); err != nil { + if _, _, err := t.Execute(command, dockertestutil.ExecuteCommandTimeout(dockerExecuteTimeout)); err != nil { return fmt.Errorf( "%s failed to join tailscale client (%s): %w", t.hostname, @@ -592,7 +593,7 @@ func WithPingUntilDirect(direct bool) PingOption { // TODO(kradalby): Make multiping, go routine magic. func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) error { args := pingArgs{ - timeout: time.Second, + timeout: defaultPingTimeout, count: defaultPingCount, direct: true, } @@ -610,42 +611,40 @@ func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) err command = append(command, hostnameOrIP) - return t.pool.Retry(func() error { - result, _, err := t.Execute( - command, - dockertestutil.ExecuteCommandTimeout( - time.Duration(int64(args.timeout)*int64(args.count)), - ), + result, _, err := t.Execute( + command, + dockertestutil.ExecuteCommandTimeout( + time.Duration(int64(args.timeout)*int64(args.count)), + ), + ) + if err != nil { + log.Printf( + "failed to run ping command from %s to %s, err: %s", + t.Hostname(), + hostnameOrIP, + err, ) - if err != nil { - log.Printf( - "failed to run ping command from %s to %s, err: %s", - t.Hostname(), - hostnameOrIP, - err, - ) - return err - } + return err + } - if strings.Contains(result, "is local") { - return nil - } + if strings.Contains(result, "is local") { + return nil + } - if !strings.Contains(result, "pong") { - return backoff.Permanent(errTailscalePingFailed) - } + if !strings.Contains(result, "pong") { + return errTailscalePingFailed + } - if !args.direct { - if strings.Contains(result, "via DERP") { - return nil - } else { - return backoff.Permanent(errTailscalePingNotDERP) - } + if !args.direct { + if strings.Contains(result, "via DERP") { + return nil + } else { + return errTailscalePingNotDERP } + } - return nil - }) + return nil } type ( @@ -720,24 +719,19 @@ func (t *TailscaleInContainer) Curl(url string, opts ...CurlOption) (string, err } var result string - err := t.pool.Retry(func() error { - var err error - result, _, err = t.Execute(command) - if err != nil { - log.Printf( - "failed to run curl command from %s to %s, err: %s", - t.Hostname(), - url, - err, - ) - - return err - } + result, _, err := t.Execute(command) + if err != nil { + log.Printf( + "failed to run curl command from %s to %s, err: %s", + t.Hostname(), + url, + err, + ) - return nil - }) + return result, err + } - return result, err + return result, nil } // WriteFile save file inside the Tailscale container. diff --git a/integration/utils.go b/integration/utils.go index 0b55654d75..91e274b1ac 100644 --- a/integration/utils.go +++ b/integration/utils.go @@ -1,6 +1,7 @@ package integration import ( + "os" "strings" "testing" "time" @@ -131,6 +132,38 @@ func isSelfClient(client TailscaleClient, addr string) bool { return false } +func isCI() bool { + if _, ok := os.LookupEnv("CI"); ok { + return true + } + + if _, ok := os.LookupEnv("GITHUB_RUN_ID"); ok { + return true + } + + return false +} + +func dockertestMaxWait() time.Duration { + wait := 60 * time.Second //nolint + + if isCI() { + wait = 300 * time.Second //nolint + } + + return wait +} + +// func dockertestCommandTimeout() time.Duration { +// timeout := 10 * time.Second //nolint +// +// if isCI() { +// timeout = 60 * time.Second //nolint +// } +// +// return timeout +// } + // pingAllNegativeHelper is intended to have 1 or more nodes timeing out from the ping, // it counts failures instead of successes. // func pingAllNegativeHelper(t *testing.T, clients []TailscaleClient, addrs []string) int {