Skip to content

Commit

Permalink
fix(create_table): avoid creating unintended foreign keys (#1130)
Browse files Browse the repository at this point in the history
* fix(create_table): avoid creating unintended foreign keys

Foreign keys should only be created for has-one and belongs-to relations iff:
- None of the referencing columns is a primary keys
- The table is an m2m 'junction' table an all referencing columns are primary keys

The m2m edge case is covered by TestDatabaseInspector_Inspect
  • Loading branch information
bevzzz authored Feb 17, 2025
1 parent c915415 commit 187743b
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 9 deletions.
56 changes: 55 additions & 1 deletion internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/uptrace/bun/driver/pgdriver"
"github.com/uptrace/bun/driver/sqliteshim"
"github.com/uptrace/bun/extra/bundebug"
"github.com/uptrace/bun/migrate/sqlschema"
"github.com/uptrace/bun/extra/bunexp"
"github.com/uptrace/bun/schema"

Expand Down Expand Up @@ -300,6 +301,7 @@ func TestDB(t *testing.T) {
{testRunInTxAndSavepoint},
{testDriverValuerReturnsItself},
{testNoPanicWhenReturningNullColumns},
{testNoForeignKeyForPrimaryKey},
}

testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
Expand Down Expand Up @@ -1831,6 +1833,59 @@ func testNoPanicWhenReturningNullColumns(t *testing.T, db *bun.DB) {
})
}

func testNoForeignKeyForPrimaryKey(t *testing.T, db *bun.DB) {
inspect := inspectDbOrSkip(t, db)

for _, tt := range []struct {
name string
model interface{}
dontWant sqlschema.ForeignKey
}{
{name: "has-one relation", model: (*struct {
bun.BaseModel `bun:"table:users"`
ID string `bun:",pk"`

Profile *struct {
bun.BaseModel `bun:"table:profiles"`
ID string `bun:",pk"`
UserID string
} `bun:"rel:has-one,join:id=user_id"`
})(nil), dontWant: sqlschema.ForeignKey{
From: sqlschema.NewColumnReference("users", "id"),
To: sqlschema.NewColumnReference("profiles", "user_id"),
}},

{name: "belongs-to relation", model: (*struct {
bun.BaseModel `bun:"table:profiles"`
ID string `bun:",pk"`

User *struct {
bun.BaseModel `bun:"table:users"`
ID string `bun:",pk"`
ProfileID string
} `bun:"rel:belongs-to,join:id=profile_id"`
})(nil), dontWant: sqlschema.ForeignKey{
From: sqlschema.NewColumnReference("profiles", "id"),
To: sqlschema.NewColumnReference("users", "profile_id"),
}},
} {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
mustDropTableOnCleanup(t, ctx, db, tt.model)

_, err := db.NewCreateTable().Model(tt.model).WithForeignKeys().Exec(ctx)
require.NoError(t, err, "create table")

state := inspect(ctx)
require.NotContainsf(t, state.ForeignKeys, tt.dontWant,
"%s.%s -> %s.%s is not inteded",
tt.dontWant.From.TableName, tt.dontWant.From.Column,
tt.dontWant.To.TableName, tt.dontWant.To.Column,
)
})
}
}

func mustResetModel(tb testing.TB, ctx context.Context, db *bun.DB, models ...interface{}) {
err := db.ResetModel(ctx, models...)
require.NoError(tb, err, "must reset model")
Expand Down Expand Up @@ -1864,7 +1919,6 @@ func TestConnResolver(t *testing.T) {
})

resolver := bunexp.NewReadWriteConnResolver(
//bunexp.WithDBReplica(rwdb),
bunexp.WithDBReplica(rodb, bunexp.DBReplicaReadOnly),
)

Expand Down
1 change: 1 addition & 0 deletions schema/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
)

