Skip to content

Commit

Permalink
Merge pull request #8439 from planetscale/no-scatter
Browse files Browse the repository at this point in the history
Add "no-scatter" flag to prohibit the use of scatter queries
  • Loading branch information
harshit-gangal authored Jul 13, 2021
2 parents 194f15d + b452048 commit 0ecc59e
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 16 deletions.
20 changes: 20 additions & 0 deletions go/vt/sqlparser/comments.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ const (
DirectiveIgnoreMaxPayloadSize = "IGNORE_MAX_PAYLOAD_SIZE"
// DirectiveIgnoreMaxMemoryRows skips memory row validation when set.
DirectiveIgnoreMaxMemoryRows = "IGNORE_MAX_MEMORY_ROWS"
// DirectiveAllowScatter lets scatter plans pass through even when they are turned off by `no-scatter`.
DirectiveAllowScatter = "ALLOW_SCATTER"
)

func isNonSpace(r rune) bool {
Expand Down Expand Up @@ -385,3 +387,21 @@ func IgnoreMaxMaxMemoryRowsDirective(stmt Statement) bool {
return false
}
}

// AllowScatterDirective returns true if the allow scatter override is set to true
func AllowScatterDirective(stmt Statement) bool {
var directives CommentDirectives
switch stmt := stmt.(type) {
case *Select:
directives = ExtractCommentDirectives(stmt.Comments)
case *Insert:
directives = ExtractCommentDirectives(stmt.Comments)
case *Update:
directives = ExtractCommentDirectives(stmt.Comments)
case *Delete:
directives = ExtractCommentDirectives(stmt.Comments)
default:
return false
}
return directives.IsSet(DirectiveAllowScatter)
}
2 changes: 1 addition & 1 deletion go/vt/vtexplain/vtexplain_vtgate.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func initVtgateExecutor(vSchemaStr, ksShardMapStr string, opts *Options) error {

streamSize := 10
var schemaTracker vtgate.SchemaInfo // no schema tracker for these tests
vtgateExecutor = vtgate.NewExecutor(context.Background(), explainTopo, vtexplainCell, resolver, opts.Normalize, false /*do not warn for sharded only*/, streamSize, cache.DefaultConfig, schemaTracker)
vtgateExecutor = vtgate.NewExecutor(context.Background(), explainTopo, vtexplainCell, resolver, opts.Normalize, false /*do not warn for sharded only*/, streamSize, cache.DefaultConfig, schemaTracker, false /*no-scatter*/)

return nil
}
Expand Down
36 changes: 31 additions & 5 deletions go/vt/vtgate/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ type Executor struct {

vm *VSchemaManager
schemaTracker SchemaInfo

// allowScatter will fail planning if set to false and a plan contains any scatter queries
allowScatter bool
}

