diff --git a/pkg/executor/BUILD.bazel b/pkg/executor/BUILD.bazel index b6a11412a5d39..73c60f967563f 100644 --- a/pkg/executor/BUILD.bazel +++ b/pkg/executor/BUILD.bazel @@ -288,6 +288,7 @@ go_library( "@org_golang_google_grpc//credentials", "@org_golang_google_grpc//credentials/insecure", "@org_golang_google_grpc//status", + "@org_golang_x_exp//maps", "@org_golang_x_sync//errgroup", "@org_uber_go_atomic//:atomic", "@org_uber_go_zap//:zap", diff --git a/pkg/executor/infoschema_reader.go b/pkg/executor/infoschema_reader.go index 2262c118db8f5..e66009f7d35db 100644 --- a/pkg/executor/infoschema_reader.go +++ b/pkg/executor/infoschema_reader.go @@ -47,6 +47,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/charset" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" plannercore "github.com/pingcap/tidb/pkg/planner/core" "github.com/pingcap/tidb/pkg/planner/core/base" "github.com/pingcap/tidb/pkg/privilege" @@ -84,6 +85,7 @@ import ( "github.com/tikv/client-go/v2/txnkv/txnlock" pd "github.com/tikv/pd/client/http" "go.uber.org/zap" + "golang.org/x/exp/maps" ) type memtableRetriever struct { @@ -112,31 +114,42 @@ func (e *memtableRetriever) retrieve(ctx context.Context, sctx sessionctx.Contex if !e.initialized { is := sctx.GetInfoSchema().(infoschema.InfoSchema) e.is = is - dbs := is.AllSchemaNames() - slices.SortFunc(dbs, func(a, b model.CIStr) int { - return strings.Compare(a.L, b.L) - }) + + var getAllSchemas = func() []model.CIStr { + dbs := is.AllSchemaNames() + slices.SortFunc(dbs, func(a, b model.CIStr) int { + return strings.Compare(a.L, b.L) + }) + return dbs + } + var err error switch e.table.Name.O { case infoschema.TableSchemata: + dbs := getAllSchemas() e.setDataFromSchemata(sctx, dbs) case infoschema.TableStatistics: + dbs := getAllSchemas() err = e.setDataForStatistics(ctx, sctx, dbs) case infoschema.TableTables: - err = e.setDataFromTables(ctx, sctx, dbs) + err = e.setDataFromTables(ctx, sctx) case infoschema.TableReferConst: + dbs := getAllSchemas() err = e.setDataFromReferConst(ctx, sctx, dbs) case infoschema.TableSequences: + dbs := getAllSchemas() err = e.setDataFromSequences(ctx, sctx, dbs) case infoschema.TablePartitions: - err = e.setDataFromPartitions(ctx, sctx, dbs) + err = e.setDataFromPartitions(ctx, sctx) case infoschema.TableClusterInfo: err = e.dataForTiDBClusterInfo(sctx) case infoschema.TableAnalyzeStatus: err = e.setDataForAnalyzeStatus(ctx, sctx) case infoschema.TableTiDBIndexes: + dbs := getAllSchemas() err = e.setDataFromIndexes(ctx, sctx, dbs) case infoschema.TableViews: + dbs := getAllSchemas() err = e.setDataFromViews(ctx, sctx, dbs) case infoschema.TableEngines: e.setDataFromEngines() @@ -145,6 +158,7 @@ func (e *memtableRetriever) retrieve(ctx context.Context, sctx sessionctx.Contex case infoschema.TableCollations: e.setDataFromCollations() case infoschema.TableKeyColumn: + dbs := getAllSchemas() err = e.setDataFromKeyColumnUsage(ctx, sctx, dbs) case infoschema.TableMetricTables: e.setDataForMetricTables() @@ -163,12 +177,14 @@ func (e *memtableRetriever) retrieve(ctx context.Context, sctx sessionctx.Contex case infoschema.TableTiDBHotRegions: err = e.setDataForTiDBHotRegions(ctx, sctx) case infoschema.TableConstraints: + dbs := getAllSchemas() err = e.setDataFromTableConstraints(ctx, sctx, dbs) case infoschema.TableSessionVar: e.rows, err = infoschema.GetDataFromSessionVariables(ctx, sctx) case infoschema.TableTiDBServersInfo: err = e.setDataForServersInfo(sctx) case infoschema.TableTiFlashReplica: + dbs := getAllSchemas() err = e.dataForTableTiFlashReplica(ctx, sctx, dbs) case infoschema.TableTiKVStoreStatus: err = e.dataForTiKVStoreStatus(ctx, sctx) @@ -201,14 +217,18 @@ func (e *memtableRetriever) retrieve(ctx context.Context, sctx sessionctx.Contex case infoschema.TableRunawayWatches: err = e.setDataFromRunawayWatches(sctx) case infoschema.TableCheckConstraints: + dbs := getAllSchemas() err = e.setDataFromCheckConstraints(ctx, sctx, dbs) case infoschema.TableTiDBCheckConstraints: + dbs := getAllSchemas() err = e.setDataFromTiDBCheckConstraints(ctx, sctx, dbs) case infoschema.TableKeywords: err = e.setDataFromKeywords() case infoschema.TableTiDBIndexUsage: + dbs := getAllSchemas() err = e.setDataFromIndexUsage(ctx, sctx, dbs) case infoschema.ClusterTableTiDBIndexUsage: + dbs := getAllSchemas() err = e.setDataForClusterIndexUsage(ctx, sctx, dbs) } if err != nil { @@ -235,8 +255,13 @@ func (e *memtableRetriever) retrieve(ctx context.Context, sctx sessionctx.Contex return adjustColumns(ret, e.columns, e.table), nil } -func getAutoIncrementID(ctx context.Context, sctx sessionctx.Context, schema model.CIStr, tblInfo *model.TableInfo) (int64, error) { - is := sctx.GetInfoSchema().(infoschema.InfoSchema) +func getAutoIncrementID( + ctx context.Context, + is infoschema.InfoSchema, + sctx sessionctx.Context, + schema model.CIStr, + tblInfo *model.TableInfo, +) (int64, error) { tbl, err := is.TableByName(ctx, schema, tblInfo.Name) if err != nil { return 0, err @@ -563,151 +588,266 @@ func (e *memtableRetriever) updateStatsCacheIfNeed() bool { return false } -func (e *memtableRetriever) setDataFromTables(ctx context.Context, sctx sessionctx.Context, schemas []model.CIStr) error { - useStatsCache := e.updateStatsCacheIfNeed() - checker := privilege.GetPrivilegeManager(sctx) - - var rows [][]types.Datum - createTimeTp := mysql.TypeDatetime - loc := sctx.GetSessionVars().TimeZone - if loc == nil { - loc = time.Local +func getMatchSchemas( + extractor base.MemTablePredicateExtractor, + is infoschema.InfoSchema, +) []model.CIStr { + ex, ok := extractor.(plannercore.TableSchemaSelector) + if ok { + if schemas := ex.SelectedSchemaNames(); len(schemas) > 0 { + ret := schemas[:0] + for _, s := range schemas { + if n, ok := is.SchemaByName(s); ok { + ret = append(ret, n.Name) + } + } + return ret + } } - extractor, ok := e.extractor.(*plannercore.InfoSchemaTablesExtractor) - if ok && extractor.SkipRequest { - return nil + schemas := is.AllSchemaNames() + slices.SortFunc(schemas, func(a, b model.CIStr) int { + return strings.Compare(a.L, b.L) + }) + return schemas +} + +func getMatchTableInfosForPartitions( + ctx context.Context, + extractor base.MemTablePredicateExtractor, + schema model.CIStr, + is infoschema.InfoSchema, +) ([]*model.TableInfo, error) { + ex, ok := extractor.(plannercore.TableSchemaSelector) + if !ok || !ex.HasTables() { + // There is no specified table in predicate. + return is.SchemaTableInfos(ctx, schema) + } + tables := make(map[int64]*model.TableInfo, 8) + // Find all table infos from predicate. + for _, n := range ex.SelectedTableNames() { + tbl, err := is.TableByName(ctx, schema, n) + if err != nil { + if terror.ErrorEqual(err, infoschema.ErrTableNotExists) { + continue + } + return nil, errors.Trace(err) + } + tblInfo := tbl.Meta() + tables[tblInfo.ID] = tblInfo } - for _, schema := range schemas { - if ok && extractor.Filter("table_schema", schema.L) { + for _, pid := range ex.SelectedPartitionIDs() { + tbl, db, _ := is.FindTableByPartitionID(pid) + if tbl == nil { continue } - tables, err := e.is.SchemaTableInfos(ctx, schema) - if err != nil { - return errors.Trace(err) + if db.Name.L != schema.L { + continue } - for _, table := range tables { - if ok && extractor.Filter("table_name", table.Name.L) { + tblInfo := tbl.Meta() + tables[tblInfo.ID] = tblInfo + } + return maps.Values(tables), nil +} + +func getMatchTableInfos( + ctx context.Context, + extractor base.MemTablePredicateExtractor, + schema model.CIStr, + is infoschema.InfoSchema, +) ([]*model.TableInfo, error) { + ex, ok := extractor.(plannercore.TableSchemaSelector) + if !ok || !ex.HasTables() { + // There is no specified table in predicate. + return is.SchemaTableInfos(ctx, schema) + } + tables := make(map[int64]*model.TableInfo, 8) + // Find all table infos from predicate. + for _, n := range ex.SelectedTableNames() { + tbl, err := is.TableByName(ctx, schema, n) + if err != nil { + if terror.ErrorEqual(err, infoschema.ErrTableNotExists) { continue } - collation := table.Collate - if collation == "" { - collation = mysql.DefaultCollationName + return nil, errors.Trace(err) + } + tblInfo := tbl.Meta() + tables[tblInfo.ID] = tblInfo + } + for _, id := range ex.SelectedTableIDs() { + tbl, ok := is.TableByID(id) + if !ok { + continue + } + _, err := is.TableByName(ctx, schema, tbl.Meta().Name) + if err != nil { + if terror.ErrorEqual(err, infoschema.ErrTableNotExists) { + continue } - createTime := types.NewTime(types.FromGoTime(table.GetUpdateTime().In(loc)), createTimeTp, types.DefaultFsp) + return nil, errors.Trace(err) + } + tblInfo := tbl.Meta() + tables[tblInfo.ID] = tblInfo + } + return maps.Values(tables), nil +} + +func (e *memtableRetriever) setDataFromOneTable( + ctx context.Context, + sctx sessionctx.Context, + loc *time.Location, + checker privilege.Manager, + schema model.CIStr, + table *model.TableInfo, + rows [][]types.Datum, + useStatsCache bool, +) ([][]types.Datum, error) { + collation := table.Collate + if collation == "" { + collation = mysql.DefaultCollationName + } + createTime := types.NewTime(types.FromGoTime(table.GetUpdateTime().In(loc)), mysql.TypeDatetime, types.DefaultFsp) - createOptions := "" + createOptions := "" - if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, table.Name.L, "", mysql.AllPrivMask) { - continue + if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, table.Name.L, "", mysql.AllPrivMask) { + return rows, nil + } + pkType := "NONCLUSTERED" + if !table.IsView() { + if table.GetPartitionInfo() != nil { + createOptions = "partitioned" + } else if table.TableCacheStatusType == model.TableCacheStatusEnable { + createOptions = "cached=on" + } + var err error + var autoIncID any + hasAutoIncID, _ := infoschema.HasAutoIncrementColumn(table) + if hasAutoIncID { + autoIncID, err = getAutoIncrementID(ctx, e.is, sctx, schema, table) + if err != nil { + return rows, err } - pkType := "NONCLUSTERED" - if !table.IsView() { - if table.GetPartitionInfo() != nil { - createOptions = "partitioned" - } else if table.TableCacheStatusType == model.TableCacheStatusEnable { - createOptions = "cached=on" + } + tableType := "BASE TABLE" + if util.IsSystemView(schema.L) { + tableType = "SYSTEM VIEW" + } + if table.IsSequence() { + tableType = "SEQUENCE" + } + if table.HasClusteredIndex() { + pkType = "CLUSTERED" + } + shardingInfo := infoschema.GetShardingInfo(schema, table) + var policyName any + if table.PlacementPolicyRef != nil { + policyName = table.PlacementPolicyRef.Name.O + } + + var rowCount, avgRowLength, dataLength, indexLength uint64 + if useStatsCache { + if table.GetPartitionInfo() == nil { + err := cache.TableRowStatsCache.UpdateByID(sctx, table.ID) + if err != nil { + return rows, err } - var err error - var autoIncID any - hasAutoIncID, _ := infoschema.HasAutoIncrementColumn(table) - if hasAutoIncID { - autoIncID, err = getAutoIncrementID(ctx, sctx, schema, table) + } else { + // needs to update all partitions for partition table. + for _, pi := range table.GetPartitionInfo().Definitions { + err := cache.TableRowStatsCache.UpdateByID(sctx, pi.ID) if err != nil { - return err + return rows, err } } - tableType := "BASE TABLE" - if util.IsSystemView(schema.L) { - tableType = "SYSTEM VIEW" - } - if table.IsSequence() { - tableType = "SEQUENCE" - } - if table.HasClusteredIndex() { - pkType = "CLUSTERED" - } - shardingInfo := infoschema.GetShardingInfo(schema, table) - var policyName any - if table.PlacementPolicyRef != nil { - policyName = table.PlacementPolicyRef.Name.O - } + } + rowCount, avgRowLength, dataLength, indexLength = cache.TableRowStatsCache.EstimateDataLength(table) + } - var rowCount, avgRowLength, dataLength, indexLength uint64 - if useStatsCache { - if table.GetPartitionInfo() == nil { - err := cache.TableRowStatsCache.UpdateByID(sctx, table.ID) - if err != nil { - return err - } - } else { - // needs to update all partitions for partition table. - for _, pi := range table.GetPartitionInfo().Definitions { - err := cache.TableRowStatsCache.UpdateByID(sctx, pi.ID) - if err != nil { - return err - } - } - } - rowCount, avgRowLength, dataLength, indexLength = cache.TableRowStatsCache.EstimateDataLength(table) - } + record := types.MakeDatums( + infoschema.CatalogVal, // TABLE_CATALOG + schema.O, // TABLE_SCHEMA + table.Name.O, // TABLE_NAME + tableType, // TABLE_TYPE + "InnoDB", // ENGINE + uint64(10), // VERSION + "Compact", // ROW_FORMAT + rowCount, // TABLE_ROWS + avgRowLength, // AVG_ROW_LENGTH + dataLength, // DATA_LENGTH + uint64(0), // MAX_DATA_LENGTH + indexLength, // INDEX_LENGTH + uint64(0), // DATA_FREE + autoIncID, // AUTO_INCREMENT + createTime, // CREATE_TIME + nil, // UPDATE_TIME + nil, // CHECK_TIME + collation, // TABLE_COLLATION + nil, // CHECKSUM + createOptions, // CREATE_OPTIONS + table.Comment, // TABLE_COMMENT + table.ID, // TIDB_TABLE_ID + shardingInfo, // TIDB_ROW_ID_SHARDING_INFO + pkType, // TIDB_PK_TYPE + policyName, // TIDB_PLACEMENT_POLICY_NAME + ) + rows = append(rows, record) + } else { + record := types.MakeDatums( + infoschema.CatalogVal, // TABLE_CATALOG + schema.O, // TABLE_SCHEMA + table.Name.O, // TABLE_NAME + "VIEW", // TABLE_TYPE + nil, // ENGINE + nil, // VERSION + nil, // ROW_FORMAT + nil, // TABLE_ROWS + nil, // AVG_ROW_LENGTH + nil, // DATA_LENGTH + nil, // MAX_DATA_LENGTH + nil, // INDEX_LENGTH + nil, // DATA_FREE + nil, // AUTO_INCREMENT + createTime, // CREATE_TIME + nil, // UPDATE_TIME + nil, // CHECK_TIME + nil, // TABLE_COLLATION + nil, // CHECKSUM + nil, // CREATE_OPTIONS + "VIEW", // TABLE_COMMENT + table.ID, // TIDB_TABLE_ID + nil, // TIDB_ROW_ID_SHARDING_INFO + pkType, // TIDB_PK_TYPE + nil, // TIDB_PLACEMENT_POLICY_NAME + ) + rows = append(rows, record) + } + return rows, nil +} - record := types.MakeDatums( - infoschema.CatalogVal, // TABLE_CATALOG - schema.O, // TABLE_SCHEMA - table.Name.O, // TABLE_NAME - tableType, // TABLE_TYPE - "InnoDB", // ENGINE - uint64(10), // VERSION - "Compact", // ROW_FORMAT - rowCount, // TABLE_ROWS - avgRowLength, // AVG_ROW_LENGTH - dataLength, // DATA_LENGTH - uint64(0), // MAX_DATA_LENGTH - indexLength, // INDEX_LENGTH - uint64(0), // DATA_FREE - autoIncID, // AUTO_INCREMENT - createTime, // CREATE_TIME - nil, // UPDATE_TIME - nil, // CHECK_TIME - collation, // TABLE_COLLATION - nil, // CHECKSUM - createOptions, // CREATE_OPTIONS - table.Comment, // TABLE_COMMENT - table.ID, // TIDB_TABLE_ID - shardingInfo, // TIDB_ROW_ID_SHARDING_INFO - pkType, // TIDB_PK_TYPE - policyName, // TIDB_PLACEMENT_POLICY_NAME - ) - rows = append(rows, record) - } else { - record := types.MakeDatums( - infoschema.CatalogVal, // TABLE_CATALOG - schema.O, // TABLE_SCHEMA - table.Name.O, // TABLE_NAME - "VIEW", // TABLE_TYPE - nil, // ENGINE - nil, // VERSION - nil, // ROW_FORMAT - nil, // TABLE_ROWS - nil, // AVG_ROW_LENGTH - nil, // DATA_LENGTH - nil, // MAX_DATA_LENGTH - nil, // INDEX_LENGTH - nil, // DATA_FREE - nil, // AUTO_INCREMENT - createTime, // CREATE_TIME - nil, // UPDATE_TIME - nil, // CHECK_TIME - nil, // TABLE_COLLATION - nil, // CHECKSUM - nil, // CREATE_OPTIONS - "VIEW", // TABLE_COMMENT - table.ID, // TIDB_TABLE_ID - nil, // TIDB_ROW_ID_SHARDING_INFO - pkType, // TIDB_PK_TYPE - nil, // TIDB_PLACEMENT_POLICY_NAME - ) - rows = append(rows, record) +func (e *memtableRetriever) setDataFromTables(ctx context.Context, sctx sessionctx.Context) error { + useStatsCache := e.updateStatsCacheIfNeed() + checker := privilege.GetPrivilegeManager(sctx) + + var rows [][]types.Datum + loc := sctx.GetSessionVars().TimeZone + if loc == nil { + loc = time.Local + } + ex := e.extractor.(*plannercore.InfoSchemaTablesExtractor) + if ex != nil && ex.SkipRequest { + return nil + } + + schemas := getMatchSchemas(e.extractor, e.is) + for _, schema := range schemas { + tables, err := getMatchTableInfos(ctx, e.extractor, schema, e.is) + if err != nil { + return errors.Trace(err) + } + for _, table := range tables { + rows, err = e.setDataFromOneTable(ctx, sctx, loc, checker, schema, table, rows, useStatsCache) + if err != nil { + return errors.Trace(err) } } } @@ -1112,29 +1252,23 @@ func calcCharOctLength(lenInChar int, cs string) int { return lenInBytes } -func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sessionctx.Context, schemas []model.CIStr) error { +func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sessionctx.Context) error { useStatsCache := e.updateStatsCacheIfNeed() checker := privilege.GetPrivilegeManager(sctx) var rows [][]types.Datum createTimeTp := mysql.TypeDatetime - extractor, ok := e.extractor.(*plannercore.InfoSchemaTablesExtractor) - if ok && extractor.SkipRequest { + ex, ok := e.extractor.(*plannercore.InfoSchemaTablesExtractor) + if ok && ex.SkipRequest { return nil } - + schemas := getMatchSchemas(e.extractor, e.is) for _, schema := range schemas { - if ok && extractor.Filter("table_schema", schema.L) { - continue - } - tables, err := e.is.SchemaTableInfos(ctx, schema) + tables, err := getMatchTableInfosForPartitions(ctx, e.extractor, schema, e.is) if err != nil { return errors.Trace(err) } for _, table := range tables { - if ok && extractor.Filter("table_name", table.Name.L) { - continue - } if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, table.Name.L, "", mysql.SelectPriv) { continue } @@ -1150,7 +1284,7 @@ func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sess } else { // needs to update needed partitions for partition table. for _, pi := range table.GetPartitionInfo().Definitions { - if ok && extractor.Filter("partition_name", pi.Name.L) { + if ok && ex.Filter("partition_name", pi.Name.L) { continue } err := cache.TableRowStatsCache.UpdateByID(sctx, pi.ID) @@ -1199,7 +1333,7 @@ func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sess rows = append(rows, record) } else { for i, pi := range table.GetPartitionInfo().Definitions { - if ok && extractor.Filter("partition_name", pi.Name.L) { + if ok && ex.Filter("partition_name", pi.Name.L) { continue } rowCount = cache.TableRowStatsCache.GetTableRows(pi.ID) diff --git a/pkg/planner/core/memtable_predicate_extractor.go b/pkg/planner/core/memtable_predicate_extractor.go index 7187771bb67c5..4eb59b7bf42e5 100644 --- a/pkg/planner/core/memtable_predicate_extractor.go +++ b/pkg/planner/core/memtable_predicate_extractor.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/tidb/pkg/infoschema" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/planner/core/base" "github.com/pingcap/tidb/pkg/planner/util" @@ -1814,7 +1815,6 @@ type InfoSchemaTablesExtractor struct { // SkipRequest means the where clause always false, we don't need to request any component SkipRequest bool - colNames []string ColPredicates map[string]set.StringSet } @@ -1825,10 +1825,11 @@ func (e *InfoSchemaTablesExtractor) Extract(ctx base.PlanContext, predicates []expression.Expression, ) (remained []expression.Expression) { var resultSet, resultSet1 set.StringSet - e.colNames = []string{"table_schema", "constraint_schema", "table_name", "constraint_name", "sequence_schema", "sequence_name", "partition_name", "schema_name", "index_name"} + colNames := []string{"table_schema", "constraint_schema", "table_name", "constraint_name", + "sequence_schema", "sequence_name", "partition_name", "schema_name", "index_name", "tidb_table_id"} e.ColPredicates = make(map[string]set.StringSet) remained = predicates - for _, colName := range e.colNames { + for _, colName := range colNames { remained, e.SkipRequest, resultSet = e.extractColWithLower(ctx, schema, names, remained, colName) if e.SkipRequest { break @@ -1896,3 +1897,75 @@ func (e *InfoSchemaTablesExtractor) Filter(colName string, val string) bool { // No need to filter records since no predicate for the column exists. return false } + +var _ TableSchemaSelector = (*InfoSchemaTablesExtractor)(nil) + +// TableSchemaSelector is used to help determine if a specified table/schema name contained in predicate, +// and return all specified table/schema names in predicate. +type TableSchemaSelector interface { + SelectedSchemaNames() []model.CIStr + + HasTables() bool + SelectedTableNames() []model.CIStr + SelectedTableIDs() []int64 + SelectedPartitionIDs() []int64 +} + +// HasTables returns true if there is table names or table IDs specified in predicate. +func (e *InfoSchemaTablesExtractor) HasTables() bool { + _, hasTableName := e.ColPredicates["table_name"] + _, hasTableID := e.ColPredicates["tidb_table_id"] + _, hasPartID := e.ColPredicates["tidb_partition_id"] + return hasTableName || hasTableID || hasPartID +} + +// SelectedTableNames gets the table names specified in predicate. +func (e *InfoSchemaTablesExtractor) SelectedTableNames() []model.CIStr { + return e.getSchemaObjectNames("table_name") +} + +// SelectedSchemaNames gets the schema names specified in predicate. +func (e *InfoSchemaTablesExtractor) SelectedSchemaNames() []model.CIStr { + return e.getSchemaObjectNames("table_schema") +} + +// SelectedTableIDs get table IDs specified in predicate. +func (e *InfoSchemaTablesExtractor) SelectedTableIDs() []int64 { + strs := e.getSchemaObjectNames("tidb_table_id") + return parseIDs(strs) +} + +// SelectedPartitionIDs get partitions IDs specified in predicate. +func (e *InfoSchemaTablesExtractor) SelectedPartitionIDs() []int64 { + strs := e.getSchemaObjectNames("tidb_partition_id") + return parseIDs(strs) +} + +func parseIDs(ids []model.CIStr) []int64 { + tableIDs := make([]int64, 0, len(ids)) + for _, s := range ids { + v, err := strconv.ParseInt(s.L, 10, 64) + if err != nil { + continue + } + tableIDs = append(tableIDs, v) + } + slices.Sort(tableIDs) + return tableIDs +} + +// getSchemaObjectNames gets the schema object names specified in predicate of given column name. +func (e *InfoSchemaTablesExtractor) getSchemaObjectNames(colName string) []model.CIStr { + predVals, ok := e.ColPredicates[colName] + if ok && len(predVals) > 0 { + tableNames := make([]model.CIStr, 0, len(predVals)) + predVals.IterateWith(func(n string) { + tableNames = append(tableNames, model.NewCIStr(n)) + }) + slices.SortFunc(tableNames, func(a, b model.CIStr) int { + return strings.Compare(a.L, b.L) + }) + return tableNames + } + return nil +} diff --git a/pkg/util/set/string_set.go b/pkg/util/set/string_set.go index 5a74790971070..e61f46182e390 100644 --- a/pkg/util/set/string_set.go +++ b/pkg/util/set/string_set.go @@ -85,3 +85,10 @@ func (s StringSet) Empty() bool { func (s StringSet) Clear() { maps.Clear(s) } + +// IterateWith iterate items in StringSet and pass it to `fn`. +func (s StringSet) IterateWith(fn func(string)) { + for k := range s { + fn(k) + } +} diff --git a/tests/integrationtest/r/infoschema/infoschema.result b/tests/integrationtest/r/infoschema/infoschema.result index 63e7d48570d3d..2edffeff5e3d9 100644 --- a/tests/integrationtest/r/infoschema/infoschema.result +++ b/tests/integrationtest/r/infoschema/infoschema.result @@ -115,6 +115,7 @@ Projection_4 8000.00 root Column#5, Column#10 └─MemTableScan_6 10000.00 root table:TABLES table_schema:["infoschema__infoschema"] select engine, DATA_LENGTH from information_schema.tables where lower(table_name) = 't5' and upper(table_schema) = 'INFOSCHEMA__INFOSCHEMA'; engine DATA_LENGTH +InnoDB 8 explain select engine, DATA_LENGTH from information_schema.tables where (table_name ='t4' or lower(table_name) = 't5') and upper(table_schema) = 'INFOSCHEMA__INFOSCHEMA'; id estRows task access object operator info Projection_4 8000.00 root Column#5, Column#10 @@ -130,3 +131,31 @@ MemTableScan_5 10000.00 root table:TABLES table_name:["T4","t4"], table_schema:[ select engine, DATA_LENGTH from information_schema.tables where table_name ='t4' and upper(table_name) ='T4' and table_schema = 'infoschema__infoschema'; engine DATA_LENGTH InnoDB 8 +create table pt1(a int primary key, b int) partition by hash(a) partitions 4; +create table pt2(a int primary key, b int) partition by hash(a) partitions 4; +select TABLE_NAME, PARTITION_NAME from information_schema.partitions where table_schema = 'infoschema__infoschema'; +TABLE_NAME PARTITION_NAME +pt1 p0 +pt1 p1 +pt1 p2 +pt1 p3 +pt2 p0 +pt2 p1 +pt2 p2 +pt2 p3 +t4 NULL +t5 NULL +select TABLE_NAME, PARTITION_NAME from information_schema.partitions where table_name = 'pt1' and table_schema = 'infoschema__infoschema'; +TABLE_NAME PARTITION_NAME +pt1 p0 +pt1 p1 +pt1 p2 +pt1 p3 +select TABLE_NAME, PARTITION_NAME from information_schema.partitions where table_name = 'pt2' and table_schema = 'infoschema__infoschema'; +TABLE_NAME PARTITION_NAME +pt2 p0 +pt2 p1 +pt2 p2 +pt2 p3 +select TABLE_NAME, PARTITION_NAME from information_schema.partitions where table_name = 'pt0' and table_schema = 'infoschema__infoschema'; +TABLE_NAME PARTITION_NAME diff --git a/tests/integrationtest/t/infoschema/infoschema.test b/tests/integrationtest/t/infoschema/infoschema.test index faad9b5573054..f4e0f993e817f 100644 --- a/tests/integrationtest/t/infoschema/infoschema.test +++ b/tests/integrationtest/t/infoschema/infoschema.test @@ -58,3 +58,14 @@ select engine, DATA_LENGTH from information_schema.tables where (table_name ='t4 explain select engine, DATA_LENGTH from information_schema.tables where table_name ='t4' and upper(table_name) ='T4' and table_schema = 'infoschema__infoschema'; select engine, DATA_LENGTH from information_schema.tables where table_name ='t4' and upper(table_name) ='T4' and table_schema = 'infoschema__infoschema'; +# TestPartitionsColumn +create table pt1(a int primary key, b int) partition by hash(a) partitions 4; +create table pt2(a int primary key, b int) partition by hash(a) partitions 4; +-- sorted_result +select TABLE_NAME, PARTITION_NAME from information_schema.partitions where table_schema = 'infoschema__infoschema'; +-- sorted_result +select TABLE_NAME, PARTITION_NAME from information_schema.partitions where table_name = 'pt1' and table_schema = 'infoschema__infoschema'; +-- sorted_result +select TABLE_NAME, PARTITION_NAME from information_schema.partitions where table_name = 'pt2' and table_schema = 'infoschema__infoschema'; +-- sorted_result +select TABLE_NAME, PARTITION_NAME from information_schema.partitions where table_name = 'pt0' and table_schema = 'infoschema__infoschema';