Skip to content

Commit 01ac72d

Browse files
author
Sajjad Rizvi
committed
sql,catalog,typedesc: validate OID range before converting it to ID
Previously the code to convert an OID to a descpb.ID assumed that the OID is larger than a predefined constant. There were no checks to validate the given OID during conversion. As a result, an out-of-range OID could be converted to an invalid descriptor ID without an error. The changes in this PR validate the range of given user-defined OID before conversion, which encourages the user to check the conversion for potential problems. Release note: None Fixes #58414
1 parent 7183943 commit 01ac72d

22 files changed

+314
-100
lines changed

pkg/ccl/backupccl/restore_planning.go

+39-10
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,23 @@ func rewriteTypesInExpr(expr string, rewrites DescRewriteMap) (string, error) {
116116
if err != nil {
117117
return "", err
118118
}
119+
119120
ctx := tree.NewFmtCtx(tree.FmtSerializable)
120121
ctx.SetIndexedTypeFormat(func(ctx *tree.FmtCtx, ref *tree.OIDTypeReference) {
121122
newRef := ref
122-
if rw, ok := rewrites[typedesc.UserDefinedTypeOIDToID(ref.OID)]; ok {
123+
var id descpb.ID
124+
id, err = typedesc.UserDefinedTypeOIDToID(ref.OID)
125+
if err != nil {
126+
return
127+
}
128+
if rw, ok := rewrites[id]; ok {
123129
newRef = &tree.OIDTypeReference{OID: typedesc.TypeIDToOID(rw.ID)}
124130
}
125131
ctx.WriteString(newRef.SQLString())
126132
})
133+
if err != nil {
134+
return "", err
135+
}
127136
ctx.FormatNode(parsed)
128137
return ctx.CloseAndGetString(), nil
129138
}
@@ -348,11 +357,15 @@ func allocateDescriptorRewrites(
348357
// Ensure that all referenced types are present.
349358
if col.Type.UserDefined() {
350359
// TODO (rohany): This can be turned into an option later.
351-
if _, ok := typesByID[typedesc.GetTypeDescID(col.Type)]; !ok {
360+
id, err := typedesc.GetUserDefinedTypeDescID(col.Type)
361+
if err != nil {
362+
return nil, err
363+
}
364+
if _, ok := typesByID[id]; !ok {
352365
return nil, errors.Errorf(
353366
"cannot restore table %q without referenced type %d",
354367
table.Name,
355-
typedesc.GetTypeDescID(col.Type),
368+
id,
356369
)
357370
}
358371
}
@@ -1025,25 +1038,37 @@ func rewriteDatabaseDescs(databases []*dbdesc.Mutable, descriptorRewrites DescRe
10251038

10261039
// rewriteIDsInTypesT rewrites all ID's in the input types.T using the input
10271040
// ID rewrite mapping.
1028-
func rewriteIDsInTypesT(typ *types.T, descriptorRewrites DescRewriteMap) {
1041+
func rewriteIDsInTypesT(typ *types.T, descriptorRewrites DescRewriteMap) error {
10291042
if !typ.UserDefined() {
1030-
return
1043+
return nil
1044+
}
1045+
tid, err := typedesc.GetUserDefinedTypeDescID(typ)
1046+
if err != nil {
1047+
return err
10311048
}
10321049
// Collect potential new OID values.
10331050
var newOID, newArrayOID oid.Oid
1034-
if rw, ok := descriptorRewrites[typedesc.GetTypeDescID(typ)]; ok {
1051+
if rw, ok := descriptorRewrites[tid]; ok {
10351052
newOID = typedesc.TypeIDToOID(rw.ID)
10361053
}
10371054
if typ.Family() != types.ArrayFamily {
1038-
if rw, ok := descriptorRewrites[typedesc.GetArrayTypeDescID(typ)]; ok {
1055+
tid, err = typedesc.GetUserDefinedArrayTypeDescID(typ)
1056+
if err != nil {
1057+
return err
1058+
}
1059+
if rw, ok := descriptorRewrites[tid]; ok {
10391060
newArrayOID = typedesc.TypeIDToOID(rw.ID)
10401061
}
10411062
}
10421063
types.RemapUserDefinedTypeOIDs(typ, newOID, newArrayOID)
10431064
// If the type is an array, then we need to rewrite the element type as well.
10441065
if typ.Family() == types.ArrayFamily {
1045-
rewriteIDsInTypesT(typ.ArrayContents(), descriptorRewrites)
1066+
if err := rewriteIDsInTypesT(typ.ArrayContents(), descriptorRewrites); err != nil {
1067+
return err
1068+
}
10461069
}
1070+
1071+
return nil
10471072
}
10481073

10491074
// rewriteTypeDescs rewrites all ID's in the input slice of TypeDescriptors
@@ -1075,7 +1100,9 @@ func rewriteTypeDescs(types []*typedesc.Mutable, descriptorRewrites DescRewriteM
10751100
}
10761101
case descpb.TypeDescriptor_ALIAS:
10771102
// We need to rewrite any ID's present in the aliased types.T.
1078-
rewriteIDsInTypesT(typ.Alias, descriptorRewrites)
1103+
if err := rewriteIDsInTypesT(typ.Alias, descriptorRewrites); err != nil {
1104+
return err
1105+
}
10791106
default:
10801107
return errors.AssertionFailedf("unknown type kind %s", t.String())
10811108
}
@@ -1285,7 +1312,9 @@ func RewriteTableDescs(
12851312
// rewriteCol is a closure that performs the ID rewrite logic on a column.
12861313
rewriteCol := func(col *descpb.ColumnDescriptor) error {
12871314
// Rewrite the types.T's IDs present in the column.
1288-
rewriteIDsInTypesT(col.Type, descriptorRewrites)
1315+
if err := rewriteIDsInTypesT(col.Type, descriptorRewrites); err != nil {
1316+
return err
1317+
}
12891318
var newUsedSeqRefs []descpb.ID
12901319
for _, seqID := range col.UsesSequenceIds {
12911320
if rewrite, ok := descriptorRewrites[seqID]; ok {

pkg/ccl/changefeedccl/schemafeed/schema_feed.go

+25-7
Original file line numberDiff line numberDiff line change
@@ -178,16 +178,27 @@ func (t *typeDependencyTracker) removeDependency(typeID, tableID descpb.ID) {
178178
}
179179
}
180180

181-
func (t *typeDependencyTracker) purgeTable(tbl catalog.TableDescriptor) {
181+
func (t *typeDependencyTracker) purgeTable(tbl catalog.TableDescriptor) error {
182182
for _, col := range tbl.UserDefinedTypeColumns() {
183-
t.removeDependency(typedesc.UserDefinedTypeOIDToID(col.GetType().Oid()), tbl.GetID())
183+
id, err := typedesc.UserDefinedTypeOIDToID(col.GetType().Oid())
184+
if err != nil {
185+
return err
186+
}
187+
t.removeDependency(id, tbl.GetID())
184188
}
189+
190+
return nil
185191
}
186192

187-
func (t *typeDependencyTracker) ingestTable(tbl catalog.TableDescriptor) {
193+
func (t *typeDependencyTracker) ingestTable(tbl catalog.TableDescriptor) error {
188194
for _, col := range tbl.UserDefinedTypeColumns() {
189-
t.addDependency(typedesc.UserDefinedTypeOIDToID(col.GetType().Oid()), tbl.GetID())
195+
id, err := typedesc.UserDefinedTypeOIDToID(col.GetType().Oid())
196+
if err != nil {
197+
return err
198+
}
199+
t.addDependency(id, tbl.GetID())
190200
}
201+
return nil
191202
}
192203

193204
func (t *typeDependencyTracker) containsType(id descpb.ID) bool {
@@ -289,7 +300,10 @@ func (tf *SchemaFeed) primeInitialTableDescs(ctx context.Context) error {
289300
// Register all types used by the initial set of tables.
290301
for _, desc := range initialDescs {
291302
tbl := desc.(catalog.TableDescriptor)
292-
tf.mu.typeDeps.ingestTable(tbl)
303+
if err := tf.mu.typeDeps.ingestTable(tbl); err != nil {
304+
tf.mu.Unlock()
305+
return err
306+
}
293307
}
294308
tf.mu.Unlock()
295309

@@ -533,7 +547,9 @@ func (tf *SchemaFeed) validateDescriptor(
533547
}
534548

535549
// Purge the old version of the table from the type mapping.
536-
tf.mu.typeDeps.purgeTable(lastVersion)
550+
if err := tf.mu.typeDeps.purgeTable(lastVersion); err != nil {
551+
return err
552+
}
537553

538554
e := TableEvent{
539555
Before: lastVersion,
@@ -559,7 +575,9 @@ func (tf *SchemaFeed) validateDescriptor(
559575
}
560576
}
561577
// Add the types used by the table into the dependency tracker.
562-
tf.mu.typeDeps.ingestTable(desc)
578+
if err := tf.mu.typeDeps.ingestTable(desc); err != nil {
579+
return err
580+
}
563581
tf.mu.previousTableVersion[desc.GetID()] = desc
564582
return nil
565583
default:

pkg/sql/BUILD.bazel

-1
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,6 @@ go_library(
302302
"//pkg/sql/inverted",
303303
"//pkg/sql/lex",
304304
"//pkg/sql/mutations",
305-
"//pkg/sql/oidext",
306305
"//pkg/sql/opt",
307306
"//pkg/sql/opt/cat",
308307
"//pkg/sql/opt/constraint",

pkg/sql/catalog/dbdesc/database_desc.go

+7-3
Original file line numberDiff line numberDiff line change
@@ -225,15 +225,19 @@ func (desc *immutable) validateMultiRegion(vea catalog.ValidationErrorAccumulato
225225

226226
// GetReferencedDescIDs returns the IDs of all descriptors referenced by
227227
// this descriptor, including itself.
228-
func (desc *immutable) GetReferencedDescIDs() catalog.DescriptorIDSet {
228+
func (desc *immutable) GetReferencedDescIDs() (catalog.DescriptorIDSet, error) {
229229
ids := catalog.MakeDescriptorIDSet(desc.GetID())
230-
if id, err := desc.MultiRegionEnumID(); err == nil {
230+
if desc.IsMultiRegion() {
231+
id, err := desc.MultiRegionEnumID()
232+
if err != nil {
233+
return catalog.DescriptorIDSet{}, err
234+
}
231235
ids.Add(id)
232236
}
233237
for _, schema := range desc.Schemas {
234238
ids.Add(schema.ID)
235239
}
236-
return ids
240+
return ids, nil
237241
}
238242

239243
// ValidateCrossReferences implements the catalog.Descriptor interface.

pkg/sql/catalog/descriptor.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ type Descriptor interface {
140140

141141
// GetReferencedDescIDs returns the IDs of all descriptors directly referenced
142142
// by this descriptor, including itself.
143-
GetReferencedDescIDs() DescriptorIDSet
143+
GetReferencedDescIDs() (DescriptorIDSet, error)
144144

145145
// ValidateSelf checks the internal consistency of the descriptor.
146146
ValidateSelf(vea ValidationErrorAccumulator)
@@ -358,7 +358,7 @@ type TypeDescriptor interface {
358358
HydrateTypeInfoWithName(ctx context.Context, typ *types.T, name *tree.TypeName, res TypeDescriptorResolver) error
359359
MakeTypesT(ctx context.Context, name *tree.TypeName, res TypeDescriptorResolver) (*types.T, error)
360360
HasPendingSchemaChanges() bool
361-
GetIDClosure() map[descpb.ID]struct{}
361+
GetIDClosure() (map[descpb.ID]struct{}, error)
362362
IsCompatibleWith(other TypeDescriptor) error
363363

364364
PrimaryRegionName() (descpb.RegionName, error)

pkg/sql/catalog/descs/collection.go

+10-2
Original file line numberDiff line numberDiff line change
@@ -2180,7 +2180,11 @@ func (dt DistSQLTypeResolver) ResolveType(
21802180

21812181
// ResolveTypeByOID implements the tree.TypeReferenceResolver interface.
21822182
func (dt DistSQLTypeResolver) ResolveTypeByOID(ctx context.Context, oid oid.Oid) (*types.T, error) {
2183-
name, desc, err := dt.GetTypeDescriptor(ctx, typedesc.UserDefinedTypeOIDToID(oid))
2183+
id, err := typedesc.UserDefinedTypeOIDToID(oid)
2184+
if err != nil {
2185+
return nil, err
2186+
}
2187+
name, desc, err := dt.GetTypeDescriptor(ctx, id)
21842188
if err != nil {
21852189
return nil, err
21862190
}
@@ -2213,7 +2217,11 @@ func (dt DistSQLTypeResolver) GetTypeDescriptor(
22132217
func (dt DistSQLTypeResolver) HydrateTypeSlice(ctx context.Context, typs []*types.T) error {
22142218
for _, t := range typs {
22152219
if t.UserDefined() {
2216-
name, desc, err := dt.GetTypeDescriptor(ctx, typedesc.GetTypeDescID(t))
2220+
id, err := typedesc.GetUserDefinedTypeDescID(t)
2221+
if err != nil {
2222+
return err
2223+
}
2224+
name, desc, err := dt.GetTypeDescriptor(ctx, id)
22172225
if err != nil {
22182226
return err
22192227
}

pkg/sql/catalog/schemadesc/schema_desc.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ func (desc *immutable) ValidateSelf(vea catalog.ValidationErrorAccumulator) {
146146

147147
// GetReferencedDescIDs returns the IDs of all descriptors referenced by
148148
// this descriptor, including itself.
149-
func (desc *immutable) GetReferencedDescIDs() catalog.DescriptorIDSet {
150-
return catalog.MakeDescriptorIDSet(desc.GetID(), desc.GetParentID())
149+
func (desc *immutable) GetReferencedDescIDs() (catalog.DescriptorIDSet, error) {
150+
return catalog.MakeDescriptorIDSet(desc.GetID(), desc.GetParentID()), nil
151151
}
152152

153153
// ValidateCrossReferences implements the catalog.Descriptor interface.

pkg/sql/catalog/tabledesc/structured.go

+23-6
Original file line numberDiff line numberDiff line change
@@ -500,27 +500,44 @@ func (desc *wrapper) getAllReferencedTypesInTableColumns(
500500
// collect the closure of ID's referenced.
501501
ids := make(map[descpb.ID]struct{})
502502
for id := range visitor.OIDs {
503-
typDesc, err := getType(typedesc.UserDefinedTypeOIDToID(id))
503+
uid, err := typedesc.UserDefinedTypeOIDToID(id)
504504
if err != nil {
505505
return nil, err
506506
}
507-
for child := range typDesc.GetIDClosure() {
507+
typDesc, err := getType(uid)
508+
if err != nil {
509+
return nil, err
510+
}
511+
children, err := typDesc.GetIDClosure()
512+
if err != nil {
513+
return nil, err
514+
}
515+
for child := range children {
508516
ids[child] = struct{}{}
509517
}
510518
}
511519

512520
// Now add all of the column types in the table.
513-
addIDsInColumn := func(c *descpb.ColumnDescriptor) {
514-
for id := range typedesc.GetTypeDescriptorClosure(c.Type) {
521+
addIDsInColumn := func(c *descpb.ColumnDescriptor) error {
522+
children, err := typedesc.GetTypeDescriptorClosure(c.Type)
523+
if err != nil {
524+
return err
525+
}
526+
for id := range children {
515527
ids[id] = struct{}{}
516528
}
529+
return nil
517530
}
518531
for i := range desc.Columns {
519-
addIDsInColumn(&desc.Columns[i])
532+
if err := addIDsInColumn(&desc.Columns[i]); err != nil {
533+
return nil, err
534+
}
520535
}
521536
for _, mut := range desc.Mutations {
522537
if c := mut.GetColumn(); c != nil {
523-
addIDsInColumn(c)
538+
if err := addIDsInColumn(c); err != nil {
539+
return nil, err
540+
}
524541
}
525542
}
526543

pkg/sql/catalog/tabledesc/validate.go

+12-4
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func (desc *wrapper) ValidateTxnCommit(
4747

4848
// GetReferencedDescIDs returns the IDs of all descriptors referenced by
4949
// this descriptor, including itself.
50-
func (desc *wrapper) GetReferencedDescIDs() catalog.DescriptorIDSet {
50+
func (desc *wrapper) GetReferencedDescIDs() (catalog.DescriptorIDSet, error) {
5151
ids := catalog.MakeDescriptorIDSet(desc.GetID(), desc.GetParentID())
5252
if desc.GetParentSchemaID() != keys.PublicSchemaID {
5353
ids.Add(desc.GetParentSchemaID())
@@ -69,7 +69,11 @@ func (desc *wrapper) GetReferencedDescIDs() catalog.DescriptorIDSet {
6969
}
7070
// Collect user defined type Oids and sequence references in columns.
7171
for _, col := range desc.DeletableColumns() {
72-
for id := range typedesc.GetTypeDescriptorClosure(col.GetType()) {
72+
children, err := typedesc.GetTypeDescriptorClosure(col.GetType())
73+
if err != nil {
74+
return catalog.DescriptorIDSet{}, err
75+
}
76+
for id := range children {
7377
ids.Add(id)
7478
}
7579
for i := 0; i < col.NumUsesSequences(); i++ {
@@ -89,7 +93,11 @@ func (desc *wrapper) GetReferencedDescIDs() catalog.DescriptorIDSet {
8993
})
9094
// Add collected Oids to return set.
9195
for oid := range visitor.OIDs {
92-
ids.Add(typedesc.UserDefinedTypeOIDToID(oid))
96+
id, err := typedesc.UserDefinedTypeOIDToID(oid)
97+
if err != nil {
98+
return catalog.DescriptorIDSet{}, err
99+
}
100+
ids.Add(id)
93101
}
94102
// Add view dependencies.
95103
for _, id := range desc.GetDependsOn() {
@@ -102,7 +110,7 @@ func (desc *wrapper) GetReferencedDescIDs() catalog.DescriptorIDSet {
102110
ids.Add(ref.ID)
103111
}
104112
// Add sequence dependencies
105-
return ids
113+
return ids, nil
106114
}
107115

108116
// ValidateCrossReferences validates that each reference to another table is

pkg/sql/catalog/typedesc/BUILD.bazel

+2
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,15 @@ go_test(
4646
"//pkg/sql/catalog/dbdesc",
4747
"//pkg/sql/catalog/descpb",
4848
"//pkg/sql/catalog/schemadesc",
49+
"//pkg/sql/oidext",
4950
"//pkg/sql/privilege",
5051
"//pkg/sql/types",
5152
"//pkg/testutils",
5253
"//pkg/testutils/serverutils",
5354
"//pkg/util/leaktest",
5455
"//pkg/util/randutil",
5556
"@com_github_cockroachdb_redact//:redact",
57+
"@com_github_lib_pq//oid",
5658
"@com_github_stretchr_testify//require",
5759
"@in_gopkg_yaml_v2//:yaml_v2",
5860
],

0 commit comments

Comments
 (0)