Skip to content

Commit

Permalink
SNOW-1313648 GO - Verify value bindings for all field types while exc…
Browse files Browse the repository at this point in the history
…eeding CLIENT_STAGE_ARRAY_BINDING_THRESHOLD (#1297)
  • Loading branch information
sfc-gh-ext-simba-jy authored Feb 13, 2025
1 parent 580e7e8 commit 17db71f
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 48 deletions.
33 changes: 24 additions & 9 deletions bind_uploader.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,10 @@ func (bu *bindUploader) buildRowsAsBytes(columns []driver.NamedValue) ([][]byte,
}).exceptionTelemetry(bu.sc)
}

_, column := snowflakeArrayToString(&columns[0], true)
_, column, err := snowflakeArrayToString(&columns[0], true)
if err != nil {
return nil, err
}
numRows := len(column)
csvRows := make([][]byte, 0)
rows := make([][]interface{}, 0)
Expand All @@ -152,7 +155,10 @@ func (bu *bindUploader) buildRowsAsBytes(columns []driver.NamedValue) ([][]byte,
}
}
for colIdx := 1; colIdx < numColumns; colIdx++ {
_, column = snowflakeArrayToString(&columns[colIdx], true)
_, column, err = snowflakeArrayToString(&columns[colIdx], true)
if err != nil {
return nil, err
}
iNumRows := len(column)
if iNumRows != numRows {
return nil, (&SnowflakeError{
Expand Down Expand Up @@ -201,7 +207,10 @@ func (sc *snowflakeConn) processBindings(
requestID UUID,
req *execRequest) error {
arrayBindThreshold := sc.getArrayBindStageThreshold()
numBinds := arrayBindValueCount(bindings)
numBinds, err := arrayBindValueCount(bindings)
if err != nil {
return err
}
if 0 < arrayBindThreshold && arrayBindThreshold <= numBinds && !describeOnly && isArrayBind(bindings) {
uploader := bindUploader{
sc: sc,
Expand All @@ -215,7 +224,6 @@ func (sc *snowflakeConn) processBindings(
req.Bindings = nil
req.BindStage = uploader.stagePath
} else {
var err error
req.Bindings, err = getBindValues(bindings, sc.cfg.Params)
if err != nil {
return err
Expand Down Expand Up @@ -246,7 +254,10 @@ func getBindValues(bindings []driver.NamedValue, params map[string]*string) (map
var bv bindingValue
if t == sliceType {
// retrieve array binding data
t, val = snowflakeArrayToString(&binding, false)
t, val, err = snowflakeArrayToString(&binding, false)
if err != nil {
return nil, err
}
} else {
bv, err = valueToString(binding.Value, tsmode, params)
val = bv.value
Expand Down Expand Up @@ -280,12 +291,16 @@ func bindingName(nv driver.NamedValue, idx int) string {
return strconv.Itoa(idx)
}

func arrayBindValueCount(bindValues []driver.NamedValue) int {
func arrayBindValueCount(bindValues []driver.NamedValue) (int, error) {
if !isArrayBind(bindValues) {
return 0
return 0, nil
}
_, arr := snowflakeArrayToString(&bindValues[0], false)
return len(bindValues) * len(arr)
_, arr, err := snowflakeArrayToString(&bindValues[0], false)
if err != nil {
return 0, err
}

return len(bindValues) * len(arr), nil
}

func isArrayBind(bindings []driver.NamedValue) bool {
Expand Down
132 changes: 132 additions & 0 deletions bindings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,130 @@ func TestBulkArrayBinding(t *testing.T) {
})
}

func TestBindingsWithSameValue(t *testing.T) {
arrayInsertTable := "test_array_binding_insert"
stageBindingTable := "test_stage_binding_insert"
interfaceArrayTable := "test_interface_binding_insert"

runDBTest(t, func(dbt *DBTest) {
dbt.mustExec(fmt.Sprintf("create or replace table %v (c1 integer, c2 string, c3 timestamp_ltz, c4 timestamp_tz, c5 timestamp_ntz, c6 date, c7 time, c9 boolean, c10 double)", arrayInsertTable))
dbt.mustExec(fmt.Sprintf("create or replace table %v (c1 integer, c2 string, c3 timestamp_ltz, c4 timestamp_tz, c5 timestamp_ntz, c6 date, c7 time, c9 boolean, c10 double)", stageBindingTable))
dbt.mustExec(fmt.Sprintf("create or replace table %v (c1 integer, c2 string, c3 timestamp_ltz, c4 timestamp_tz, c5 timestamp_ntz, c6 date, c7 time, c9 boolean, c10 double)", interfaceArrayTable))

defer func() {
dbt.mustExec(fmt.Sprintf("drop table if exists %v", arrayInsertTable))
dbt.mustExec(fmt.Sprintf("drop table if exists %v", stageBindingTable))
dbt.mustExec(fmt.Sprintf("drop table if exists %v", interfaceArrayTable))
}()

numRows := 5

intArr := make([]int, numRows)
strArr := make([]string, numRows)
timeArr := make([]time.Time, numRows)
boolArr := make([]bool, numRows)
doubleArr := make([]float64, numRows)

intAnyArr := make([]any, numRows)
strAnyArr := make([]any, numRows)
timeAnyArr := make([]any, numRows)
boolAnyArr := make([]bool, numRows)
doubleAnyArr := make([]float64, numRows)

for i := 0; i < numRows; i++ {
intArr[i] = i
intAnyArr[i] = i

double := rand.Float64()
doubleArr[i] = double
doubleAnyArr[i] = double

strArr[i] = "test" + strconv.Itoa(i)
strAnyArr[i] = "test" + strconv.Itoa(i)

b := getRandomBool()
boolArr[i] = b
boolAnyArr[i] = b

date := getRandomDate()
timeArr[i] = date
timeAnyArr[i] = date
}

dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?, ?, ?, ?, ?, ?, ?, ?)", interfaceArrayTable), Array(&intAnyArr), Array(&strAnyArr), Array(&timeAnyArr, TimestampLTZType), Array(&timeAnyArr, TimestampTZType), Array(&timeAnyArr, TimestampNTZType), Array(&timeAnyArr, DateType), Array(&timeAnyArr, TimeType), Array(&boolArr), Array(&doubleArr))
dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?, ?, ?, ?, ?, ?, ?, ?)", arrayInsertTable), Array(&intArr), Array(&strArr), Array(&timeArr, TimestampLTZType), Array(&timeArr, TimestampTZType), Array(&timeArr, TimestampNTZType), Array(&timeArr, DateType), Array(&timeArr, TimeType), Array(&boolArr), Array(&doubleArr))
dbt.mustExec("ALTER SESSION SET CLIENT_STAGE_ARRAY_BINDING_THRESHOLD = 1")
dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?, ?, ?, ?, ?, ?, ?, ?)", stageBindingTable), Array(&intArr), Array(&strArr), Array(&timeArr, TimestampLTZType), Array(&timeArr, TimestampTZType), Array(&timeArr, TimestampNTZType), Array(&timeArr, DateType), Array(&timeArr, TimeType), Array(&boolArr), Array(&doubleArr))

insertRows := dbt.mustQuery("select * from " + arrayInsertTable + " order by c1")
bindingRows := dbt.mustQuery("select * from " + stageBindingTable + " order by c1")
interfaceRows := dbt.mustQuery("select * from " + interfaceArrayTable + " order by c1")

defer func() {
assertNilF(t, insertRows.Close())
assertNilF(t, bindingRows.Close())
assertNilF(t, interfaceRows.Close())
}()
var i, bi, ii int
var s, bs, is string
var ltz, bltz, iltz, itz, btz, tz, intz, ntz, bntz, iDate, date, bDate, itt, tt, btt time.Time
var b, bb, ib bool
var d, bd, id float64

timeFormat := "15:04:05"
for k := 0; k < numRows; k++ {
assertTrueF(t, insertRows.Next())
assertNilF(t, insertRows.Scan(&i, &s, &ltz, &tz, &ntz, &date, &tt, &b, &d))

assertTrueF(t, bindingRows.Next())
assertNilF(t, bindingRows.Scan(&bi, &bs, &bltz, &btz, &bntz, &bDate, &btt, &bb, &bd))

assertTrueF(t, interfaceRows.Next())
assertNilF(t, interfaceRows.Scan(&ii, &is, &iltz, &itz, &intz, &iDate, &itt, &ib, &id))

assertEqualE(t, k, i)
assertEqualE(t, k, bi)
assertEqualE(t, k, ii)

assertEqualE(t, "test"+strconv.Itoa(k), s)
assertEqualE(t, "test"+strconv.Itoa(k), bs)
assertEqualE(t, "test"+strconv.Itoa(k), is)

utcTime := timeArr[k].UTC()
assertEqualE(t, ltz.UTC(), utcTime)
assertEqualE(t, bltz.UTC(), utcTime)
assertEqualE(t, iltz.UTC(), utcTime)

assertEqualE(t, tz.UTC(), utcTime)
assertEqualE(t, btz.UTC(), utcTime)
assertEqualE(t, itz.UTC(), utcTime)

assertEqualE(t, ntz.UTC(), utcTime)
assertEqualE(t, bntz.UTC(), utcTime)
assertEqualE(t, intz.UTC(), utcTime)

testingDate := timeArr[k].Truncate(24 * time.Hour)
assertEqualE(t, date, testingDate)
assertEqualE(t, bDate, testingDate)
assertEqualE(t, iDate, testingDate)

testingTime := timeArr[k].Format(timeFormat)
assertEqualE(t, tt.Format(timeFormat), testingTime)
assertEqualE(t, btt.Format(timeFormat), testingTime)
assertEqualE(t, itt.Format(timeFormat), testingTime)

assertEqualE(t, b, boolArr[k])
assertEqualE(t, bb, boolArr[k])
assertEqualE(t, ib, boolArr[k])

assertEqualE(t, d, doubleArr[k])
assertEqualE(t, bd, doubleArr[k])
assertEqualE(t, id, doubleArr[k])

}
})
}

func TestBulkArrayBindingTimeWithPrecision(t *testing.T) {
runDBTest(t, func(dbt *DBTest) {
dbt.mustExec(fmt.Sprintf("create or replace table %v (s time(0), ms time(3), us time(6), ns time(9))", dbname))
Expand Down Expand Up @@ -1423,3 +1547,11 @@ func testInsertLOBData(t *testing.T, useArrowFormat bool, isLiteral bool) {
dbt.mustExec(unsetFeatureMaxLOBSize)
})
}

func getRandomDate() time.Time {
return time.Date(rand.Intn(1582)+1, time.January, rand.Intn(40), rand.Intn(40), rand.Intn(40), rand.Intn(40), rand.Intn(40), time.UTC)
}

func getRandomBool() bool {
return rand.Int63n(time.Now().Unix())%2 == 0
}
Loading

0 comments on commit 17db71f

Please sign in to comment.