Skip to content

Commit

Permalink
Merge pull request #2871 from dolthub/zachmu/resolve-alter-column
Browse files Browse the repository at this point in the history
[no-release-notes] Exported several analyzer methods for use by doltgres
  • Loading branch information
zachmu authored Mar 6, 2025
2 parents 14a57e0 + c856028 commit f73a318
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 43 deletions.
2 changes: 1 addition & 1 deletion sql/analyzer/rule_ids.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ const (
finalizeUnionsId // finalizeUnions
loadTriggersId // loadTriggers
processTruncateId // processTruncate
resolveAlterColumnId // resolveAlterColumn
ResolveAlterColumnId // ResolveAlterColumn
stripTableNameInDefaultsId // stripTableNamesFromColumnDefaults
optimizeJoinsId // optimizeJoins
pushFiltersId // pushFilters
Expand Down
4 changes: 2 additions & 2 deletions sql/analyzer/ruleid_string.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion sql/analyzer/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ var OnceBeforeDefault = []Rule{
{validateAlterTableId, validateAlterTable},
{validateExprSemId, validateExprSem},
{resolveDropConstraintId, resolveDropConstraint},
{resolveAlterColumnId, resolveAlterColumn},
{ResolveAlterColumnId, resolveAlterColumn},
{validateDropTablesId, validateDropTables},
{resolveCreateSelectId, resolveCreateSelect},
{validateDropConstraintId, validateDropConstraint},
Expand Down
78 changes: 39 additions & 39 deletions sql/analyzer/validate_create_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,18 @@ func validateAlterTable(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.S
switch n := n.(type) {
case *plan.RenameTable:
for _, name := range n.NewNames {
err = validateIdentifier(name)
err = ValidateIdentifier(name)
if err != nil {
return false
}
}
case *plan.CreateCheck:
err = validateIdentifier(n.Check.Name)
err = ValidateIdentifier(n.Check.Name)
if err != nil {
return false
}
case *plan.CreateForeignKey:
err = validateIdentifier(n.FkDef.Name)
err = ValidateIdentifier(n.FkDef.Name)
if err != nil {
return false
}
Expand Down Expand Up @@ -177,7 +177,7 @@ func resolveAlterColumn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.S
validator = sv
}
}
keyedColumns, err = getTableIndexColumns(ctx, n.Table)
keyedColumns, err = GetTableIndexColumns(ctx, n.Table)
return false
case *plan.RenameColumn:
if rt, ok := n.Table.(*plan.ResolvedTable); ok {
Expand All @@ -192,7 +192,7 @@ func resolveAlterColumn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.S
validator = sv
}
}
keyedColumns, err = getTableIndexColumns(ctx, n.Table)
keyedColumns, err = GetTableIndexColumns(ctx, n.Table)
return false
case *plan.DropColumn:
if rt, ok := n.Table.(*plan.ResolvedTable); ok {
Expand All @@ -207,7 +207,7 @@ func resolveAlterColumn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.S
validator = sv
}
}
indexes, err = getTableIndexNames(ctx, a, n.Table)
indexes, err = GetTableIndexNames(ctx, a, n.Table)
default:
}
return true
Expand Down Expand Up @@ -237,7 +237,7 @@ func resolveAlterColumn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.S
return nil, transform.SameTree, err
}

sch, err = validateModifyColumn(ctx, initialSch, sch, n.(*plan.ModifyColumn), keyedColumns)
sch, err = ValidateModifyColumn(ctx, initialSch, sch, n.(*plan.ModifyColumn), keyedColumns)
if err != nil {
return nil, transform.SameTree, err
}
Expand All @@ -247,7 +247,7 @@ func resolveAlterColumn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.S
if err != nil {
return nil, transform.SameTree, err
}
sch, err = validateRenameColumn(initialSch, sch, n.(*plan.RenameColumn))
sch, err = ValidateRenameColumn(initialSch, sch, n.(*plan.RenameColumn))
if err != nil {
return nil, transform.SameTree, err
}
Expand All @@ -270,7 +270,7 @@ func resolveAlterColumn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.S
if err != nil {
return nil, transform.SameTree, err
}
sch, err = validateDropColumn(initialSch, sch, n.(*plan.DropColumn))
sch, err = ValidateDropColumn(initialSch, sch, n.(*plan.DropColumn))
if err != nil {
return nil, transform.SameTree, err
}
Expand All @@ -287,7 +287,7 @@ func resolveAlterColumn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.S
return nil, transform.SameTree, err
}

