Skip to content

Commit

Permalink
Cleanup: fixing buffer reuse issues and incorrect log statements
Browse files Browse the repository at this point in the history
  • Loading branch information
fulghum committed Mar 3, 2025
1 parent c4f6bba commit dcbcb5d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 15 deletions.
29 changes: 15 additions & 14 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ func (c *Conn) readEphemeralPacket(ctx context.Context) ([]byte, error) {

// readEphemeralPacketDirect attempts to read a packet from the socket directly.
// It needs to be used for the first handshake packet the server receives,
// so we do't buffer the SSL negotiation packet. As a shortcut, only
// so we don't buffer the SSL negotiation packet. As a shortcut, only
// packets smaller than MaxPacketSize can be read here.
// This function usually shouldn't be used - use readEphemeralPacket.
func (c *Conn) readEphemeralPacketDirect(ctx context.Context) ([]byte, error) {
Expand Down Expand Up @@ -1099,7 +1099,7 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error {
if c.cs != nil {
log.Error("Received ComStmtPrepare with outstanding cursor")
if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); werr != nil {
log.Error("Error writing error packet to client: %v", werr)
log.Errorf("Error writing error packet to client: %v", werr)
return werr
}
return nil
Expand Down Expand Up @@ -1181,7 +1181,7 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error {
log.Errorf("unable to prepare query: %s", err.Error())
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr)
log.Errorf("Error writing query error to client %v: %v", c.ConnectionID, werr)
return werr
}
return nil
Expand All @@ -1195,7 +1195,7 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error {
if c.cs != nil {
log.Error("Received ComStmtExecute with outstanding cursor")
if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); werr != nil {
log.Error("Error writing error packet to client: %v", werr)
log.Errorf("Error writing error packet to client: %v", werr)
return werr
}
return nil
Expand All @@ -1211,7 +1211,7 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error {
if err != nil {
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr)
log.Errorf("Error writing query error to client %v: %v", c.ConnectionID, werr)
return werr
}
return c.flush(ctx)
Expand Down Expand Up @@ -1276,18 +1276,18 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error {
stmtID, ok := c.parseComStmtReset(data)
c.recycleReadPacket()
if !ok {
log.Error("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data)
log.Errorf("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data)
if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil {
log.Error("Error writing error packet to client: %v", err)
log.Errorf("Error writing error packet to client: %v", err)
return err
}
}

prepare, ok := c.PrepareData[stmtID]
if !ok {
log.Error("Commands were executed in an improper order from client %v, packet: %v", c.ConnectionID, data)
log.Errorf("Commands were executed in an improper order from client %v, packet: %v", c.ConnectionID, data)
if werr := c.writeErrorPacket(CRCommandsOutOfSync, SSUnknownComError, "commands were executed in an improper order: %v", data); werr != nil {
log.Error("Error writing error packet to client: %v", err)
log.Errorf("Error writing error packet to client: %v", err)
return werr
}
}
Expand All @@ -1301,17 +1301,18 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error {
c.discardCursor()

if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil {
log.Error("Error writing ComStmtReset OK packet to client %v: %v", c.ConnectionID, err)
log.Errorf("Error writing ComStmtReset OK packet to client %v: %v", c.ConnectionID, err)
return err
}
case ComStmtFetch:
c.startWriterBuffering()
stmtID, numRows, ok := c.parseComStmtFetch(data)
c.recycleReadPacket()
if !ok {
log.Error("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data)
if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); werr != nil {
log.Error("Error writing error packet to client: %v", werr)
log.Errorf("Unable to parse COM_STMT_FETCH message on connection %v", c.ConnectionID)
if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError,
"unable to parse COM_STMT_FETCH message on connection %v", c.ConnectionID); werr != nil {
log.Errorf("Error writing error packet to client: %v", werr)
return werr
}
return c.flush(ctx)
Expand All @@ -1321,7 +1322,7 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error {
if c.cs == nil || stmtID != c.cs.stmtID {
log.Errorf("Requested stmtID does not match stmtID of open cursor. Client %v, returning error: %v", c.ConnectionID, data)
if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); werr != nil {
log.Error("Error writing error packet to client: %v", err)
log.Errorf("Error writing error packet to client: %v", err)
return werr
}
return c.flush(ctx)
Expand Down
9 changes: 8 additions & 1 deletion go/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,12 +473,19 @@ func (l *Listener) handle(ctx context.Context, conn net.Conn, connectionID uint3
return
}

clientAuthResponse, err = c.readEphemeralPacket(context.Background())
data, err := c.readEphemeralPacket(context.Background())
if err != nil {
l.handleConnectionError(c, fmt.Sprintf("Error reading auth switch response for %s: %v", c, err))
return
}

var ok bool
clientAuthResponse, _, ok = readBytesCopy(data, 0, len(data))
c.recycleReadPacket()
if !ok {
l.handleConnectionError(c, fmt.Sprintf("Unable to copy client auth response for %s", c))
return
}
}

userData, err := negotiatedAuthMethod.HandleAuthPluginData(c, user, serverAuthPluginData, clientAuthResponse, conn.RemoteAddr())
Expand Down

0 comments on commit dcbcb5d

Please sign in to comment.