Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extended ExecuteSelectStreaming #596

Merged
merged 7 commits into from
Nov 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ err := conn.ExecuteSelectStreaming(`select id, name from table LIMIT 100500`, &r
// Copy it if you need.
// ...
}
return false, nil
})
return nil
}, nil)

// ...
```
Expand Down
10 changes: 7 additions & 3 deletions client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ type Conn struct {
// This function will be called for every row in resultset from ExecuteSelectStreaming.
type SelectPerRowCallback func(row []FieldValue) error

// This function will be called once per result from ExecuteSelectStreaming
type SelectPerResultCallback func(result *Result) error

func getNetProto(addr string) string {
proto := "tcp"
if strings.Contains(addr, "/") {
Expand Down Expand Up @@ -183,6 +186,7 @@ func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) {

// ExecuteSelectStreaming will call perRowCallback for every row in resultset
// WITHOUT saving any row data to Result.{Values/RawPkg/RowDatas} fields.
// When given, perResultCallback will be called once per result
//
// ExecuteSelectStreaming should be used only for SELECT queries with a large response resultset for memory preserving.
//
Expand All @@ -193,14 +197,14 @@ func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) {
// // Use the row as you want.
// // You must not save FieldValue.AsString() value after this callback is done. Copy it if you need.
// return nil
// })
// }, nil)
//
func (c *Conn) ExecuteSelectStreaming(command string, result *Result, perRowCallback SelectPerRowCallback) error {
func (c *Conn) ExecuteSelectStreaming(command string, result *Result, perRowCallback SelectPerRowCallback, perResultCallback SelectPerResultCallback) error {
if err := c.writeCommandStr(COM_QUERY, command); err != nil {
return errors.Trace(err)
}

return c.readResultStreaming(false, result, perRowCallback)
return c.readResultStreaming(false, result, perRowCallback, perResultCallback)
}

func (c *Conn) Begin() error {
Expand Down
18 changes: 15 additions & 3 deletions client/resp.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ func (c *Conn) readResult(binary bool) (*Result, error) {
return c.readResultset(firstPkgBuf, binary)
}

func (c *Conn) readResultStreaming(binary bool, result *Result, perRowCb SelectPerRowCallback) error {
func (c *Conn) readResultStreaming(binary bool, result *Result, perRowCb SelectPerRowCallback, perResCb SelectPerResultCallback) error {
firstPkgBuf, err := c.ReadPacketReuseMem(utils.ByteSliceGet(16)[:0])
defer utils.ByteSlicePut(firstPkgBuf)

Expand Down Expand Up @@ -267,7 +267,7 @@ func (c *Conn) readResultStreaming(binary bool, result *Result, perRowCb SelectP
return ErrMalformPacket
}

return c.readResultsetStreaming(firstPkgBuf, binary, result, perRowCb)
return c.readResultsetStreaming(firstPkgBuf, binary, result, perRowCb, perResCb)
}

func (c *Conn) readResultset(data []byte, binary bool) (*Result, error) {
Expand All @@ -293,7 +293,7 @@ func (c *Conn) readResultset(data []byte, binary bool) (*Result, error) {
return result, nil
}

func (c *Conn) readResultsetStreaming(data []byte, binary bool, result *Result, perRowCb SelectPerRowCallback) error {
func (c *Conn) readResultsetStreaming(data []byte, binary bool, result *Result, perRowCb SelectPerRowCallback, perResCb SelectPerResultCallback) error {
columnCount, _, n := LengthEncodedInt(data)

if n-len(data) != 0 {
Expand All @@ -307,14 +307,26 @@ func (c *Conn) readResultsetStreaming(data []byte, binary bool, result *Result,
result.Reset(int(columnCount))
}

// this is a streaming resultset
result.Resultset.Streaming = true

if err := c.readResultColumns(result); err != nil {
return errors.Trace(err)
}

if perResCb != nil {
if err := perResCb(result); err != nil {
return err
}
}

if err := c.readResultRowsStreaming(result, binary, perRowCb); err != nil {
return errors.Trace(err)
}

// this resultset is done streaming
result.Resultset.StreamingDone = true

return nil
}

Expand Down
8 changes: 8 additions & 0 deletions client/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ func (s *Stmt) Execute(args ...interface{}) (*Result, error) {
return s.conn.readResult(true)
}

func (s *Stmt) ExecuteSelectStreaming(result *Result, perRowCb SelectPerRowCallback, perResCb SelectPerResultCallback, args ...interface{}) error {
if err := s.write(args...); err != nil {
return errors.Trace(err)
}

return s.conn.readResultStreaming(true, result, perRowCb, perResCb)
}

func (s *Stmt) Close() error {
if err := s.conn.writeCommandUint32(COM_STMT_CLOSE, s.id); err != nil {
return errors.Trace(err)
Expand Down
3 changes: 3 additions & 0 deletions mysql/resultset.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ type Resultset struct {
RawPkg []byte

RowDatas []RowData

Streaming bool
StreamingDone bool
}

var (
Expand Down
4 changes: 2 additions & 2 deletions mysql/resultset_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/siddontang/go/hack"
)

func formatTextValue(value interface{}) ([]byte, error) {
func FormatTextValue(value interface{}) ([]byte, error) {
switch v := value.(type) {
case int8:
return strconv.AppendInt(nil, int64(v), 10), nil
Expand Down Expand Up @@ -165,7 +165,7 @@ func BuildSimpleTextResultset(names []string, values [][]interface{}) (*Resultse
return nil, errors.Errorf("row types aren't consistent")
}
}
b, err = formatTextValue(value)
b, err = FormatTextValue(value)

if err != nil {
return nil, errors.Trace(err)
Expand Down
2 changes: 1 addition & 1 deletion server/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (c *Conn) HandleCommand() error {

v := c.dispatch(data)

err = c.writeValue(v)
err = c.WriteValue(v)

if c.Conn != nil {
c.ResetSequence()
Expand Down
30 changes: 29 additions & 1 deletion server/resp.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ func (c *Conn) writeAuthMoreDataFastAuth() error {
}

func (c *Conn) writeResultset(r *Resultset) error {
// for a streaming resultset, that handled rowdata separately in a callback
// of type SelectPerRowCallback, we can suffice by ending the stream with
// an EOF
if r.StreamingDone {
return c.writeEOF()
}

columnLen := PutLengthEncodedInt(uint64(len(r.Fields)))

data := make([]byte, 4, 1024)
Expand All @@ -129,6 +136,12 @@ func (c *Conn) writeResultset(r *Resultset) error {
return err
}

// streaming resultsets handle rowdata in a separate callback of type
// SelectPerRowCallback so we're done here
if r.Streaming {
return nil
}

for _, v := range r.RowDatas {
data = data[0:4]
data = append(data, v...)
Expand Down Expand Up @@ -163,10 +176,23 @@ func (c *Conn) writeFieldList(fs []*Field, data []byte) error {
return nil
}

func (c *Conn) writeFieldValues(fv []FieldValue) error {
data := make([]byte, 4, 1024)
for _, v := range fv {
tv, err := FormatTextValue(v.Value())
if err != nil {
return err
}
data = append(data, PutLengthEncodedString(tv)...)
}

return c.WritePacket(data)
}

type noResponse struct{}
type eofResponse struct{}

func (c *Conn) writeValue(value interface{}) error {
func (c *Conn) WriteValue(value interface{}) error {
switch v := value.(type) {
case noResponse:
return nil
Expand All @@ -184,6 +210,8 @@ func (c *Conn) writeValue(value interface{}) error {
}
case []*Field:
return c.writeFieldList(v, nil)
case []FieldValue:
return c.writeFieldValues(v)
case *Stmt:
return c.writePrepare(v)
default:
Expand Down