Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm committed May 3, 2024
1 parent 18d753b commit 200d0b1
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 53 deletions.
46 changes: 11 additions & 35 deletions go/adbc/driver/snowflake/binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"database/sql"
"database/sql/driver"
"fmt"
"io"

"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow/go/v17/arrow"
Expand Down Expand Up @@ -113,56 +112,33 @@ func (r *snowflakeBindReader) Release() {
}

func (r *snowflakeBindReader) Next() (array.RecordReader, error) {
for r.currentBatch == nil || r.nextIndex >= r.currentBatch.NumRows() {
if r.stream != nil && r.stream.Next() {
if r.currentBatch != nil {
r.currentBatch.Release()
}
r.currentBatch = r.stream.Record()
r.nextIndex = 0
continue
} else if r.stream != nil && r.stream.Err() != nil {
return nil, r.stream.Err()
} else {
// end-of-stream
return nil, nil
}
}

params, err := convertArrowToNamedValue(r.currentBatch, int(r.nextIndex))
params, err := r.NextParams()
if err != nil {
return nil, err
} else if params == nil {
// end-of-stream
return nil, nil
}
r.nextIndex++

return r.doQuery(params)
}

func (r *snowflakeBindReader) NextUpdate() error {
func (r *snowflakeBindReader) NextParams() ([]driver.NamedValue, error) {
for r.currentBatch == nil || r.nextIndex >= r.currentBatch.NumRows() {
if r.currentBatch != nil {
r.currentBatch.Release()
}
if r.stream != nil && r.stream.Next() {
if r.currentBatch != nil {
r.currentBatch.Release()
}
r.currentBatch = r.stream.Record()
r.nextIndex = 0
continue
} else if r.stream != nil && r.stream.Err() != nil {
return r.stream.Err()
return nil, r.stream.Err()
} else {
return io.EOF
return nil, nil
}
}

params, err := convertArrowToNamedValue(r.currentBatch, int(r.nextIndex))
if err != nil {
return err
}
r.nextIndex++

_, err = r.doQuery(params)
if err != nil {
return err
}
return nil
return params, err
}
34 changes: 16 additions & 18 deletions go/adbc/driver/snowflake/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"context"
"database/sql/driver"
"fmt"
"io"
"strconv"
"strings"

Expand Down Expand Up @@ -514,19 +513,6 @@ func (st *statement) ExecuteUpdate(ctx context.Context) (int64, error) {
if st.streamBind != nil || st.bound != nil {
numRows := int64(0)
bind := snowflakeBindReader{
doQuery: func(params []driver.NamedValue) (array.RecordReader, error) {
r, err := st.cnxn.cn.ExecContext(ctx, st.query, params)
if err != nil {
return nil, errToAdbcErr(adbc.StatusInternal, err)
}
n, err := r.RowsAffected()
if err != nil {
numRows = -1
} else if numRows >= 0 {
numRows += n
}
return nil, nil
},
currentBatch: st.bound,
stream: st.streamBind,
}
Expand All @@ -535,13 +521,25 @@ func (st *statement) ExecuteUpdate(ctx context.Context) (int64, error) {

defer bind.Release()
for {
err := bind.NextUpdate()
if err == io.EOF {
return numRows, nil
} else if err != nil {
params, err := bind.NextParams()
if err != nil {
return -1, err
} else if params == nil {
break
}

r, err := st.cnxn.cn.ExecContext(ctx, st.query, params)
if err != nil {
return -1, errToAdbcErr(adbc.StatusInternal, err)
}
n, err := r.RowsAffected()
if err != nil {
numRows = -1
} else if numRows >= 0 {
numRows += n
}
}
return numRows, nil
}

r, err := st.cnxn.cn.ExecContext(ctx, st.query, nil)
Expand Down

0 comments on commit 200d0b1

Please sign in to comment.