diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 093b474fb2b..6304f7c7388 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -421,7 +421,7 @@ func (c *Conn) readEphemeralPacket(ctx context.Context) ([]byte, error) { // readEphemeralPacketDirect attempts to read a packet from the socket directly. // It needs to be used for the first handshake packet the server receives, -// so we do't buffer the SSL negotiation packet. As a shortcut, only +// so we don't buffer the SSL negotiation packet. As a shortcut, only // packets smaller than MaxPacketSize can be read here. // This function usually shouldn't be used - use readEphemeralPacket. func (c *Conn) readEphemeralPacketDirect(ctx context.Context) ([]byte, error) { @@ -1099,7 +1099,7 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error { if c.cs != nil { log.Error("Received ComStmtPrepare with outstanding cursor") if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); werr != nil { - log.Error("Error writing error packet to client: %v", werr) + log.Errorf("Error writing error packet to client: %v", werr) return werr } return nil @@ -1181,7 +1181,7 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error { log.Errorf("unable to prepare query: %s", err.Error()) if werr := c.writeErrorPacketFromError(err); werr != nil { // If we can't even write the error, we're done. - log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) + log.Errorf("Error writing query error to client %v: %v", c.ConnectionID, werr) return werr } return nil @@ -1195,7 +1195,7 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error { if c.cs != nil { log.Error("Received ComStmtExecute with outstanding cursor") if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); werr != nil { - log.Error("Error writing error packet to client: %v", werr) + log.Errorf("Error writing error packet to client: %v", werr) return werr } return nil @@ -1211,7 +1211,7 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error { if err != nil { if werr := c.writeErrorPacketFromError(err); werr != nil { // If we can't even write the error, we're done. - log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) + log.Errorf("Error writing query error to client %v: %v", c.ConnectionID, werr) return werr } return c.flush(ctx) @@ -1276,18 +1276,18 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error { stmtID, ok := c.parseComStmtReset(data) c.recycleReadPacket() if !ok { - log.Error("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data) + log.Errorf("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data) if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil { - log.Error("Error writing error packet to client: %v", err) + log.Errorf("Error writing error packet to client: %v", err) return err } } prepare, ok := c.PrepareData[stmtID] if !ok { - log.Error("Commands were executed in an improper order from client %v, packet: %v", c.ConnectionID, data) + log.Errorf("Commands were executed in an improper order from client %v, packet: %v", c.ConnectionID, data) if werr := c.writeErrorPacket(CRCommandsOutOfSync, SSUnknownComError, "commands were executed in an improper order: %v", data); werr != nil { - log.Error("Error writing error packet to client: %v", err) + log.Errorf("Error writing error packet to client: %v", err) return werr } } @@ -1301,7 +1301,7 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error { c.discardCursor() if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil { - log.Error("Error writing ComStmtReset OK packet to client %v: %v", c.ConnectionID, err) + log.Errorf("Error writing ComStmtReset OK packet to client %v: %v", c.ConnectionID, err) return err } case ComStmtFetch: @@ -1309,9 +1309,10 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error { stmtID, numRows, ok := c.parseComStmtFetch(data) c.recycleReadPacket() if !ok { - log.Error("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data) - if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); werr != nil { - log.Error("Error writing error packet to client: %v", werr) + log.Errorf("Unable to parse COM_STMT_FETCH message on connection %v", c.ConnectionID) + if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, + "unable to parse COM_STMT_FETCH message on connection %v", c.ConnectionID); werr != nil { + log.Errorf("Error writing error packet to client: %v", werr) return werr } return c.flush(ctx) @@ -1321,7 +1322,7 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error { if c.cs == nil || stmtID != c.cs.stmtID { log.Errorf("Requested stmtID does not match stmtID of open cursor. Client %v, returning error: %v", c.ConnectionID, data) if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); werr != nil { - log.Error("Error writing error packet to client: %v", err) + log.Errorf("Error writing error packet to client: %v", err) return werr } return c.flush(ctx) diff --git a/go/mysql/server.go b/go/mysql/server.go index 3fed025dd29..f6be9351afd 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -473,12 +473,19 @@ func (l *Listener) handle(ctx context.Context, conn net.Conn, connectionID uint3 return } - clientAuthResponse, err = c.readEphemeralPacket(context.Background()) + data, err := c.readEphemeralPacket(context.Background()) if err != nil { l.handleConnectionError(c, fmt.Sprintf("Error reading auth switch response for %s: %v", c, err)) return } + + var ok bool + clientAuthResponse, _, ok = readBytesCopy(data, 0, len(data)) c.recycleReadPacket() + if !ok { + l.handleConnectionError(c, fmt.Sprintf("Unable to copy client auth response for %s", c)) + return + } } userData, err := negotiatedAuthMethod.HandleAuthPluginData(c, user, serverAuthPluginData, clientAuthResponse, conn.RemoteAddr())