diff --git a/pkg/ccl/backupccl/restore_planning.go b/pkg/ccl/backupccl/restore_planning.go index d1bf9496265f..052d75e85f5e 100644 --- a/pkg/ccl/backupccl/restore_planning.go +++ b/pkg/ccl/backupccl/restore_planning.go @@ -116,14 +116,23 @@ func rewriteTypesInExpr(expr string, rewrites DescRewriteMap) (string, error) { if err != nil { return "", err } + ctx := tree.NewFmtCtx(tree.FmtSerializable) ctx.SetIndexedTypeFormat(func(ctx *tree.FmtCtx, ref *tree.OIDTypeReference) { newRef := ref - if rw, ok := rewrites[typedesc.UserDefinedTypeOIDToID(ref.OID)]; ok { + var id descpb.ID + id, err = typedesc.UserDefinedTypeOIDToID(ref.OID) + if err != nil { + return + } + if rw, ok := rewrites[id]; ok { newRef = &tree.OIDTypeReference{OID: typedesc.TypeIDToOID(rw.ID)} } ctx.WriteString(newRef.SQLString()) }) + if err != nil { + return "", err + } ctx.FormatNode(parsed) return ctx.CloseAndGetString(), nil } @@ -348,11 +357,15 @@ func allocateDescriptorRewrites( // Ensure that all referenced types are present. if col.Type.UserDefined() { // TODO (rohany): This can be turned into an option later. - if _, ok := typesByID[typedesc.GetTypeDescID(col.Type)]; !ok { + id, err := typedesc.GetUserDefinedTypeDescID(col.Type) + if err != nil { + return nil, err + } + if _, ok := typesByID[id]; !ok { return nil, errors.Errorf( "cannot restore table %q without referenced type %d", table.Name, - typedesc.GetTypeDescID(col.Type), + id, ) } } @@ -1025,25 +1038,37 @@ func rewriteDatabaseDescs(databases []*dbdesc.Mutable, descriptorRewrites DescRe // rewriteIDsInTypesT rewrites all ID's in the input types.T using the input // ID rewrite mapping. -func rewriteIDsInTypesT(typ *types.T, descriptorRewrites DescRewriteMap) { +func rewriteIDsInTypesT(typ *types.T, descriptorRewrites DescRewriteMap) error { if !typ.UserDefined() { - return + return nil + } + tid, err := typedesc.GetUserDefinedTypeDescID(typ) + if err != nil { + return err } // Collect potential new OID values. var newOID, newArrayOID oid.Oid - if rw, ok := descriptorRewrites[typedesc.GetTypeDescID(typ)]; ok { + if rw, ok := descriptorRewrites[tid]; ok { newOID = typedesc.TypeIDToOID(rw.ID) } if typ.Family() != types.ArrayFamily { - if rw, ok := descriptorRewrites[typedesc.GetArrayTypeDescID(typ)]; ok { + tid, err = typedesc.GetUserDefinedArrayTypeDescID(typ) + if err != nil { + return err + } + if rw, ok := descriptorRewrites[tid]; ok { newArrayOID = typedesc.TypeIDToOID(rw.ID) } } types.RemapUserDefinedTypeOIDs(typ, newOID, newArrayOID) // If the type is an array, then we need to rewrite the element type as well. if typ.Family() == types.ArrayFamily { - rewriteIDsInTypesT(typ.ArrayContents(), descriptorRewrites) + if err := rewriteIDsInTypesT(typ.ArrayContents(), descriptorRewrites); err != nil { + return err + } } + + return nil } // rewriteTypeDescs rewrites all ID's in the input slice of TypeDescriptors @@ -1075,7 +1100,9 @@ func rewriteTypeDescs(types []*typedesc.Mutable, descriptorRewrites DescRewriteM } case descpb.TypeDescriptor_ALIAS: // We need to rewrite any ID's present in the aliased types.T. - rewriteIDsInTypesT(typ.Alias, descriptorRewrites) + if err := rewriteIDsInTypesT(typ.Alias, descriptorRewrites); err != nil { + return err + } default: return errors.AssertionFailedf("unknown type kind %s", t.String()) } @@ -1285,7 +1312,9 @@ func RewriteTableDescs( // rewriteCol is a closure that performs the ID rewrite logic on a column. rewriteCol := func(col *descpb.ColumnDescriptor) error { // Rewrite the types.T's IDs present in the column. - rewriteIDsInTypesT(col.Type, descriptorRewrites) + if err := rewriteIDsInTypesT(col.Type, descriptorRewrites); err != nil { + return err + } var newUsedSeqRefs []descpb.ID for _, seqID := range col.UsesSequenceIds { if rewrite, ok := descriptorRewrites[seqID]; ok { diff --git a/pkg/ccl/changefeedccl/schemafeed/schema_feed.go b/pkg/ccl/changefeedccl/schemafeed/schema_feed.go index 628229e39bf0..a0366f99d746 100644 --- a/pkg/ccl/changefeedccl/schemafeed/schema_feed.go +++ b/pkg/ccl/changefeedccl/schemafeed/schema_feed.go @@ -178,16 +178,27 @@ func (t *typeDependencyTracker) removeDependency(typeID, tableID descpb.ID) { } } -func (t *typeDependencyTracker) purgeTable(tbl catalog.TableDescriptor) { +func (t *typeDependencyTracker) purgeTable(tbl catalog.TableDescriptor) error { for _, col := range tbl.UserDefinedTypeColumns() { - t.removeDependency(typedesc.UserDefinedTypeOIDToID(col.GetType().Oid()), tbl.GetID()) + id, err := typedesc.UserDefinedTypeOIDToID(col.GetType().Oid()) + if err != nil { + return err + } + t.removeDependency(id, tbl.GetID()) } + + return nil } -func (t *typeDependencyTracker) ingestTable(tbl catalog.TableDescriptor) { +func (t *typeDependencyTracker) ingestTable(tbl catalog.TableDescriptor) error { for _, col := range tbl.UserDefinedTypeColumns() { - t.addDependency(typedesc.UserDefinedTypeOIDToID(col.GetType().Oid()), tbl.GetID()) + id, err := typedesc.UserDefinedTypeOIDToID(col.GetType().Oid()) + if err != nil { + return err + } + t.addDependency(id, tbl.GetID()) } + return nil } func (t *typeDependencyTracker) containsType(id descpb.ID) bool { @@ -289,7 +300,10 @@ func (tf *SchemaFeed) primeInitialTableDescs(ctx context.Context) error { // Register all types used by the initial set of tables. for _, desc := range initialDescs { tbl := desc.(catalog.TableDescriptor) - tf.mu.typeDeps.ingestTable(tbl) + if err := tf.mu.typeDeps.ingestTable(tbl); err != nil { + tf.mu.Unlock() + return err + } } tf.mu.Unlock() @@ -533,7 +547,9 @@ func (tf *SchemaFeed) validateDescriptor( } // Purge the old version of the table from the type mapping. - tf.mu.typeDeps.purgeTable(lastVersion) + if err := tf.mu.typeDeps.purgeTable(lastVersion); err != nil { + return err + } e := TableEvent{ Before: lastVersion, @@ -559,7 +575,9 @@ func (tf *SchemaFeed) validateDescriptor( } } // Add the types used by the table into the dependency tracker. - tf.mu.typeDeps.ingestTable(desc) + if err := tf.mu.typeDeps.ingestTable(desc); err != nil { + return err + } tf.mu.previousTableVersion[desc.GetID()] = desc return nil default: diff --git a/pkg/sql/BUILD.bazel b/pkg/sql/BUILD.bazel index a595af981f85..85aa241635ae 100644 --- a/pkg/sql/BUILD.bazel +++ b/pkg/sql/BUILD.bazel @@ -302,7 +302,6 @@ go_library( "//pkg/sql/inverted", "//pkg/sql/lex", "//pkg/sql/mutations", - "//pkg/sql/oidext", "//pkg/sql/opt", "//pkg/sql/opt/cat", "//pkg/sql/opt/constraint", diff --git a/pkg/sql/catalog/dbdesc/database_desc.go b/pkg/sql/catalog/dbdesc/database_desc.go index b083b89817d3..6d26ac76d701 100644 --- a/pkg/sql/catalog/dbdesc/database_desc.go +++ b/pkg/sql/catalog/dbdesc/database_desc.go @@ -225,15 +225,19 @@ func (desc *immutable) validateMultiRegion(vea catalog.ValidationErrorAccumulato // GetReferencedDescIDs returns the IDs of all descriptors referenced by // this descriptor, including itself. -func (desc *immutable) GetReferencedDescIDs() catalog.DescriptorIDSet { +func (desc *immutable) GetReferencedDescIDs() (catalog.DescriptorIDSet, error) { ids := catalog.MakeDescriptorIDSet(desc.GetID()) - if id, err := desc.MultiRegionEnumID(); err == nil { + if desc.IsMultiRegion() { + id, err := desc.MultiRegionEnumID() + if err != nil { + return catalog.DescriptorIDSet{}, err + } ids.Add(id) } for _, schema := range desc.Schemas { ids.Add(schema.ID) } - return ids + return ids, nil } // ValidateCrossReferences implements the catalog.Descriptor interface. diff --git a/pkg/sql/catalog/descriptor.go b/pkg/sql/catalog/descriptor.go index 3d22ff81bf91..1b07c5f632d2 100644 --- a/pkg/sql/catalog/descriptor.go +++ b/pkg/sql/catalog/descriptor.go @@ -140,7 +140,7 @@ type Descriptor interface { // GetReferencedDescIDs returns the IDs of all descriptors directly referenced // by this descriptor, including itself. - GetReferencedDescIDs() DescriptorIDSet + GetReferencedDescIDs() (DescriptorIDSet, error) // ValidateSelf checks the internal consistency of the descriptor. ValidateSelf(vea ValidationErrorAccumulator) @@ -358,7 +358,7 @@ type TypeDescriptor interface { HydrateTypeInfoWithName(ctx context.Context, typ *types.T, name *tree.TypeName, res TypeDescriptorResolver) error MakeTypesT(ctx context.Context, name *tree.TypeName, res TypeDescriptorResolver) (*types.T, error) HasPendingSchemaChanges() bool - GetIDClosure() map[descpb.ID]struct{} + GetIDClosure() (map[descpb.ID]struct{}, error) IsCompatibleWith(other TypeDescriptor) error PrimaryRegionName() (descpb.RegionName, error) diff --git a/pkg/sql/catalog/descs/collection.go b/pkg/sql/catalog/descs/collection.go index b1501e8f7b70..371044db6554 100644 --- a/pkg/sql/catalog/descs/collection.go +++ b/pkg/sql/catalog/descs/collection.go @@ -2180,7 +2180,11 @@ func (dt DistSQLTypeResolver) ResolveType( // ResolveTypeByOID implements the tree.TypeReferenceResolver interface. func (dt DistSQLTypeResolver) ResolveTypeByOID(ctx context.Context, oid oid.Oid) (*types.T, error) { - name, desc, err := dt.GetTypeDescriptor(ctx, typedesc.UserDefinedTypeOIDToID(oid)) + id, err := typedesc.UserDefinedTypeOIDToID(oid) + if err != nil { + return nil, err + } + name, desc, err := dt.GetTypeDescriptor(ctx, id) if err != nil { return nil, err } @@ -2213,7 +2217,11 @@ func (dt DistSQLTypeResolver) GetTypeDescriptor( func (dt DistSQLTypeResolver) HydrateTypeSlice(ctx context.Context, typs []*types.T) error { for _, t := range typs { if t.UserDefined() { - name, desc, err := dt.GetTypeDescriptor(ctx, typedesc.GetTypeDescID(t)) + id, err := typedesc.GetUserDefinedTypeDescID(t) + if err != nil { + return err + } + name, desc, err := dt.GetTypeDescriptor(ctx, id) if err != nil { return err } diff --git a/pkg/sql/catalog/schemadesc/schema_desc.go b/pkg/sql/catalog/schemadesc/schema_desc.go index 27701353c397..abadad7655f3 100644 --- a/pkg/sql/catalog/schemadesc/schema_desc.go +++ b/pkg/sql/catalog/schemadesc/schema_desc.go @@ -146,8 +146,8 @@ func (desc *immutable) ValidateSelf(vea catalog.ValidationErrorAccumulator) { // GetReferencedDescIDs returns the IDs of all descriptors referenced by // this descriptor, including itself. -func (desc *immutable) GetReferencedDescIDs() catalog.DescriptorIDSet { - return catalog.MakeDescriptorIDSet(desc.GetID(), desc.GetParentID()) +func (desc *immutable) GetReferencedDescIDs() (catalog.DescriptorIDSet, error) { + return catalog.MakeDescriptorIDSet(desc.GetID(), desc.GetParentID()), nil } // ValidateCrossReferences implements the catalog.Descriptor interface. diff --git a/pkg/sql/catalog/tabledesc/structured.go b/pkg/sql/catalog/tabledesc/structured.go index aa218aa0cfce..123c84693cf6 100644 --- a/pkg/sql/catalog/tabledesc/structured.go +++ b/pkg/sql/catalog/tabledesc/structured.go @@ -500,27 +500,44 @@ func (desc *wrapper) getAllReferencedTypesInTableColumns( // collect the closure of ID's referenced. ids := make(map[descpb.ID]struct{}) for id := range visitor.OIDs { - typDesc, err := getType(typedesc.UserDefinedTypeOIDToID(id)) + uid, err := typedesc.UserDefinedTypeOIDToID(id) if err != nil { return nil, err } - for child := range typDesc.GetIDClosure() { + typDesc, err := getType(uid) + if err != nil { + return nil, err + } + children, err := typDesc.GetIDClosure() + if err != nil { + return nil, err + } + for child := range children { ids[child] = struct{}{} } } // Now add all of the column types in the table. - addIDsInColumn := func(c *descpb.ColumnDescriptor) { - for id := range typedesc.GetTypeDescriptorClosure(c.Type) { + addIDsInColumn := func(c *descpb.ColumnDescriptor) error { + children, err := typedesc.GetTypeDescriptorClosure(c.Type) + if err != nil { + return err + } + for id := range children { ids[id] = struct{}{} } + return nil } for i := range desc.Columns { - addIDsInColumn(&desc.Columns[i]) + if err := addIDsInColumn(&desc.Columns[i]); err != nil { + return nil, err + } } for _, mut := range desc.Mutations { if c := mut.GetColumn(); c != nil { - addIDsInColumn(c) + if err := addIDsInColumn(c); err != nil { + return nil, err + } } } diff --git a/pkg/sql/catalog/tabledesc/validate.go b/pkg/sql/catalog/tabledesc/validate.go index 3ba8c901b617..79400ba9c084 100644 --- a/pkg/sql/catalog/tabledesc/validate.go +++ b/pkg/sql/catalog/tabledesc/validate.go @@ -47,7 +47,7 @@ func (desc *wrapper) ValidateTxnCommit( // GetReferencedDescIDs returns the IDs of all descriptors referenced by // this descriptor, including itself. -func (desc *wrapper) GetReferencedDescIDs() catalog.DescriptorIDSet { +func (desc *wrapper) GetReferencedDescIDs() (catalog.DescriptorIDSet, error) { ids := catalog.MakeDescriptorIDSet(desc.GetID(), desc.GetParentID()) if desc.GetParentSchemaID() != keys.PublicSchemaID { ids.Add(desc.GetParentSchemaID()) @@ -69,7 +69,11 @@ func (desc *wrapper) GetReferencedDescIDs() catalog.DescriptorIDSet { } // Collect user defined type Oids and sequence references in columns. for _, col := range desc.DeletableColumns() { - for id := range typedesc.GetTypeDescriptorClosure(col.GetType()) { + children, err := typedesc.GetTypeDescriptorClosure(col.GetType()) + if err != nil { + return catalog.DescriptorIDSet{}, err + } + for id := range children { ids.Add(id) } for i := 0; i < col.NumUsesSequences(); i++ { @@ -89,7 +93,11 @@ func (desc *wrapper) GetReferencedDescIDs() catalog.DescriptorIDSet { }) // Add collected Oids to return set. for oid := range visitor.OIDs { - ids.Add(typedesc.UserDefinedTypeOIDToID(oid)) + id, err := typedesc.UserDefinedTypeOIDToID(oid) + if err != nil { + return catalog.DescriptorIDSet{}, err + } + ids.Add(id) } // Add view dependencies. for _, id := range desc.GetDependsOn() { @@ -102,7 +110,7 @@ func (desc *wrapper) GetReferencedDescIDs() catalog.DescriptorIDSet { ids.Add(ref.ID) } // Add sequence dependencies - return ids + return ids, nil } // ValidateCrossReferences validates that each reference to another table is diff --git a/pkg/sql/catalog/typedesc/BUILD.bazel b/pkg/sql/catalog/typedesc/BUILD.bazel index 10b9127b8ef1..a03794ab256c 100644 --- a/pkg/sql/catalog/typedesc/BUILD.bazel +++ b/pkg/sql/catalog/typedesc/BUILD.bazel @@ -46,6 +46,7 @@ go_test( "//pkg/sql/catalog/dbdesc", "//pkg/sql/catalog/descpb", "//pkg/sql/catalog/schemadesc", + "//pkg/sql/oidext", "//pkg/sql/privilege", "//pkg/sql/types", "//pkg/testutils", @@ -53,6 +54,7 @@ go_test( "//pkg/util/leaktest", "//pkg/util/randutil", "@com_github_cockroachdb_redact//:redact", + "@com_github_lib_pq//oid", "@com_github_stretchr_testify//require", "@in_gopkg_yaml_v2//:yaml_v2", ], diff --git a/pkg/sql/catalog/typedesc/type_desc.go b/pkg/sql/catalog/typedesc/type_desc.go index 5bb2ee0760fc..17c370ae42f9 100644 --- a/pkg/sql/catalog/typedesc/type_desc.go +++ b/pkg/sql/catalog/typedesc/type_desc.go @@ -114,19 +114,25 @@ func TypeIDToOID(id descpb.ID) oid.Oid { } // UserDefinedTypeOIDToID converts a user defined type OID into a -// descriptor ID. -func UserDefinedTypeOIDToID(oid oid.Oid) descpb.ID { - return descpb.ID(oid) - oidext.CockroachPredefinedOIDMax +// descriptor ID. OID of a user-defined type must be greater than +// CockroachPredefinedOIDMax. The function returns an error if the +// given OID is less than or equals to CockroachPredefinedMax. +func UserDefinedTypeOIDToID(oid oid.Oid) (descpb.ID, error) { + if descpb.ID(oid) <= oidext.CockroachPredefinedOIDMax { + return 0, errors.Newf("user-defined OID %d should be greater "+ + "than predefined Max: %d.", oid, oidext.CockroachPredefinedOIDMax) + } + return descpb.ID(oid) - oidext.CockroachPredefinedOIDMax, nil } -// GetTypeDescID gets the type descriptor ID from a user defined type. -func GetTypeDescID(t *types.T) descpb.ID { +// GetUserDefinedTypeDescID gets the type descriptor ID from a user defined type. +func GetUserDefinedTypeDescID(t *types.T) (descpb.ID, error) { return UserDefinedTypeOIDToID(t.Oid()) } -// GetArrayTypeDescID gets the ID of the array type descriptor from a user +// GetUserDefinedArrayTypeDescID gets the ID of the array type descriptor from a user // defined type. -func GetArrayTypeDescID(t *types.T) descpb.ID { +func GetUserDefinedArrayTypeDescID(t *types.T) (descpb.ID, error) { return UserDefinedTypeOIDToID(t.UserDefinedArrayOID()) } @@ -554,16 +560,20 @@ func (desc *immutable) validateEnumMembers(vea catalog.ValidationErrorAccumulato // GetReferencedDescIDs returns the IDs of all descriptors referenced by // this descriptor, including itself. -func (desc *immutable) GetReferencedDescIDs() catalog.DescriptorIDSet { +func (desc *immutable) GetReferencedDescIDs() (catalog.DescriptorIDSet, error) { ids := catalog.MakeDescriptorIDSet(desc.GetReferencingDescriptorIDs()...) ids.Add(desc.GetParentID()) if desc.GetParentSchemaID() != keys.PublicSchemaID { ids.Add(desc.GetParentSchemaID()) } - for id := range desc.GetIDClosure() { + children, err := desc.GetIDClosure() + if err != nil { + return catalog.DescriptorIDSet{}, err + } + for id := range children { ids.Add(id) } - return ids + return ids, nil } // ValidateCrossReferences performs cross reference checks on the type descriptor. @@ -599,7 +609,10 @@ func (desc *immutable) ValidateCrossReferences( } case descpb.TypeDescriptor_ALIAS: if desc.GetAlias().UserDefined() { - aliasedID := UserDefinedTypeOIDToID(desc.GetAlias().Oid()) + aliasedID, err := UserDefinedTypeOIDToID(desc.GetAlias().Oid()) + if err != nil { + vea.Report(err) + } if _, err := vdg.GetTypeDescriptor(aliasedID); err != nil { vea.Report(errors.Wrapf(err, "aliased type %d does not exist", aliasedID)) } @@ -724,7 +737,11 @@ func HydrateTypesInTableDescriptor( hydrateCol := func(col *descpb.ColumnDescriptor) error { if col.Type.UserDefined() { // Look up its type descriptor. - name, typDesc, err := res.GetTypeDescriptor(ctx, GetTypeDescID(col.Type)) + td, err := GetUserDefinedTypeDescID(col.Type) + if err != nil { + return err + } + name, typDesc, err := res.GetTypeDescriptor(ctx, td) if err != nil { return err } @@ -787,7 +804,11 @@ func (desc *immutable) HydrateTypeInfoWithName( case types.ArrayFamily: // Hydrate the element type. elemType := typ.ArrayContents() - elemTypName, elemTypDesc, err := res.GetTypeDescriptor(ctx, GetTypeDescID(elemType)) + id, err := GetUserDefinedTypeDescID(elemType) + if err != nil { + return err + } + elemTypName, elemTypDesc, err := res.GetTypeDescriptor(ctx, id) if err != nil { return err } @@ -901,14 +922,17 @@ func (desc *immutable) HasPendingSchemaChanges() bool { // GetIDClosure returns all type descriptor IDs that are referenced by this // type descriptor. -func (desc *immutable) GetIDClosure() map[descpb.ID]struct{} { +func (desc *immutable) GetIDClosure() (map[descpb.ID]struct{}, error) { ret := make(map[descpb.ID]struct{}) // Collect the descriptor's own ID. ret[desc.ID] = struct{}{} if desc.Kind == descpb.TypeDescriptor_ALIAS { // If this descriptor is an alias for another type, then get collect the // closure for alias. - children := GetTypeDescriptorClosure(desc.Alias) + children, err := GetTypeDescriptorClosure(desc.Alias) + if err != nil { + return nil, err + } for id := range children { ret[id] = struct{}{} } @@ -916,28 +940,39 @@ func (desc *immutable) GetIDClosure() map[descpb.ID]struct{} { // Otherwise, take the array type ID. ret[desc.ArrayTypeID] = struct{}{} } - return ret + return ret, nil } // GetTypeDescriptorClosure returns all type descriptor IDs that are // referenced by this input types.T. -func GetTypeDescriptorClosure(typ *types.T) map[descpb.ID]struct{} { +func GetTypeDescriptorClosure(typ *types.T) (map[descpb.ID]struct{}, error) { if !typ.UserDefined() { - return map[descpb.ID]struct{}{} + return map[descpb.ID]struct{}{}, nil + } + id, err := GetUserDefinedTypeDescID(typ) + if err != nil { + return nil, err } // Collect the type's descriptor ID. ret := map[descpb.ID]struct{}{ - GetTypeDescID(typ): {}, + id: {}, } if typ.Family() == types.ArrayFamily { // If we have an array type, then collect all types in the contents. - children := GetTypeDescriptorClosure(typ.ArrayContents()) + children, err := GetTypeDescriptorClosure(typ.ArrayContents()) + if err != nil { + return nil, err + } for id := range children { ret[id] = struct{}{} } } else { // Otherwise, take the array type ID. - ret[GetArrayTypeDescID(typ)] = struct{}{} + id, err := GetUserDefinedArrayTypeDescID(typ) + if err != nil { + return nil, err + } + ret[id] = struct{}{} } - return ret + return ret, nil } diff --git a/pkg/sql/catalog/typedesc/type_desc_test.go b/pkg/sql/catalog/typedesc/type_desc_test.go index 91e32be2502c..0a1a495f16ba 100644 --- a/pkg/sql/catalog/typedesc/type_desc_test.go +++ b/pkg/sql/catalog/typedesc/type_desc_test.go @@ -13,6 +13,7 @@ package typedesc_test import ( "context" "fmt" + "math" "testing" "github.com/cockroachdb/cockroach/pkg/keys" @@ -22,10 +23,12 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" "github.com/cockroachdb/cockroach/pkg/sql/catalog/schemadesc" "github.com/cockroachdb/cockroach/pkg/sql/catalog/typedesc" + "github.com/cockroachdb/cockroach/pkg/sql/oidext" "github.com/cockroachdb/cockroach/pkg/sql/privilege" "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/lib/pq/oid" "github.com/stretchr/testify/require" ) @@ -784,3 +787,28 @@ func TestValidateTypeDesc(t *testing.T) { } } } + +func TestOIDToIDConversion(t *testing.T) { + tests := []struct { + oid oid.Oid + ok bool + name string + }{ + {oid.Oid(0), false, "default OID"}, + {oid.Oid(1), false, "Standard OID"}, + {oid.Oid(oidext.CockroachPredefinedOIDMax), false, "max standard OID"}, + {oid.Oid(oidext.CockroachPredefinedOIDMax + 1), true, "user-defined OID"}, + {oid.Oid(math.MaxUint32), true, "max user-defined OID"}, + } + + for _, test := range tests { + t.Run(fmt.Sprint(test.oid), func(t *testing.T) { + _, err := typedesc.UserDefinedTypeOIDToID(test.oid) + if test.ok { + require.NoError(t, err) + } else { + require.Error(t, err) + } + }) + } +} diff --git a/pkg/sql/catalog/validate.go b/pkg/sql/catalog/validate.go index 3ad024fbfc7c..bb2b6e846a2d 100644 --- a/pkg/sql/catalog/validate.go +++ b/pkg/sql/catalog/validate.go @@ -437,9 +437,14 @@ type collectorState struct { } // addDirectReferences adds all immediate neighbors of desc to the state. -func (cs *collectorState) addDirectReferences(desc Descriptor) { +func (cs *collectorState) addDirectReferences(desc Descriptor) error { cs.vdg.Descriptors[desc.GetID()] = desc - desc.GetReferencedDescIDs().ForEach(cs.referencedBy.Add) + idSet, err := desc.GetReferencedDescIDs() + if err != nil { + return err + } + idSet.ForEach(cs.referencedBy.Add) + return nil } // getMissingDescs fetches the descriptors which have corresponding IDs in the @@ -491,7 +496,9 @@ func collectDescriptorsForValidation( referencedBy: MakeDescriptorIDSet(), } for _, desc := range descriptors { - cs.addDirectReferences(desc) + if err := cs.addDirectReferences(desc); err != nil { + return nil, err + } } newDescs, err := cs.getMissingDescs(ctx, maybeBatchDescGetter) if err != nil { @@ -503,7 +510,9 @@ func collectDescriptorsForValidation( } switch newDesc.(type) { case DatabaseDescriptor, TypeDescriptor: - cs.addDirectReferences(newDesc) + if err := cs.addDirectReferences(newDesc); err != nil { + return nil, err + } } } _, err = cs.getMissingDescs(ctx, maybeBatchDescGetter) diff --git a/pkg/sql/database_region_change_finalizer.go b/pkg/sql/database_region_change_finalizer.go index 1b7560c0f326..e911f2eddb64 100644 --- a/pkg/sql/database_region_change_finalizer.go +++ b/pkg/sql/database_region_change_finalizer.go @@ -169,8 +169,14 @@ func (r *databaseRegionChangeFinalizer) repartitionRegionalByRowTables( // the table descriptor with the new type metadata. for i := range tableDesc.Columns { col := &tableDesc.Columns[i] - if col.Type.UserDefined() && typedesc.UserDefinedTypeOIDToID(col.Type.Oid()) == r.typeID { - col.Type.TypeMeta = types.UserDefinedTypeMetadata{} + if col.Type.UserDefined() { + tid, err := typedesc.UserDefinedTypeOIDToID(col.Type.Oid()) + if err != nil { + return err + } + if tid == r.typeID { + col.Type.TypeMeta = types.UserDefinedTypeMetadata{} + } } } if err := typedesc.HydrateTypesInTableDescriptor( diff --git a/pkg/sql/opt/optbuilder/builder.go b/pkg/sql/opt/optbuilder/builder.go index 441576614e43..5228724bd4bf 100644 --- a/pkg/sql/opt/optbuilder/builder.go +++ b/pkg/sql/opt/optbuilder/builder.go @@ -429,7 +429,11 @@ func (b *Builder) maybeTrackRegclassDependenciesForViews(texpr tree.TypedExpr) { func (b *Builder) maybeTrackUserDefinedTypeDepsForViews(texpr tree.TypedExpr) { if b.trackViewDeps { if texpr.ResolvedType().UserDefined() { - for id := range typedesc.GetTypeDescriptorClosure(texpr.ResolvedType()) { + children, err := typedesc.GetTypeDescriptorClosure(texpr.ResolvedType()) + if err != nil { + panic(err) + } + for id := range children { b.viewTypeDeps.Add(int(id)) } } diff --git a/pkg/sql/opt/testutils/testcat/test_catalog.go b/pkg/sql/opt/testutils/testcat/test_catalog.go index 466c92b0d691..a17eecb91659 100644 --- a/pkg/sql/opt/testutils/testcat/test_catalog.go +++ b/pkg/sql/opt/testutils/testcat/test_catalog.go @@ -794,7 +794,11 @@ func (tt *Table) CollectTypes(ord int) (descpb.IDs, error) { ids := make(descpb.IDs, 0, len(visitor.OIDs)) for collectedOid := range visitor.OIDs { - ids = append(ids, typedesc.UserDefinedTypeOIDToID(collectedOid)) + id, err := typedesc.UserDefinedTypeOIDToID(collectedOid) + if err != nil { + return nil, err + } + ids = append(ids, id) } return ids, nil } diff --git a/pkg/sql/opt_catalog.go b/pkg/sql/opt_catalog.go index beada93cbb70..293ad01cd19c 100644 --- a/pkg/sql/opt_catalog.go +++ b/pkg/sql/opt_catalog.go @@ -2274,7 +2274,11 @@ func collectTypes(col catalog.Column) (descpb.IDs, error) { ids := make(descpb.IDs, 0, len(visitor.OIDs)) for collectedOid := range visitor.OIDs { - ids = append(ids, typedesc.UserDefinedTypeOIDToID(collectedOid)) + id, err := typedesc.UserDefinedTypeOIDToID(collectedOid) + if err != nil { + return nil, err + } + ids = append(ids, id) } return ids, nil } diff --git a/pkg/sql/pg_catalog.go b/pkg/sql/pg_catalog.go index c7a52038a0ac..0528acd06725 100644 --- a/pkg/sql/pg_catalog.go +++ b/pkg/sql/pg_catalog.go @@ -30,7 +30,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog/schemaexpr" "github.com/cockroachdb/cockroach/pkg/sql/catalog/tabledesc" "github.com/cockroachdb/cockroach/pkg/sql/catalog/typedesc" - "github.com/cockroachdb/cockroach/pkg/sql/oidext" "github.com/cockroachdb/cockroach/pkg/sql/parser" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" @@ -2437,17 +2436,20 @@ https://www.postgresql.org/docs/9.5/catalog-pg-type.html`, return true, nil } - // This oid is not a user-defined type and we didn't find it in the - // map of predefined types, return false. Note that in common usage we - // only really expect the value 0 here (which cockroach uses internally - // in the typelem field amongst others). Users, however, may join on - // this index with any value. - if ooid <= oidext.CockroachPredefinedOIDMax { + // Check if it is a user defined type. + if !types.IsOIDUserDefinedType(ooid) { + // This oid is not a user-defined type and we didn't find it in the + // map of predefined types, return false. Note that in common usage we + // only really expect the value 0 here (which cockroach uses internally + // in the typelem field amongst others). Users, however, may join on + // this index with any value. return false, nil } - // Check if it is a user defined type. - id := typedesc.UserDefinedTypeOIDToID(ooid) + id, err := typedesc.UserDefinedTypeOIDToID(ooid) + if err != nil { + return false, err + } typDesc, err := p.Descriptors().GetImmutableTypeByID(ctx, p.txn, id, tree.ObjectLookupFlags{}) if err != nil { if errors.Is(err, catalog.ErrDescriptorNotFound) { diff --git a/pkg/sql/resolver.go b/pkg/sql/resolver.go index 0e55c2f06d68..0b901ba4a06c 100644 --- a/pkg/sql/resolver.go +++ b/pkg/sql/resolver.go @@ -276,7 +276,16 @@ func (p *planner) IsTypeVisible( if _, ok := types.OidToType[typeID]; ok { return true, true, nil } - typName, _, err := p.GetTypeDescriptor(ctx, typedesc.UserDefinedTypeOIDToID(typeID)) + + if !types.IsOIDUserDefinedType(typeID) { + return false, false, nil //nolint:returnerrcheck + } + + id, err := typedesc.UserDefinedTypeOIDToID(typeID) + if err != nil { + return false, false, err + } + typName, _, err := p.GetTypeDescriptor(ctx, id) if err != nil { // If a "not found" error happened here, we return "not exists" rather than // the error. @@ -361,7 +370,11 @@ func (p *planner) ResolveType( // ResolveTypeByOID implements the tree.TypeResolver interface. func (p *planner) ResolveTypeByOID(ctx context.Context, oid oid.Oid) (*types.T, error) { - name, desc, err := p.GetTypeDescriptor(ctx, typedesc.UserDefinedTypeOIDToID(oid)) + id, err := typedesc.UserDefinedTypeOIDToID(oid) + if err != nil { + return nil, err + } + name, desc, err := p.GetTypeDescriptor(ctx, id) if err != nil { return nil, err } diff --git a/pkg/sql/sem/tree/casts.go b/pkg/sql/sem/tree/casts.go index 15011abd173a..63df8464be0e 100644 --- a/pkg/sql/sem/tree/casts.go +++ b/pkg/sql/sem/tree/casts.go @@ -1302,7 +1302,11 @@ func performIntToOidCast(ctx *EvalContext, t *types.T, v DInt) (Datum, error) { ret := &DOid{semanticType: t, DInt: v} if typ, ok := types.OidToType[oid.Oid(v)]; ok { ret.name = typ.PGName() - } else if typ, err := ctx.Planner.ResolveTypeByOID(ctx.Context, oid.Oid(v)); err == nil { + } else if types.IsOIDUserDefinedType(oid.Oid(v)) { + typ, err := ctx.Planner.ResolveTypeByOID(ctx.Context, oid.Oid(v)) + if err != nil { + return nil, err + } ret.name = typ.PGName() } return ret, nil diff --git a/pkg/sql/type_change.go b/pkg/sql/type_change.go index 569c96d99432..26750b1a7f7e 100644 --- a/pkg/sql/type_change.go +++ b/pkg/sql/type_change.go @@ -658,7 +658,10 @@ func findUsagesOfEnumValue( if !ok { return true, expr, nil } - id := typedesc.UserDefinedTypeOIDToID(typeOid.OID) + id, err := typedesc.UserDefinedTypeOIDToID(typeOid.OID) + if err != nil { + return false, expr, err + } if id != typeID { return true, expr, nil } @@ -680,8 +683,12 @@ func findUsagesOfEnumValue( if !ok { return true, expr, nil } + id, err := typedesc.UserDefinedTypeOIDToID(typeOid.OID) + if err != nil { + return false, expr, err + } // -1 since the type of this CastExpr is the array type. - id := typedesc.UserDefinedTypeOIDToID(typeOid.OID) - 1 + id = id - 1 if id != typeID { return true, expr, nil } @@ -726,7 +733,10 @@ func findUsagesOfEnumValueInViewQuery( if !ok { return true, expr, nil } - id := typedesc.UserDefinedTypeOIDToID(typeOid.OID) + id, err := typedesc.UserDefinedTypeOIDToID(typeOid.OID) + if err != nil { + return false, expr, err + } if id != typeID { return true, expr, nil } @@ -816,22 +826,28 @@ func (t *typeSchemaChanger) canRemoveEnumValue( } } - if typeDesc.ID == typedesc.GetTypeDescID(col.GetType()) { - if !firstClause { - query.WriteString(" OR") + if col.GetType().UserDefined() { + tid, terr := typedesc.GetUserDefinedTypeDescID(col.GetType()) + if terr != nil { + return terr } - sqlPhysRep, err := convertToSQLStringRepresentation(member.PhysicalRepresentation) - if err != nil { - return err + if typeDesc.ID == tid { + if !firstClause { + query.WriteString(" OR") + } + sqlPhysRep, err := convertToSQLStringRepresentation(member.PhysicalRepresentation) + if err != nil { + return err + } + colName := col.ColName() + query.WriteString(fmt.Sprintf( + " t.%s = %s", + colName.String(), + sqlPhysRep, + )) + firstClause = false + validationQueryConstructed = true } - colName := col.ColName() - query.WriteString(fmt.Sprintf( - " t.%s = %s", - colName.String(), - sqlPhysRep, - )) - firstClause = false - validationQueryConstructed = true } } query.WriteString(" LIMIT 1") @@ -923,7 +939,14 @@ func (t *typeSchemaChanger) canRemoveEnumValueFromArrayUsages( // ) WHERE unnest = 'enum_value' firstClause := true for _, col := range desc.PublicColumns() { - if arrayTypeDesc.GetID() == typedesc.GetTypeDescID(col.GetType()) { + if !col.GetType().UserDefined() { + continue + } + tid, terr := typedesc.GetUserDefinedTypeDescID(col.GetType()) + if terr != nil { + return terr + } + if arrayTypeDesc.GetID() == tid { if !firstClause { unionUnnests.WriteString(" UNION ") }