From 8d99c2ae866e47d09bf2fb1de8b47c69ad0d1eb0 Mon Sep 17 00:00:00 2001 From: Taylor Bantle Date: Wed, 17 Jul 2024 13:01:53 -0700 Subject: [PATCH] Fix information_schema.columns for databases with schemas --- sql/information_schema/columns_table.go | 84 ++++++++++++++----------- 1 file changed, 47 insertions(+), 37 deletions(-) diff --git a/sql/information_schema/columns_table.go b/sql/information_schema/columns_table.go index a4679b0680..acf813d5dd 100644 --- a/sql/information_schema/columns_table.go +++ b/sql/information_schema/columns_table.go @@ -138,12 +138,17 @@ func (c *ColumnsTable) AllColumns(ctx *sql.Context) (sql.Schema, error) { var allColumns sql.Schema - for _, db := range c.catalog.AllDatabases(ctx) { - err := sql.DBTableIter(ctx, db, func(t sql.Table) (cont bool, err error) { + databases, err := allDatabases(ctx, c.catalog, false) + if err != nil { + return nil, err + } + + for _, db := range databases { + err := sql.DBTableIter(ctx, db.database, func(t sql.Table) (cont bool, err error) { tableSch := t.Schema() for i := range tableSch { newCol := tableSch[i].Copy() - newCol.DatabaseSource = db.Name() + newCol.DatabaseSource = db.database.Name() allColumns = append(allColumns, newCol) } return true, nil @@ -205,7 +210,12 @@ func columnsRowIter(ctx *sql.Context, catalog sql.Catalog, allColsWithDefaultVal } globalPrivSetMap = getCurrentPrivSetMapForColumn(privSet.ToSlice(), globalPrivSetMap) - for _, db := range catalog.AllDatabases(ctx) { + databases, err := allDatabases(ctx, catalog, false) + if err != nil { + return nil, err + } + + for _, db := range databases { rs, err := getRowsFromDatabase(ctx, db, privSet, globalPrivSetMap, allColsWithDefaultValue) if err != nil { return nil, err @@ -224,7 +234,7 @@ func columnsRowIter(ctx *sql.Context, catalog sql.Catalog, allColsWithDefaultVal // getRowFromColumn returns a single row for given column. The arguments passed are used to define all row values. // These include the current ordinal position, so this column will get the next position number, sql.Column object, // database name, table name, column key and column privileges information through privileges set for the table. -func getRowFromColumn(ctx *sql.Context, curOrdPos int, col *sql.Column, dbName, tblName, columnKey string, privSetTbl sql.PrivilegeSetTable, privSetMap map[string]struct{}) sql.Row { +func getRowFromColumn(ctx *sql.Context, curOrdPos int, col *sql.Column, catName, schName, tblName, columnKey string, privSetTbl sql.PrivilegeSetTable, privSetMap map[string]struct{}) sql.Row { var ( ordinalPos = uint32(curOrdPos + 1) nullable = "NO" @@ -279,8 +289,8 @@ func getRowFromColumn(ctx *sql.Context, curOrdPos int, col *sql.Column, dbName, privileges := strings.Join(curColPrivStr, ",") return sql.Row{ - "def", // table_catalog - dbName, // table_schema + catName, // table_catalog + schName, // table_schema tblName, // table_name col.Name, // column_name ordinalPos, // ordinal_position @@ -305,7 +315,7 @@ func getRowFromColumn(ctx *sql.Context, curOrdPos int, col *sql.Column, dbName, } // getRowsFromTable returns array of rows for all accessible columns of the given table. -func getRowsFromTable(ctx *sql.Context, db sql.Database, t sql.Table, privSetDb sql.PrivilegeSetDatabase, privSetMap map[string]struct{}, allColsWithDefaultValue sql.Schema) ([]sql.Row, error) { +func getRowsFromTable(ctx *sql.Context, db dbWithNames, t sql.Table, privSetDb sql.PrivilegeSetDatabase, privSetMap map[string]struct{}, allColsWithDefaultValue sql.Schema) ([]sql.Row, error) { var rows []sql.Row privSetTbl := privSetDb.Table(t.Name()) @@ -317,7 +327,7 @@ func getRowsFromTable(ctx *sql.Context, db sql.Database, t sql.Table, privSetDb } tblName := t.Name() - for i, col := range schemaForTable(t, db, allColsWithDefaultValue) { + for i, col := range schemaForTable(t, db.database, allColsWithDefaultValue) { var columnKey string // Check column PK here first because there are PKs from table implementations that don't implement sql.IndexedTable if col.PrimaryKey { @@ -331,7 +341,7 @@ func getRowsFromTable(ctx *sql.Context, db sql.Database, t sql.Table, privSetDb } } - r := getRowFromColumn(ctx, i, col, db.Name(), tblName, columnKey, privSetTbl, curPrivSetMap) + r := getRowFromColumn(ctx, i, col, db.catalogName, db.schemaName, tblName, columnKey, privSetTbl, curPrivSetMap) if r != nil { rows = append(rows, r) } @@ -341,40 +351,40 @@ func getRowsFromTable(ctx *sql.Context, db sql.Database, t sql.Table, privSetDb } // getRowsFromViews returns array or rows for columns for all views for given database. -func getRowsFromViews(ctx *sql.Context, db sql.Database) ([]sql.Row, error) { +func getRowsFromViews(ctx *sql.Context, db dbWithNames) ([]sql.Row, error) { var rows []sql.Row // TODO: View Definition is lacking information to properly fill out these table // TODO: Should somehow get reference to table(s) view is referencing // TODO: Each column that view references should also show up as unique entries as well - views, err := viewsInDatabase(ctx, db) + views, err := viewsInDatabase(ctx, db.database) if err != nil { return nil, err } for _, view := range views { rows = append(rows, sql.Row{ - "def", // table_catalog - db.Name(), // table_schema - view.Name, // table_name - "", // column_name - uint32(0), // ordinal_position - nil, // column_default - "", // is_nullable - nil, // data_type - nil, // character_maximum_length - nil, // character_octet_length - nil, // numeric_precision - nil, // numeric_scale - nil, // datetime_precision - "", // character_set_name - "", // collation_name - "", // column_type - "", // column_key - "", // extra - "select", // privileges - "", // column_comment - "", // generation_expression - nil, // srs_id + db.catalogName, // table_catalog + db.schemaName, // table_schema + view.Name, // table_name + "", // column_name + uint32(0), // ordinal_position + nil, // column_default + "", // is_nullable + nil, // data_type + nil, // character_maximum_length + nil, // character_octet_length + nil, // numeric_precision + nil, // numeric_scale + nil, // datetime_precision + "", // character_set_name + "", // collation_name + "", // column_type + "", // column_key + "", // extra + "select", // privileges + "", // column_comment + "", // generation_expression + nil, // srs_id }) } @@ -382,9 +392,9 @@ func getRowsFromViews(ctx *sql.Context, db sql.Database) ([]sql.Row, error) { } // getRowsFromDatabase returns array of rows for all accessible columns of accessible table of the given database. -func getRowsFromDatabase(ctx *sql.Context, db sql.Database, privSet sql.PrivilegeSet, privSetMap map[string]struct{}, allColsWithDefaultValue sql.Schema) ([]sql.Row, error) { +func getRowsFromDatabase(ctx *sql.Context, db dbWithNames, privSet sql.PrivilegeSet, privSetMap map[string]struct{}, allColsWithDefaultValue sql.Schema) ([]sql.Row, error) { var rows []sql.Row - dbName := db.Name() + dbName := db.database.Name() privSetDb := privSet.Database(dbName) curPrivSetMap := getCurrentPrivSetMapForColumn(privSetDb.ToSlice(), privSetMap) @@ -392,7 +402,7 @@ func getRowsFromDatabase(ctx *sql.Context, db sql.Database, privSet sql.Privileg curPrivSetMap["select"] = struct{}{} } - err := sql.DBTableIter(ctx, db, func(t sql.Table) (cont bool, err error) { + err := sql.DBTableIter(ctx, db.database, func(t sql.Table) (cont bool, err error) { rs, err := getRowsFromTable(ctx, db, t, privSetDb, curPrivSetMap, allColsWithDefaultValue) if err != nil { return false, err