Skip to content

Commit

Permalink
executor: fix csv parser (#9005)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lingyu Song authored and jackysp committed Jan 15, 2019
1 parent cb43fc9 commit 33b4c3e
Show file tree
Hide file tree
Showing 3 changed files with 313 additions and 22 deletions.
7 changes: 6 additions & 1 deletion executor/executor_pkg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ func (s *testExecSuite) TestGetFieldsFromLine(c *C) {
`"\0\b\n\r\t\Z\\\ \c\'\""`,
[]string{string([]byte{0, '\b', '\n', '\r', '\t', 26, '\\', ' ', ' ', 'c', '\'', '"'})},
},
// Test mixed.
{
`"123",456,"\t7890",abcd`,
[]string{"123", "456", "\t7890", "abcd"},
},
}

ldInfo := LoadDataInfo{
Expand All @@ -214,7 +219,7 @@ func (s *testExecSuite) TestGetFieldsFromLine(c *C) {
}

_, err := ldInfo.getFieldsFromLine([]byte(`1,a string,100.20`))
c.Assert(err, NotNil)
c.Assert(err, IsNil)
}

func assertEqualStrings(c *C, got []field, expect []string) {
Expand Down
185 changes: 164 additions & 21 deletions executor/load_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
package executor

import (
"bytes"
"context"
"fmt"
"strings"
Expand Down Expand Up @@ -209,7 +208,6 @@ func (e *LoadDataInfo) InsertData(prevData, curData []byte) ([]byte, bool, error
if len(prevData) == 0 && len(curData) == 0 {
return nil, false, nil
}

var line []byte
var isEOF, hasStarting, reachLimit bool
if len(prevData) > 0 && len(curData) == 0 {
Expand All @@ -220,7 +218,6 @@ func (e *LoadDataInfo) InsertData(prevData, curData []byte) ([]byte, bool, error
for len(curData) > 0 {
line, curData, hasStarting = e.getLine(prevData, curData)
prevData = nil

// If it doesn't find the terminated symbol and this data isn't the last data,
// the data can't be inserted.
if line == nil && !isEOF {
Expand Down Expand Up @@ -313,28 +310,174 @@ func (e *LoadDataInfo) addRecordLD(row []types.Datum) (int64, error) {
type field struct {
str []byte
maybeNull bool
enclosed bool
}

type fieldWriter struct {
pos int
enclosedChar byte
fieldTermChar byte
term *string
isEnclosed bool
isLineStart bool
isFieldStart bool
ReadBuf *[]byte
OutputBuf []byte
}

func (w *fieldWriter) Init(enclosedChar byte, fieldTermChar byte, readBuf *[]byte, term *string) {
w.isEnclosed = false
w.isLineStart = true
w.isFieldStart = true
w.ReadBuf = readBuf
w.enclosedChar = enclosedChar
w.fieldTermChar = fieldTermChar
w.term = term
}

func (w *fieldWriter) putback() {
w.pos--
}

func (w *fieldWriter) getChar() (bool, byte) {
if w.pos < len(*w.ReadBuf) {
ret := (*w.ReadBuf)[w.pos]
w.pos++
return true, ret
}
return false, 0
}

func (w *fieldWriter) isTerminator() bool {
chkpt, isterm := w.pos, true
for i := 1; i < len(*w.term); i++ {
flag, ch := w.getChar()
if !flag || ch != (*w.term)[i] {
isterm = false
break
}
}
if !isterm {
w.pos = chkpt
return false
}
return true
}

func (w *fieldWriter) outputField(enclosed bool) field {
var fild []byte
start := 0
if enclosed {
start = 1
}
for i := start; i < len(w.OutputBuf); i++ {
fild = append(fild, w.OutputBuf[i])
}
if len(fild) == 0 {
fild = []byte("")
}
w.OutputBuf = w.OutputBuf[0:0]
w.isEnclosed = false
w.isFieldStart = true
return field{fild, false, enclosed}
}

func (w *fieldWriter) GetField() (bool, field) {
// The first return value implies whether fieldWriter read the last character of line.
if w.isLineStart {
_, ch := w.getChar()
if ch == w.enclosedChar {
w.isEnclosed = true
w.isFieldStart, w.isLineStart = false, false
w.OutputBuf = append(w.OutputBuf, ch)
} else {
w.putback()
}
}
for {
flag, ch := w.getChar()
if !flag {
ret := w.outputField(false)
return true, ret
}
if ch == w.enclosedChar && w.isFieldStart {
// If read enclosed char at field start.
w.isEnclosed = true
w.OutputBuf = append(w.OutputBuf, ch)
w.isLineStart, w.isFieldStart = false, false
continue
}
w.isLineStart, w.isFieldStart = false, false
if ch == w.fieldTermChar && !w.isEnclosed {
// If read filed terminate char.
if w.isTerminator() {
ret := w.outputField(false)
return false, ret
}
w.OutputBuf = append(w.OutputBuf, ch)
} else if ch == w.enclosedChar && w.isEnclosed {
// If read enclosed char, look ahead.
flag, ch = w.getChar()
if !flag {
ret := w.outputField(true)
return true, ret
} else if ch == w.enclosedChar {
w.OutputBuf = append(w.OutputBuf, ch)
continue
} else if ch == w.fieldTermChar {
// If the next char is fieldTermChar, look ahead.
if w.isTerminator() {
ret := w.outputField(true)
return false, ret
}
w.OutputBuf = append(w.OutputBuf, ch)
} else {
// If there is no terminator behind enclosedChar, put the char back.
w.OutputBuf = append(w.OutputBuf, w.enclosedChar)
w.putback()
}
} else if ch == '\\' {
// TODO: escape only support '\'
w.OutputBuf = append(w.OutputBuf, ch)
flag, ch = w.getChar()
if flag {
if ch == w.enclosedChar {
w.OutputBuf = append(w.OutputBuf, ch)
} else {
w.putback()
}
}
} else {
w.OutputBuf = append(w.OutputBuf, ch)
}
}
}

// getFieldsFromLine splits line according to fieldsInfo.
func (e *LoadDataInfo) getFieldsFromLine(line []byte) ([]field, error) {
var sep []byte
if e.FieldsInfo.Enclosed != 0 {
if line[0] != e.FieldsInfo.Enclosed || line[len(line)-1] != e.FieldsInfo.Enclosed {
return nil, errors.Errorf("line %s should begin and end with %c", string(line), e.FieldsInfo.Enclosed)
var (
reader fieldWriter
fields []field
)

if len(line) == 0 {
str := []byte("")
fields = append(fields, field{str, false, false})
return fields, nil
}

reader.Init(e.FieldsInfo.Enclosed, e.FieldsInfo.Terminated[0], &line, &e.FieldsInfo.Terminated)
for {
eol, f := reader.GetField()
f = f.escape()
if string(f.str) == "NULL" && !f.enclosed {
f.str = []byte{'N'}
f.maybeNull = true
}
fields = append(fields, f)
if eol {
break
}
line = line[1 : len(line)-1]
sep = make([]byte, 0, len(e.FieldsInfo.Terminated)+2)
sep = append(sep, e.FieldsInfo.Enclosed)
sep = append(sep, e.FieldsInfo.Terminated...)
sep = append(sep, e.FieldsInfo.Enclosed)
} else {
sep = []byte(e.FieldsInfo.Terminated)
}
rawCols := bytes.Split(line, sep)
fields := make([]field, 0, len(rawCols))
for _, v := range rawCols {
f := field{v, false}
fields = append(fields, f.escape())
}
return fields, nil
}
Expand All @@ -354,7 +497,7 @@ func (f *field) escape() field {
f.str[pos] = c
pos++
}
return field{f.str[:pos], f.maybeNull}
return field{f.str[:pos], f.maybeNull, f.enclosed}
}

func (f *field) escapeChar(c byte) byte {
Expand Down
143 changes: 143 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,149 @@ func runTestLoadData(c *C, server *Server) {
dbt.Assert(err, NotNil)
})

err = fp.Close()
c.Assert(err, IsNil)
err = os.Remove(path)
c.Assert(err, IsNil)

fp, err = os.Create(path)
c.Assert(err, IsNil)
c.Assert(fp, NotNil)

// Test mixed unenclosed and enclosed fields.
_, err = fp.WriteString(
"\"abc\",123\n" +
"def,456,\n" +
"hig,\"789\",")
c.Assert(err, IsNil)

runTestsOnNewDB(c, func(config *mysql.Config) {
config.AllowAllFiles = true
config.Strict = false
}, "LoadData", func(dbt *DBTest) {
dbt.mustExec("create table test (str varchar(10) default null, i int default null)")
_, err1 := dbt.db.Exec(`load data local infile '/tmp/load_data_test.csv' into table test FIELDS TERMINATED BY ',' enclosed by '"'`)
dbt.Assert(err1, IsNil)
var (
str string
id int
)
rows := dbt.mustQuery("select * from test")
dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data"))
err = rows.Scan(&str, &id)
dbt.Check(err, IsNil)
dbt.Check(str, DeepEquals, "abc")
dbt.Check(id, DeepEquals, 123)
dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data"))
rows.Scan(&str, &id)
dbt.Check(str, DeepEquals, "def")
dbt.Check(id, DeepEquals, 456)
dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data"))
rows.Scan(&str, &id)
dbt.Check(str, DeepEquals, "hig")
dbt.Check(id, DeepEquals, 789)
dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data"))
dbt.mustExec("delete from test")
})

err = fp.Close()
c.Assert(err, IsNil)
err = os.Remove(path)
c.Assert(err, IsNil)

fp, err = os.Create(path)
c.Assert(err, IsNil)
c.Assert(fp, NotNil)

// Test irregular csv file.
_, err = fp.WriteString(
`,\N,NULL,,` + "\n" +
"00,0,000000,,\n" +
`2003-03-03, 20030303,030303,\N` + "\n")
c.Assert(err, IsNil)

runTestsOnNewDB(c, func(config *mysql.Config) {
config.AllowAllFiles = true
config.Strict = false
}, "LoadData", func(dbt *DBTest) {
dbt.mustExec("create table test (a date, b date, c date not null, d date)")
_, err1 := dbt.db.Exec(`load data local infile '/tmp/load_data_test.csv' into table test FIELDS TERMINATED BY ','`)
dbt.Assert(err1, IsNil)
var (
a sql.NullString
b sql.NullString
d sql.NullString
c sql.NullString
)
rows := dbt.mustQuery("select * from test")
dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data"))
err = rows.Scan(&a, &b, &c, &d)
dbt.Check(err, IsNil)
dbt.Check(a.String, Equals, "0000-00-00")
dbt.Check(b.String, Equals, "")
dbt.Check(c.String, Equals, "0000-00-00")
dbt.Check(d.String, Equals, "0000-00-00")
dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data"))
rows.Scan(&a, &b, &c, &d)
dbt.Check(a.String, Equals, "0000-00-00")
dbt.Check(b.String, Equals, "0000-00-00")
dbt.Check(c.String, Equals, "0000-00-00")
dbt.Check(d.String, Equals, "0000-00-00")
dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data"))
rows.Scan(&a, &b, &c, &d)
dbt.Check(a.String, Equals, "2003-03-03")
dbt.Check(b.String, Equals, "2003-03-03")
dbt.Check(c.String, Equals, "2003-03-03")
dbt.Check(d.String, Equals, "")
dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data"))
dbt.mustExec("delete from test")
})

err = fp.Close()
c.Assert(err, IsNil)
err = os.Remove(path)
c.Assert(err, IsNil)

fp, err = os.Create(path)
c.Assert(err, IsNil)
c.Assert(fp, NotNil)

// Test double enclosed.
_, err = fp.WriteString(
`"field1","field2"` + "\n" +
`"a""b","cd""ef"` + "\n" +
`"a"b",c"d"e` + "\n")
c.Assert(err, IsNil)

runTestsOnNewDB(c, func(config *mysql.Config) {
config.AllowAllFiles = true
config.Strict = false
}, "LoadData", func(dbt *DBTest) {
dbt.mustExec("create table test (a varchar(20), b varchar(20))")
_, err1 := dbt.db.Exec(`load data local infile '/tmp/load_data_test.csv' into table test FIELDS TERMINATED BY ',' enclosed by '"'`)
dbt.Assert(err1, IsNil)
var (
a sql.NullString
b sql.NullString
)
rows := dbt.mustQuery("select * from test")
dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data"))
err = rows.Scan(&a, &b)
dbt.Check(err, IsNil)
dbt.Check(a.String, Equals, "field1")
dbt.Check(b.String, Equals, "field2")
dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data"))
rows.Scan(&a, &b)
dbt.Check(a.String, Equals, `a"b`)
dbt.Check(b.String, Equals, `cd"ef`)
dbt.Check(rows.Next(), IsTrue, Commentf("unexpected data"))
rows.Scan(&a, &b)
dbt.Check(a.String, Equals, `a"b`)
dbt.Check(b.String, Equals, `c"d"e`)
dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data"))
dbt.mustExec("delete from test")
})

// unsupport ClientLocalFiles capability
server.capability ^= tmysql.ClientLocalFiles
runTestsOnNewDB(c, func(config *mysql.Config) {
Expand Down

0 comments on commit 33b4c3e

Please sign in to comment.