From 5b8053f2ee9c68952bc07a3a109e814ac4b261d7 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 17 Jul 2023 13:35:05 +0200 Subject: [PATCH] introduce rw lock for db, ish... Signed-off-by: Kristoffer Dalby --- hscontrol/db/addresses.go | 6 +++ hscontrol/db/api_key.go | 21 ++++++++ hscontrol/db/db.go | 2 + hscontrol/db/machine.go | 102 +++++++++++++++++++++++++++++++++-- hscontrol/db/machine_test.go | 10 +--- hscontrol/db/preauth_keys.go | 23 ++++++++ hscontrol/db/routes.go | 50 +++++++++++++++++ hscontrol/db/users.go | 24 +++++++++ 8 files changed, 225 insertions(+), 13 deletions(-) diff --git a/hscontrol/db/addresses.go b/hscontrol/db/addresses.go index 1a7d35defcc..f40358289a3 100644 --- a/hscontrol/db/addresses.go +++ b/hscontrol/db/addresses.go @@ -18,6 +18,9 @@ import ( var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP") func (hsdb *HSDatabase) getAvailableIPs() (types.MachineAddresses, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + var ips types.MachineAddresses var err error for _, ipPrefix := range hsdb.ipPrefixes { @@ -33,6 +36,9 @@ func (hsdb *HSDatabase) getAvailableIPs() (types.MachineAddresses, error) { } func (hsdb *HSDatabase) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + usedIps, err := hsdb.getUsedIPs() if err != nil { return nil, err diff --git a/hscontrol/db/api_key.go b/hscontrol/db/api_key.go index 4e4030ebfe9..bc8dc2bbdea 100644 --- a/hscontrol/db/api_key.go +++ b/hscontrol/db/api_key.go @@ -22,6 +22,9 @@ var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey") func (hsdb *HSDatabase) CreateAPIKey( expiration *time.Time, ) (string, *types.APIKey, error) { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength) if err != nil { return "", nil, err @@ -55,6 +58,9 @@ func (hsdb *HSDatabase) CreateAPIKey( // ListAPIKeys returns the list of ApiKeys for a user. func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + keys := []types.APIKey{} if err := hsdb.db.Find(&keys).Error; err != nil { return nil, err @@ -65,6 +71,9 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { // GetAPIKey returns a ApiKey for a given key. func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + key := types.APIKey{} if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil { return nil, result.Error @@ -75,6 +84,9 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { // GetAPIKeyByID returns a ApiKey for a given id. func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + key := types.APIKey{} if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil { return nil, result.Error @@ -86,6 +98,9 @@ func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { // DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey // does not exist. func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + if result := hsdb.db.Unscoped().Delete(key); result.Error != nil { return result.Error } @@ -95,6 +110,9 @@ func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { // ExpireAPIKey marks a ApiKey as expired. func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil { return err } @@ -103,6 +121,9 @@ func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { } func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + prefix, hash, found := strings.Cut(keyStr, ".") if !found { return false, ErrAPIKeyFailedToParse diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index ea6ce21fcb8..19bf94259d4 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -40,6 +40,8 @@ type HSDatabase struct { db *gorm.DB notifier *notifier.Notifier + mu sync.RWMutex + ipAllocationMutex sync.Mutex ipPrefixes []netip.Prefix diff --git a/hscontrol/db/machine.go b/hscontrol/db/machine.go index 47dfaa12e0b..8a3e22815ae 100644 --- a/hscontrol/db/machine.go +++ b/hscontrol/db/machine.go @@ -36,6 +36,9 @@ var ( // ListPeers returns all peers of machine, regardless of any Policy or if the node is expired. func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + log.Trace(). Caller(). Str("machine", machine.Hostname). @@ -63,6 +66,9 @@ func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error } func (hsdb *HSDatabase) ListMachines() ([]types.Machine, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + machines := []types.Machine{} if err := hsdb.db. Preload("AuthKey"). @@ -77,6 +83,9 @@ func (hsdb *HSDatabase) ListMachines() ([]types.Machine, error) { } func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) (types.Machines, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + machines := types.Machines{} if err := hsdb.db. Preload("AuthKey"). @@ -92,6 +101,9 @@ func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) (types.Machine // GetMachine finds a Machine by name and user and returns the Machine struct. func (hsdb *HSDatabase) GetMachine(user string, name string) (*types.Machine, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + machines, err := hsdb.ListMachinesByUser(user) if err != nil { return nil, err @@ -111,6 +123,9 @@ func (hsdb *HSDatabase) GetMachineByGivenName( user string, givenName string, ) (*types.Machine, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + machines, err := hsdb.ListMachinesByUser(user) if err != nil { return nil, err @@ -127,6 +142,9 @@ func (hsdb *HSDatabase) GetMachineByGivenName( // GetMachineByID finds a Machine by ID and returns the Machine struct. func (hsdb *HSDatabase) GetMachineByID(id uint64) (*types.Machine, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + mach := types.Machine{} if result := hsdb.db. Preload("AuthKey"). @@ -144,6 +162,9 @@ func (hsdb *HSDatabase) GetMachineByID(id uint64) (*types.Machine, error) { func (hsdb *HSDatabase) GetMachineByMachineKey( machineKey key.MachinePublic, ) (*types.Machine, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + mach := types.Machine{} if result := hsdb.db. Preload("AuthKey"). @@ -161,6 +182,9 @@ func (hsdb *HSDatabase) GetMachineByMachineKey( func (hsdb *HSDatabase) GetMachineByNodeKey( nodeKey key.NodePublic, ) (*types.Machine, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + machine := types.Machine{} if result := hsdb.db. Preload("AuthKey"). @@ -179,6 +203,9 @@ func (hsdb *HSDatabase) GetMachineByNodeKey( func (hsdb *HSDatabase) GetMachineByAnyKey( machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic, ) (*types.Machine, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + machine := types.Machine{} if result := hsdb.db. Preload("AuthKey"). @@ -199,6 +226,9 @@ func (hsdb *HSDatabase) GetMachineByAnyKey( // UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database // and updates it with the latest data from the database. func (hsdb *HSDatabase) UpdateMachineFromDatabase(machine *types.Machine) error { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + if result := hsdb.db.Find(machine).First(&machine); result.Error != nil { return result.Error } @@ -211,6 +241,9 @@ func (hsdb *HSDatabase) SetTags( machine *types.Machine, tags []string, ) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + newTags := []string{} for _, tag := range tags { if !util.StringOrPrefixListContains(newTags, tag) { @@ -233,6 +266,9 @@ func (hsdb *HSDatabase) SetTags( // ExpireMachine takes a Machine struct and sets the expire field to now. func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + now := time.Now() machine.Expiry = &now @@ -251,6 +287,9 @@ func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error { // RenameMachine takes a Machine struct and a new GivenName for the machines // and renames it. func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + err := util.CheckForFQDNRules( newName, ) @@ -260,7 +299,8 @@ func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) er Str("func", "RenameMachine"). Str("machine", machine.Hostname). Str("newName", newName). - Err(err) + Err(err). + Msg("failed to rename machine") return err } @@ -280,6 +320,9 @@ func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) er // RefreshMachine takes a Machine struct and a new expiry time. func (hsdb *HSDatabase) RefreshMachine(machine *types.Machine, expiry time.Time) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + now := time.Now() machine.LastSuccessfulUpdate = &now @@ -302,6 +345,9 @@ func (hsdb *HSDatabase) RefreshMachine(machine *types.Machine, expiry time.Time) // DeleteMachine softs deletes a Machine from the database. func (hsdb *HSDatabase) DeleteMachine(machine *types.Machine) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + err := hsdb.DeleteMachineRoutes(machine) if err != nil { return err @@ -315,6 +361,9 @@ func (hsdb *HSDatabase) DeleteMachine(machine *types.Machine) error { } func (hsdb *HSDatabase) TouchMachine(machine *types.Machine) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + return hsdb.db.Updates(types.Machine{ ID: machine.ID, LastSeen: machine.LastSeen, @@ -324,6 +373,9 @@ func (hsdb *HSDatabase) TouchMachine(machine *types.Machine) error { // HardDeleteMachine hard deletes a Machine from the database. func (hsdb *HSDatabase) HardDeleteMachine(machine *types.Machine) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + err := hsdb.DeleteMachineRoutes(machine) if err != nil { return err @@ -343,6 +395,9 @@ func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( machineExpiry *time.Time, registrationMethod string, ) (*types.Machine, error) { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + nodeKey := key.NodePublic{} err := nodeKey.UnmarshalText([]byte(nodeKeyStr)) if err != nil { @@ -399,6 +454,9 @@ func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( // RegisterMachine is executed from the CLI to register a new Machine using its MachineKey. func (hsdb *HSDatabase) RegisterMachine(machine types.Machine, ) (*types.Machine, error) { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + log.Debug(). Str("machine", machine.Hostname). Str("machine_key", machine.MachineKey). @@ -456,6 +514,9 @@ func (hsdb *HSDatabase) RegisterMachine(machine types.Machine, // MachineSetNodeKey sets the node key of a machine and saves it to the database. func (hsdb *HSDatabase) MachineSetNodeKey(machine *types.Machine, nodeKey key.NodePublic) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + machine.NodeKey = util.NodePublicKeyStripPrefix(nodeKey) if err := hsdb.db.Save(machine).Error; err != nil { @@ -470,6 +531,9 @@ func (hsdb *HSDatabase) MachineSetMachineKey( machine *types.Machine, nodeKey key.MachinePublic, ) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + machine.MachineKey = util.MachinePublicKeyStripPrefix(nodeKey) if err := hsdb.db.Save(machine).Error; err != nil { @@ -482,6 +546,9 @@ func (hsdb *HSDatabase) MachineSetMachineKey( // MachineSave saves a machine object to the database, prefer to use a specific save method rather // than this. It is intended to be used when we are changing or. func (hsdb *HSDatabase) MachineSave(machine *types.Machine) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + if err := hsdb.db.Save(machine).Error; err != nil { return err } @@ -491,6 +558,9 @@ func (hsdb *HSDatabase) MachineSave(machine *types.Machine) error { // GetAdvertisedRoutes returns the routes that are be advertised by the given machine. func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *types.Machine) ([]netip.Prefix, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + routes := types.Routes{} err := hsdb.db. @@ -516,6 +586,9 @@ func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *types.Machine) ([]netip.Pre // GetEnabledRoutes returns the routes that are enabled for the machine. func (hsdb *HSDatabase) GetEnabledRoutes(machine *types.Machine) ([]netip.Prefix, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + routes := types.Routes{} err := hsdb.db. @@ -541,6 +614,9 @@ func (hsdb *HSDatabase) GetEnabledRoutes(machine *types.Machine) ([]netip.Prefix } func (hsdb *HSDatabase) IsRoutesEnabled(machine *types.Machine, routeStr string) bool { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + route, err := netip.ParsePrefix(routeStr) if err != nil { return false @@ -575,6 +651,9 @@ func OnlineMachineMap(peers types.Machines) map[tailcfg.NodeID]bool { func (hsdb *HSDatabase) ListOnlineMachines( machine *types.Machine, ) (map[tailcfg.NodeID]bool, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + peers, err := hsdb.ListPeers(machine) if err != nil { return nil, err @@ -585,6 +664,10 @@ func (hsdb *HSDatabase) ListOnlineMachines( // enableRoutes enables new routes based on a list of new routes. func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string) error { + // TODO(kradalby): figure out this lock + // hsdb.mu.Lock() + // defer hsdb.mu.Unlock() + newRoutes := make([]netip.Prefix, len(routeStrs)) for index, routeStr := range routeStrs { route, err := netip.ParsePrefix(routeStr) @@ -642,7 +725,7 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string return nil } -func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool) (string, error) { +func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper( suppliedName, ) @@ -669,7 +752,10 @@ func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool } func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string) (string, error) { - givenName, err := hsdb.generateGivenName(suppliedName, false) + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + givenName, err := generateGivenName(suppliedName, false) if err != nil { return "", err } @@ -682,7 +768,7 @@ func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string for _, machine := range machines { if machine.MachineKey != machineKey && machine.GivenName == givenName { - postfixedName, err := hsdb.generateGivenName(suppliedName, true) + postfixedName, err := generateGivenName(suppliedName, true) if err != nil { return "", err } @@ -695,6 +781,10 @@ func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string } func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Duration) { + // TODO(kradalby): figure out this lock + // hsdb.mu.Lock() + // defer hsdb.mu.Unlock() + users, err := hsdb.ListUsers() if err != nil { log.Error().Err(err).Msg("Error listing users") @@ -744,6 +834,10 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati } func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time { + // TODO(kradalby): figure out this lock + // hsdb.mu.Lock() + // defer hsdb.mu.Unlock() + // use the time of the start of the function to ensure we // dont miss some machines by returning it _after_ we have // checked everything. diff --git a/hscontrol/db/machine_test.go b/hscontrol/db/machine_test.go index 0220bb8102f..319415f3ede 100644 --- a/hscontrol/db/machine_test.go +++ b/hscontrol/db/machine_test.go @@ -450,14 +450,12 @@ func TestHeadscale_generateGivenName(t *testing.T) { } tests := []struct { name string - db *HSDatabase args args want *regexp.Regexp wantErr bool }{ { name: "simple machine name generation", - db: &HSDatabase{}, args: args{ suppliedName: "testmachine", randomSuffix: false, @@ -467,7 +465,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 53 chars", - db: &HSDatabase{}, args: args{ suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine", randomSuffix: false, @@ -477,7 +474,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 63 chars", - db: &HSDatabase{}, args: args{ suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", randomSuffix: false, @@ -487,7 +483,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 64 chars", - db: &HSDatabase{}, args: args{ suppliedName: "machineeee123456789012345678901234567890123456789012345678901234", randomSuffix: false, @@ -497,7 +492,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 73 chars", - db: &HSDatabase{}, args: args{ suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123", randomSuffix: false, @@ -507,7 +501,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with random suffix", - db: &HSDatabase{}, args: args{ suppliedName: "test", randomSuffix: true, @@ -517,7 +510,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 63 chars with random suffix", - db: &HSDatabase{}, args: args{ suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", randomSuffix: true, @@ -528,7 +520,7 @@ func TestHeadscale_generateGivenName(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.db.generateGivenName(tt.args.suppliedName, tt.args.randomSuffix) + got, err := generateGivenName(tt.args.suppliedName, tt.args.randomSuffix) if (err != nil) != tt.wantErr { t.Errorf( "Headscale.GenerateGivenName() error = %v, wantErr %v", diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index abb79c34c24..65285d3f808 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -28,6 +28,10 @@ func (hsdb *HSDatabase) CreatePreAuthKey( expiration *time.Time, aclTags []string, ) (*types.PreAuthKey, error) { + // TODO(kradalby): figure out this lock + // hsdb.mu.Lock() + // defer hsdb.mu.Unlock() + user, err := hsdb.GetUser(userName) if err != nil { return nil, err @@ -92,6 +96,9 @@ func (hsdb *HSDatabase) CreatePreAuthKey( // ListPreAuthKeys returns the list of PreAuthKeys for a user. func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + user, err := hsdb.GetUser(userName) if err != nil { return nil, err @@ -107,6 +114,9 @@ func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, er // GetPreAuthKey returns a PreAuthKey for a given key. func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + pak, err := hsdb.ValidatePreAuthKey(key) if err != nil { return nil, err @@ -122,6 +132,10 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKe // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey // does not exist. func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error { + // TODO(kradalby): figure out this lock + // hsdb.mu.Lock() + // defer hsdb.mu.Unlock() + return hsdb.db.Transaction(func(db *gorm.DB) error { if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil { return result.Error @@ -137,6 +151,9 @@ func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error { // MarkExpirePreAuthKey marks a PreAuthKey as expired. func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { return err } @@ -146,6 +163,9 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { // UsePreAuthKey marks a PreAuthKey as used. func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + k.Used = true if err := hsdb.db.Save(k).Error; err != nil { return fmt.Errorf("failed to update key used status in the database: %w", err) @@ -157,6 +177,9 @@ func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error { // ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node // If returns no error and a PreAuthKey, it can be used. func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + pak := types.PreAuthKey{} if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( result.Error, diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index 90ec3b1d660..a62985d93c6 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -13,6 +13,9 @@ import ( var ErrRouteIsNotAvailable = errors.New("route is not available") func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + var routes types.Routes err := hsdb.db.Preload("Machine").Find(&routes).Error if err != nil { @@ -23,6 +26,9 @@ func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) { } func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + var routes types.Routes err := hsdb.db. Preload("Machine"). @@ -36,6 +42,10 @@ func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (type } func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error) { + // TODO(kradalby): figure out this lock + // hsdb.mu.RLock() + // defer hsdb.mu.RUnlock() + var routes types.Routes err := hsdb.db. Preload("Machine"). @@ -49,6 +59,9 @@ func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error) } func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + var route types.Route err := hsdb.db.Preload("Machine").First(&route, id).Error if err != nil { @@ -59,6 +72,10 @@ func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { } func (hsdb *HSDatabase) EnableRoute(id uint64) error { + // TODO(kradalby): figure out this lock + // hsdb.mu.Lock() + // defer hsdb.mu.Unlock() + route, err := hsdb.GetRoute(id) if err != nil { return err @@ -79,6 +96,10 @@ func (hsdb *HSDatabase) EnableRoute(id uint64) error { } func (hsdb *HSDatabase) DisableRoute(id uint64) error { + // TODO(kradalby): figure out this lock + // hsdb.mu.Lock() + // defer hsdb.mu.Unlock() + route, err := hsdb.GetRoute(id) if err != nil { return err @@ -118,6 +139,10 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { } func (hsdb *HSDatabase) DeleteRoute(id uint64) error { + // TODO(kradalby): figure out this lock + // hsdb.mu.Lock() + // defer hsdb.mu.Unlock() + route, err := hsdb.GetRoute(id) if err != nil { return err @@ -154,6 +179,10 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { } func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error { + // TODO(kradalby): figure out this lock + // hsdb.mu.Lock() + // defer hsdb.mu.Unlock() + routes, err := hsdb.GetMachineRoutes(m) if err != nil { return err @@ -170,6 +199,9 @@ func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error { // isUniquePrefix returns if there is another machine providing the same route already. func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + var count int64 hsdb.db. Model(&types.Route{}). @@ -182,6 +214,9 @@ func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool { } func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + var route types.Route err := hsdb.db. Preload("Machine"). @@ -201,6 +236,9 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, erro // getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) // Exit nodes are not considered for this, as they are never marked as Primary. func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + var routes types.Routes err := hsdb.db. Preload("Machine"). @@ -214,6 +252,10 @@ func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes, } func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error { + // TODO(kradalby): figure out this lock + // hsdb.mu.Lock() + // defer hsdb.mu.Unlock() + currentRoutes := types.Routes{} err := hsdb.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error if err != nil { @@ -264,6 +306,10 @@ func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error { } func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { + // TODO(kradalby): figure out this lock + // hsdb.mu.Lock() + // defer hsdb.mu.Unlock() + // first, get all the enabled routes var routes types.Routes err := hsdb.db. @@ -388,6 +434,10 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( aclPolicy *policy.ACLPolicy, machine *types.Machine, ) error { + // TODO(kradalby): figure out this lock + // hsdb.mu.Lock() + // defer hsdb.mu.Unlock() + if len(machine.IPAddresses) == 0 { return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs } diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index ce186751a14..883b856b986 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -18,6 +18,9 @@ var ( // CreateUser creates a new User. Returns error if could not be created // or another user already exists. func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + err := util.CheckForFQDNRules(name) if err != nil { return nil, err @@ -42,6 +45,10 @@ func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { // DestroyUser destroys a User. Returns error if the User does // not exist or if there are machines associated with it. func (hsdb *HSDatabase) DestroyUser(name string) error { + // TODO(kradalby): figure out this lock + // hsdb.mu.Lock() + // defer hsdb.mu.Unlock() + user, err := hsdb.GetUser(name) if err != nil { return ErrUserNotFound @@ -76,6 +83,10 @@ func (hsdb *HSDatabase) DestroyUser(name string) error { // RenameUser renames a User. Returns error if the User does // not exist or if another User exists with the new name. func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { + // TODO(kradalby): figure out this lock + // hsdb.mu.Lock() + // defer hsdb.mu.Unlock() + var err error oldUser, err := hsdb.GetUser(oldName) if err != nil { @@ -104,6 +115,9 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { // GetUser fetches a user by name. func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + user := types.User{} if result := hsdb.db.First(&user, "name = ?", name); errors.Is( result.Error, @@ -117,6 +131,9 @@ func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) { // ListUsers gets all the existing users. func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + users := []types.User{} if err := hsdb.db.Find(&users).Error; err != nil { return nil, err @@ -127,6 +144,9 @@ func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { // ListMachinesByUser gets all the nodes in a given user. func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + err := util.CheckForFQDNRules(name) if err != nil { return nil, err @@ -146,6 +166,10 @@ func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error) // SetMachineUser assigns a Machine to a user. func (hsdb *HSDatabase) SetMachineUser(machine *types.Machine, username string) error { + // TODO(kradalby): figure out this lock + // hsdb.mu.Lock() + // defer hsdb.mu.Unlock() + err := util.CheckForFQDNRules(username) if err != nil { return err