Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always make sure to escape all strings #17649

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions go/test/endtoend/vreplication/vdiff_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,7 @@ func getVDiffInfo(json string) *vdiffInfo {
}

func encodeString(in string) string {
var buf strings.Builder
sqltypes.NewVarChar(in).EncodeSQL(&buf)
return buf.String()
return sqltypes.EncodeStringSQL(in)
}

// generateMoreCustomers creates additional test data for better tests
Expand Down
22 changes: 10 additions & 12 deletions go/vt/binlog/binlogplayer/binlog_player.go
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ func (blp *BinlogPlayer) setVReplicationState(state binlogdatapb.VReplicationWor
})
}
blp.blplStats.State.Store(state.String())
query := fmt.Sprintf("update _vt.vreplication set state='%v', message=%v where id=%v", state.String(), encodeString(MessageTruncate(message)), blp.uid)
query := fmt.Sprintf("update _vt.vreplication set state=%v, message=%v where id=%v", encodeString(state.String()), encodeString(MessageTruncate(message)), blp.uid)
if _, err := blp.dbClient.ExecuteFetch(query, 1); err != nil {
return fmt.Errorf("could not set state: %v: %v", query, err)
}
Expand Down Expand Up @@ -637,9 +637,9 @@ func CreateVReplication(workflow string, source *binlogdatapb.BinlogSource, posi
protoutil.SortBinlogSourceTables(source)
return fmt.Sprintf("insert into _vt.vreplication "+
"(workflow, source, pos, max_tps, max_replication_lag, time_updated, transaction_timestamp, state, db_name, workflow_type, workflow_sub_type, defer_secondary_keys, options) "+
"values (%v, %v, %v, %v, %v, %v, 0, '%v', %v, %d, %d, %v, %s)",
"values (%v, %v, %v, %v, %v, %v, 0, %v, %v, %d, %d, %v, %s)",
encodeString(workflow), encodeString(source.String()), encodeString(position), maxTPS, maxReplicationLag,
timeUpdated, binlogdatapb.VReplicationWorkflowState_Running.String(), encodeString(dbName), workflowType,
timeUpdated, encodeString(binlogdatapb.VReplicationWorkflowState_Running.String()), encodeString(dbName), workflowType,
workflowSubType, deferSecondaryKeys, encodeString("{}"))
}

Expand All @@ -649,9 +649,9 @@ func CreateVReplicationState(workflow string, source *binlogdatapb.BinlogSource,
protoutil.SortBinlogSourceTables(source)
return fmt.Sprintf("insert into _vt.vreplication "+
"(workflow, source, pos, max_tps, max_replication_lag, time_updated, transaction_timestamp, state, db_name, workflow_type, workflow_sub_type, options) "+
"values (%v, %v, %v, %v, %v, %v, 0, '%v', %v, %d, %d, %s)",
"values (%v, %v, %v, %v, %v, %v, 0, %v, %v, %d, %d, %s)",
encodeString(workflow), encodeString(source.String()), encodeString(position), throttler.MaxRateModuleDisabled,
throttler.ReplicationLagModuleDisabled, time.Now().Unix(), state.String(), encodeString(dbName),
throttler.ReplicationLagModuleDisabled, time.Now().Unix(), encodeString(state.String()), encodeString(dbName),
workflowType, workflowSubType, encodeString("{}"))
}

Expand Down Expand Up @@ -694,15 +694,15 @@ func GenerateUpdateTimeThrottled(uid int32, timeThrottledUnix int64, componentTh
// StartVReplicationUntil returns a statement to start the replication with a stop position.
func StartVReplicationUntil(uid int32, pos string) string {
return fmt.Sprintf(
"update _vt.vreplication set state='%v', stop_pos=%v where id=%v",
binlogdatapb.VReplicationWorkflowState_Running.String(), encodeString(pos), uid)
"update _vt.vreplication set state=%v, stop_pos=%v where id=%v",
encodeString(binlogdatapb.VReplicationWorkflowState_Running.String()), encodeString(pos), uid)
}

// StopVReplication returns a statement to stop the replication.
func StopVReplication(uid int32, message string) string {
return fmt.Sprintf(
"update _vt.vreplication set state='%v', message=%v where id=%v",
binlogdatapb.VReplicationWorkflowState_Stopped.String(), encodeString(MessageTruncate(message)), uid)
"update _vt.vreplication set state=%v, message=%v where id=%v",
encodeString(binlogdatapb.VReplicationWorkflowState_Stopped.String()), encodeString(MessageTruncate(message)), uid)
}

// DeleteVReplication returns a statement to delete the replication.
Expand All @@ -717,9 +717,7 @@ func MessageTruncate(msg string) string {
}

