Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redo signal handling #854

Merged
merged 4 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 6 additions & 45 deletions internal/cmd/shell/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"net"
"os"
"os/signal"
"path/filepath"
"time"

Expand Down Expand Up @@ -272,58 +271,20 @@ func (m *mysql) Run(ctx context.Context, sigc chan os.Signal, signals []os.Signa
c.Stderr = os.Stderr
c.Stdin = os.Stdin

// Set up a new channel for signals received while MySQL is active.
// This is registered for all signals so we forward them all to MySQL,
// so we behave as much as possible like a regular MySQL.
// When we exit this function, we stop the custom signal receiver.
msig := make(chan os.Signal, 1)
signal.Notify(msig)
defer signal.Stop(msig)

// We stop handling signals for our default setup from the CLI. This
// is needed, so we stop handling for example the default os.Interrupt
// that stops the shell and we forward it to MySQL.
// When we exit from this function, we restore the signals as they were.
signal.Stop(sigc)
defer signal.Notify(sigc, signals...)

err := c.Start()
if err != nil {
return err
c.SysProcAttr = sysProcAttr()
cancel := setupSignals(ctx, c, sigc, signals)
if cancel != nil {
defer cancel()
}

wait := make(chan error, 1)
go func() {
wait <- c.Wait()
close(wait)
}()

for {
select {
case sig := <-msig:
if err := c.Process.Signal(sig); err != nil {
// If we failed to send a signal to the process, just in case
// it's still alive, make sure we kill it.
_ = c.Process.Signal(os.Kill)
return err
}
case err := <-wait:
if err != nil {
// If we failed to wait for the process, just in case
// we send a hard kill to ensure the MySQL subprocess
// gets killed.
c.Process.Signal(os.Kill)
}
return err
}
}
return c.Run()
}

func formatMySQLBranch(database string, branch *ps.DatabaseBranch) string {
branchStr := branch.Name

if branch.Production {
branchStr = fmt.Sprintf("| %s |", branch.Name)
branchStr = fmt.Sprintf("|%s %s %s|", warnSign, branch.Name, warnSign)
}

return fmt.Sprintf("%s/%s> ", database, branchStr)
Expand Down
29 changes: 29 additions & 0 deletions internal/cmd/shell/shell_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
//go:build !windows

package shell

import (
"context"
"os"
"os/exec"
"syscall"
)

// warnSign shows the warning signal for prod branches.
const warnSign = "⚠"

// sysProcAttr returns the attributes for starting the process
// that are platform specific. We set the Foreground flag for unix
// like platforms, which means the new process gets its own process
// group and runs on the foreground. This ensures proper signal handling.
func sysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{
Foreground: true,
}
}

// setupSignals does not need to do any work since the foreground
// logic works well for these cases.
func setupSignals(_ context.Context, _ *exec.Cmd, _ chan os.Signal, _ []os.Signal) func() {
return nil
}
58 changes: 58 additions & 0 deletions internal/cmd/shell/shell_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
//go:build windows

package shell

import (
"context"
"os"
"os/exec"
"os/signal"
"syscall"
)

// warnSign shows the warning signal for prod branches. Windows
// doesn't handle Unicode characters well here, so fall back to ASCII.
const warnSign = "!"

// sysProcAttr returns the attributes for starting the process
// that are platform specific. On Windows no additional flags are set.
func sysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{}
}

// setupSignals handles setup for signals. On Windows we need to forward
// signals since there's otherwise no good way to foreground an independent
// MySQL shell. CREATE_NEW_CONSOLE signal wise is what we want, but we don't
// want to open up a new window but use the existing one. CREATE_NEW_PROCESS_GROUP
// doesn't handle Ctrl+c in the way that we want.
func setupSignals(ctx context.Context, c *exec.Cmd, sigc chan os.Signal, signals []os.Signal) func() {
// Set up a new channel for signals received while MySQL is active.
// This is registered for all signals, so we forward them all to MySQL,
// so we behave as much as possible as a regular MySQL.
// When we exit this function, we stop the custom signal receiver.
msig := make(chan os.Signal, 1)
signal.Notify(msig)

// We stop handling signals for our default setup from the CLI. This
// is needed, so we stop handling for example the default os.Interrupt
// that stops the shell and we forward it to MySQL.
// When we exit from this function, we restore the signals as they were.
signal.Stop(sigc)

go func() {
for {
select {
case sig := <-msig:
_ = c.Process.Signal(sig)
case <-ctx.Done():
return
}
}

}()

return func() {
signal.Stop(msig)
signal.Notify(sigc, signals...)
}
}