keyedColumns = updateKeyedColumns(keyedColumns, nn)
keyedColumns = UpdateKeyedColumns(keyedColumns, nn)
return n, transform.NewTree, nil
case *plan.AlterPK:
n, err := nn.WithTargetSchema(sch.Copy())
Expand All @@ -304,7 +304,7 @@ func resolveAlterColumn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.S
if err != nil {
return nil, transform.SameTree, err
}
sch, err = validateAlterDefault(initialSch, sch, n.(*plan.AlterDefaultSet))
sch, err = ValidateAlterDefault(initialSch, sch, n.(*plan.AlterDefaultSet))
if err != nil {
return nil, transform.SameTree, err
}
Expand All @@ -314,7 +314,7 @@ func resolveAlterColumn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.S
if err != nil {
return nil, transform.SameTree, err
}
sch, err = validateDropDefault(initialSch, sch, n.(*plan.AlterDefaultDrop))
sch, err = ValidateDropDefault(initialSch, sch, n.(*plan.AlterDefaultDrop))
if err != nil {
return nil, transform.SameTree, err
}
Expand Down Expand Up @@ -345,8 +345,8 @@ func resolveAlterColumn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.S
return n, same, nil
}

// updateKeyedColumns updates the keyedColumns map based on the action of the AlterIndex node
func updateKeyedColumns(keyedColumns map[string]bool, n *plan.AlterIndex) map[string]bool {
// UpdateKeyedColumns updates the keyedColumns map based on the action of the AlterIndex node
func UpdateKeyedColumns(keyedColumns map[string]bool, n *plan.AlterIndex) map[string]bool {
switch n.Action {
case plan.IndexAction_Create:
for _, col := range n.Columns {
Expand All @@ -361,16 +361,16 @@ func updateKeyedColumns(keyedColumns map[string]bool, n *plan.AlterIndex) map[st
return keyedColumns
}

// validateRenameColumn checks that a DDL RenameColumn node can be safely executed (e.g. no collision with other
// ValidateRenameColumn checks that a DDL RenameColumn node can be safely executed (e.g. no collision with other
// column names, doesn't invalidate any table check constraints).
//
// Note that schema is passed in twice, because one version is the initial version before the alter column expressions
// are applied, and the second version is the current schema that is being modified as multiple nodes are processed.
func validateRenameColumn(initialSch, sch sql.Schema, rc *plan.RenameColumn) (sql.Schema, error) {
func ValidateRenameColumn(initialSch, sch sql.Schema, rc *plan.RenameColumn) (sql.Schema, error) {
table := rc.Table
nameable := table.(sql.Nameable)

err := validateIdentifier(rc.NewColumnName)
err := ValidateIdentifier(rc.NewColumnName)
if err != nil {
return nil, err
}
Expand All @@ -387,7 +387,7 @@ func validateRenameColumn(initialSch, sch sql.Schema, rc *plan.RenameColumn) (sq
return nil, sql.ErrTableColumnNotFound.New(nameable.Name(), rc.ColumnName)
}

err = validateColumnNotUsedInCheckConstraint(rc.ColumnName, rc.Checks())
err = ValidateColumnNotUsedInCheckConstraint(rc.ColumnName, rc.Checks())
if err != nil {
return nil, err
}
Expand All @@ -402,7 +402,7 @@ func ValidateAddColumn(schema sql.Schema, ac *plan.AddColumn) (sql.Schema, error
table := ac.Table
nameable := table.(sql.Nameable)

err := validateIdentifier(ac.Column().Name)
err := ValidateIdentifier(ac.Column().Name)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -449,7 +449,7 @@ func isStrictMysqlCompatibilityEnabled(ctx *sql.Context) (bool, error) {
return i == 1, nil
}

func validateModifyColumn(ctx *sql.Context, initialSch sql.Schema, schema sql.Schema, mc *plan.ModifyColumn, keyedColumns map[string]bool) (sql.Schema, error) {
func ValidateModifyColumn(ctx *sql.Context, initialSch sql.Schema, schema sql.Schema, mc *plan.ModifyColumn, keyedColumns map[string]bool) (sql.Schema, error) {
table := mc.Table
tableName := table.(sql.Nameable).Name()

Expand All @@ -461,7 +461,7 @@ func validateModifyColumn(ctx *sql.Context, initialSch sql.Schema, schema sql.Sc
}

newCol := mc.NewColumn()
if err := validateIdentifier(newCol.Name); err != nil {
if err := ValidateIdentifier(newCol.Name); err != nil {
return nil, err
}

Expand Down Expand Up @@ -522,14 +522,14 @@ func validateModifyColumn(ctx *sql.Context, initialSch sql.Schema, schema sql.Sc
return newSch, nil
}

func validateIdentifier(name string) error {
func ValidateIdentifier(name string) error {
if len(name) > sql.MaxIdentifierLength {
return sql.ErrInvalidIdentifier.New(name)
}
return nil
}

func validateDropColumn(initialSch, sch sql.Schema, dc *plan.DropColumn) (sql.Schema, error) {
func ValidateDropColumn(initialSch, sch sql.Schema, dc *plan.DropColumn) (sql.Schema, error) {
table := dc.Table
nameable := table.(sql.Nameable)

Expand All @@ -550,9 +550,9 @@ func validateDropColumn(initialSch, sch sql.Schema, dc *plan.DropColumn) (sql.Sc
return newSch, nil
}

// validateColumnNotUsedInCheckConstraint validates that the specified column name is not referenced in any of
// ValidateColumnNotUsedInCheckConstraint validates that the specified column name is not referenced in any of
// the specified table check constraints.
func validateColumnNotUsedInCheckConstraint(columnName string, checks sql.CheckConstraints) error {
func ValidateColumnNotUsedInCheckConstraint(columnName string, checks sql.CheckConstraints) error {
var err error
for _, check := range checks {
_ = transform.InspectExpr(check.Expr, func(e sql.Expression) bool {
Expand Down Expand Up @@ -626,7 +626,7 @@ func validateColumnSafeToDropWithCheckConstraint(columnName string, checks sql.C
func validateAlterIndex(ctx *sql.Context, initialSch, sch sql.Schema, ai *plan.AlterIndex, indexes []string) ([]string, error) {
switch ai.Action {
case plan.IndexAction_Create:
err := validateIdentifier(ai.IndexName)
err := ValidateIdentifier(ai.IndexName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -662,7 +662,7 @@ func validateAlterIndex(ctx *sql.Context, initialSch, sch sql.Schema, ai *plan.A
// Remove the index from the list
return append(indexes[:savedIdx], indexes[savedIdx+1:]...), nil
case plan.IndexAction_Rename:
err := validateIdentifier(ai.IndexName)
err := ValidateIdentifier(ai.IndexName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -917,8 +917,8 @@ func validateIndex(ctx *sql.Context, colMap map[string]*sql.Column, idxDef *sql.
return nil
}

// getTableIndexColumns returns the columns over which indexes are defined
func getTableIndexColumns(ctx *sql.Context, table sql.Node) (map[string]bool, error) {
// GetTableIndexColumns returns the columns over which indexes are defined
func GetTableIndexColumns(ctx *sql.Context, table sql.Node) (map[string]bool, error) {
ia, err := newIndexAnalyzerForNode(ctx, table)
if err != nil {
return nil, err
Expand All @@ -937,8 +937,8 @@ func getTableIndexColumns(ctx *sql.Context, table sql.Node) (map[string]bool, er
return keyedColumns, nil
}

// getTableIndexNames returns the names of indexes associated with a table.
func getTableIndexNames(ctx *sql.Context, _ *Analyzer, table sql.Node) ([]string, error) {
// GetTableIndexNames returns the names of indexes associated with a table.
func GetTableIndexNames(ctx *sql.Context, _ *Analyzer, table sql.Node) ([]string, error) {
ia, err := newIndexAnalyzerForNode(ctx, table)
if err != nil {
return nil, err
Expand All @@ -951,7 +951,7 @@ func getTableIndexNames(ctx *sql.Context, _ *Analyzer, table sql.Node) ([]string
names[i] = index.ID()
}

if hasPrimaryKeys(table.Schema()) {
if HasPrimaryKeys(table.Schema()) {
names = append(names, "PRIMARY")
}

Expand All @@ -963,7 +963,7 @@ func validatePrimaryKey(ctx *sql.Context, initialSch, sch sql.Schema, ai *plan.A
tableName := getTableName(ai.Table)
switch ai.Action {
case plan.PrimaryKeyAction_Create:
if hasPrimaryKeys(sch) {
if HasPrimaryKeys(sch) {
return nil, sql.ErrMultiplePrimaryKeysDefined.New()
}

Expand Down Expand Up @@ -996,7 +996,7 @@ func validatePrimaryKey(ctx *sql.Context, initialSch, sch sql.Schema, ai *plan.A

return sch, nil
case plan.PrimaryKeyAction_Drop:
if !hasPrimaryKeys(sch) {
if !HasPrimaryKeys(sch) {
return nil, sql.ErrCantDropFieldOrKey.New("PRIMARY")
}

Expand All @@ -1012,8 +1012,8 @@ func validatePrimaryKey(ctx *sql.Context, initialSch, sch sql.Schema, ai *plan.A
}
}

// validateAlterDefault validates the addition of a default value to a column.
func validateAlterDefault(initialSch, sch sql.Schema, as *plan.AlterDefaultSet) (sql.Schema, error) {
// ValidateAlterDefault validates the addition of a default value to a column.
func ValidateAlterDefault(initialSch, sch sql.Schema, as *plan.AlterDefaultSet) (sql.Schema, error) {
idx := sch.IndexOf(as.ColumnName, getTableName(as.Table))
if idx == -1 {
return nil, sql.ErrTableColumnNotFound.New(as.ColumnName)
Expand All @@ -1029,8 +1029,8 @@ func validateAlterDefault(initialSch, sch sql.Schema, as *plan.AlterDefaultSet)
return sch, err
}

// validateDropDefault validates the dropping of a default value.
func validateDropDefault(initialSch, sch sql.Schema, ad *plan.AlterDefaultDrop) (sql.Schema, error) {
// ValidateDropDefault validates the dropping of a default value.
func ValidateDropDefault(initialSch, sch sql.Schema, ad *plan.AlterDefaultDrop) (sql.Schema, error) {
idx := sch.IndexOf(ad.ColumnName, getTableName(ad.Table))
if idx == -1 {
return nil, sql.ErrTableColumnNotFound.New(ad.ColumnName)
Expand All @@ -1041,7 +1041,7 @@ func validateDropDefault(initialSch, sch sql.Schema, ad *plan.AlterDefaultDrop)
return sch, nil
}

func hasPrimaryKeys(sch sql.Schema) bool {
func HasPrimaryKeys(sch sql.Schema) bool {
for _, c := range sch {
if c.PrimaryKey {
return true
Expand Down

0 comments on commit f73a318

Please sign in to comment.