type Field struct {
Table *Table // Contains this field
StructField reflect.StructField
IsPtr bool

Expand Down
51 changes: 45 additions & 6 deletions schema/relation.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ const (
)

type Relation struct {
Type int
Field *Field // Has the bun tag defining this relation.

// Base and Join can be explained with this query:
//
// SELECT * FROM base_table JOIN join_table

Type int
Field *Field
JoinTable *Table
BasePKs []*Field
JoinPKs []*Field
Expand All @@ -34,10 +34,49 @@ type Relation struct {
M2MJoinPKs []*Field
}

// References returns true if the table to which the Relation belongs needs to declare a foreign key constraint to create the relation.
// For other relations, the constraint is created in either the referencing table (1:N, 'has-many' relations) or a mapping table (N:N, 'm2m' relations).
// References returns true if the table which defines this Relation
// needs to declare a foreign key constraint, as is the case
// for 'has-one' and 'belongs-to' relations. For other relations,
// the constraint is created either in the referencing table (1:N, 'has-many' relations)
// or the junction table (N:N, 'm2m' relations).
//
// Usage of `rel:` tag does not always imply creation of foreign keys (when WithForeignKeys() is not set)
// and can be used exclusively for joining tables at query time. For example:
//
// type User struct {
// ID int64 `bun:",pk"`
// Profile *Profile `bun:",rel:has-one,join:id=user_id"`
// }
//
// Creating a FK users.id -> profiles.user_id would be confusing and incorrect,
// so for such cases References() returns false. One notable exception to this rule
// is when a Relation is defined in a junction table, in which case it is perfectly
// fine for its primary keys to reference other tables. Consider:
//
// // UsersToGroups maps users to groups they follow.
// type UsersToGroups struct {
// UserID string `bun:"user_id,pk"` // Needs FK to users.id
// GroupID string `bun:"group_id,pk"` // Needs FK to groups.id
//
// User *User `bun:"rel:belongs-to,join:user_id=id"`
// Group *Group `bun:"rel:belongs-to,join:group_id=id"`
// }
//
// Here BooksToReaders has a composite primary key, composed of other primary keys.
func (r *Relation) References() bool {
return r.Type == HasOneRelation || r.Type == BelongsToRelation
allPK := true
nonePK := true
for _, f := range r.BasePKs {
allPK = allPK && f.IsPK
nonePK = nonePK && !f.IsPK
}

// Erring on the side of caution, only create foreign keys
// if the referencing columns are part of a composite PK
// in the junction table of the m2m relationship.
effectsM2M := r.Field.Table.IsM2MTable && allPK

return (r.Type == HasOneRelation || r.Type == BelongsToRelation) && (effectsM2M || nonePK)
}

func (r *Relation) String() string {
Expand Down
11 changes: 9 additions & 2 deletions schema/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ type Table struct {
FieldMap map[string]*Field
StructMap map[string]*structField

Relations map[string]*Relation
Unique map[string][]*Field
IsM2MTable bool // If true, this table is the "junction table" of an m2m relation.
Relations map[string]*Relation
Unique map[string][]*Field

SoftDeleteField *Field
UpdateSoftDeleteField func(fv reflect.Value, tm time.Time) error
Expand Down Expand Up @@ -516,6 +517,7 @@ func (t *Table) newField(sf reflect.StructField, tag tagparser.Tag) *Field {
}

field := &Field{
Table: t,
StructField: sf,
IsPtr: sf.Type.Kind() == reflect.Ptr,

Expand Down Expand Up @@ -895,6 +897,7 @@ func (t *Table) m2mRelation(field *Field) *Relation {
JoinTable: joinTable,
M2MTable: m2mTable,
}
m2mTable.markM2M()

if field.Tag.HasOption("join_on") {
rel.Condition = field.Tag.Options["join_on"]
Expand Down Expand Up @@ -940,6 +943,10 @@ func (t *Table) m2mRelation(field *Field) *Relation {
return rel
}

func (t *Table) markM2M() {
t.IsM2MTable = true
}

//------------------------------------------------------------------------------

func (t *Table) Dialect() Dialect { return t.dialect }
Expand Down

0 comments on commit 187743b

Please sign in to comment.