Skip to content

Commit

Permalink
Enable 'dolt sql' to connect to local server instances.
Browse files Browse the repository at this point in the history
Smart routing for sql command execution against running local servers.

#3922
  • Loading branch information
macneale4 authored May 22, 2023
2 parents a4a97d2 + cb8a970 commit 82cc919
Show file tree
Hide file tree
Showing 13 changed files with 514 additions and 182 deletions.
99 changes: 53 additions & 46 deletions go/cmd/dolt/commands/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ const (
UserFlag = "user"
DefaultUser = "root"
DefaultHost = "localhost"
UseDbFlag = "use-db"

welcomeMsg = `# Welcome to the DoltSQL shell.
# Statements must be terminated with ';'.
Expand Down Expand Up @@ -233,7 +234,7 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE
fi, err := os.Stdin.Stat()
if err != nil {
if !osutil.IsWindows {
return HandleVErrAndExitCode(errhand.BuildDError("Couldn't stat STDIN. This is a bug.").Build(), usage)
return sqlHandleVErrAndExitCode(queryist, errhand.BuildDError("Couldn't stat STDIN. This is a bug.").Build(), usage)
}
} else {
isTty = fi.Mode()&os.ModeCharDevice != 0
Expand All @@ -246,11 +247,11 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE
isTty = false
input, err = os.OpenFile(fileInput, os.O_RDONLY, os.ModePerm)
if err != nil {
return HandleVErrAndExitCode(errhand.BuildDError("couldn't open file %s", fileInput).Build(), usage)
return sqlHandleVErrAndExitCode(queryist, errhand.BuildDError("couldn't open file %s", fileInput).Build(), usage)
}
info, err := os.Stat(fileInput)
if err != nil {
return HandleVErrAndExitCode(errhand.BuildDError("couldn't get file size %s", fileInput).Build(), usage)
return sqlHandleVErrAndExitCode(queryist, errhand.BuildDError("couldn't get file size %s", fileInput).Build(), usage)
}

// initialize fileReadProg global variable if there is a file to process queries from
Expand All @@ -261,30 +262,59 @@ func (cmd SqlCmd) Exec(ctx context.Context, commandStr string, args []string, dE
if isTty {
err := execShell(sqlCtx, queryist, format)
if err != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
return sqlHandleVErrAndExitCode(queryist, errhand.VerboseErrorFromError(err), usage)
}
} else if runInBatchMode {
se, ok := queryist.(*engine.SqlEngine)
if !ok {
misuse := fmt.Errorf("Using batch with non-local access pattern. Stop server if it is running")
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(misuse), usage)
return sqlHandleVErrAndExitCode(queryist, errhand.VerboseErrorFromError(misuse), usage)
}

verr := execBatch(sqlCtx, se, input, continueOnError, format)
if verr != nil {
return HandleVErrAndExitCode(verr, usage)
return sqlHandleVErrAndExitCode(queryist, verr, usage)
}
} else {
err := execMultiStatements(sqlCtx, queryist, input, continueOnError, format)
if err != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
return sqlHandleVErrAndExitCode(queryist, errhand.VerboseErrorFromError(err), usage)
}
}
}

return 0
}

// sqlHandleVErrAndExitCode is a helper function to print errors to the user. Currently, the Queryist interface is used to
// determine if this is a local or remote execution. This is hacky, and too simplistic. We should possibly add an error
// messaging interface to the CliContext.
func sqlHandleVErrAndExitCode(queryist cli.Queryist, verr errhand.VerboseError, usage cli.UsagePrinter) int {
if verr != nil {
if msg := verr.Verbose(); strings.TrimSpace(msg) != "" {
if _, ok := queryist.(*engine.SqlEngine); !ok {
// We are in a context where we are attempting to connect to a remote database. These errors
// are unstructured, so we add some additional context around them.
tmpMsg := `You've encountered a new behavior in dolt which is not fully documented yet.
A local dolt server is using your dolt data directory, and in an attempt to service your request, we are attempting to
connect to it. That has failed. You should stop the server, or reach out to @macneale on https://discord.gg/gqr7K4VNKe
for help.`
cli.PrintErrln(tmpMsg)
msg = fmt.Sprintf("A local server is running, and dolt is failing to connect. Error connecting to remote database: \"%s\".\n", msg)
}
cli.PrintErrln(msg)
}

if verr.ShouldPrintUsage() {
usage()
}

return 1
}

return 0
}

// handleLegacyArguments is a temporary function to parse args, and print a error and explanation when the old form is provided.
func (cmd SqlCmd) handleLegacyArguments(ap *argparser.ArgParser, commandStr string, args []string) (*argparser.ArgParseResults, error) {

Expand Down Expand Up @@ -326,47 +356,47 @@ func (cmd SqlCmd) handleLegacyArguments(ap *argparser.ArgParser, commandStr stri

func listSavedQueries(ctx *sql.Context, qryist cli.Queryist, dEnv *env.DoltEnv, format engine.PrintResultFormat, usage cli.UsagePrinter) int {
if !dEnv.Valid() {
return HandleVErrAndExitCode(errhand.BuildDError("error: --%s must be used in a dolt database directory.", listSavedFlag).Build(), usage)
return sqlHandleVErrAndExitCode(qryist, errhand.BuildDError("error: --%s must be used in a dolt database directory.", listSavedFlag).Build(), usage)
}

workingRoot, err := dEnv.WorkingRoot(ctx)
if err != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
return sqlHandleVErrAndExitCode(qryist, errhand.VerboseErrorFromError(err), usage)
}

hasQC, err := workingRoot.HasTable(ctx, doltdb.DoltQueryCatalogTableName)

if err != nil {
verr := errhand.BuildDError("error: Failed to read from repository.").AddCause(err).Build()
return HandleVErrAndExitCode(verr, usage)
return sqlHandleVErrAndExitCode(qryist, verr, usage)
}

if !hasQC {
return 0
}

query := "SELECT * FROM " + doltdb.DoltQueryCatalogTableName
return HandleVErrAndExitCode(execQuery(ctx, qryist, query, format), usage)
return sqlHandleVErrAndExitCode(qryist, execQuery(ctx, qryist, query, format), usage)
}

func executeSavedQuery(ctx *sql.Context, qryist cli.Queryist, dEnv *env.DoltEnv, savedQueryName string, format engine.PrintResultFormat, usage cli.UsagePrinter) int {
if !dEnv.Valid() {
return HandleVErrAndExitCode(errhand.BuildDError("error: --%s must be used in a dolt database directory.", executeFlag).Build(), usage)
return sqlHandleVErrAndExitCode(qryist, errhand.BuildDError("error: --%s must be used in a dolt database directory.", executeFlag).Build(), usage)
}

workingRoot, err := dEnv.WorkingRoot(ctx)
if err != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
return sqlHandleVErrAndExitCode(qryist, errhand.VerboseErrorFromError(err), usage)
}

sq, err := dtables.RetrieveFromQueryCatalog(ctx, workingRoot, savedQueryName)

if err != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
return sqlHandleVErrAndExitCode(qryist, errhand.VerboseErrorFromError(err), usage)
}

cli.PrintErrf("Executing saved query '%s':\n%s\n", savedQueryName, sq.Query)
return HandleVErrAndExitCode(execQuery(ctx, qryist, sq.Query, format), usage)
return sqlHandleVErrAndExitCode(qryist, execQuery(ctx, qryist, sq.Query, format), usage)
}

func queryMode(
Expand All @@ -388,19 +418,19 @@ func queryMode(
se, ok := qryist.(*engine.SqlEngine)
if !ok {
misuse := fmt.Errorf("Using batch with non-local access pattern. Stop server if it is running")
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(misuse), usage)
return sqlHandleVErrAndExitCode(se, errhand.VerboseErrorFromError(misuse), usage)
}

batchInput := strings.NewReader(query)
verr := execBatch(ctx, se, batchInput, continueOnError, format)
if verr != nil {
return HandleVErrAndExitCode(verr, usage)
return sqlHandleVErrAndExitCode(qryist, verr, usage)
}
} else {
input := strings.NewReader(query)
err := execMultiStatements(ctx, qryist, input, continueOnError, format)
if err != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
return sqlHandleVErrAndExitCode(qryist, errhand.VerboseErrorFromError(err), usage)
}
}

Expand All @@ -409,58 +439,35 @@ func queryMode(

func execSaveQuery(ctx *sql.Context, dEnv *env.DoltEnv, qryist cli.Queryist, apr *argparser.ArgParseResults, query string, format engine.PrintResultFormat, usage cli.UsagePrinter) int {
if !dEnv.Valid() {
return HandleVErrAndExitCode(errhand.BuildDError("error: --%s must be used in a dolt database directory.", saveFlag).Build(), usage)
return sqlHandleVErrAndExitCode(qryist, errhand.BuildDError("error: --%s must be used in a dolt database directory.", saveFlag).Build(), usage)
}

saveName := apr.GetValueOrDefault(saveFlag, "")

verr := execQuery(ctx, qryist, query, format)
if verr != nil {
return HandleVErrAndExitCode(verr, usage)
return sqlHandleVErrAndExitCode(qryist, verr, usage)
}

workingRoot, err := dEnv.WorkingRoot(ctx)
if err != nil {
return HandleVErrAndExitCode(errhand.BuildDError("error: failed to get working root").AddCause(err).Build(), usage)
return sqlHandleVErrAndExitCode(qryist, errhand.BuildDError("error: failed to get working root").AddCause(err).Build(), usage)
}

saveMessage := apr.GetValueOrDefault(messageFlag, "")
newRoot, verr := saveQuery(ctx, workingRoot, query, saveName, saveMessage)
if verr != nil {
return HandleVErrAndExitCode(verr, usage)
return sqlHandleVErrAndExitCode(qryist, verr, usage)
}

err = dEnv.UpdateWorkingRoot(ctx, newRoot)
if err != nil {
return HandleVErrAndExitCode(errhand.BuildDError("error: failed to update working root").AddCause(err).Build(), usage)
return sqlHandleVErrAndExitCode(qryist, errhand.BuildDError("error: failed to update working root").AddCause(err).Build(), usage)
}

return 0
}

// getMultiRepoEnv returns an appropriate MultiRepoEnv for this invocation of the command
func getMultiRepoEnv(ctx context.Context, workingDir string, dEnv *env.DoltEnv) (mrEnv *env.MultiRepoEnv, resolvedDir string, verr errhand.VerboseError) {
var err error
fs := dEnv.FS
if len(workingDir) > 0 {
fs, err = fs.WithWorkingDir(workingDir)
}
if err != nil {
return nil, "", errhand.VerboseErrorFromError(err)
}
resolvedDir, err = fs.Abs("")
if err != nil {
return nil, "", errhand.VerboseErrorFromError(err)
}

mrEnv, err = env.MultiEnvForDirectory(ctx, dEnv.Config.WriteableConfig(), fs, dEnv.Version, dEnv.IgnoreLockFile, dEnv)
if err != nil {
return nil, "", errhand.VerboseErrorFromError(err)
}

return mrEnv, resolvedDir, nil
}

func execBatch(
sqlCtx *sql.Context,
se *engine.SqlEngine,
Expand Down
67 changes: 58 additions & 9 deletions go/cmd/dolt/commands/sqlserver/sqlclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ import (
"github.com/dolthub/dolt/go/cmd/dolt/commands/engine"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/utils/argparser"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
"github.com/dolthub/dolt/go/libraries/utils/iohelp"
)

const (
sqlClientDualFlag = "dual"
SqlClientQueryFlag = "query"
SqlClientUseDbFlag = "use-db"
sqlClientResultFormat = "result-format"
)

Expand Down Expand Up @@ -85,7 +85,7 @@ func (cmd SqlClientCmd) ArgParser() *argparser.ArgParser {
ap := SqlServerCmd{}.ArgParserWithName(cmd.Name())
ap.SupportsFlag(sqlClientDualFlag, "d", "Causes this command to spawn a dolt server that is automatically connected to.")
ap.SupportsString(SqlClientQueryFlag, "q", "string", "Sends the given query to the server and immediately exits.")
ap.SupportsString(SqlClientUseDbFlag, "", "db_name", fmt.Sprintf("Selects the given database before executing a query. "+
ap.SupportsString(commands.UseDbFlag, "", "db_name", fmt.Sprintf("Selects the given database before executing a query. "+
"By default, uses the current folder's name. Must be used with the --%s flag.", SqlClientQueryFlag))
ap.SupportsString(sqlClientResultFormat, "", "format", fmt.Sprintf("Returns the results in the given format. Must be used with the --%s flag.", SqlClientQueryFlag))
return ap
Expand Down Expand Up @@ -127,16 +127,16 @@ func (cmd SqlClientCmd) Exec(ctx context.Context, commandStr string, args []stri
cli.PrintErrln(color.RedString(fmt.Sprintf("--%s flag may not be used with --%s", sqlClientDualFlag, SqlClientQueryFlag)))
return 1
}
if apr.Contains(SqlClientUseDbFlag) {
cli.PrintErrln(color.RedString(fmt.Sprintf("--%s flag may not be used with --%s", sqlClientDualFlag, SqlClientUseDbFlag)))
if apr.Contains(commands.UseDbFlag) {
cli.PrintErrln(color.RedString(fmt.Sprintf("--%s flag may not be used with --%s", sqlClientDualFlag, commands.UseDbFlag)))
return 1
}
if apr.Contains(sqlClientResultFormat) {
cli.PrintErrln(color.RedString(fmt.Sprintf("--%s flag may not be used with --%s", sqlClientDualFlag, sqlClientResultFormat)))
return 1
}

serverConfig, err = GetServerConfig(dEnv, apr)
serverConfig, err = GetServerConfig(dEnv.FS, apr)
if err != nil {
cli.PrintErrln(color.RedString("Bad Configuration"))
cli.PrintErrln(err.Error())
Expand All @@ -159,7 +159,7 @@ func (cmd SqlClientCmd) Exec(ctx context.Context, commandStr string, args []stri
return 1
}
} else {
serverConfig, err = GetServerConfig(dEnv, apr)
serverConfig, err = GetServerConfig(dEnv.FS, apr)
if err != nil {
cli.PrintErrln(color.RedString("Bad Configuration"))
cli.PrintErrln(err.Error())
Expand All @@ -168,13 +168,13 @@ func (cmd SqlClientCmd) Exec(ctx context.Context, commandStr string, args []stri
}

query, hasQuery := apr.GetValue(SqlClientQueryFlag)
dbToUse, hasUseDb := apr.GetValue(SqlClientUseDbFlag)
dbToUse, hasUseDb := apr.GetValue(commands.UseDbFlag)
resultFormat, hasResultFormat := apr.GetValue(sqlClientResultFormat)
if !hasQuery && hasUseDb {
cli.PrintErrln(color.RedString(fmt.Sprintf("--%s may only be used with --%s", SqlClientUseDbFlag, SqlClientQueryFlag)))
cli.PrintErrln(color.RedString(fmt.Sprintf("--%s may only be used with --%s", commands.UseDbFlag, SqlClientQueryFlag)))
return 1
} else if !hasQuery && hasResultFormat {
cli.PrintErrln(color.RedString(fmt.Sprintf("--%s may only be used with --%s", SqlClientUseDbFlag, sqlClientResultFormat)))
cli.PrintErrln(color.RedString(fmt.Sprintf("--%s may only be used with --%s", commands.UseDbFlag, sqlClientResultFormat)))
return 1
}
if !hasUseDb && hasQuery {
Expand Down Expand Up @@ -453,3 +453,52 @@ func secondsSince(start time.Time, end time.Time) float64 {
timeDisplay := float64(seconds) + float64(milliRemainder)*.001
return timeDisplay
}

type ConnectionQueryist struct {
connection *dbr.Connection
}

func (c ConnectionQueryist) Query(ctx *sql.Context, query string) (sql.Schema, sql.RowIter, error) {
rows, err := c.connection.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
rowIter, err := NewMysqlRowWrapper(rows)
if err != nil {
return nil, nil, err
}
return rowIter.Schema(), rowIter, nil
}

var _ cli.Queryist = ConnectionQueryist{}

// BuildConnectionStringQueryist returns a Queryist that connects to the server specified by the given server config. Presence in this
// module isn't ideal, but it's the only way to get the server config into the queryist.
func BuildConnectionStringQueryist(ctx context.Context, cwdFS filesys.Filesys, apr *argparser.ArgParseResults, port int, database string) (cli.LateBindQueryist, error) {
serverConfig, err := GetServerConfig(cwdFS, apr)
if err != nil {
return nil, err
}

parsedMySQLConfig, err := mysqlDriver.ParseDSN(ConnectionString(serverConfig, database))
if err != nil {
return nil, err
}

parsedMySQLConfig.Addr = fmt.Sprintf("localhost:%d", port)

mysqlConnector, err := mysqlDriver.NewConnector(parsedMySQLConfig)
if err != nil {
return nil, err
}

conn := &dbr.Connection{DB: mysql.OpenDB(mysqlConnector), EventReceiver: nil, Dialect: dialect.MySQL}

queryist := ConnectionQueryist{connection: conn}

var lateBind cli.LateBindQueryist = func(ctx context.Context) (cli.Queryist, *sql.Context, func(), error) {
return queryist, sql.NewContext(ctx), func() { conn.Conn(ctx) }, nil
}

return lateBind, nil
}
10 changes: 5 additions & 5 deletions go/cmd/dolt/commands/sqlserver/sqlserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func startServer(ctx context.Context, versionStr, commandStr string, args []stri
if err := validateSqlServerArgs(apr); err != nil {
return 1
}
serverConfig, err := GetServerConfig(dEnv, apr)
serverConfig, err := GetServerConfig(dEnv.FS, apr)
if err != nil {
if serverController != nil {
serverController.StopServer()
Expand Down Expand Up @@ -246,16 +246,16 @@ func startServer(ctx context.Context, versionStr, commandStr string, args []stri

// GetServerConfig returns ServerConfig that is set either from yaml file if given, if not it is set with values defined
// on command line. Server config variables not defined are set to default values.
func GetServerConfig(dEnv *env.DoltEnv, apr *argparser.ArgParseResults) (ServerConfig, error) {
func GetServerConfig(cwdFS filesys.Filesys, apr *argparser.ArgParseResults) (ServerConfig, error) {
var yamlCfg YAMLConfig
if cfgFile, ok := apr.GetValue(configFileFlag); ok {
cfg, err := getYAMLServerConfig(dEnv.FS, cfgFile)
cfg, err := getYAMLServerConfig(cwdFS, cfgFile)
if err != nil {
return nil, err
}
yamlCfg = cfg.(YAMLConfig)
} else {
return getCommandLineServerConfig(dEnv, apr)
return getCommandLineServerConfig(apr)
}

// if command line user argument was given, replace yaml's user and password
Expand Down Expand Up @@ -350,7 +350,7 @@ func SetupDoltConfig(dEnv *env.DoltEnv, apr *argparser.ArgParseResults, config S

// getCommandLineServerConfig sets server config variables and persisted global variables with values defined on command line.
// If not defined, it sets variables to default values.
func getCommandLineServerConfig(dEnv *env.DoltEnv, apr *argparser.ArgParseResults) (ServerConfig, error) {
func getCommandLineServerConfig(apr *argparser.ArgParseResults) (ServerConfig, error) {
serverConfig := DefaultServerConfig()

if sock, ok := apr.GetValue(socketFlag); ok {
Expand Down
Loading

0 comments on commit 82cc919

Please sign in to comment.