Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql/catalog/typedesc: validate OID range before converting it to ID #65352

Merged
merged 1 commit into from
May 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions pkg/ccl/backupccl/restore_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,23 @@ func rewriteTypesInExpr(expr string, rewrites DescRewriteMap) (string, error) {
if err != nil {
return "", err
}

ctx := tree.NewFmtCtx(tree.FmtSerializable)
ctx.SetIndexedTypeFormat(func(ctx *tree.FmtCtx, ref *tree.OIDTypeReference) {
newRef := ref
if rw, ok := rewrites[typedesc.UserDefinedTypeOIDToID(ref.OID)]; ok {
var id descpb.ID
id, err = typedesc.UserDefinedTypeOIDToID(ref.OID)
if err != nil {
return
}
if rw, ok := rewrites[id]; ok {
newRef = &tree.OIDTypeReference{OID: typedesc.TypeIDToOID(rw.ID)}
}
ctx.WriteString(newRef.SQLString())
})
if err != nil {
return "", err
}
ctx.FormatNode(parsed)
return ctx.CloseAndGetString(), nil
}
Expand Down Expand Up @@ -348,11 +357,15 @@ func allocateDescriptorRewrites(
// Ensure that all referenced types are present.
if col.Type.UserDefined() {
// TODO (rohany): This can be turned into an option later.
if _, ok := typesByID[typedesc.GetTypeDescID(col.Type)]; !ok {
id, err := typedesc.GetUserDefinedTypeDescID(col.Type)
if err != nil {
return nil, err
}
if _, ok := typesByID[id]; !ok {
return nil, errors.Errorf(
"cannot restore table %q without referenced type %d",
table.Name,
typedesc.GetTypeDescID(col.Type),
id,
)
}
}
Expand Down Expand Up @@ -1025,25 +1038,37 @@ func rewriteDatabaseDescs(databases []*dbdesc.Mutable, descriptorRewrites DescRe

// rewriteIDsInTypesT rewrites all ID's in the input types.T using the input
// ID rewrite mapping.
func rewriteIDsInTypesT(typ *types.T, descriptorRewrites DescRewriteMap) {
func rewriteIDsInTypesT(typ *types.T, descriptorRewrites DescRewriteMap) error {
if !typ.UserDefined() {
return
return nil
}
tid, err := typedesc.GetUserDefinedTypeDescID(typ)
if err != nil {
return err
}
// Collect potential new OID values.
var newOID, newArrayOID oid.Oid
if rw, ok := descriptorRewrites[typedesc.GetTypeDescID(typ)]; ok {
if rw, ok := descriptorRewrites[tid]; ok {
newOID = typedesc.TypeIDToOID(rw.ID)
}
if typ.Family() != types.ArrayFamily {
if rw, ok := descriptorRewrites[typedesc.GetArrayTypeDescID(typ)]; ok {
tid, err = typedesc.GetUserDefinedArrayTypeDescID(typ)
if err != nil {
return err
}
if rw, ok := descriptorRewrites[tid]; ok {
newArrayOID = typedesc.TypeIDToOID(rw.ID)
}
}
types.RemapUserDefinedTypeOIDs(typ, newOID, newArrayOID)
// If the type is an array, then we need to rewrite the element type as well.
if typ.Family() == types.ArrayFamily {
rewriteIDsInTypesT(typ.ArrayContents(), descriptorRewrites)
if err := rewriteIDsInTypesT(typ.ArrayContents(), descriptorRewrites); err != nil {
return err
}
}

return nil
}

// rewriteTypeDescs rewrites all ID's in the input slice of TypeDescriptors
Expand Down Expand Up @@ -1075,7 +1100,9 @@ func rewriteTypeDescs(types []*typedesc.Mutable, descriptorRewrites DescRewriteM
}
case descpb.TypeDescriptor_ALIAS:
// We need to rewrite any ID's present in the aliased types.T.
rewriteIDsInTypesT(typ.Alias, descriptorRewrites)
if err := rewriteIDsInTypesT(typ.Alias, descriptorRewrites); err != nil {
return err
}
default:
return errors.AssertionFailedf("unknown type kind %s", t.String())
}
Expand Down Expand Up @@ -1285,7 +1312,9 @@ func RewriteTableDescs(
// rewriteCol is a closure that performs the ID rewrite logic on a column.
rewriteCol := func(col *descpb.ColumnDescriptor) error {
// Rewrite the types.T's IDs present in the column.
rewriteIDsInTypesT(col.Type, descriptorRewrites)
if err := rewriteIDsInTypesT(col.Type, descriptorRewrites); err != nil {
return err
}
var newUsedSeqRefs []descpb.ID
for _, seqID := range col.UsesSequenceIds {
if rewrite, ok := descriptorRewrites[seqID]; ok {
Expand Down
32 changes: 25 additions & 7 deletions pkg/ccl/changefeedccl/schemafeed/schema_feed.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,27 @@ func (t *typeDependencyTracker) removeDependency(typeID, tableID descpb.ID) {
}
}

func (t *typeDependencyTracker) purgeTable(tbl catalog.TableDescriptor) {
func (t *typeDependencyTracker) purgeTable(tbl catalog.TableDescriptor) error {
for _, col := range tbl.UserDefinedTypeColumns() {
t.removeDependency(typedesc.UserDefinedTypeOIDToID(col.GetType().Oid()), tbl.GetID())
id, err := typedesc.UserDefinedTypeOIDToID(col.GetType().Oid())
if err != nil {
return err
}
t.removeDependency(id, tbl.GetID())
}

return nil
}

func (t *typeDependencyTracker) ingestTable(tbl catalog.TableDescriptor) {
func (t *typeDependencyTracker) ingestTable(tbl catalog.TableDescriptor) error {
for _, col := range tbl.UserDefinedTypeColumns() {
t.addDependency(typedesc.UserDefinedTypeOIDToID(col.GetType().Oid()), tbl.GetID())
id, err := typedesc.UserDefinedTypeOIDToID(col.GetType().Oid())
if err != nil {
return err
}
t.addDependency(id, tbl.GetID())
}
return nil
}

func (t *typeDependencyTracker) containsType(id descpb.ID) bool {
Expand Down Expand Up @@ -289,7 +300,10 @@ func (tf *SchemaFeed) primeInitialTableDescs(ctx context.Context) error {
// Register all types used by the initial set of tables.
for _, desc := range initialDescs {
tbl := desc.(catalog.TableDescriptor)
tf.mu.typeDeps.ingestTable(tbl)
if err := tf.mu.typeDeps.ingestTable(tbl); err != nil {
tf.mu.Unlock()
return err
}
}
tf.mu.Unlock()

Expand Down Expand Up @@ -533,7 +547,9 @@ func (tf *SchemaFeed) validateDescriptor(
}

// Purge the old version of the table from the type mapping.
tf.mu.typeDeps.purgeTable(lastVersion)
if err := tf.mu.typeDeps.purgeTable(lastVersion); err != nil {
return err
}

e := TableEvent{
Before: lastVersion,
Expand All @@ -559,7 +575,9 @@ func (tf *SchemaFeed) validateDescriptor(
}
}
// Add the types used by the table into the dependency tracker.
tf.mu.typeDeps.ingestTable(desc)
if err := tf.mu.typeDeps.ingestTable(desc); err != nil {
return err
}
tf.mu.previousTableVersion[desc.GetID()] = desc
return nil
default:
Expand Down
1 change: 0 additions & 1 deletion pkg/sql/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,6 @@ go_library(
"//pkg/sql/inverted",
"//pkg/sql/lex",
"//pkg/sql/mutations",
"//pkg/sql/oidext",
"//pkg/sql/opt",
"//pkg/sql/opt/cat",
"//pkg/sql/opt/constraint",
Expand Down
10 changes: 7 additions & 3 deletions pkg/sql/catalog/dbdesc/database_desc.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,19 @@ func (desc *immutable) validateMultiRegion(vea catalog.ValidationErrorAccumulato

// GetReferencedDescIDs returns the IDs of all descriptors referenced by
// this descriptor, including itself.
func (desc *immutable) GetReferencedDescIDs() catalog.DescriptorIDSet {
func (desc *immutable) GetReferencedDescIDs() (catalog.DescriptorIDSet, error) {
ids := catalog.MakeDescriptorIDSet(desc.GetID())
if id, err := desc.MultiRegionEnumID(); err == nil {
if desc.IsMultiRegion() {
id, err := desc.MultiRegionEnumID()
if err != nil {
return catalog.DescriptorIDSet{}, err
}
ids.Add(id)
}
for _, schema := range desc.Schemas {
ids.Add(schema.ID)
}
return ids
return ids, nil
}

// ValidateCrossReferences implements the catalog.Descriptor interface.
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/catalog/descriptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ type Descriptor interface {

// GetReferencedDescIDs returns the IDs of all descriptors directly referenced
// by this descriptor, including itself.
GetReferencedDescIDs() DescriptorIDSet
GetReferencedDescIDs() (DescriptorIDSet, error)

// ValidateSelf checks the internal consistency of the descriptor.
ValidateSelf(vea ValidationErrorAccumulator)
Expand Down Expand Up @@ -358,7 +358,7 @@ type TypeDescriptor interface {
HydrateTypeInfoWithName(ctx context.Context, typ *types.T, name *tree.TypeName, res TypeDescriptorResolver) error
MakeTypesT(ctx context.Context, name *tree.TypeName, res TypeDescriptorResolver) (*types.T, error)
HasPendingSchemaChanges() bool
GetIDClosure() map[descpb.ID]struct{}
GetIDClosure() (map[descpb.ID]struct{}, error)
IsCompatibleWith(other TypeDescriptor) error

PrimaryRegionName() (descpb.RegionName, error)
Expand Down
12 changes: 10 additions & 2 deletions pkg/sql/catalog/descs/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -2180,7 +2180,11 @@ func (dt DistSQLTypeResolver) ResolveType(

// ResolveTypeByOID implements the tree.TypeReferenceResolver interface.
func (dt DistSQLTypeResolver) ResolveTypeByOID(ctx context.Context, oid oid.Oid) (*types.T, error) {
name, desc, err := dt.GetTypeDescriptor(ctx, typedesc.UserDefinedTypeOIDToID(oid))
id, err := typedesc.UserDefinedTypeOIDToID(oid)
if err != nil {
return nil, err
}
name, desc, err := dt.GetTypeDescriptor(ctx, id)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -2213,7 +2217,11 @@ func (dt DistSQLTypeResolver) GetTypeDescriptor(
func (dt DistSQLTypeResolver) HydrateTypeSlice(ctx context.Context, typs []*types.T) error {
for _, t := range typs {
if t.UserDefined() {
name, desc, err := dt.GetTypeDescriptor(ctx, typedesc.GetTypeDescID(t))
id, err := typedesc.GetUserDefinedTypeDescID(t)
if err != nil {
return err
}
name, desc, err := dt.GetTypeDescriptor(ctx, id)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/catalog/schemadesc/schema_desc.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ func (desc *immutable) ValidateSelf(vea catalog.ValidationErrorAccumulator) {

// GetReferencedDescIDs returns the IDs of all descriptors referenced by
// this descriptor, including itself.
func (desc *immutable) GetReferencedDescIDs() catalog.DescriptorIDSet {
return catalog.MakeDescriptorIDSet(desc.GetID(), desc.GetParentID())
func (desc *immutable) GetReferencedDescIDs() (catalog.DescriptorIDSet, error) {
return catalog.MakeDescriptorIDSet(desc.GetID(), desc.GetParentID()), nil
}

// ValidateCrossReferences implements the catalog.Descriptor interface.
Expand Down
29 changes: 23 additions & 6 deletions pkg/sql/catalog/tabledesc/structured.go
Original file line number Diff line number Diff line change
Expand Up @@ -500,27 +500,44 @@ func (desc *wrapper) getAllReferencedTypesInTableColumns(
// collect the closure of ID's referenced.
ids := make(map[descpb.ID]struct{})
for id := range visitor.OIDs {
typDesc, err := getType(typedesc.UserDefinedTypeOIDToID(id))
uid, err := typedesc.UserDefinedTypeOIDToID(id)
if err != nil {
return nil, err
}
for child := range typDesc.GetIDClosure() {
typDesc, err := getType(uid)
if err != nil {
return nil, err
}
children, err := typDesc.GetIDClosure()
if err != nil {
return nil, err
}
for child := range children {
ids[child] = struct{}{}
}
}

// Now add all of the column types in the table.
addIDsInColumn := func(c *descpb.ColumnDescriptor) {
for id := range typedesc.GetTypeDescriptorClosure(c.Type) {
addIDsInColumn := func(c *descpb.ColumnDescriptor) error {
children, err := typedesc.GetTypeDescriptorClosure(c.Type)
if err != nil {
return err
}
for id := range children {
ids[id] = struct{}{}
}
return nil
}
for i := range desc.Columns {
addIDsInColumn(&desc.Columns[i])
if err := addIDsInColumn(&desc.Columns[i]); err != nil {
return nil, err
}
}
for _, mut := range desc.Mutations {
if c := mut.GetColumn(); c != nil {
addIDsInColumn(c)
if err := addIDsInColumn(c); err != nil {
return nil, err
}
}
}

Expand Down
16 changes: 12 additions & 4 deletions pkg/sql/catalog/tabledesc/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (desc *wrapper) ValidateTxnCommit(

// GetReferencedDescIDs returns the IDs of all descriptors referenced by
// this descriptor, including itself.
func (desc *wrapper) GetReferencedDescIDs() catalog.DescriptorIDSet {
func (desc *wrapper) GetReferencedDescIDs() (catalog.DescriptorIDSet, error) {
ids := catalog.MakeDescriptorIDSet(desc.GetID(), desc.GetParentID())
if desc.GetParentSchemaID() != keys.PublicSchemaID {
ids.Add(desc.GetParentSchemaID())
Expand All @@ -69,7 +69,11 @@ func (desc *wrapper) GetReferencedDescIDs() catalog.DescriptorIDSet {
}
// Collect user defined type Oids and sequence references in columns.
for _, col := range desc.DeletableColumns() {
for id := range typedesc.GetTypeDescriptorClosure(col.GetType()) {
children, err := typedesc.GetTypeDescriptorClosure(col.GetType())
if err != nil {
return catalog.DescriptorIDSet{}, err
}
for id := range children {
ids.Add(id)
}
for i := 0; i < col.NumUsesSequences(); i++ {
Expand All @@ -89,7 +93,11 @@ func (desc *wrapper) GetReferencedDescIDs() catalog.DescriptorIDSet {
})
// Add collected Oids to return set.
for oid := range visitor.OIDs {
ids.Add(typedesc.UserDefinedTypeOIDToID(oid))
id, err := typedesc.UserDefinedTypeOIDToID(oid)
if err != nil {
return catalog.DescriptorIDSet{}, err
}
ids.Add(id)
}
// Add view dependencies.
for _, id := range desc.GetDependsOn() {
Expand All @@ -102,7 +110,7 @@ func (desc *wrapper) GetReferencedDescIDs() catalog.DescriptorIDSet {
ids.Add(ref.ID)
}
// Add sequence dependencies
return ids
return ids, nil
}

// ValidateCrossReferences validates that each reference to another table is
Expand Down
2 changes: 2 additions & 0 deletions pkg/sql/catalog/typedesc/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ go_test(
"//pkg/sql/catalog/dbdesc",
"//pkg/sql/catalog/descpb",
"//pkg/sql/catalog/schemadesc",
"//pkg/sql/oidext",
"//pkg/sql/privilege",
"//pkg/sql/types",
"//pkg/testutils",
"//pkg/testutils/serverutils",
"//pkg/util/leaktest",
"//pkg/util/randutil",
"@com_github_cockroachdb_redact//:redact",
"@com_github_lib_pq//oid",
"@com_github_stretchr_testify//require",
"@in_gopkg_yaml_v2//:yaml_v2",
],
Expand Down
Loading