Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix information_schema.columns for databases with schemas #2596

Merged
merged 1 commit into from
Jul 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 47 additions & 37 deletions sql/information_schema/columns_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,17 @@ func (c *ColumnsTable) AllColumns(ctx *sql.Context) (sql.Schema, error) {

var allColumns sql.Schema

for _, db := range c.catalog.AllDatabases(ctx) {
err := sql.DBTableIter(ctx, db, func(t sql.Table) (cont bool, err error) {
databases, err := allDatabases(ctx, c.catalog, false)
if err != nil {
return nil, err
}

for _, db := range databases {
err := sql.DBTableIter(ctx, db.database, func(t sql.Table) (cont bool, err error) {
tableSch := t.Schema()
for i := range tableSch {
newCol := tableSch[i].Copy()
newCol.DatabaseSource = db.Name()
newCol.DatabaseSource = db.database.Name()
allColumns = append(allColumns, newCol)
}
return true, nil
Expand Down Expand Up @@ -205,7 +210,12 @@ func columnsRowIter(ctx *sql.Context, catalog sql.Catalog, allColsWithDefaultVal
}
globalPrivSetMap = getCurrentPrivSetMapForColumn(privSet.ToSlice(), globalPrivSetMap)

for _, db := range catalog.AllDatabases(ctx) {
databases, err := allDatabases(ctx, catalog, false)
if err != nil {
return nil, err
}

for _, db := range databases {
rs, err := getRowsFromDatabase(ctx, db, privSet, globalPrivSetMap, allColsWithDefaultValue)
if err != nil {
return nil, err
Expand All @@ -224,7 +234,7 @@ func columnsRowIter(ctx *sql.Context, catalog sql.Catalog, allColsWithDefaultVal
// getRowFromColumn returns a single row for given column. The arguments passed are used to define all row values.
// These include the current ordinal position, so this column will get the next position number, sql.Column object,
// database name, table name, column key and column privileges information through privileges set for the table.
func getRowFromColumn(ctx *sql.Context, curOrdPos int, col *sql.Column, dbName, tblName, columnKey string, privSetTbl sql.PrivilegeSetTable, privSetMap map[string]struct{}) sql.Row {
func getRowFromColumn(ctx *sql.Context, curOrdPos int, col *sql.Column, catName, schName, tblName, columnKey string, privSetTbl sql.PrivilegeSetTable, privSetMap map[string]struct{}) sql.Row {
var (
ordinalPos = uint32(curOrdPos + 1)
nullable = "NO"
Expand Down Expand Up @@ -279,8 +289,8 @@ func getRowFromColumn(ctx *sql.Context, curOrdPos int, col *sql.Column, dbName,
privileges := strings.Join(curColPrivStr, ",")

return sql.Row{
"def", // table_catalog
dbName, // table_schema
catName, // table_catalog
schName, // table_schema
tblName, // table_name
col.Name, // column_name
ordinalPos, // ordinal_position
Expand All @@ -305,7 +315,7 @@ func getRowFromColumn(ctx *sql.Context, curOrdPos int, col *sql.Column, dbName,
}

// getRowsFromTable returns array of rows for all accessible columns of the given table.
func getRowsFromTable(ctx *sql.Context, db sql.Database, t sql.Table, privSetDb sql.PrivilegeSetDatabase, privSetMap map[string]struct{}, allColsWithDefaultValue sql.Schema) ([]sql.Row, error) {
func getRowsFromTable(ctx *sql.Context, db dbWithNames, t sql.Table, privSetDb sql.PrivilegeSetDatabase, privSetMap map[string]struct{}, allColsWithDefaultValue sql.Schema) ([]sql.Row, error) {
var rows []sql.Row

privSetTbl := privSetDb.Table(t.Name())
Expand All @@ -317,7 +327,7 @@ func getRowsFromTable(ctx *sql.Context, db sql.Database, t sql.Table, privSetDb
}

tblName := t.Name()
for i, col := range schemaForTable(t, db, allColsWithDefaultValue) {
for i, col := range schemaForTable(t, db.database, allColsWithDefaultValue) {
var columnKey string
// Check column PK here first because there are PKs from table implementations that don't implement sql.IndexedTable
if col.PrimaryKey {
Expand All @@ -331,7 +341,7 @@ func getRowsFromTable(ctx *sql.Context, db sql.Database, t sql.Table, privSetDb
}
}

r := getRowFromColumn(ctx, i, col, db.Name(), tblName, columnKey, privSetTbl, curPrivSetMap)
r := getRowFromColumn(ctx, i, col, db.catalogName, db.schemaName, tblName, columnKey, privSetTbl, curPrivSetMap)
if r != nil {
rows = append(rows, r)
}
Expand All @@ -341,58 +351,58 @@ func getRowsFromTable(ctx *sql.Context, db sql.Database, t sql.Table, privSetDb
}

// getRowsFromViews returns array or rows for columns for all views for given database.
func getRowsFromViews(ctx *sql.Context, db sql.Database) ([]sql.Row, error) {
func getRowsFromViews(ctx *sql.Context, db dbWithNames) ([]sql.Row, error) {
var rows []sql.Row
// TODO: View Definition is lacking information to properly fill out these table
// TODO: Should somehow get reference to table(s) view is referencing
// TODO: Each column that view references should also show up as unique entries as well
views, err := viewsInDatabase(ctx, db)
views, err := viewsInDatabase(ctx, db.database)
if err != nil {
return nil, err
}

for _, view := range views {
rows = append(rows, sql.Row{
"def", // table_catalog
db.Name(), // table_schema
view.Name, // table_name
"", // column_name
uint32(0), // ordinal_position
nil, // column_default
"", // is_nullable
nil, // data_type
nil, // character_maximum_length
nil, // character_octet_length
nil, // numeric_precision
nil, // numeric_scale
nil, // datetime_precision
"", // character_set_name
"", // collation_name
"", // column_type
"", // column_key
"", // extra
"select", // privileges
"", // column_comment
"", // generation_expression
nil, // srs_id
db.catalogName, // table_catalog
db.schemaName, // table_schema
view.Name, // table_name
"", // column_name
uint32(0), // ordinal_position
nil, // column_default
"", // is_nullable
nil, // data_type
nil, // character_maximum_length
nil, // character_octet_length
nil, // numeric_precision
nil, // numeric_scale
nil, // datetime_precision
"", // character_set_name
"", // collation_name
"", // column_type
"", // column_key
"", // extra
"select", // privileges
"", // column_comment
"", // generation_expression
nil, // srs_id
})
}

return rows, nil
}

// getRowsFromDatabase returns array of rows for all accessible columns of accessible table of the given database.
func getRowsFromDatabase(ctx *sql.Context, db sql.Database, privSet sql.PrivilegeSet, privSetMap map[string]struct{}, allColsWithDefaultValue sql.Schema) ([]sql.Row, error) {
func getRowsFromDatabase(ctx *sql.Context, db dbWithNames, privSet sql.PrivilegeSet, privSetMap map[string]struct{}, allColsWithDefaultValue sql.Schema) ([]sql.Row, error) {
var rows []sql.Row
dbName := db.Name()
dbName := db.database.Name()

privSetDb := privSet.Database(dbName)
curPrivSetMap := getCurrentPrivSetMapForColumn(privSetDb.ToSlice(), privSetMap)
if dbName == sql.InformationSchemaDatabaseName {
curPrivSetMap["select"] = struct{}{}
}

err := sql.DBTableIter(ctx, db, func(t sql.Table) (cont bool, err error) {
err := sql.DBTableIter(ctx, db.database, func(t sql.Table) (cont bool, err error) {
rs, err := getRowsFromTable(ctx, db, t, privSetDb, curPrivSetMap, allColsWithDefaultValue)
if err != nil {
return false, err
Expand Down