diff --git a/main.go b/main.go index dac72e0..f745ab3 100644 --- a/main.go +++ b/main.go @@ -2,14 +2,17 @@ package main import ( "bufio" + "context" "flag" "fmt" "io" "log" "os" "os/exec" + "os/signal" "strings" "sync" + "syscall" ) type database struct { @@ -61,6 +64,9 @@ func main() { out := make(chan string) go println(out) + quitContext, cancel := context.WithCancel(context.Background()) + go awaitSignal(cancel) + var wg sync.WaitGroup wg.Add(len(targetDatabases)) @@ -68,7 +74,7 @@ func main() { for _, k := range targetDatabases { go func(db database, k string) { defer wg.Done() - if r := runSQL(db, sql, k, len(targetDatabases) > 1, out); !r { + if r := runSQL(quitContext, db, sql, k, len(targetDatabases) > 1, out); !r { returnCode = 1 } }(databases[k], k) @@ -78,7 +84,7 @@ func main() { os.Exit(returnCode) } -func runSQL(db database, sql string, key string, prependKey bool, out chan string) bool { +func runSQL(quitContext context.Context, db database, sql string, key string, prependKey bool, out chan string) bool { userOption := "" if db.User != "" { userOption = fmt.Sprintf("-u %v ", db.User) @@ -105,10 +111,10 @@ func runSQL(db database, sql string, key string, prependKey bool, out chan strin var cmd *exec.Cmd if db.AppServer != "" { query := fmt.Sprintf(`'%v'`, strings.Replace(sql, `'`, `'"'"'`, -1)) - cmd = exec.Command("ssh", db.AppServer, mysql+options+query) + cmd = exec.CommandContext(quitContext, "ssh", db.AppServer, mysql+options+query) } else { args := append(trimEmpty(strings.Split(options, " ")), sql) - cmd = exec.Command("mysql", args...) + cmd = exec.CommandContext(quitContext, "mysql", args...) } stdout, err := cmd.StdoutPipe() @@ -187,3 +193,10 @@ func trimEmpty(s []string) []string { } return r } + +func awaitSignal(cancel context.CancelFunc) { + signals := make(chan os.Signal) + signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) + <-signals + cancel() +}