Skip to content

Commit

Permalink
Restore foreign keys and add constraints (#1562)
Browse files Browse the repository at this point in the history
* fix #1482, restore foregin keys, add constraints

* #1562, fix tests, fix formatting

* #1562: fix tests

* #1562: fix local run of test_integration
  • Loading branch information
vsychov authored May 16, 2024
1 parent 2bac80c commit 7fd2485
Show file tree
Hide file tree
Showing 15 changed files with 149 additions and 61 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ after improving the test harness as part of adopting [#1460](https://github.com/
- Add command to backfill IP addresses for nodes missing IPs from configured prefixes. [#1869](https://github.com/juanfont/headscale/pull/1869)
- Log available update as warning [#1877](https://github.com/juanfont/headscale/pull/1877)
- Add `autogroup:internet` to Policy [#1917](https://github.com/juanfont/headscale/pull/1917)
- Restore foreign keys and add constraints [#1562](https://github.com/juanfont/headscale/pull/1562)

## 0.22.3 (2023-05-12)

Expand Down
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ test_integration:
--name headscale-test-suite \
-v $$PWD:$$PWD -w $$PWD/integration \
-v /var/run/docker.sock:/var/run/docker.sock \
-v $$PWD/control_logs:/tmp/control \
golang:1 \
go run gotest.tools/gotestsum@latest -- -failfast ./... -timeout 120m -parallel 8

Expand Down
11 changes: 9 additions & 2 deletions hscontrol/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,11 @@ func (h *Headscale) handleAuthKey(
Msg("node was already registered before, refreshing with new auth key")

node.NodeKey = nodeKey
node.AuthKeyID = uint(pak.ID)
pakID := uint(pak.ID)
if pakID != 0 {
node.AuthKeyID = &pakID
}

node.Expiry = &registerRequest.Expiry
node.User = pak.User
node.UserID = pak.UserID
Expand Down Expand Up @@ -373,7 +377,6 @@ func (h *Headscale) handleAuthKey(
Expiry: &registerRequest.Expiry,
NodeKey: nodeKey,
LastSeen: &now,
AuthKeyID: uint(pak.ID),
ForcedTags: pak.Proto().GetAclTags(),
}

Expand All @@ -389,6 +392,10 @@ func (h *Headscale) handleAuthKey(
return
}

pakID := uint(pak.ID)
if pakID != 0 {
nodeToRegister.AuthKeyID = &pakID
}
node, err = h.db.RegisterNode(
nodeToRegister,
ipv4, ipv6,
Expand Down
9 changes: 4 additions & 5 deletions hscontrol/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ func NewHeadscaleDatabase(
_ = tx.Migrator().
RenameColumn(&types.Node{}, "nickname", "given_name")

// If the Node table has a column for registered,
dbConn.Model(&types.Node{}).Where("auth_key_id = ?", 0).Update("auth_key_id", nil)
// If the Node table has a column for registered,
// find all occourences of "false" and drop them. Then
// remove the column.
if tx.Migrator().HasColumn(&types.Node{}, "registered") {
Expand Down Expand Up @@ -441,8 +442,7 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
db, err := gorm.Open(
sqlite.Open(cfg.Sqlite.Path+"?_synchronous=1&_journal_mode=WAL"),
&gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
Logger: dbLogger,
Logger: dbLogger,
},
)

Expand Down Expand Up @@ -488,8 +488,7 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
}

db, err := gorm.Open(postgres.Open(dbString), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
Logger: dbLogger,
Logger: dbLogger,
})
if err != nil {
return nil, err
Expand Down
26 changes: 26 additions & 0 deletions hscontrol/db/ip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,11 @@ func TestIPAllocatorSequential(t *testing.T) {
name: "simple-with-db",
dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-with-db")
user := types.User{Name: ""}
db.DB.Save(&user)

db.DB.Save(&types.Node{
User: user,
IPv4: nap("100.64.0.1"),
IPv6: nap("fd7a:115c:a1e0::1"),
})
Expand All @@ -112,8 +115,11 @@ func TestIPAllocatorSequential(t *testing.T) {
name: "before-after-free-middle-in-db",
dbFunc: func() *HSDatabase {
db := dbForTest(t, "before-after-free-middle-in-db")
user := types.User{Name: ""}
db.DB.Save(&user)

db.DB.Save(&types.Node{
User: user,
IPv4: nap("100.64.0.2"),
IPv6: nap("fd7a:115c:a1e0::2"),
})
Expand Down Expand Up @@ -307,8 +313,11 @@ func TestBackfillIPAddresses(t *testing.T) {
name: "simple-backfill-ipv6",
dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-backfill-ipv6")
user := types.User{Name: ""}
db.DB.Save(&user)

db.DB.Save(&types.Node{
User: user,
IPv4: nap("100.64.0.1"),
})

Expand Down Expand Up @@ -337,8 +346,11 @@ func TestBackfillIPAddresses(t *testing.T) {
name: "simple-backfill-ipv4",
dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-backfill-ipv4")
user := types.User{Name: ""}
db.DB.Save(&user)

db.DB.Save(&types.Node{
User: user,
IPv6: nap("fd7a:115c:a1e0::1"),
})

Expand Down Expand Up @@ -367,8 +379,11 @@ func TestBackfillIPAddresses(t *testing.T) {
name: "simple-backfill-remove-ipv6",
dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-backfill-remove-ipv6")
user := types.User{Name: ""}
db.DB.Save(&user)

db.DB.Save(&types.Node{
User: user,
IPv4: nap("100.64.0.1"),
IPv6: nap("fd7a:115c:a1e0::1"),
})
Expand All @@ -392,8 +407,11 @@ func TestBackfillIPAddresses(t *testing.T) {
name: "simple-backfill-remove-ipv4",
dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-backfill-remove-ipv4")
user := types.User{Name: ""}
db.DB.Save(&user)

db.DB.Save(&types.Node{
User: user,
IPv4: nap("100.64.0.1"),
IPv6: nap("fd7a:115c:a1e0::1"),
})
Expand All @@ -417,17 +435,23 @@ func TestBackfillIPAddresses(t *testing.T) {
name: "multi-backfill-ipv6",
dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-backfill-ipv6")
user := types.User{Name: ""}
db.DB.Save(&user)

db.DB.Save(&types.Node{
User: user,
IPv4: nap("100.64.0.1"),
})
db.DB.Save(&types.Node{
User: user,
IPv4: nap("100.64.0.2"),
})
db.DB.Save(&types.Node{
User: user,
IPv4: nap("100.64.0.3"),
})
db.DB.Save(&types.Node{
User: user,
IPv4: nap("100.64.0.4"),
})

Expand All @@ -451,6 +475,8 @@ func TestBackfillIPAddresses(t *testing.T) {
"MachineKeyDatabaseField",
"NodeKeyDatabaseField",
"DiscoKeyDatabaseField",
"User",
"UserID",
"Endpoints",
"HostinfoDatabaseField",
"Hostinfo",
Expand Down
2 changes: 1 addition & 1 deletion hscontrol/db/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ func DeleteNode(tx *gorm.DB,
}

// Unscoped causes the node to be fully removed from the database.
if err := tx.Unscoped().Delete(&node).Error; err != nil {
if err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error; err != nil {
return changed, err
}

Expand Down
57 changes: 38 additions & 19 deletions hscontrol/db/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func (s *Suite) TestGetNode(c *check.C) {

nodeKey := key.NewNode()
machineKey := key.NewMachine()
pakID := uint(pak.ID)

node := &types.Node{
ID: 0,
Expand All @@ -37,9 +38,10 @@ func (s *Suite) TestGetNode(c *check.C) {
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
}
db.DB.Save(node)
trx := db.DB.Save(node)
c.Assert(trx.Error, check.IsNil)

_, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil)
Expand All @@ -58,16 +60,18 @@ func (s *Suite) TestGetNodeByID(c *check.C) {
nodeKey := key.NewNode()
machineKey := key.NewMachine()

pakID := uint(pak.ID)
node := types.Node{
ID: 0,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
}
db.DB.Save(&node)
trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)

_, err = db.GetNodeByID(0)
c.Assert(err, check.IsNil)
Expand All @@ -88,16 +92,18 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {

machineKey := key.NewMachine()

pakID := uint(pak.ID)
node := types.Node{
ID: 0,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
}
db.DB.Save(&node)
trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)

_, err = db.GetNodeByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
c.Assert(err, check.IsNil)
Expand All @@ -117,9 +123,9 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
Hostname: "testnode3",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(1),
}
db.DB.Save(&node)
trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)

_, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]())
c.Assert(err, check.IsNil)
Expand All @@ -138,6 +144,7 @@ func (s *Suite) TestListPeers(c *check.C) {
_, err = db.GetNodeByID(0)
c.Assert(err, check.NotNil)

pakID := uint(pak.ID)
for index := 0; index <= 10; index++ {
nodeKey := key.NewNode()
machineKey := key.NewMachine()
Expand All @@ -149,9 +156,10 @@ func (s *Suite) TestListPeers(c *check.C) {
Hostname: "testnode" + strconv.Itoa(index),
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
}
db.DB.Save(&node)
trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
}

node0ByID, err := db.GetNodeByID(0)
Expand Down Expand Up @@ -188,6 +196,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
for index := 0; index <= 10; index++ {
nodeKey := key.NewNode()
machineKey := key.NewMachine()
pakID := uint(stor[index%2].key.ID)

v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1)))
node := types.Node{
Expand All @@ -198,9 +207,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
Hostname: "testnode" + strconv.Itoa(index),
UserID: stor[index%2].user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(stor[index%2].key.ID),
AuthKeyID: &pakID,
}
db.DB.Save(&node)
trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
}

aclPolicy := &policy.ACLPolicy{
Expand Down Expand Up @@ -272,6 +282,7 @@ func (s *Suite) TestExpireNode(c *check.C) {

nodeKey := key.NewNode()
machineKey := key.NewMachine()
pakID := uint(pak.ID)

node := &types.Node{
ID: 0,
Expand All @@ -280,7 +291,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
Expiry: &time.Time{},
}
db.DB.Save(node)
Expand Down Expand Up @@ -316,6 +327,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {

machineKey2 := key.NewMachine()

pakID := uint(pak.ID)
node := &types.Node{
ID: 0,
MachineKey: machineKey.Public(),
Expand All @@ -324,9 +336,11 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
GivenName: "hostname-1",
UserID: user1.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
}
db.DB.Save(node)

trx := db.DB.Save(node)
c.Assert(trx.Error, check.IsNil)

givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2")
comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict")
Expand Down Expand Up @@ -357,16 +371,19 @@ func (s *Suite) TestSetTags(c *check.C) {
nodeKey := key.NewNode()
machineKey := key.NewMachine()

pakID := uint(pak.ID)
node := &types.Node{
ID: 0,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
}
db.DB.Save(node)

trx := db.DB.Save(node)
c.Assert(trx.Error, check.IsNil)

// assign simple tags
sTags := []string{"tag:test", "tag:foo"}
Expand Down Expand Up @@ -548,22 +565,24 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
route2 := netip.MustParsePrefix("10.11.0.0/24")

v4 := netip.MustParseAddr("100.64.0.1")
pakID := uint(pak.ID)
node := types.Node{
ID: 0,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "test",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:exit"},
RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2},
},
IPv4: &v4,
}

db.DB.Save(&node)
trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)

sendUpdate, err := db.SaveNodeRoutes(&node)
c.Assert(err, check.IsNil)
Expand Down
Loading

0 comments on commit 7fd2485

Please sign in to comment.