var executorOnce sync.Once
Expand All @@ -118,6 +121,7 @@ func NewExecutor(
streamSize int,
cacheCfg *cache.Config,
schemaTracker SchemaInfo,
noScatter bool,
) *Executor {
e := &Executor{
serv: serv,
Expand All @@ -130,6 +134,7 @@ func NewExecutor(
warnShardedOnly: warnOnShardedOnly,
streamSize: streamSize,
schemaTracker: schemaTracker,
allowScatter: !noScatter,
}

vschemaacl.Init()
Expand Down Expand Up @@ -208,7 +213,7 @@ func (e *Executor) execute(ctx context.Context, safeSession *SafeSession, sql st
}

func (e *Executor) legacyExecute(ctx context.Context, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *LogStats) (sqlparser.StatementType, *sqltypes.Result, error) {
//Start an implicit transaction if necessary.
// Start an implicit transaction if necessary.
if !safeSession.Autocommit && !safeSession.InTransaction() {
if err := e.txConn.Begin(ctx, safeSession); err != nil {
return 0, nil, err
Expand Down Expand Up @@ -398,7 +403,7 @@ func (e *Executor) handleCommit(ctx context.Context, safeSession *SafeSession, l
return &sqltypes.Result{}, err
}

//Commit commits the existing transactions
// Commit commits the existing transactions
func (e *Executor) Commit(ctx context.Context, safeSession *SafeSession) error {
return e.txConn.Commit(ctx, safeSession)
}
Expand Down Expand Up @@ -552,7 +557,7 @@ func getValueFor(expr *sqlparser.SetExpr) (interface{}, error) {
}

func (e *Executor) handleSetVitessMetadata(ctx context.Context, name, value string) (*sqltypes.Result, error) {
//TODO(kalfonso): move to its own acl check and consolidate into an acl component that can handle multiple operations (vschema, metadata)
// TODO(kalfonso): move to its own acl check and consolidate into an acl component that can handle multiple operations (vschema, metadata)
user := callerid.ImmediateCallerIDFromContext(ctx)
allowed := vschemaacl.Authorized(user)
if !allowed {
Expand Down Expand Up @@ -1240,7 +1245,8 @@ func (e *Executor) getPlan(vcursor *vcursorImpl, sql string, comments sqlparser.
if !skipQueryPlanCache && !sqlparser.SkipQueryPlanCacheDirective(statement) && sqlparser.CachePlan(statement) {
e.plans.Set(planKey, plan)
}
return plan, nil

return e.checkThatPlanIsValid(stmt, plan)
}

// skipQueryPlanCache extracts SkipQueryPlanCache from session
Expand Down Expand Up @@ -1457,7 +1463,7 @@ func (e *Executor) handlePrepare(ctx context.Context, safeSession *SafeSession,
var errCount uint64
if err != nil {
logStats.Error = err
errCount = 1 //nolint
errCount = 1 // nolint
return nil, err
}
logStats.RowsAffected = qr.RowsAffected
Expand Down Expand Up @@ -1515,3 +1521,23 @@ func (e *Executor) startVStream(ctx context.Context, rss []*srvtopo.ResolvedShar
vs.stream(ctx)
return nil
}

func (e *Executor) checkThatPlanIsValid(stmt sqlparser.Statement, plan *engine.Plan) (*engine.Plan, error) {
if e.allowScatter || sqlparser.AllowScatterDirective(stmt) {
return plan, nil
}
// we go over all the primitives in the plan, searching for a route that is of SelectScatter opcode
badPrimitive := engine.Find(func(node engine.Primitive) bool {
router, ok := node.(*engine.Route)
if !ok {
return false
}
return router.Opcode == engine.SelectScatter
}, plan.Instructions)

if badPrimitive == nil {
return plan, nil
}

return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "plan includes scatter, which is disallowed using the `no_scatter` command line argument")
}
16 changes: 11 additions & 5 deletions go/vt/vtgate/executor_framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"strings"
"testing"

vtgatepb "vitess.io/vitess/go/vt/proto/vtgate"

"github.com/stretchr/testify/require"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -398,7 +400,7 @@ func createLegacyExecutorEnv() (executor *Executor, sbc1, sbc2, sbclookup *sandb
bad.VSchema = badVSchema

getSandbox(KsTestUnsharded).VSchema = unshardedVSchema
executor = NewExecutor(context.Background(), serv, cell, resolver, false, false, testBufferSize, cache.DefaultConfig, nil)
executor = NewExecutor(context.Background(), serv, cell, resolver, false, false, testBufferSize, cache.DefaultConfig, nil, false)

key.AnyShardPicker = DestinationAnyShardPickerFirstShard{}
return executor, sbc1, sbc2, sbclookup
Expand Down Expand Up @@ -433,7 +435,7 @@ func createExecutorEnv() (executor *Executor, sbc1, sbc2, sbclookup *sandboxconn
bad.VSchema = badVSchema

getSandbox(KsTestUnsharded).VSchema = unshardedVSchema
executor = NewExecutor(context.Background(), serv, cell, resolver, false, false, testBufferSize, cache.DefaultConfig, nil)
executor = NewExecutor(context.Background(), serv, cell, resolver, false, false, testBufferSize, cache.DefaultConfig, nil, false)

key.AnyShardPicker = DestinationAnyShardPickerFirstShard{}
return executor, sbc1, sbc2, sbclookup
Expand All @@ -453,19 +455,23 @@ func createCustomExecutor(vschema string) (executor *Executor, sbc1, sbc2, sbclo
sbclookup = hc.AddTestTablet(cell, "0", 1, KsTestUnsharded, "0", topodatapb.TabletType_MASTER, true, 1, nil)
getSandbox(KsTestUnsharded).VSchema = unshardedVSchema

executor = NewExecutor(context.Background(), serv, cell, resolver, false, false, testBufferSize, cache.DefaultConfig, nil)
executor = NewExecutor(context.Background(), serv, cell, resolver, false, false, testBufferSize, cache.DefaultConfig, nil, false)
return executor, sbc1, sbc2, sbclookup
}

func executorExec(executor *Executor, sql string, bv map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
func executorExecSession(executor *Executor, sql string, bv map[string]*querypb.BindVariable, session *vtgatepb.Session) (*sqltypes.Result, error) {
return executor.Execute(
context.Background(),
"TestExecute",
NewSafeSession(masterSession),
NewSafeSession(session),
sql,
bv)
}

func executorExec(executor *Executor, sql string, bv map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
return executorExecSession(executor, sql, bv, masterSession)
}

func executorPrepare(executor *Executor, sql string, bv map[string]*querypb.BindVariable) ([]*querypb.Field, error) {
return executor.Prepare(
context.Background(),
Expand Down
45 changes: 43 additions & 2 deletions go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,7 @@ func TestStreamSelectIN(t *testing.T) {
}

func createExecutor(serv *sandboxTopo, cell string, resolver *Resolver) *Executor {
return NewExecutor(context.Background(), serv, cell, resolver, false, false, testBufferSize, cache.DefaultConfig, nil)
return NewExecutor(context.Background(), serv, cell, resolver, false, false, testBufferSize, cache.DefaultConfig, nil, false)
}

func TestSelectScatter(t *testing.T) {
Expand Down Expand Up @@ -2540,7 +2540,7 @@ func TestStreamOrderByLimitWithMultipleResults(t *testing.T) {
count++
}

executor := NewExecutor(context.Background(), serv, cell, resolver, true, false, testBufferSize, cache.DefaultConfig, nil)
executor := NewExecutor(context.Background(), serv, cell, resolver, true, false, testBufferSize, cache.DefaultConfig, nil, false)
before := runtime.NumGoroutine()

query := "select id, col from user order by id limit 2"
Expand All @@ -2553,3 +2553,44 @@ func TestStreamOrderByLimitWithMultipleResults(t *testing.T) {
time.Sleep(100 * time.Millisecond)
assert.GreaterOrEqual(t, before, runtime.NumGoroutine(), "left open goroutines lingering")
}

func TestSelectScatterFails(t *testing.T) {
sess := &vtgatepb.Session{}
cell := "aa"
hc := discovery.NewFakeHealthCheck()
s := createSandbox("TestExecutor")
s.VSchema = executorVSchema
getSandbox(KsTestUnsharded).VSchema = unshardedVSchema
serv := new(sandboxTopo)
resolver := newTestResolver(hc, serv, cell)

shards := []string{"-20", "20-40", "40-60", "60-80", "80-a0", "a0-c0", "c0-e0", "e0-"}
for i, shard := range shards {
sbc := hc.AddTestTablet(cell, shard, 1, "TestExecutor", shard, topodatapb.TabletType_MASTER, true, 1, nil)
sbc.SetResults([]*sqltypes.Result{{
Fields: []*querypb.Field{
{Name: "col1", Type: sqltypes.Int32},
{Name: "col2", Type: sqltypes.Int32},
{Name: "weight_string(col2)", Type: sqltypes.VarBinary},
},
InsertID: 0,
Rows: [][]sqltypes.Value{{
sqltypes.NewInt32(1),
sqltypes.NewInt32(int32(i % 4)),
sqltypes.NULL,
}},
}})
}

executor := createExecutor(serv, cell, resolver)
executor.allowScatter = false
logChan := QueryLogger.Subscribe("Test")
defer QueryLogger.Unsubscribe(logChan)

_, err := executorExecSession(executor, "select id from user", nil, sess)
require.Error(t, err)
assert.Contains(t, err.Error(), "scatter")

_, err = executorExecSession(executor, "select /*vt+ ALLOW_SCATTER */ id from user", nil, sess)
require.NoError(t, err)
}
2 changes: 1 addition & 1 deletion go/vt/vtgate/executor_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestStreamSQLSharded(t *testing.T) {
for _, shard := range shards {
_ = hc.AddTestTablet(cell, shard, 1, "TestExecutor", shard, topodatapb.TabletType_MASTER, true, 1, nil)
}
executor := NewExecutor(context.Background(), serv, cell, resolver, false, false, testBufferSize, cache.DefaultConfig, nil)
executor := NewExecutor(context.Background(), serv, cell, resolver, false, false, testBufferSize, cache.DefaultConfig, nil, false)

sql := "stream * from sharded_user_msgs"
result, err := executorStreamMessages(executor, sql)
Expand Down
5 changes: 3 additions & 2 deletions go/vt/vtgate/vtgate.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ var (
warnMemoryRows = flag.Int("warn_memory_rows", 30000, "Warning threshold for in-memory results. A row count higher than this amount will cause the VtGateWarnings.ResultsExceeded counter to be incremented.")
defaultDDLStrategy = flag.String("ddl_strategy", string(schema.DDLStrategyDirect), "Set default strategy for DDL statements. Override with @@ddl_strategy session variable")
dbDDLPlugin = flag.String("dbddl_plugin", "fail", "controls how to handle CREATE/DROP DATABASE. use it if you are using your own database provisioning service")
noScatter = flag.Bool("no_scatter", false, "when set to true, the planner will fail instead of producing a plan that includes scatter queries")

// TODO(deepthi): change these two vars to unexported and move to healthcheck.go when LegacyHealthcheck is removed

Expand Down Expand Up @@ -214,7 +215,7 @@ func Init(ctx context.Context, serv srvtopo.Server, cell string, tabletTypesToWa
LFU: *queryPlanCacheLFU,
}

executor := NewExecutor(ctx, serv, cell, resolver, *normalizeQueries, *warnShardedOnly, *streamBufferSize, cacheCfg, si)
executor := NewExecutor(ctx, serv, cell, resolver, *normalizeQueries, *warnShardedOnly, *streamBufferSize, cacheCfg, si, *noScatter)

// connect the schema tracker with the vschema manager
if *enableSchemaChangeSignal {
Expand Down Expand Up @@ -618,7 +619,7 @@ func LegacyInit(ctx context.Context, hc discovery.LegacyHealthCheck, serv srvtop
}

rpcVTGate = &VTGate{
executor: NewExecutor(ctx, serv, cell, resolver, *normalizeQueries, *warnShardedOnly, *streamBufferSize, cacheCfg, nil),
executor: NewExecutor(ctx, serv, cell, resolver, *normalizeQueries, *warnShardedOnly, *streamBufferSize, cacheCfg, nil, *noScatter),
resolver: resolver,
vsm: vsm,
txConn: tc,
Expand Down

0 comments on commit 0ecc59e

Please sign in to comment.