From 08b02bb75b05e20bc31bec2a87a3ca23d989365f Mon Sep 17 00:00:00 2001
From: Aaron Zinger
encode(data: bytes, format: string) → string
Encodes data
using format
(hex
/ escape
/ base64
).
format(string, anyelement...) → string
Interprets the first argument as a format string similar to C sprintf and interpolates the remaining arguments.
+from_ip(val: bytes) → string
Converts the byte string representation of an IP to its character string representation.
from_uuid(val: bytes) → string
Converts the byte string representation of a UUID to its character string representation.
diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index 16b4e142bd91..43456a0d802d 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -398,7 +398,7 @@ ALL_TESTS = [ "//pkg/sql/schemachanger/screl:screl_test", "//pkg/sql/schemachanger/scrun:scrun_test", "//pkg/sql/schemachanger:schemachanger_test", - "//pkg/sql/sem/builtins:builtins_disallowed_imports_test", + "//pkg/sql/sem/builtins/pgformat:pgformat_test", "//pkg/sql/sem/builtins:builtins_test", "//pkg/sql/sem/cast:cast_test", "//pkg/sql/sem/catconstants:catconstants_disallowed_imports_test", @@ -1525,6 +1525,8 @@ GO_TARGETS = [ "//pkg/sql/sem/asof:asof", "//pkg/sql/sem/builtins/builtinconstants:builtinconstants", "//pkg/sql/sem/builtins/builtinsregistry:builtinsregistry", + "//pkg/sql/sem/builtins/pgformat:pgformat", + "//pkg/sql/sem/builtins/pgformat:pgformat_test", "//pkg/sql/sem/builtins:builtins", "//pkg/sql/sem/builtins:builtins_test", "//pkg/sql/sem/cast:cast", @@ -2497,6 +2499,7 @@ GET_X_DATA_TARGETS = [ "//pkg/sql/sem/builtins:get_x_data", "//pkg/sql/sem/builtins/builtinconstants:get_x_data", "//pkg/sql/sem/builtins/builtinsregistry:get_x_data", + "//pkg/sql/sem/builtins/pgformat:get_x_data", "//pkg/sql/sem/cast:get_x_data", "//pkg/sql/sem/catconstants:get_x_data", "//pkg/sql/sem/catid:get_x_data", diff --git a/pkg/sql/logictest/testdata/logic_test/builtin_function b/pkg/sql/logictest/testdata/logic_test/builtin_function index 13952783daf1..6c803d7b70fa 100644 --- a/pkg/sql/logictest/testdata/logic_test/builtin_function +++ b/pkg/sql/logictest/testdata/logic_test/builtin_function @@ -2482,6 +2482,61 @@ SELECT array_to_string(NULL, ','), array_to_string(NULL, 'foo', 'zerp') ---- NULL NULL +# Examples from https://www.postgresql.org/docs/9.3/functions-string.html#FUNCTIONS-STRING-FORMAT +query T +SELECT format('Hello %s', 'World') +---- +Hello World + +query T +SELECT format('INSERT INTO %I VALUES(%L)', 'locations', 'C:\Program Files') +---- +INSERT INTO locations VALUES(e'C:\\Program Files') + +query T +SELECT format('|%10s|', 'foo') +---- +| foo| + +query T +SELECT format('|%-10s|', 'foo') +---- +|foo | + +query T +SELECT format('|%*s|', 10, 'foo') +---- +| foo| + +query T +SELECT format('|%*s|', -10, 'foo') +---- +|foo | + +query T +SELECT format('|%-*s|', 10, 'foo') +---- +|foo | + +query T +SELECT format('|%-*s|', -10, 'foo') +---- +|foo | + +# Escaping $ into \x24 only needed in testlogic or prepared statements +query T +SELECT format(E'Testing %3\x24s, %2\x24s, %1\x24s', 'one', 'two', 'three') +---- +Testing three, two, one + +query T +SELECT format(E'Testing %3\x24s, %2\x24s, %s', 'one', 'two', 'three') +---- +Testing three, two, three + +query error pq: format\(\): error parsing format string: not enough arguments +SELECT format(E'%2\x24s','foo'); + subtest pg_is_in_recovery query B colnames diff --git a/pkg/sql/logictest/testdata/logic_test/format b/pkg/sql/logictest/testdata/logic_test/format new file mode 100644 index 000000000000..3084dca76ca1 --- /dev/null +++ b/pkg/sql/logictest/testdata/logic_test/format @@ -0,0 +1,199 @@ +# LogicTest: !fakedist-spec-planning + +# tests from https://github.com/postgres/postgres/blob/4ca9985957881c223b4802d309c0bbbcf8acd1c1/src/test/regress/sql/text.sql#L55 + +query T +select format(NULL) +---- +NULL + +query T +select format('Hello') +---- +Hello + +query T +select format('Hello %s', 'World') +---- +Hello World + +query T +select format('Hello %%') +---- +Hello % + +query T +select format('Hello %%%%') +---- +Hello %% + +query error pq: format\(\): error parsing format string: not enough arguments +select format('Hello %s %s', 'World') + +query error pq: format\(\): error parsing format string: not enough arguments +select format('Hello %s') + +query error pq: format\(\): error parsing format string: unrecognized verb x +select format('Hello %x', 20) + +query T +select format('INSERT INTO %I VALUES(%L,%L)', 'mytab', 10, 'Hello') +---- +INSERT INTO mytab VALUES('10','Hello') + +query T +select format('%s%s%s','Hello', NULL,'World') +---- +HelloWorld + +query T +select format('INSERT INTO %I VALUES(%L,%L)', 'mytab', 10, NULL) +---- +INSERT INTO mytab VALUES('10',NULL) + +query T +select format('INSERT INTO %I VALUES(%L,%L)', 'mytab', NULL, 'Hello'); +---- +INSERT INTO mytab VALUES(NULL,'Hello') + +query error pq: format\(\): error parsing format string: NULL cannot be formatted as a SQL identifier +select format('INSERT INTO %I VALUES(%L,%L)', NULL, 10, 'Hello') + +# Many of the below tests involve strings with a literal $. +# This can break TestLogic under some conditions. If you're seeing mysterious errors in this file, +# they can likely be fixed by escaping $ into \x24, e.g. replace '%1$s' with E'%\x24s'. +# For now, strings are left unescaped here for readability. +query T +select format('%1$s %3$s', 1, 2, 3) +---- +1 3 + +query T +select format('%1$s %12$s', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) +---- +1 12 + +query error pq: format\(\): error parsing format string: not enough arguments +select format('%1$s %4$s', 1, 2, 3) + +query error pq: format\(\): error parsing format string: not enough arguments +select format('%1$s %13$s', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) + +query error pq: format\(\): error parsing format string: positions must be positive and 1-indexed +select format('%0$s', 'Hello') + +query error pq: format\(\): error parsing format string: positions must be positive and 1-indexed +select format('%*0$s', 'Hello') + +query error pq: format\(\): error parsing format string: unterminated format specifier +select format('%1$', 1) + +query error pq: format\(\): error parsing format string: unterminated format specifier +select format('%1$1', 1) + +query error pq: format\(\): error parsing format string: unterminated format specifier +select format('%1$1', 1) + +# Mixing positional and non-positional placeholders is allowed here, unusually. +# A non-positional placeholder consumes the argument after the last one, +# whether or not the last one was positional. + +query T +select format('Hello %s %1$s %s', 'World', 'Hello again') +---- +Hello World World Hello again + +query T +select format('Hello %s %s, %2$s %2$s', 'World', 'Hello again') +---- +Hello World Hello again, Hello again Hello again + +query T +select format('>>%10s<<', 'Hello') +---- +>> Hello<< + +query T +select format('>>%10s<<', NULL) +---- +>> << + +query T +select format('>>%10s<<', '') +---- +>> << + +query T +select format('>>%-10s<<', '') +---- +>> << + +query T +select format('>>%-10s<<', 'Hello') +---- +>>Hello << + +query T +select format('>>%-10s<<', NULL) +---- +>> << + +query T +select format('>>%1$10s<<', 'Hello') +---- +>> Hello<< + +query T +select format('>>%1$-10I<<', 'Hello') +---- +>>"Hello" << + +query T +select format('>>%2$*1$L<<', 10, 'Hello') +---- +>> 'Hello'<< + +query T +select format('>>%2$*1$L<<', 10, NULL) +---- +>> NULL<< + +query T +select format('>>%*s<<', 10, 'Hello') +---- +>> Hello<< + +query T +select format('>>%*1$s<<', 10, 'Hello') +---- +>> Hello<< + +query T +select format('>>%-s<<', 'Hello') +---- +>>Hello<< + +query T +select format('>>%10L<<', NULL) +---- +>> NULL<< + +# Null is equivalent to zero minimum width. +# Zero minimum width has no effect. +query T +select format('>>%2$*1$L<<', NULL, 'Hello') +---- +>>'Hello'<< + +query T +select format('>>%2$*1$L<<', 0, 'Hello') +---- +>>'Hello'<< + +# This is an error in postgres, but our +# implementation allows width and position flags +# to be in either order. +query T +select format('>>%*1$2$L<<', 10, 'Hello') +---- +>> 'Hello'<< diff --git a/pkg/sql/sem/builtins/BUILD.bazel b/pkg/sql/sem/builtins/BUILD.bazel index 3b4a58474dc3..323c30e4c32f 100644 --- a/pkg/sql/sem/builtins/BUILD.bazel +++ b/pkg/sql/sem/builtins/BUILD.bazel @@ -1,6 +1,5 @@ load("//build/bazelutil/unused_checker:unused.bzl", "get_x_data") load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") -load("//pkg/testutils/buildutil:buildutil.bzl", "disallowed_imports_test") go_library( name = "builtins", @@ -71,6 +70,7 @@ go_library( "//pkg/sql/sem/asof", "//pkg/sql/sem/builtins/builtinconstants", "//pkg/sql/sem/builtins/builtinsregistry", + "//pkg/sql/sem/builtins/pgformat", "//pkg/sql/sem/catconstants", "//pkg/sql/sem/catid", "//pkg/sql/sem/eval", @@ -126,7 +126,6 @@ go_library( go_test( name = "builtins_test", - size = "medium", srcs = [ "aggregate_builtins_test.go", "all_builtins_test.go", @@ -157,6 +156,7 @@ go_test( "//pkg/sql/randgen", "//pkg/sql/sem/builtins/builtinconstants", "//pkg/sql/sem/builtins/builtinsregistry", + "//pkg/sql/sem/builtins/pgformat", "//pkg/sql/sem/eval", "//pkg/sql/sem/tree", "//pkg/sql/sem/tree/treewindow", @@ -179,15 +179,4 @@ go_test( ], ) -disallowed_imports_test( - src = "builtins", - disallowed_list = [ - "//pkg/sql/catalog/descs", - "//pkg/sql/catalog/tabledesc", - "//pkg/sql/catalog/schemadesc", - "//pkg/sql/catalog/dbdesc", - "//pkg/sql/catalog/typedesc", - ], -) - get_x_data(name = "get_x_data") diff --git a/pkg/sql/sem/builtins/builtins.go b/pkg/sql/sem/builtins/builtins.go index fd91cdc3f1c8..c4cbc0afd121 100644 --- a/pkg/sql/sem/builtins/builtins.go +++ b/pkg/sql/sem/builtins/builtins.go @@ -63,6 +63,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/rowenc/keyside" "github.com/cockroachdb/cockroach/pkg/sql/sem/asof" "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/builtinconstants" + "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/pgformat" "github.com/cockroachdb/cockroach/pkg/sql/sem/catid" "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" @@ -223,6 +224,8 @@ var regularBuiltins = map[string]builtinDefinition{ ), ), + "format": formatImpls, + "octet_length": makeBuiltin(tree.FunctionProperties{Category: builtinconstants.CategoryString}, stringOverload1( func(_ *eval.Context, s string) (tree.Datum, error) { @@ -7197,6 +7200,27 @@ var lengthImpls = func(incBitOverload bool) builtinDefinition { return b } +var formatImpls = makeBuiltin(tree.FunctionProperties{Category: builtinconstants.CategoryString}, + tree.Overload{ + Types: tree.VariadicType{FixedTypes: []*types.T{types.String}, VarType: types.Any}, + ReturnType: tree.FixedReturnType(types.String), + Fn: func(ctx *eval.Context, args tree.Datums) (tree.Datum, error) { + if args[0] == tree.DNull { + return tree.DNull, nil + } + formatStr := tree.MustBeDString(args[0]) + formatArgs := args[1:] + str, err := pgformat.Format(ctx, string(formatStr), formatArgs...) + if err != nil { + return nil, pgerror.Wrap(err, pgcode.InvalidParameterValue, "error parsing format string") + } + return tree.NewDString(str), nil + }, + Info: "Interprets the first argument as a format string similar to C sprintf and interpolates the remaining arguments.", + Volatility: volatility.Stable, + NullableArgs: true, + }) + var substringImpls = makeBuiltin(tree.FunctionProperties{Category: builtinconstants.CategoryString}, tree.Overload{ Types: tree.ArgTypes{ diff --git a/pkg/sql/sem/builtins/builtins_test.go b/pkg/sql/sem/builtins/builtins_test.go index b80dd55fe3aa..070a8bf71ff7 100644 --- a/pkg/sql/sem/builtins/builtins_test.go +++ b/pkg/sql/sem/builtins/builtins_test.go @@ -23,6 +23,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/builtinconstants" "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/builtinsregistry" + "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/pgformat" "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/types" @@ -305,6 +306,35 @@ func TestEscapeFormatRandom(t *testing.T) { } } +func TestFormatWithWeirdFormatStrings(t *testing.T) { + defer leaktest.AfterTest(t)() + specialFormatChars := []byte{'%', '$', '-', '*', '0', '1', 's', 'I', 'L'} + numSpecial := len(specialFormatChars) + specialFreq := 0.2 + evalContext := eval.NewTestingEvalContext(cluster.MakeTestingClusterSettings()) + datums := make(tree.Datums, 10) + for i := range datums { + datums[i] = tree.NewDInt(tree.DInt(i)) + } + for i := 0; i < 1000; i++ { + b := make([]byte, rand.Intn(100)) + for j := 0; j < len(b); j++ { + if rand.Float64() < specialFreq { + b[j] = specialFormatChars[rand.Intn(numSpecial)] + } else { + b[j] = byte(rand.Intn(256)) + } + } + str := string(b) + // Mostly just making sure no panics + _, err := pgformat.Format(evalContext, str, datums...) + if err != nil { + require.Regexp(t, `position|width|not enough arguments|unrecognized verb|unterminated format`, err.Error(), + "input string was %s", str) + } + } +} + func TestLPadRPad(t *testing.T) { defer leaktest.AfterTest(t)() testCases := []struct { diff --git a/pkg/sql/sem/builtins/pgformat/BUILD.bazel b/pkg/sql/sem/builtins/pgformat/BUILD.bazel new file mode 100644 index 000000000000..e08099659df3 --- /dev/null +++ b/pkg/sql/sem/builtins/pgformat/BUILD.bazel @@ -0,0 +1,46 @@ +load("//build/bazelutil/unused_checker:unused.bzl", "get_x_data") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "pgformat", + srcs = ["format.go"], + importpath = "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins/pgformat", + visibility = ["//visibility:public"], + deps = [ + "//pkg/sql/lexbase", + "//pkg/sql/sem/eval", + "//pkg/sql/sem/tree", + "//pkg/sql/types", + "@com_github_cockroachdb_errors//:errors", + ], +) + +go_test( + name = "pgformat_test", + srcs = [ + "format_test.go", + "main_test.go", + ], + deps = [ + "//pkg/base", + "//pkg/keys", + "//pkg/security/securityassets", + "//pkg/security/securitytest", + "//pkg/server", + "//pkg/sql/catalog/desctestutils", + "//pkg/sql/randgen", + "//pkg/sql/sem/tree", + "//pkg/sql/types", + "//pkg/testutils/serverutils", + "//pkg/testutils/skip", + "//pkg/testutils/sqlutils", + "//pkg/testutils/testcluster", + "//pkg/util", + "//pkg/util/leaktest", + "//pkg/util/randutil", + "//pkg/util/timeutil", + "@com_github_stretchr_testify//require", + ], +) + +get_x_data(name = "get_x_data") diff --git a/pkg/sql/sem/builtins/pgformat/format.go b/pkg/sql/sem/builtins/pgformat/format.go new file mode 100644 index 000000000000..98c1b695d8eb --- /dev/null +++ b/pkg/sql/sem/builtins/pgformat/format.go @@ -0,0 +1,271 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package pgformat + +import ( + "github.com/cockroachdb/cockroach/pkg/sql/lexbase" + "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/errors" +) + +// pp is used to store a printer's state. +type pp struct { + ctx *eval.Context + buf *tree.FmtCtx + + // padRight records whether the '-' flag is currently in effect. + padRight bool + // num records an integer that could be a width or position, we won't know + // until we hit a $ or verb. + num *int + // width gives the minimum width of the next value + width int +} + +func (p *pp) clearState() { + p.padRight = false + p.num = nil + p.width = 0 +} + +func (p *pp) popInt() (v int, ok bool) { + if p.num == nil { + return + } + v = *p.num + ok = true + p.num = nil + return +} + +// Format formats according to a format specifier in the style of postgres format() +// and returns the resulting string. +func Format(ctx *eval.Context, format string, a ...tree.Datum) (string, error) { + p := pp{ + ctx: ctx, + buf: ctx.FmtCtx(tree.FmtArrayToString), + } + err := p.doPrintf(format, a) + if err != nil { + return "", err + } + s := p.buf.CloseAndGetString() + return s, nil +} + +// tooLarge reports whether the magnitude of the integer is +// too large to be used as a formatting width or precision. +func tooLarge(x int) bool { + const max int = 1e6 + return x > max || x < -max +} + +// parsenum converts ASCII to integer. num is 0 (and isnum is false) if no number present. +func parsenum(s string, start, end int) (num int, isnum bool, newi int) { + if start >= end { + return 0, false, end + } + for newi = start; newi < end && '0' <= s[newi] && s[newi] <= '9'; newi++ { + if tooLarge(num) { + return 0, false, end + } + num = num*10 + int(s[newi]-'0') + isnum = true + } + return +} + +func (p *pp) printArg(arg tree.Datum, verb rune) error { + var writeFunc func(*tree.FmtCtx) (numBytesWritten int) + if arg == tree.DNull { + switch verb { + case 's': + writeFunc = func(_ *tree.FmtCtx) int { return 0 } + case 'I': + return errors.New("NULL cannot be formatted as a SQL identifier") + case 'L': + writeFunc = func(buf *tree.FmtCtx) int { buf.WriteString("NULL"); return 4 } + } + } else { + switch verb { + case 's': + writeFunc = func(buf *tree.FmtCtx) int { + lenBefore := buf.Len() + buf.FormatNode(arg) + return buf.Len() - lenBefore + } + case 'I': + writeFunc = func(buf *tree.FmtCtx) int { + lenBefore := buf.Len() + bare := p.ctx.FmtCtx(tree.FmtArrayToString) + bare.FormatNode(arg) + str := bare.CloseAndGetString() + lexbase.EncodeRestrictedSQLIdent(&buf.Buffer, str, lexbase.EncNoFlags) + return buf.Len() - lenBefore + } + case 'L': + writeFunc = func(buf *tree.FmtCtx) int { + lenBefore := buf.Len() + bare := p.ctx.FmtCtx(tree.FmtArrayToString) + bare.FormatNode(arg) + str := bare.CloseAndGetString() + lexbase.EncodeSQLString(&buf.Buffer, str) + return buf.Len() - lenBefore + } + } + } + if p.width == 0 { + writeFunc(p.buf) + return nil + } + + // negative width passed via * sets the - flag, + // does not toggle it. + if p.width < 0 { + p.width = -p.width + p.padRight = true + } + + if p.padRight { + for n := writeFunc(p.buf); n < p.width; n++ { + p.buf.WriteRune(' ') + } + return nil + } + + scratch := p.ctx.FmtCtx(tree.FmtArrayToString) + for n := writeFunc(scratch); n < p.width; n++ { + p.buf.WriteRune(' ') + } + _, err := scratch.WriteTo(p.buf) + return err +} + +// intFromArg gets the argNumth element of a. On return, isInt reports whether the argument has integer type. +func intFromArg( + ctx *eval.Context, a []tree.Datum, argNum int, +) (num int, isInt bool, newArgNum int) { + newArgNum = argNum + if argNum < len(a) && argNum >= 0 { + datum := a[argNum] + // null is interpreted as 0 as in postgres. + if datum == tree.DNull { + return 0, true, argNum + 1 + } + // This is much more permissive than postgres. + dInt, err := eval.PerformCast(ctx, datum, types.Int) + if err == nil { + num = int(tree.MustBeDInt(dInt)) + isInt = true + newArgNum = argNum + 1 + } + if tooLarge(num) { + num = 0 + isInt = false + } + } + return +} + +// doPrintf is copied from golang's internal implementation of fmt, +// but modified to use the sql function format()'s syntax for width +// and positional arguments. +func (p *pp) doPrintf(format string, a []tree.Datum) error { + end := len(format) + argNum := 0 // we process one argument per non-trivial format +formatLoop: + for i := 0; i < end; { + lasti := i + for i < end && format[i] != '%' { + i++ + } + if i > lasti { + p.buf.WriteString(format[lasti:i]) + } + if i >= end { + // done processing format string + break + } + + // Process one verb + i++ + + p.clearState() + for ; i < end; i++ { + c := format[i] + switch c { + case '-': + p.padRight = true + case 's', 'I', 'L': + if p.width == 0 { + p.width, _ = p.popInt() + } + if argNum >= len(a) { + return errors.New("not enough arguments") + } + if argNum < 0 { + return errors.New("positions must be positive and 1-indexed") + } + err := p.printArg(a[argNum], rune(c)) + if err != nil { + return err + } + argNum++ + i++ + continue formatLoop + case '%': + p.buf.WriteByte(c) + i++ + continue formatLoop + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + var n int + n, _, i = parsenum(format, i, end) + p.num = &n + i-- + case '$': + rawArgNum, ok := p.popInt() + if !ok { + return errors.New("empty positional argument") + } + argNum = rawArgNum - 1 + case '*': + var rawArgNum int + var isNum bool + rawArgNum, isNum, afterNum := parsenum(format, i+1, end) + if isNum { + i = afterNum + if i == end || format[i] != '$' { + return errors.New(`width argument position must be ended by "$"`) + } + if rawArgNum < 1 { + return errors.New("positions must be positive and 1-indexed") + } + p.width, isNum, argNum = intFromArg(p.ctx, a, rawArgNum-1) + if !isNum { + return errors.New("non-numeric width") + } + } else { + p.width, isNum, argNum = intFromArg(p.ctx, a, argNum) + if !isNum { + return errors.New("non-numeric width") + } + } + default: + return errors.Newf("unrecognized verb %c", c) + } + + } + return errors.New("unterminated format specifier") + } + return nil +} diff --git a/pkg/sql/sem/builtins/pgformat/format_test.go b/pkg/sql/sem/builtins/pgformat/format_test.go new file mode 100644 index 000000000000..1d2b29e7ff1b --- /dev/null +++ b/pkg/sql/sem/builtins/pgformat/format_test.go @@ -0,0 +1,148 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package pgformat_test + +import ( + "context" + "fmt" + "math/rand" + "strings" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/keys" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/desctestutils" + "github.com/cockroachdb/cockroach/pkg/sql/randgen" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/skip" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/stretchr/testify/require" +) + +// Tests for the format() SQL function. +func TestFormat(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + s, sqlDB, kvDB := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(ctx) + typesToTest := make([]*types.T, 0, 256) + // Types we don't support that are present in types.OidToType + var skipType func(typ *types.T) bool + skipType = func(typ *types.T) bool { + switch typ.Family() { + case types.AnyFamily, types.OidFamily: + return true + case types.ArrayFamily: + if !randgen.IsAllowedForArray(typ.ArrayContents()) { + return true + } + if skipType(typ.ArrayContents()) { + return true + } + } + return !randgen.IsLegalColumnType(typ) + } + for _, typ := range types.OidToType { + if !skipType(typ) { + typesToTest = append(typesToTest, typ) + } + } + + seed := rand.New(rand.NewSource(timeutil.Now().UnixNano())) + createTable := func(t *testing.T, tdb *sqlutils.SQLRunner, typ *types.T) (tableNamer func(string) string) { + columnSpec := fmt.Sprintf("c %s", typ.SQLString()) + tableName := fmt.Sprintf("%s_table_%d", strings.Replace(typ.String(), "\"", "", -1), seed.Int()) + tableName = strings.Replace(tableName, `[]`, `_array`, -1) + + // Create the table. + createStmt := fmt.Sprintf(`CREATE TABLE %s (%s)`, tableName, columnSpec) + tdb.Exec(t, createStmt) + return func(s string) string { return strings.Replace(s, "tablename", tableName, -1) } + } + insertRows := func(t *testing.T, tdb *sqlutils.SQLRunner, r func(string) string) { + // Insert numRows rows of random data with the first row being all NULL. + numRows := 10 // arbitrary + if util.RaceEnabled { + numRows = 2 + } + tab := desctestutils.TestingGetPublicTableDescriptor(kvDB, keys.SystemSQLCodec, "defaultdb", r(`tablename`)) + for i := 0; i < numRows; i++ { + var row []string + for _, col := range tab.WritableColumns() { + if col.GetName() == "rowid" { + continue + } + var d tree.Datum + if i == 0 { + d = tree.DNull + } else { + const nullOk = false + d = randgen.RandDatum(seed, col.GetType(), nullOk) + } + row = append(row, tree.AsStringWithFlags(d, tree.FmtParsable)) + } + tdb.Exec(t, fmt.Sprintf(r(`INSERT INTO tablename VALUES (%s)`), + strings.Join(row, ", "))) + } + } + + for _, typ := range typesToTest { + t.Run(typ.String(), func(t *testing.T) { + conn, err := sqlDB.Conn(ctx) + require.NoError(t, err) + tdb := sqlutils.MakeSQLRunner(conn) + r := createTable(t, tdb, typ) + insertRows(t, tdb, r) + var values string + tdb.QueryRow(t, r(`SELECT array_to_string(array_agg(c::string),', ') FROM tablename`)).Scan(&values) + t.Log(values) + t.Run("%s does not error", func(t *testing.T) { + tdb.Exec(t, r(`SELECT format('%s',c) from tablename`)) + }) + t.Run("%I creates a valid identifier", func(t *testing.T) { + stmts := tdb.Query(t, + r(`SELECT format('CREATE TABLE IF NOT EXISTS %I (i int)', c) FROM tablename WHERE c IS NOT NULL`), + ) + var shouldBeValidStmt string + for stmts.Next() { + require.NoError(t, stmts.Scan(&shouldBeValidStmt)) + tdb.Exec(t, shouldBeValidStmt) + } + }) + t.Run("%L creates a literal logically equivalent to the value", func(t *testing.T) { + if typ.Family() == types.ArrayFamily { + skip.WithIssue(t, 84274) + } + stmts := tdb.Query(t, r(`SELECT rowid, format('%L', c) FROM tablename WHERE c IS NOT NULL`)) + var literal string + var rowid int + queries := make([]string, 0, 10) + for stmts.Next() { + require.NoError(t, stmts.Scan(&rowid, &literal)) + queries = append(queries, + fmt.Sprintf(r(`SELECT count(*) FROM tablename WHERE rowid=%d AND c=%s`), + rowid, literal)) + } + for _, query := range queries { + tdb.CheckQueryResults(t, query, [][]string{{`1`}}) + } + }) + + }) + } + +} diff --git a/pkg/sql/sem/builtins/pgformat/fuzz.go b/pkg/sql/sem/builtins/pgformat/fuzz.go new file mode 100644 index 000000000000..973a061e40b6 --- /dev/null +++ b/pkg/sql/sem/builtins/pgformat/fuzz.go @@ -0,0 +1,34 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +//go:build gofuzz +// +build gofuzz + +package pgformat + +import "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" +import "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + +// FuzzFormat passes the input to pgformat.Format() +// as both the format string and format arguments. +func FuzzFormat(input []byte) int { + ctx := eval.MakeTestingEvalContext(nil) + str := string(input) + args := make(tree.Datums, 16) + for i, _ := range args { + args[i] = tree.NewDString(string(input)) + } + _, err := Format(&ctx, str, args...) + + if err == nil { + return 0 + } + return 1 +} diff --git a/pkg/sql/sem/builtins/pgformat/main_test.go b/pkg/sql/sem/builtins/pgformat/main_test.go new file mode 100644 index 000000000000..e0c54f8fc5cb --- /dev/null +++ b/pkg/sql/sem/builtins/pgformat/main_test.go @@ -0,0 +1,33 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package pgformat_test + +import ( + "os" + "testing" + + "github.com/cockroachdb/cockroach/pkg/security/securityassets" + "github.com/cockroachdb/cockroach/pkg/security/securitytest" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/testcluster" + "github.com/cockroachdb/cockroach/pkg/util/randutil" +) + +//go:generate ../../../util/leaktest/add-leaktest.sh *_test.go + +func TestMain(m *testing.M) { + securityassets.SetLoader(securitytest.EmbeddedAssets) + randutil.SeedForTests() + serverutils.InitTestServerFactory(server.TestServerFactory) + serverutils.InitTestClusterFactory(testcluster.TestClusterFactory) + os.Exit(m.Run()) +}