func encodeString(in string) string {
buf := bytes.NewBuffer(nil)
sqltypes.NewVarChar(in).EncodeSQL(buf)
return buf.String()
return sqltypes.EncodeStringSQL(in)
}

// ReadVReplicationPos returns a statement to query the gtid for a
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtctl/vdiff_env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func newTestVDiffEnv(t testing.TB, ctx context.Context, sourceShards, targetShar
// But this is one statement per stream.
env.tmc.setVRResults(
primary.tablet,
fmt.Sprintf("update _vt.vreplication set state='Running', stop_pos='%s', message='synchronizing for vdiff' where id=%d", vdiffSourceGtid, j+1),
fmt.Sprintf("update _vt.vreplication set state='Running', stop_pos=%s, message='synchronizing for vdiff' where id=%d", sqltypes.EncodeStringSQL(vdiffSourceGtid), j+1),
&sqltypes.Result{},
)
}
Expand Down
1 change: 0 additions & 1 deletion go/vt/vtctl/workflow/resharder.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ func (rs *resharder) createStreams(ctx context.Context) error {
if err != nil {
return err
}
optionsJSON = fmt.Sprintf("'%s'", optionsJSON)
for _, source := range rs.sourceShards {
if !key.KeyRangeIntersect(target.KeyRange, source.KeyRange) {
continue
Expand Down
10 changes: 5 additions & 5 deletions go/vt/vtctl/workflow/traffic_switcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -858,8 +858,8 @@ func (ts *trafficSwitcher) getReverseVReplicationUpdateQuery(targetCell string,
}

if ts.optCells != "" || ts.optTabletTypes != "" {
query := fmt.Sprintf("update _vt.vreplication set cell = '%s', tablet_types = '%s', options = '%s' where workflow = '%s' and db_name = '%s'",
ts.optCells, ts.optTabletTypes, options, ts.ReverseWorkflowName(), dbname)
query := fmt.Sprintf("update _vt.vreplication set cell = %s, tablet_types = %s, options = %s where workflow = %s and db_name = %s",
sqltypes.EncodeStringSQL(ts.optCells), sqltypes.EncodeStringSQL(ts.optTabletTypes), sqltypes.EncodeStringSQL(options), sqltypes.EncodeStringSQL(ts.ReverseWorkflowName()), sqltypes.EncodeStringSQL(dbname))
return query
}
return ""
Expand Down Expand Up @@ -941,8 +941,8 @@ func (ts *trafficSwitcher) createReverseVReplication(ctx context.Context) error
// For non-reference tables we return an error if there's no primary
// vindex as it's not clear what to do.
if len(vtable.ColumnVindexes) > 0 && len(vtable.ColumnVindexes[0].Columns) > 0 {
inKeyrange = fmt.Sprintf(" where in_keyrange(%s, '%s.%s', '%s')", sqlparser.String(vtable.ColumnVindexes[0].Columns[0]),
ts.SourceKeyspaceName(), vtable.ColumnVindexes[0].Name, key.KeyRangeString(source.GetShard().KeyRange))
inKeyrange = fmt.Sprintf(" where in_keyrange(%s, '%s.%s', %s)", sqlparser.String(vtable.ColumnVindexes[0].Columns[0]),
ts.SourceKeyspaceName(), vtable.ColumnVindexes[0].Name, encodeString(key.KeyRangeString(source.GetShard().KeyRange)))
} else {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "no primary vindex found for the %s table in the %s keyspace",
vtable.Name.String(), ts.SourceKeyspaceName())
Expand Down Expand Up @@ -1184,7 +1184,7 @@ func (ts *trafficSwitcher) freezeTargetVReplication(ctx context.Context) error {
// re-invoked after a freeze, it will skip all the previous steps
err := ts.ForAllTargets(func(target *MigrationTarget) error {
ts.Logger().Infof("Marking target streams frozen for workflow %s db_name %s", ts.WorkflowName(), target.GetPrimary().DbName())
query := fmt.Sprintf("update _vt.vreplication set message = '%s' where db_name=%s and workflow=%s", Frozen,
query := fmt.Sprintf("update _vt.vreplication set message = %s where db_name=%s and workflow=%s", encodeString(Frozen),
encodeString(target.GetPrimary().DbName()), encodeString(ts.WorkflowName()))
_, err := ts.TabletManagerClient().VReplicationExec(ctx, target.GetPrimary().Tablet, query)
return err
Expand Down
5 changes: 1 addition & 4 deletions go/vt/vtctl/workflow/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package workflow

import (
"bytes"
"context"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -627,9 +626,7 @@ func ReverseWorkflowName(workflow string) string {
// this public, but it doesn't belong in package workflow. Maybe package sqltypes,
// or maybe package sqlescape?
func encodeString(in string) string {
Copy link
Contributor

@timvaillancourt timvaillancourt Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This identical func is defined in several files, should we make a common helper or just use sqltypes.EncodeStringSQL(...) directly? 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we can clean up these as well now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@timvaillancourt (and also @mattlord since you asked me about this too) Given that I aim to back port this, I think this is better / safer in a separate change. Making this change is a lot of churn / changes across many files which makes the back port harder.

So currently thinking of doing this cleanup then in a follow up PR that we don't need to back port.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

buf := bytes.NewBuffer(nil)
sqltypes.NewVarChar(in).EncodeSQL(buf)
return buf.String()
return sqltypes.EncodeStringSQL(in)
}

func getRenameFileName(tableName string) string {
Expand Down
5 changes: 1 addition & 4 deletions go/vt/vttablet/endtoend/vstreamer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package endtoend

import (
"bytes"
"context"
"errors"
"fmt"
Expand Down Expand Up @@ -472,9 +471,7 @@ func expectLogs(ctx context.Context, t *testing.T, query string, eventCh chan []
}

func encodeString(in string) string {
buf := bytes.NewBuffer(nil)
sqltypes.NewVarChar(in).EncodeSQL(buf)
return buf.String()
return sqltypes.EncodeStringSQL(in)
}

func validateSchemaInserted(client *framework.QueryClient, ddl string) bool {
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vttablet/onlineddl/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1571,8 +1571,8 @@ func (e *Executor) ExecuteWithVReplication(ctx context.Context, onlineDDL *schem

{
// temporary hack. todo: this should be done when inserting any _vt.vreplication record across all workflow types
query := fmt.Sprintf("update _vt.vreplication set workflow_type = %d where workflow = '%s'",
binlogdatapb.VReplicationWorkflowType_OnlineDDL, v.workflow)
query := fmt.Sprintf("update _vt.vreplication set workflow_type = %d where workflow = %s",
binlogdatapb.VReplicationWorkflowType_OnlineDDL, sqltypes.EncodeStringSQL(v.workflow))
if _, err := e.vreplicationExec(ctx, tablet.Tablet, query); err != nil {
return vterrors.Wrapf(err, "VReplicationExec(%v, %s)", tablet.Tablet, query)
}
Expand Down
5 changes: 1 addition & 4 deletions go/vt/vttablet/tabletmanager/vdiff/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package vdiff
import (
"context"
"fmt"
"strings"

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/evalengine"
Expand Down Expand Up @@ -59,9 +58,7 @@ func newMergeSorter(participants map[string]*shardStreamer, comparePKs []compare
// Utility functions

func encodeString(in string) string {
var buf strings.Builder
sqltypes.NewVarChar(in).EncodeSQL(&buf)
return buf.String()
return sqltypes.EncodeStringSQL(in)
}

func pkColsToGroupByParams(pkCols []int, collationEnv *collations.Environment) []*engine.GroupByParams {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ func NewInsertGenerator(state binlogdatapb.VReplicationWorkflowState, dbname str
func (ig *InsertGenerator) AddRow(workflow string, bls *binlogdatapb.BinlogSource, pos, cell, tabletTypes string,
workflowType binlogdatapb.VReplicationWorkflowType, workflowSubType binlogdatapb.VReplicationWorkflowSubType, deferSecondaryKeys bool, options string) {
if options == "" {
options = "'{}'"
options = "{}"
}
protoutil.SortBinlogSourceTables(bls)
fmt.Fprintf(ig.buf, "%s(%v, %v, %v, %v, %v, %v, %v, %v, 0, '%v', %v, %d, %d, %v, %v)",
fmt.Fprintf(ig.buf, "%s(%v, %v, %v, %v, %v, %v, %v, %v, 0, %v, %v, %d, %d, %v, %v)",
ig.prefix,
encodeString(workflow),
encodeString(bls.String()),
Expand All @@ -66,12 +66,12 @@ func (ig *InsertGenerator) AddRow(workflow string, bls *binlogdatapb.BinlogSourc
encodeString(cell),
encodeString(tabletTypes),
ig.now,
ig.state,
encodeString(ig.state),
encodeString(ig.dbname),
workflowType,
workflowSubType,
deferSecondaryKeys,
options,
encodeString(options),
)
ig.prefix = ", "
}
Expand Down
6 changes: 2 additions & 4 deletions go/vt/vttablet/tabletmanager/vreplication/vreplicator.go
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ func (vr *vreplicator) setState(state binlogdatapb.VReplicationWorkflowState, me
})
}
vr.stats.State.Store(state.String())
query := fmt.Sprintf("update _vt.vreplication set state='%v', message=%v where id=%v", state, encodeString(binlogplayer.MessageTruncate(message)), vr.id)
query := fmt.Sprintf("update _vt.vreplication set state=%v, message=%v where id=%v", encodeString(state.String()), encodeString(binlogplayer.MessageTruncate(message)), vr.id)
// If we're batching a transaction, then include the state update
// in the current transaction batch.
if vr.dbClient.InTransaction && vr.dbClient.maxBatchSize > 0 {
Expand All @@ -528,9 +528,7 @@ func (vr *vreplicator) setState(state binlogdatapb.VReplicationWorkflowState, me
}

func encodeString(in string) string {
var buf strings.Builder
sqltypes.NewVarChar(in).EncodeSQL(&buf)
return buf.String()
return sqltypes.EncodeStringSQL(in)
}

func (vr *vreplicator) getSettingFKCheck() error {
Expand Down
13 changes: 8 additions & 5 deletions go/vt/vttablet/tabletserver/schema/tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ limitations under the License.
package schema

import (
"bytes"
"context"
"fmt"
"sync"
"time"

"vitess.io/vitess/go/bytes2"
"vitess.io/vitess/go/constants/sidecar"
"vitess.io/vitess/go/mysql/replication"
"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -231,10 +231,15 @@ func (tr *Tracker) saveCurrentSchemaToDb(ctx context.Context, gtid, ddl string,
}
defer conn.Recycle()

// We serialize a blob here, encodeString is for strings only
// and should not be used for binary data.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found this with cleaning up the helpers, that here we were encoding a blob as a string which shouldn't be done.

blobVal := sqltypes.MakeTrusted(sqltypes.VarBinary, blob)
buf := bytes2.Buffer{}
blobVal.EncodeSQLBytes2(&buf)
query := sqlparser.BuildParsedQuery("insert into %s.schema_version "+
"(pos, ddl, schemax, time_updated) "+
"values (%s, %s, %s, %d)", sidecar.GetIdentifier(), encodeString(gtid),
encodeString(ddl), encodeString(string(blob)), timestamp).Query
encodeString(ddl), buf.String(), timestamp).Query
_, err = conn.Conn.Exec(ctx, query, 1, false)
if err != nil {
return err
Expand All @@ -243,9 +248,7 @@ func (tr *Tracker) saveCurrentSchemaToDb(ctx context.Context, gtid, ddl string,
}

func encodeString(in string) string {
buf := bytes.NewBuffer(nil)
sqltypes.NewVarChar(in).EncodeSQL(buf)
return buf.String()
return sqltypes.EncodeStringSQL(in)
}

// MustReloadSchemaOnDDL returns true if the ddl is for the db which is part of the workflow and is not an online ddl artifact
Expand Down
4 changes: 1 addition & 3 deletions go/vt/vttablet/tabletserver/vstreamer/vstreamer.go
Original file line number Diff line number Diff line change
Expand Up @@ -960,9 +960,7 @@ type extColInfo struct {
}

func encodeString(in string) string {
buf := bytes.NewBuffer(nil)
sqltypes.NewVarChar(in).EncodeSQL(buf)
return buf.String()
return sqltypes.EncodeStringSQL(in)
}

func (vs *vstreamer) processJournalEvent(vevents []*binlogdatapb.VEvent, plan *streamerPlan, rows mysql.Rows) ([]*binlogdatapb.VEvent, error) {
Expand Down
9 changes: 8 additions & 1 deletion go/vt/vttablet/tabletserver/vstreamer/vstreamer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"

"vitess.io/vitess/go/bytes2"
"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -371,9 +372,15 @@ func TestVersion(t *testing.T) {
}
blob, _ := dbSchema.MarshalVT()
gtid := "MariaDB/0-41983-20"
// We serialize a blob here, encodeString is for strings only
// and should not be used for binary data.
blobVal := sqltypes.MakeTrusted(sqltypes.VarBinary, blob)
buf := bytes2.Buffer{}
blobVal.EncodeSQLBytes2(&buf)

testcases := []testcase{{
input: []string{
fmt.Sprintf("insert into _vt.schema_version values(1, '%s', 123, 'create table t1', %v)", gtid, encodeString(string(blob))),
fmt.Sprintf("insert into _vt.schema_version values(1, '%s', 123, 'create table t1', %v)", gtid, buf.String()),
},
// External table events don't get sent.
output: [][]string{{
Expand Down
5 changes: 1 addition & 4 deletions go/vt/wrangler/keyspace.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package wrangler

import (
"bytes"
"context"
"errors"
"fmt"
Expand Down Expand Up @@ -125,7 +124,5 @@ func (wr *Wrangler) updateShardRecords(ctx context.Context, keyspace string, sha
}

func encodeString(in string) string {
buf := bytes.NewBuffer(nil)
sqltypes.NewVarChar(in).EncodeSQL(buf)
return buf.String()
return sqltypes.EncodeStringSQL(in)
}
Loading