Skip to content

Commit 9f9cacf

Browse files
author
Sajjad Rizvi
committed
sql,catalog,typedesc: validate OID range before converting it to ID
Previously the code to convert an OID to a descpb.ID assumed that the OID is larger than a predefined constant. There were no checks to validate the given OID during conversion. As a result, an out-of-range OID could be converted to an invalid descriptor ID without an error. The changes in this PR validate the range of given user-defined OID before conversion, which encourages the user to check the conversion for potential problems. Release note: None Fixes #58414
1 parent 39c8068 commit 9f9cacf

16 files changed

+198
-53
lines changed

pkg/ccl/backupccl/restore_planning.go

+38-10
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,22 @@ func rewriteTypesInExpr(expr string, rewrites DescRewriteMap) (string, error) {
116116
if err != nil {
117117
return "", err
118118
}
119+
119120
ctx := tree.NewFmtCtx(tree.FmtSerializable)
120121
ctx.SetIndexedTypeFormat(func(ctx *tree.FmtCtx, ref *tree.OIDTypeReference) {
121122
newRef := ref
122-
if rw, ok := rewrites[typedesc.UserDefinedTypeOIDToID(ref.OID)]; ok {
123+
id, err := typedesc.UserDefinedTypeOIDToID(ref.OID)
124+
if err != nil {
125+
return
126+
}
127+
if rw, ok := rewrites[id]; ok {
123128
newRef = &tree.OIDTypeReference{OID: typedesc.TypeIDToOID(rw.ID)}
124129
}
125130
ctx.WriteString(newRef.SQLString())
126131
})
132+
if err != nil {
133+
return "", err
134+
}
127135
ctx.FormatNode(parsed)
128136
return ctx.CloseAndGetString(), nil
129137
}
@@ -348,11 +356,15 @@ func allocateDescriptorRewrites(
348356
// Ensure that all referenced types are present.
349357
if col.Type.UserDefined() {
350358
// TODO (rohany): This can be turned into an option later.
351-
if _, ok := typesByID[typedesc.GetTypeDescID(col.Type)]; !ok {
359+
id, err := typedesc.GetUserDefinedTypeDescID(col.Type)
360+
if err != nil {
361+
return nil, err
362+
}
363+
if _, ok := typesByID[id]; !ok {
352364
return nil, errors.Errorf(
353365
"cannot restore table %q without referenced type %d",
354366
table.Name,
355-
typedesc.GetTypeDescID(col.Type),
367+
id,
356368
)
357369
}
358370
}
@@ -1025,25 +1037,37 @@ func rewriteDatabaseDescs(databases []*dbdesc.Mutable, descriptorRewrites DescRe
10251037

10261038
// rewriteIDsInTypesT rewrites all ID's in the input types.T using the input
10271039
// ID rewrite mapping.
1028-
func rewriteIDsInTypesT(typ *types.T, descriptorRewrites DescRewriteMap) {
1040+
func rewriteIDsInTypesT(typ *types.T, descriptorRewrites DescRewriteMap) error {
10291041
if !typ.UserDefined() {
1030-
return
1042+
return nil
1043+
}
1044+
tid, err := typedesc.GetUserDefinedTypeDescID(typ)
1045+
if err != nil {
1046+
return err
10311047
}
10321048
// Collect potential new OID values.
10331049
var newOID, newArrayOID oid.Oid
1034-
if rw, ok := descriptorRewrites[typedesc.GetTypeDescID(typ)]; ok {
1050+
if rw, ok := descriptorRewrites[tid]; ok {
10351051
newOID = typedesc.TypeIDToOID(rw.ID)
10361052
}
10371053
if typ.Family() != types.ArrayFamily {
1038-
if rw, ok := descriptorRewrites[typedesc.GetArrayTypeDescID(typ)]; ok {
1054+
tid, err = typedesc.GetUserDefinedArrayTypeDescID(typ)
1055+
if err != nil {
1056+
return err
1057+
}
1058+
if rw, ok := descriptorRewrites[tid]; ok {
10391059
newArrayOID = typedesc.TypeIDToOID(rw.ID)
10401060
}
10411061
}
10421062
types.RemapUserDefinedTypeOIDs(typ, newOID, newArrayOID)
10431063
// If the type is an array, then we need to rewrite the element type as well.
10441064
if typ.Family() == types.ArrayFamily {
1045-
rewriteIDsInTypesT(typ.ArrayContents(), descriptorRewrites)
1065+
if err := rewriteIDsInTypesT(typ.ArrayContents(), descriptorRewrites); err != nil {
1066+
return err
1067+
}
10461068
}
1069+
1070+
return nil
10471071
}
10481072

10491073
// rewriteTypeDescs rewrites all ID's in the input slice of TypeDescriptors
@@ -1075,7 +1099,9 @@ func rewriteTypeDescs(types []*typedesc.Mutable, descriptorRewrites DescRewriteM
10751099
}
10761100
case descpb.TypeDescriptor_ALIAS:
10771101
// We need to rewrite any ID's present in the aliased types.T.
1078-
rewriteIDsInTypesT(typ.Alias, descriptorRewrites)
1102+
if err := rewriteIDsInTypesT(typ.Alias, descriptorRewrites); err != nil {
1103+
return err
1104+
}
10791105
default:
10801106
return errors.AssertionFailedf("unknown type kind %s", t.String())
10811107
}
@@ -1285,7 +1311,9 @@ func RewriteTableDescs(
12851311
// rewriteCol is a closure that performs the ID rewrite logic on a column.
12861312
rewriteCol := func(col *descpb.ColumnDescriptor) error {
12871313
// Rewrite the types.T's IDs present in the column.
1288-
rewriteIDsInTypesT(col.Type, descriptorRewrites)
1314+
if err := rewriteIDsInTypesT(col.Type, descriptorRewrites); err != nil {
1315+
return err
1316+
}
12891317
var newUsedSeqRefs []descpb.ID
12901318
for _, seqID := range col.UsesSequenceIds {
12911319
if rewrite, ok := descriptorRewrites[seqID]; ok {

pkg/ccl/changefeedccl/schemafeed/schema_feed.go

+25-7
Original file line numberDiff line numberDiff line change
@@ -178,16 +178,27 @@ func (t *typeDependencyTracker) removeDependency(typeID, tableID descpb.ID) {
178178
}
179179
}
180180

181-
func (t *typeDependencyTracker) purgeTable(tbl catalog.TableDescriptor) {
181+
func (t *typeDependencyTracker) purgeTable(tbl catalog.TableDescriptor) error {
182182
for _, col := range tbl.UserDefinedTypeColumns() {
183-
t.removeDependency(typedesc.UserDefinedTypeOIDToID(col.GetType().Oid()), tbl.GetID())
183+
id, err := typedesc.UserDefinedTypeOIDToID(col.GetType().Oid())
184+
if err != nil {
185+
return err
186+
}
187+
t.removeDependency(id, tbl.GetID())
184188
}
189+
190+
return nil
185191
}
186192

187-
func (t *typeDependencyTracker) ingestTable(tbl catalog.TableDescriptor) {
193+
func (t *typeDependencyTracker) ingestTable(tbl catalog.TableDescriptor) error {
188194
for _, col := range tbl.UserDefinedTypeColumns() {
189-
t.addDependency(typedesc.UserDefinedTypeOIDToID(col.GetType().Oid()), tbl.GetID())
195+
id, err := typedesc.UserDefinedTypeOIDToID(col.GetType().Oid())
196+
if err != nil {
197+
return err
198+
}
199+
t.addDependency(id, tbl.GetID())
190200
}
201+
return nil
191202
}
192203

193204
func (t *typeDependencyTracker) containsType(id descpb.ID) bool {
@@ -289,7 +300,10 @@ func (tf *SchemaFeed) primeInitialTableDescs(ctx context.Context) error {
289300
// Register all types used by the initial set of tables.
290301
for _, desc := range initialDescs {
291302
tbl := desc.(catalog.TableDescriptor)
292-
tf.mu.typeDeps.ingestTable(tbl)
303+
if err := tf.mu.typeDeps.ingestTable(tbl); err != nil {
304+
tf.mu.Unlock()
305+
return err
306+
}
293307
}
294308
tf.mu.Unlock()
295309

@@ -533,7 +547,9 @@ func (tf *SchemaFeed) validateDescriptor(
533547
}
534548

535549
// Purge the old version of the table from the type mapping.
536-
tf.mu.typeDeps.purgeTable(lastVersion)
550+
if err := tf.mu.typeDeps.purgeTable(lastVersion); err != nil {
551+
return err
552+
}
537553

538554
e := TableEvent{
539555
Before: lastVersion,
@@ -559,7 +575,9 @@ func (tf *SchemaFeed) validateDescriptor(
559575
}
560576
}
561577
// Add the types used by the table into the dependency tracker.
562-
tf.mu.typeDeps.ingestTable(desc)
578+
if err := tf.mu.typeDeps.ingestTable(desc); err != nil {
579+
return err
580+
}
563581
tf.mu.previousTableVersion[desc.GetID()] = desc
564582
return nil
565583
default:

pkg/sql/BUILD.bazel

-1
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,6 @@ go_library(
302302
"//pkg/sql/inverted",
303303
"//pkg/sql/lex",
304304
"//pkg/sql/mutations",
305-
"//pkg/sql/oidext",
306305
"//pkg/sql/opt",
307306
"//pkg/sql/opt/cat",
308307
"//pkg/sql/opt/constraint",

pkg/sql/catalog/descs/collection.go

+10-2
Original file line numberDiff line numberDiff line change
@@ -2180,7 +2180,11 @@ func (dt DistSQLTypeResolver) ResolveType(
21802180

21812181
// ResolveTypeByOID implements the tree.TypeReferenceResolver interface.
21822182
func (dt DistSQLTypeResolver) ResolveTypeByOID(ctx context.Context, oid oid.Oid) (*types.T, error) {
2183-
name, desc, err := dt.GetTypeDescriptor(ctx, typedesc.UserDefinedTypeOIDToID(oid))
2183+
id, err := typedesc.UserDefinedTypeOIDToID(oid)
2184+
if err != nil {
2185+
return nil, err
2186+
}
2187+
name, desc, err := dt.GetTypeDescriptor(ctx, id)
21842188
if err != nil {
21852189
return nil, err
21862190
}
@@ -2213,7 +2217,11 @@ func (dt DistSQLTypeResolver) GetTypeDescriptor(
22132217
func (dt DistSQLTypeResolver) HydrateTypeSlice(ctx context.Context, typs []*types.T) error {
22142218
for _, t := range typs {
22152219
if t.UserDefined() {
2216-
name, desc, err := dt.GetTypeDescriptor(ctx, typedesc.GetTypeDescID(t))
2220+
id, err := typedesc.GetUserDefinedTypeDescID(t)
2221+
if err != nil {
2222+
return err
2223+
}
2224+
name, desc, err := dt.GetTypeDescriptor(ctx, id)
22172225
if err != nil {
22182226
return err
22192227
}

pkg/sql/catalog/tabledesc/structured.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,11 @@ func (desc *wrapper) getAllReferencedTypesInTableColumns(
500500
// collect the closure of ID's referenced.
501501
ids := make(map[descpb.ID]struct{})
502502
for id := range visitor.OIDs {
503-
typDesc, err := getType(typedesc.UserDefinedTypeOIDToID(id))
503+
uid, err := typedesc.UserDefinedTypeOIDToID(id)
504+
if err != nil {
505+
return nil, err
506+
}
507+
typDesc, err := getType(uid)
504508
if err != nil {
505509
return nil, err
506510
}

pkg/sql/catalog/tabledesc/validate.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,11 @@ func (desc *wrapper) GetReferencedDescIDs() catalog.DescriptorIDSet {
8989
})
9090
// Add collected Oids to return set.
9191
for oid := range visitor.OIDs {
92-
ids.Add(typedesc.UserDefinedTypeOIDToID(oid))
92+
id, err := typedesc.UserDefinedTypeOIDToID(oid)
93+
if err != nil {
94+
panic(err)
95+
}
96+
ids.Add(id)
9397
}
9498
// Add view dependencies.
9599
for _, id := range desc.GetDependsOn() {

pkg/sql/catalog/typedesc/BUILD.bazel

+2
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,15 @@ go_test(
4646
"//pkg/sql/catalog/dbdesc",
4747
"//pkg/sql/catalog/descpb",
4848
"//pkg/sql/catalog/schemadesc",
49+
"//pkg/sql/oidext",
4950
"//pkg/sql/privilege",
5051
"//pkg/sql/types",
5152
"//pkg/testutils",
5253
"//pkg/testutils/serverutils",
5354
"//pkg/util/leaktest",
5455
"//pkg/util/randutil",
5556
"@com_github_cockroachdb_redact//:redact",
57+
"@com_github_lib_pq//oid",
5658
"@com_github_stretchr_testify//require",
5759
"@in_gopkg_yaml_v2//:yaml_v2",
5860
],

pkg/sql/catalog/typedesc/type_desc.go

+43-12
Original file line numberDiff line numberDiff line change
@@ -114,22 +114,34 @@ func TypeIDToOID(id descpb.ID) oid.Oid {
114114
}
115115

116116
// UserDefinedTypeOIDToID converts a user defined type OID into a
117-
// descriptor ID.
118-
func UserDefinedTypeOIDToID(oid oid.Oid) descpb.ID {
119-
return descpb.ID(oid) - oidext.CockroachPredefinedOIDMax
117+
// descriptor ID. OID of a user-defined type must be greater than
118+
// CockroachPredefinedOIDMax. The function throws an error if the
119+
// given OID is less than CockroachPredefinedMax.
120+
func UserDefinedTypeOIDToID(oid oid.Oid) (descpb.ID, error) {
121+
if descpb.ID(oid) <= oidext.CockroachPredefinedOIDMax {
122+
return 0, errors.Newf("user-defined OID %d should be greater "+
123+
"than predefined Max: %d.", oid, oidext.CockroachPredefinedOIDMax)
124+
}
125+
return descpb.ID(oid) - oidext.CockroachPredefinedOIDMax, nil
120126
}
121127

122-
// GetTypeDescID gets the type descriptor ID from a user defined type.
123-
func GetTypeDescID(t *types.T) descpb.ID {
128+
// GetUserDefinedTypeDescID gets the type descriptor ID from a user defined type.
129+
func GetUserDefinedTypeDescID(t *types.T) (descpb.ID, error) {
124130
return UserDefinedTypeOIDToID(t.Oid())
125131
}
126132

127-
// GetArrayTypeDescID gets the ID of the array type descriptor from a user
133+
// GetUserDefinedArrayTypeDescID gets the ID of the array type descriptor from a user
128134
// defined type.
129-
func GetArrayTypeDescID(t *types.T) descpb.ID {
135+
func GetUserDefinedArrayTypeDescID(t *types.T) (descpb.ID, error) {
130136
return UserDefinedTypeOIDToID(t.UserDefinedArrayOID())
131137
}
132138

139+
// makeTypeIDRangeError is a helper to create an error message for invalid type ID conversion.
140+
func makeTypeIDRangeError(t *types.T) error {
141+
return errors.Newf("user-defined OID %d should be greater than Max OID: %d. "+
142+
"The type has %s", t.Oid(), oidext.CockroachPredefinedOIDMax, t.DebugString())
143+
}
144+
133145
// TypeDesc implements the Descriptor interface.
134146
func (desc *immutable) TypeDesc() *descpb.TypeDescriptor {
135147
return &desc.TypeDescriptor
@@ -599,7 +611,10 @@ func (desc *immutable) ValidateCrossReferences(
599611
}
600612
case descpb.TypeDescriptor_ALIAS:
601613
if desc.GetAlias().UserDefined() {
602-
aliasedID := UserDefinedTypeOIDToID(desc.GetAlias().Oid())
614+
aliasedID, err := UserDefinedTypeOIDToID(desc.GetAlias().Oid())
615+
if err != nil {
616+
vea.Report(err)
617+
}
603618
if _, err := vdg.GetTypeDescriptor(aliasedID); err != nil {
604619
vea.Report(errors.Wrapf(err, "aliased type %d does not exist", aliasedID))
605620
}
@@ -724,7 +739,11 @@ func HydrateTypesInTableDescriptor(
724739
hydrateCol := func(col *descpb.ColumnDescriptor) error {
725740
if col.Type.UserDefined() {
726741
// Look up its type descriptor.
727-
name, typDesc, err := res.GetTypeDescriptor(ctx, GetTypeDescID(col.Type))
742+
td, err := GetUserDefinedTypeDescID(col.Type)
743+
if err != nil {
744+
return err
745+
}
746+
name, typDesc, err := res.GetTypeDescriptor(ctx, td)
728747
if err != nil {
729748
return err
730749
}
@@ -787,7 +806,11 @@ func (desc *immutable) HydrateTypeInfoWithName(
787806
case types.ArrayFamily:
788807
// Hydrate the element type.
789808
elemType := typ.ArrayContents()
790-
elemTypName, elemTypDesc, err := res.GetTypeDescriptor(ctx, GetTypeDescID(elemType))
809+
id, err := GetUserDefinedTypeDescID(elemType)
810+
if err != nil {
811+
return err
812+
}
813+
elemTypName, elemTypDesc, err := res.GetTypeDescriptor(ctx, id)
791814
if err != nil {
792815
return err
793816
}
@@ -925,9 +948,13 @@ func GetTypeDescriptorClosure(typ *types.T) map[descpb.ID]struct{} {
925948
if !typ.UserDefined() {
926949
return map[descpb.ID]struct{}{}
927950
}
951+
id, err := GetUserDefinedTypeDescID(typ)
952+
if err != nil {
953+
panic(err)
954+
}
928955
// Collect the type's descriptor ID.
929956
ret := map[descpb.ID]struct{}{
930-
GetTypeDescID(typ): {},
957+
id: {},
931958
}
932959
if typ.Family() == types.ArrayFamily {
933960
// If we have an array type, then collect all types in the contents.
@@ -937,7 +964,11 @@ func GetTypeDescriptorClosure(typ *types.T) map[descpb.ID]struct{} {
937964
}
938965
} else {
939966
// Otherwise, take the array type ID.
940-
ret[GetArrayTypeDescID(typ)] = struct{}{}
967+
id, err := GetUserDefinedArrayTypeDescID(typ)
968+
if err != nil {
969+
panic(err)
970+
}
971+
ret[id] = struct{}{}
941972
}
942973
return ret
943974
}

0 commit comments

Comments
 (0)