From 109a9cb3ff293053d1dd0b406b37dda6f5d84f71 Mon Sep 17 00:00:00 2001 From: Jayc Date: Mon, 11 Oct 2021 16:14:51 +0530 Subject: [PATCH] Mssql version update 0.9 -> 0.10 --- go.mod | 2 +- go.sum | 2 + .../denisenkom/go-mssqldb/.gitignore | 8 + .../denisenkom/go-mssqldb/.golangci.yml | 10 + .../go-mssqldb/accesstokenconnector.go | 29 +- .../denisenkom/go-mssqldb/appveyor.yml | 3 + .../github.com/denisenkom/go-mssqldb/buf.go | 23 +- .../denisenkom/go-mssqldb/bulkcopy.go | 57 +-- .../denisenkom/go-mssqldb/conn_str.go | 65 ++- .../denisenkom/go-mssqldb/fedauth.go | 82 ++++ .../github.com/denisenkom/go-mssqldb/mssql.go | 243 +++++----- .../denisenkom/go-mssqldb/mssql_go110.go | 2 +- .../denisenkom/go-mssqldb/mssql_go19.go | 2 +- .../github.com/denisenkom/go-mssqldb/net.go | 44 +- .../github.com/denisenkom/go-mssqldb/ntlm.go | 13 +- .../github.com/denisenkom/go-mssqldb/rpc.go | 6 - .../github.com/denisenkom/go-mssqldb/tds.go | 429 +++++++++++++----- .../github.com/denisenkom/go-mssqldb/token.go | 411 ++++++++++------- .../denisenkom/go-mssqldb/token_string.go | 44 +- .../github.com/denisenkom/go-mssqldb/tran.go | 10 +- .../denisenkom/go-mssqldb/tvp_go19.go | 73 ++- .../github.com/denisenkom/go-mssqldb/types.go | 12 +- vendor/modules.txt | 2 +- 23 files changed, 1002 insertions(+), 570 deletions(-) create mode 100644 vendor/github.com/denisenkom/go-mssqldb/.gitignore create mode 100644 vendor/github.com/denisenkom/go-mssqldb/.golangci.yml create mode 100644 vendor/github.com/denisenkom/go-mssqldb/fedauth.go diff --git a/go.mod b/go.mod index 136594dfb7..6250dcc0a2 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( github.com/cenkalti/backoff v2.2.1+incompatible github.com/cenkalti/backoff/v4 v4.0.2 github.com/containerd/continuity v0.1.0 // indirect - github.com/denisenkom/go-mssqldb v0.9.0 + github.com/denisenkom/go-mssqldb v0.10.0 github.com/dgraph-io/badger/v2 v2.2007.2 github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 // indirect github.com/fsnotify/fsnotify v1.5.1 diff --git a/go.sum b/go.sum index aa682f1578..9468b8d72c 100644 --- a/go.sum +++ b/go.sum @@ -166,6 +166,8 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/denisenkom/go-mssqldb v0.0.0-20190515213511-eb9f6a1743f3/go.mod h1:zAg7JM8CkOJ43xKXIj7eRO9kmWm/TW578qo+oDO6tuM= github.com/denisenkom/go-mssqldb v0.9.0 h1:RSohk2RsiZqLZ0zCjtfn3S4Gp4exhpBWHyQ7D0yGjAk= github.com/denisenkom/go-mssqldb v0.9.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= +github.com/denisenkom/go-mssqldb v0.10.0 h1:QykgLZBorFE95+gO3u9esLd0BmbvpWp0/waNNZfHBM8= +github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/dgraph-io/badger/v2 v2.2007.2 h1:EjjK0KqwaFMlPin1ajhP943VPENHJdEz1KLIegjaI3k= github.com/dgraph-io/badger/v2 v2.2007.2/go.mod h1:26P/7fbL4kUZVEVKLAKXkBXKOydDmM2p1e+NhhnBCAE= github.com/dgraph-io/ristretto v0.0.3-0.20200630154024-f66de99634de h1:t0UHb5vdojIDUqktM6+xJAfScFBsVpXZmqC9dsgJmeA= diff --git a/vendor/github.com/denisenkom/go-mssqldb/.gitignore b/vendor/github.com/denisenkom/go-mssqldb/.gitignore new file mode 100644 index 0000000000..1dda7039b4 --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/.gitignore @@ -0,0 +1,8 @@ +/.idea +/.connstr +.vscode +.terraform +*.tfstate* +*.log +*.swp +*~ diff --git a/vendor/github.com/denisenkom/go-mssqldb/.golangci.yml b/vendor/github.com/denisenkom/go-mssqldb/.golangci.yml new file mode 100644 index 0000000000..959cd5e613 --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/.golangci.yml @@ -0,0 +1,10 @@ +linters: + enable: + # basic go linters + - gofmt + - golint + - govet + + # sql related linters + - rowserrcheck + - sqlclosecheck diff --git a/vendor/github.com/denisenkom/go-mssqldb/accesstokenconnector.go b/vendor/github.com/denisenkom/go-mssqldb/accesstokenconnector.go index 8dbe5099e4..8365e4d8be 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/accesstokenconnector.go +++ b/vendor/github.com/denisenkom/go-mssqldb/accesstokenconnector.go @@ -6,19 +6,8 @@ import ( "context" "database/sql/driver" "errors" - "fmt" ) -var _ driver.Connector = &accessTokenConnector{} - -// accessTokenConnector wraps Connector and injects a -// fresh access token when connecting to the database -type accessTokenConnector struct { - Connector - - accessTokenProvider func() (string, error) -} - // NewAccessTokenConnector creates a new connector from a DSN and a token provider. // The token provider func will be called when a new connection is requested and should return a valid access token. // The returned connector may be used with sql.OpenDB. @@ -32,20 +21,10 @@ func NewAccessTokenConnector(dsn string, tokenProvider func() (string, error)) ( return nil, err } - c := &accessTokenConnector{ - Connector: *conn, - accessTokenProvider: tokenProvider, - } - return c, nil -} - -// Connect returns a new database connection -func (c *accessTokenConnector) Connect(ctx context.Context) (driver.Conn, error) { - var err error - c.Connector.params.fedAuthAccessToken, err = c.accessTokenProvider() - if err != nil { - return nil, fmt.Errorf("mssql: error retrieving access token: %+v", err) + conn.params.fedAuthLibrary = fedAuthLibrarySecurityToken + conn.securityTokenProvider = func(ctx context.Context) (string, error) { + return tokenProvider() } - return c.Connector.Connect(ctx) + return conn, nil } diff --git a/vendor/github.com/denisenkom/go-mssqldb/appveyor.yml b/vendor/github.com/denisenkom/go-mssqldb/appveyor.yml index dfcb62de02..ecb893a3d7 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/appveyor.yml +++ b/vendor/github.com/denisenkom/go-mssqldb/appveyor.yml @@ -39,6 +39,9 @@ environment: - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 GOVERSION: 115 SQLINSTANCE: SQL2017 + - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 + GOVERSION: 116 + SQLINSTANCE: SQL2017 install: - set GOROOT=c:\go%GOVERSION% diff --git a/vendor/github.com/denisenkom/go-mssqldb/buf.go b/vendor/github.com/denisenkom/go-mssqldb/buf.go index ba39b40f17..bad2b00de5 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/buf.go +++ b/vendor/github.com/denisenkom/go-mssqldb/buf.go @@ -48,8 +48,8 @@ type tdsBuffer struct { func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer { return &tdsBuffer{ packetSize: int(bufsize), - wbuf: make([]byte, 1<<16), - rbuf: make([]byte, 1<<16), + wbuf: make([]byte, bufsize), + rbuf: make([]byte, bufsize), rpos: 8, transport: transport, } @@ -137,19 +137,28 @@ func (w *tdsBuffer) FinishPacket() error { var headerSize = binary.Size(header{}) func (r *tdsBuffer) readNextPacket() error { - h := header{} - var err error - err = binary.Read(r.transport, binary.BigEndian, &h) + buf := r.rbuf[:headerSize] + _, err := io.ReadFull(r.transport, buf) if err != nil { return err } + h := header{ + PacketType: packetType(buf[0]), + Status: buf[1], + Size: binary.BigEndian.Uint16(buf[2:4]), + Spid: binary.BigEndian.Uint16(buf[4:6]), + PacketNo: buf[6], + Pad: buf[7], + } if int(h.Size) > r.packetSize { - return errors.New("Invalid packet size, it is longer than buffer size") + return errors.New("invalid packet size, it is longer than buffer size") } if headerSize > int(h.Size) { - return errors.New("Invalid packet size, it is shorter than header size") + return errors.New("invalid packet size, it is shorter than header size") } _, err = io.ReadFull(r.transport, r.rbuf[headerSize:h.Size]) + //s := base64.StdEncoding.EncodeToString(r.rbuf[headerSize:h.Size]) + //fmt.Print(s) if err != nil { return err } diff --git a/vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go b/vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go index 1d5eacb381..ba49d1ce7c 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go +++ b/vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go @@ -44,8 +44,9 @@ type BulkOptions struct { type DataValue interface{} const ( - sqlDateFormat = "2006-01-02" - sqlTimeFormat = "2006-01-02 15:04:05.999999999Z07:00" + sqlDateFormat = "2006-01-02" + sqlDateTimeFormat = "2006-01-02 15:04:05.999999999Z07:00" + sqlTimeFormat = "15:04:05.9999999" ) func (cn *Conn) CreateBulk(table string, columns []string) (_ *Bulk) { @@ -86,7 +87,7 @@ func (b *Bulk) sendBulkCommand(ctx context.Context) (err error) { b.bulkColumns = append(b.bulkColumns, *bulkCol) b.dlogf("Adding column %s %s %#x", colname, bulkCol.ColName, bulkCol.ti.TypeId) } else { - return fmt.Errorf("Column %s does not exist in destination table %s", colname, b.tablename) + return fmt.Errorf("column %s does not exist in destination table %s", colname, b.tablename) } } @@ -166,7 +167,7 @@ func (b *Bulk) AddRow(row []interface{}) (err error) { } if len(row) != len(b.bulkColumns) { - return fmt.Errorf("Row does not have the same number of columns than the destination table %d %d", + return fmt.Errorf("row does not have the same number of columns than the destination table %d %d", len(row), len(b.bulkColumns)) } @@ -215,7 +216,7 @@ func (b *Bulk) makeRowData(row []interface{}) ([]byte, error) { } func (b *Bulk) Done() (rowcount int64, err error) { - if b.headerSent == false { + if !b.headerSent { //no rows had been sent return 0, nil } @@ -233,24 +234,13 @@ func (b *Bulk) Done() (rowcount int64, err error) { buf.FinishPacket() - tokchan := make(chan tokenStruct, 5) - go processResponse(b.ctx, b.cn.sess, tokchan, nil) - - var rowCount int64 - for token := range tokchan { - switch token := token.(type) { - case doneStruct: - if token.Status&doneCount != 0 { - rowCount = int64(token.RowCount) - } - if token.isError() { - return 0, token.getError() - } - case error: - return 0, b.cn.checkBadConn(token) - } + reader := startReading(b.cn.sess, b.ctx, nil) + err = reader.iterateResponse() + if err != nil { + return 0, b.cn.checkBadConn(err) } - return rowCount, nil + + return reader.rowCount, nil } func (b *Bulk) createColMetadata() []byte { @@ -421,7 +411,7 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) res.ti.Size = len(res.buffer) case string: var t time.Time - if t, err = time.Parse(sqlTimeFormat, val); err != nil { + if t, err = time.Parse(sqlDateTimeFormat, val); err != nil { return res, fmt.Errorf("bulk: unable to convert string to date: %v", err) } res.buffer = encodeDateTime2(t, int(col.ti.Scale)) @@ -437,7 +427,7 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) res.ti.Size = len(res.buffer) case string: var t time.Time - if t, err = time.Parse(sqlTimeFormat, val); err != nil { + if t, err = time.Parse(sqlDateTimeFormat, val); err != nil { return res, fmt.Errorf("bulk: unable to convert string to date: %v", err) } res.buffer = encodeDateTimeOffset(t, int(col.ti.Scale)) @@ -468,7 +458,7 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) case time.Time: t = val case string: - if t, err = time.Parse(sqlTimeFormat, val); err != nil { + if t, err = time.Parse(sqlDateTimeFormat, val); err != nil { return res, fmt.Errorf("bulk: unable to convert string to date: %v", err) } default: @@ -485,7 +475,22 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) } else { err = fmt.Errorf("mssql: invalid size of column %d", col.ti.Size) } - + case typeTimeN: + var t time.Time + switch val := val.(type) { + case time.Time: + res.buffer = encodeTime(val.Hour(), val.Minute(), val.Second(), val.Nanosecond(), int(col.ti.Scale)) + res.ti.Size = len(res.buffer) + case string: + if t, err = time.Parse(sqlTimeFormat, val); err != nil { + return res, fmt.Errorf("bulk: unable to convert string to time: %v", err) + } + res.buffer = encodeTime(t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), int(col.ti.Scale)) + res.ti.Size = len(res.buffer) + default: + err = fmt.Errorf("mssql: invalid type for time column: %T %s", val, val) + return + } // case typeMoney, typeMoney4, typeMoneyN: case typeDecimal, typeDecimalN, typeNumeric, typeNumericN: prec := col.ti.Prec diff --git a/vendor/github.com/denisenkom/go-mssqldb/conn_str.go b/vendor/github.com/denisenkom/go-mssqldb/conn_str.go index 26ac50f38d..d7d9e06af0 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/conn_str.go +++ b/vendor/github.com/denisenkom/go-mssqldb/conn_str.go @@ -37,11 +37,17 @@ type connectParams struct { failOverPartner string failOverPort uint64 packetSize uint16 - fedAuthAccessToken string + fedAuthLibrary int + fedAuthADALWorkflow byte } +// default packet size for TDS buffer +const defaultPacketSize = 4096 + func parseConnectParams(dsn string) (connectParams, error) { - var p connectParams + p := connectParams{ + fedAuthLibrary: fedAuthLibraryReserved, + } var params map[string]string if strings.HasPrefix(dsn, "odbc:") { @@ -65,7 +71,7 @@ func parseConnectParams(dsn string) (connectParams, error) { var err error p.logFlags, err = strconv.ParseUint(strlog, 10, 64) if err != nil { - return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error()) + return p, fmt.Errorf("invalid log parameter '%s': %s", strlog, err.Error()) } } server := params["server"] @@ -87,20 +93,19 @@ func parseConnectParams(dsn string) (connectParams, error) { var err error p.port, err = strconv.ParseUint(strport, 10, 16) if err != nil { - f := "Invalid tcp port '%v': %v" + f := "invalid tcp port '%v': %v" return p, fmt.Errorf(f, strport, err.Error()) } } // https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option - // Default packet size remains at 4096 bytes - p.packetSize = 4096 + p.packetSize = defaultPacketSize strpsize, ok := params["packet size"] if ok { var err error psize, err := strconv.ParseUint(strpsize, 0, 16) if err != nil { - f := "Invalid packet size '%v': %v" + f := "invalid packet size '%v': %v" return p, fmt.Errorf(f, strpsize, err.Error()) } @@ -123,7 +128,7 @@ func parseConnectParams(dsn string) (connectParams, error) { if strconntimeout, ok := params["connection timeout"]; ok { timeout, err := strconv.ParseUint(strconntimeout, 10, 64) if err != nil { - f := "Invalid connection timeout '%v': %v" + f := "invalid connection timeout '%v': %v" return p, fmt.Errorf(f, strconntimeout, err.Error()) } p.conn_timeout = time.Duration(timeout) * time.Second @@ -132,7 +137,7 @@ func parseConnectParams(dsn string) (connectParams, error) { if strdialtimeout, ok := params["dial timeout"]; ok { timeout, err := strconv.ParseUint(strdialtimeout, 10, 64) if err != nil { - f := "Invalid dial timeout '%v': %v" + f := "invalid dial timeout '%v': %v" return p, fmt.Errorf(f, strdialtimeout, err.Error()) } p.dial_timeout = time.Duration(timeout) * time.Second @@ -144,7 +149,7 @@ func parseConnectParams(dsn string) (connectParams, error) { if keepAlive, ok := params["keepalive"]; ok { timeout, err := strconv.ParseUint(keepAlive, 10, 64) if err != nil { - f := "Invalid keepAlive value '%s': %s" + f := "invalid keepAlive value '%s': %s" return p, fmt.Errorf(f, keepAlive, err.Error()) } p.keepAlive = time.Duration(timeout) * time.Second @@ -157,7 +162,7 @@ func parseConnectParams(dsn string) (connectParams, error) { var err error p.encrypt, err = strconv.ParseBool(encrypt) if err != nil { - f := "Invalid encrypt '%s': %s" + f := "invalid encrypt '%s': %s" return p, fmt.Errorf(f, encrypt, err.Error()) } } @@ -169,7 +174,7 @@ func parseConnectParams(dsn string) (connectParams, error) { var err error p.trustServerCertificate, err = strconv.ParseBool(trust) if err != nil { - f := "Invalid trust server certificate '%s': %s" + f := "invalid trust server certificate '%s': %s" return p, fmt.Errorf(f, trust, err.Error()) } } @@ -209,7 +214,7 @@ func parseConnectParams(dsn string) (connectParams, error) { if ok { if appintent == "ReadOnly" { if p.database == "" { - return p, fmt.Errorf("Database must be specified when ApplicationIntent is ReadOnly") + return p, fmt.Errorf("database must be specified when ApplicationIntent is ReadOnly") } p.typeFlags |= fReadOnlyIntent } @@ -225,7 +230,7 @@ func parseConnectParams(dsn string) (connectParams, error) { var err error p.failOverPort, err = strconv.ParseUint(failOverPort, 0, 16) if err != nil { - f := "Invalid tcp port '%v': %v" + f := "invalid tcp port '%v': %v" return p, fmt.Errorf(f, failOverPort, err.Error()) } } @@ -233,6 +238,30 @@ func parseConnectParams(dsn string) (connectParams, error) { return p, nil } +// convert connectionParams to url style connection string +// used mostly for testing +func (p connectParams) toUrl() *url.URL { + q := url.Values{} + if p.database != "" { + q.Add("database", p.database) + } + if p.logFlags != 0 { + q.Add("log", strconv.FormatUint(p.logFlags, 10)) + } + res := url.URL{ + Scheme: "sqlserver", + Host: p.host, + User: url.UserPassword(p.user, p.password), + } + if p.instance != "" { + res.Path = p.instance + } + if len(q) > 0 { + res.RawQuery = q.Encode() + } + return &res +} + func splitConnectionString(dsn string) (res map[string]string) { res = map[string]string{} parts := strings.Split(dsn, ";") @@ -340,7 +369,7 @@ func splitConnectionStringOdbc(dsn string) (map[string]string, error) { case parserStateBeforeKey: switch { case c == '=': - return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i) + return res, fmt.Errorf("unexpected character = at index %d. Expected start of key or semi-colon or whitespace", i) case !unicode.IsSpace(c) && c != ';': state = parserStateKey key += string(c) @@ -419,7 +448,7 @@ func splitConnectionStringOdbc(dsn string) (map[string]string, error) { case unicode.IsSpace(c): // Ignore whitespace default: - return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i) + return res, fmt.Errorf("unexpected character %c at index %d. Expected semi-colon or whitespace", c, i) } case parserStateEndValue: @@ -429,7 +458,7 @@ func splitConnectionStringOdbc(dsn string) (map[string]string, error) { case unicode.IsSpace(c): // Ignore whitespace default: - return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i) + return res, fmt.Errorf("unexpected character %c at index %d. Expected semi-colon or whitespace", c, i) } } } @@ -444,7 +473,7 @@ func splitConnectionStringOdbc(dsn string) (map[string]string, error) { case parserStateBareValue: res[key] = strings.TrimRightFunc(value, unicode.IsSpace) case parserStateBracedValue: - return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn)) + return res, fmt.Errorf("unexpected end of braced value at index %d", len(dsn)) case parserStateBracedValueClosingBrace: // End of braced value res[key] = value case parserStateEndValue: // Okay diff --git a/vendor/github.com/denisenkom/go-mssqldb/fedauth.go b/vendor/github.com/denisenkom/go-mssqldb/fedauth.go new file mode 100644 index 0000000000..86fed253e1 --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/fedauth.go @@ -0,0 +1,82 @@ +package mssql + +import ( + "context" + "errors" +) + +// Federated authentication library affects the login data structure and message sequence. +const ( + // fedAuthLibraryLiveIDCompactToken specifies the Microsoft Live ID Compact Token authentication scheme + fedAuthLibraryLiveIDCompactToken = 0x00 + + // fedAuthLibrarySecurityToken specifies a token-based authentication where the token is available + // without additional information provided during the login sequence. + fedAuthLibrarySecurityToken = 0x01 + + // fedAuthLibraryADAL specifies a token-based authentication where a token is obtained during the + // login sequence using the server SPN and STS URL provided by the server during login. + fedAuthLibraryADAL = 0x02 + + // fedAuthLibraryReserved is used to indicate that no federated authentication scheme applies. + fedAuthLibraryReserved = 0x7F +) + +// Federated authentication ADAL workflow affects the mechanism used to authenticate. +const ( + // fedAuthADALWorkflowPassword uses a username/password to obtain a token from Active Directory + fedAuthADALWorkflowPassword = 0x01 + + // fedAuthADALWorkflowPassword uses the Windows identity to obtain a token from Active Directory + fedAuthADALWorkflowIntegrated = 0x02 + + // fedAuthADALWorkflowMSI uses the managed identity service to obtain a token + fedAuthADALWorkflowMSI = 0x03 +) + +// newSecurityTokenConnector creates a new connector from a DSN and a token provider. +// When invoked, token provider implementations should contact the security token +// service specified and obtain the appropriate token, or return an error +// to indicate why a token is not available. +// The returned connector may be used with sql.OpenDB. +func newSecurityTokenConnector(dsn string, tokenProvider func(ctx context.Context) (string, error)) (*Connector, error) { + if tokenProvider == nil { + return nil, errors.New("mssql: tokenProvider cannot be nil") + } + + conn, err := NewConnector(dsn) + if err != nil { + return nil, err + } + + conn.params.fedAuthLibrary = fedAuthLibrarySecurityToken + conn.securityTokenProvider = tokenProvider + + return conn, nil +} + +// newADALTokenConnector creates a new connector from a DSN and a Active Directory token provider. +// Token provider implementations are called during federated +// authentication login sequences where the server provides a service +// principal name and security token service endpoint that should be used +// to obtain the token. Implementations should contact the security token +// service specified and obtain the appropriate token, or return an error +// to indicate why a token is not available. +// +// The returned connector may be used with sql.OpenDB. +func newActiveDirectoryTokenConnector(dsn string, adalWorkflow byte, tokenProvider func(ctx context.Context, serverSPN, stsURL string) (string, error)) (*Connector, error) { + if tokenProvider == nil { + return nil, errors.New("mssql: tokenProvider cannot be nil") + } + + conn, err := NewConnector(dsn) + if err != nil { + return nil, err + } + + conn.params.fedAuthLibrary = fedAuthLibraryADAL + conn.params.fedAuthADALWorkflow = adalWorkflow + conn.adalTokenProvider = tokenProvider + + return conn, nil +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/mssql.go b/vendor/github.com/denisenkom/go-mssqldb/mssql.go index 25c268edc6..6e2f4af894 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/mssql.go +++ b/vendor/github.com/denisenkom/go-mssqldb/mssql.go @@ -58,6 +58,7 @@ func (d *Driver) OpenConnector(dsn string) (*Connector, error) { if err != nil { return nil, err } + return &Connector{ params: params, driver: d, @@ -100,6 +101,12 @@ type Connector struct { params connectParams driver *Driver + // callback that can provide a security token during login + securityTokenProvider func(ctx context.Context) (string, error) + + // callback that can provide a security token during ADAL login + adalTokenProvider func(ctx context.Context, serverSPN, stsURL string) (string, error) + // SessionInitSQL is executed after marking a given session to be reset. // When not present, the next query will still reset the session to the // database defaults. @@ -148,15 +155,7 @@ type Conn struct { processQueryText bool connectionGood bool - outs map[string]interface{} - returnStatus *ReturnStatus -} - -func (c *Conn) setReturnStatus(s ReturnStatus) { - if c.returnStatus == nil { - return - } - *c.returnStatus = s + outs map[string]interface{} } func (c *Conn) checkBadConn(err error) error { @@ -201,20 +200,15 @@ func (c *Conn) clearOuts() { } func (c *Conn) simpleProcessResp(ctx context.Context) error { - tokchan := make(chan tokenStruct, 5) - go processResponse(ctx, c.sess, tokchan, c.outs) + reader := startReading(c.sess, ctx, c.outs) c.clearOuts() - for tok := range tokchan { - switch token := tok.(type) { - case doneStruct: - if token.isError() { - return c.checkBadConn(token.getError()) - } - case error: - return c.checkBadConn(token) - } + + var resultError error + err := reader.iterateResponse() + if err != nil { + return c.checkBadConn(err) } - return nil + return resultError } func (c *Conn) Commit() error { @@ -239,7 +233,7 @@ func (c *Conn) sendCommitRequest() error { c.sess.log.Printf("Failed to send CommitXact with %v", err) } c.connectionGood = false - return fmt.Errorf("Faild to send CommitXact: %v", err) + return fmt.Errorf("faild to send CommitXact: %v", err) } return nil } @@ -266,7 +260,7 @@ func (c *Conn) sendRollbackRequest() error { c.sess.log.Printf("Failed to send RollbackXact with %v", err) } c.connectionGood = false - return fmt.Errorf("Failed to send RollbackXact: %v", err) + return fmt.Errorf("failed to send RollbackXact: %v", err) } return nil } @@ -303,7 +297,7 @@ func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) erro c.sess.log.Printf("Failed to send BeginXact with %v", err) } c.connectionGood = false - return fmt.Errorf("Failed to send BeginXact: %v", err) + return fmt.Errorf("failed to send BeginXact: %v", err) } return nil } @@ -478,7 +472,7 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) { conn.sess.log.Printf("Failed to send Rpc with %v", err) } conn.connectionGood = false - return fmt.Errorf("Failed to send RPC: %v", err) + return fmt.Errorf("failed to send RPC: %v", err) } } return @@ -595,38 +589,46 @@ func (s *Stmt) queryContext(ctx context.Context, args []namedValue) (rows driver } func (s *Stmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) { - tokchan := make(chan tokenStruct, 5) ctx, cancel := context.WithCancel(ctx) - go processResponse(ctx, s.c.sess, tokchan, s.c.outs) + reader := startReading(s.c.sess, ctx, s.c.outs) s.c.clearOuts() // process metadata var cols []columnStruct loop: - for tok := range tokchan { - switch token := tok.(type) { - // By ignoring DONE token we effectively - // skip empty result-sets. - // This improves results in queries like that: - // set nocount on; select 1 - // see TestIgnoreEmptyResults test - //case doneStruct: - //break loop - case []columnStruct: - cols = token - break loop - case doneStruct: - if token.isError() { - cancel() - return nil, s.c.checkBadConn(token.getError()) + for { + tok, err := reader.nextToken() + if err == nil { + if tok == nil { + break + } else { + switch token := tok.(type) { + // By ignoring DONE token we effectively + // skip empty result-sets. + // This improves results in queries like that: + // set nocount on; select 1 + // see TestIgnoreEmptyResults test + //case doneStruct: + //break loop + case []columnStruct: + cols = token + break loop + case doneStruct: + if token.isError() { + // need to cleanup cancellable context + cancel() + return nil, s.c.checkBadConn(token.getError()) + } + case ReturnStatus: + s.c.sess.setReturnStatus(token) + } } - case ReturnStatus: - s.c.setReturnStatus(token) - case error: + } else { + // need to cleanup cancellable context cancel() - return nil, s.c.checkBadConn(token) + return nil, s.c.checkBadConn(err) } } - res = &Rows{stmt: s, tokchan: tokchan, cols: cols, cancel: cancel} + res = &Rows{stmt: s, reader: reader, cols: cols, cancel: cancel} return } @@ -648,48 +650,46 @@ func (s *Stmt) exec(ctx context.Context, args []namedValue) (res driver.Result, } func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) { - tokchan := make(chan tokenStruct, 5) - go processResponse(ctx, s.c.sess, tokchan, s.c.outs) + reader := startReading(s.c.sess, ctx, s.c.outs) s.c.clearOuts() - var rowCount int64 - for token := range tokchan { - switch token := token.(type) { - case doneInProcStruct: - if token.Status&doneCount != 0 { - rowCount += int64(token.RowCount) - } - case doneStruct: - if token.Status&doneCount != 0 { - rowCount += int64(token.RowCount) - } - if token.isError() { - return nil, token.getError() - } - case ReturnStatus: - s.c.setReturnStatus(token) - case error: - return nil, token - } + err = reader.iterateResponse() + if err != nil { + return nil, s.c.checkBadConn(err) } - return &Result{s.c, rowCount}, nil + return &Result{s.c, reader.rowCount}, nil } type Rows struct { - stmt *Stmt - cols []columnStruct - tokchan chan tokenStruct - + stmt *Stmt + cols []columnStruct + reader *tokenProcessor nextCols []columnStruct cancel func() } func (rc *Rows) Close() error { + // need to add a test which returns lots of rows + // and check closing after reading only few rows rc.cancel() - for _ = range rc.tokchan { + + for { + tok, err := rc.reader.nextToken() + if err == nil { + if tok == nil { + return nil + } else { + // continue consuming tokens + continue + } + } else { + if err == rc.reader.ctx.Err() { + return nil + } else { + return err + } + } } - rc.tokchan = nil - return nil } func (rc *Rows) Columns() (res []string) { @@ -707,27 +707,34 @@ func (rc *Rows) Next(dest []driver.Value) error { if rc.nextCols != nil { return io.EOF } - for tok := range rc.tokchan { - switch tokdata := tok.(type) { - case []columnStruct: - rc.nextCols = tokdata - return io.EOF - case []interface{}: - for i := range dest { - dest[i] = tokdata[i] - } - return nil - case doneStruct: - if tokdata.isError() { - return rc.stmt.c.checkBadConn(tokdata.getError()) + for { + tok, err := rc.reader.nextToken() + if err == nil { + if tok == nil { + return io.EOF + } else { + switch tokdata := tok.(type) { + case []columnStruct: + rc.nextCols = tokdata + return io.EOF + case []interface{}: + for i := range dest { + dest[i] = tokdata[i] + } + return nil + case doneStruct: + if tokdata.isError() { + return rc.stmt.c.checkBadConn(tokdata.getError()) + } + case ReturnStatus: + rc.stmt.c.sess.setReturnStatus(tokdata) + } } - case ReturnStatus: - rc.stmt.c.setReturnStatus(tokdata) - case error: - return rc.stmt.c.checkBadConn(tokdata) + + } else { + return rc.stmt.c.checkBadConn(err) } } - return io.EOF } func (rc *Rows) HasNextResultSet() bool { @@ -895,35 +902,41 @@ func (c *Conn) Ping(ctx context.Context) error { var _ driver.ConnBeginTx = &Conn{} -// BeginTx satisfies ConnBeginTx. -func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { - if !c.connectionGood { - return nil, driver.ErrBadConn - } - if opts.ReadOnly { - return nil, errors.New("Read-only transactions are not supported") - } - - var tdsIsolation isoLevel - switch sql.IsolationLevel(opts.Isolation) { +func convertIsolationLevel(level sql.IsolationLevel) (isoLevel, error) { + switch level { case sql.LevelDefault: - tdsIsolation = isolationUseCurrent + return isolationUseCurrent, nil case sql.LevelReadUncommitted: - tdsIsolation = isolationReadUncommited + return isolationReadUncommited, nil case sql.LevelReadCommitted: - tdsIsolation = isolationReadCommited + return isolationReadCommited, nil case sql.LevelWriteCommitted: - return nil, errors.New("LevelWriteCommitted isolation level is not supported") + return isolationUseCurrent, errors.New("LevelWriteCommitted isolation level is not supported") case sql.LevelRepeatableRead: - tdsIsolation = isolationRepeatableRead + return isolationRepeatableRead, nil case sql.LevelSnapshot: - tdsIsolation = isolationSnapshot + return isolationSnapshot, nil case sql.LevelSerializable: - tdsIsolation = isolationSerializable + return isolationSerializable, nil case sql.LevelLinearizable: - return nil, errors.New("LevelLinearizable isolation level is not supported") + return isolationUseCurrent, errors.New("LevelLinearizable isolation level is not supported") default: - return nil, errors.New("Isolation level is not supported or unknown") + return isolationUseCurrent, errors.New("isolation level is not supported or unknown") + } +} + +// BeginTx satisfies ConnBeginTx. +func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if !c.connectionGood { + return nil, driver.ErrBadConn + } + if opts.ReadOnly { + return nil, errors.New("read-only transactions are not supported") + } + + tdsIsolation, err := convertIsolationLevel(sql.IsolationLevel(opts.Isolation)) + if err != nil { + return nil, err } return c.begin(ctx, tdsIsolation) } diff --git a/vendor/github.com/denisenkom/go-mssqldb/mssql_go110.go b/vendor/github.com/denisenkom/go-mssqldb/mssql_go110.go index 6d76fbad08..e4edc752b5 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/mssql_go110.go +++ b/vendor/github.com/denisenkom/go-mssqldb/mssql_go110.go @@ -48,5 +48,5 @@ func (c *Connector) Driver() driver.Driver { } func (r *Result) LastInsertId() (int64, error) { - return -1, errors.New("LastInsertId is not supported. Please use the OUTPUT clause or add `select ID = convert(bigint, SCOPE_IDENTITY())` to the end of your query.") + return -1, errors.New("LastInsertId is not supported. Please use the OUTPUT clause or add `select ID = convert(bigint, SCOPE_IDENTITY())` to the end of your query") } diff --git a/vendor/github.com/denisenkom/go-mssqldb/mssql_go19.go b/vendor/github.com/denisenkom/go-mssqldb/mssql_go19.go index a2bd1167ba..2b4edeba6c 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/mssql_go19.go +++ b/vendor/github.com/denisenkom/go-mssqldb/mssql_go19.go @@ -110,7 +110,7 @@ func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error { return nil case *ReturnStatus: *v = 0 // By default the return value should be zero. - c.returnStatus = v + c.sess.returnStatus = v return driver.ErrRemoveArgument case TVP: return nil diff --git a/vendor/github.com/denisenkom/go-mssqldb/net.go b/vendor/github.com/denisenkom/go-mssqldb/net.go index 94858cc74f..bb7b784cbf 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/net.go +++ b/vendor/github.com/denisenkom/go-mssqldb/net.go @@ -7,8 +7,8 @@ import ( ) type timeoutConn struct { - c net.Conn - timeout time.Duration + c net.Conn + timeout time.Duration } func newTimeoutConn(conn net.Conn, timeout time.Duration) *timeoutConn { @@ -51,21 +51,21 @@ func (c timeoutConn) RemoteAddr() net.Addr { } func (c timeoutConn) SetDeadline(t time.Time) error { - panic("Not implemented") + return c.c.SetDeadline(t) } func (c timeoutConn) SetReadDeadline(t time.Time) error { - panic("Not implemented") + return c.c.SetReadDeadline(t) } func (c timeoutConn) SetWriteDeadline(t time.Time) error { - panic("Not implemented") + return c.c.SetWriteDeadline(t) } // this connection is used during TLS Handshake // TDS protocol requires TLS handshake messages to be sent inside TDS packets type tlsHandshakeConn struct { - buf *tdsBuffer + buf *tdsBuffer packetPending bool continueRead bool } @@ -75,7 +75,7 @@ func (c *tlsHandshakeConn) Read(b []byte) (n int, err error) { c.packetPending = false err = c.buf.FinishPacket() if err != nil { - err = fmt.Errorf("Cannot send handshake packet: %s", err.Error()) + err = fmt.Errorf("cannot send handshake packet: %s", err.Error()) return } c.continueRead = false @@ -84,7 +84,7 @@ func (c *tlsHandshakeConn) Read(b []byte) (n int, err error) { var packet packetType packet, err = c.buf.BeginRead() if err != nil { - err = fmt.Errorf("Cannot read handshake packet: %s", err.Error()) + err = fmt.Errorf("cannot read handshake packet: %s", err.Error()) return } if packet != packPrelogin { @@ -105,27 +105,27 @@ func (c *tlsHandshakeConn) Write(b []byte) (n int, err error) { } func (c *tlsHandshakeConn) Close() error { - panic("Not implemented") + return c.buf.transport.Close() } func (c *tlsHandshakeConn) LocalAddr() net.Addr { - panic("Not implemented") + return nil } func (c *tlsHandshakeConn) RemoteAddr() net.Addr { - panic("Not implemented") + return nil } -func (c *tlsHandshakeConn) SetDeadline(t time.Time) error { - panic("Not implemented") +func (c *tlsHandshakeConn) SetDeadline(_ time.Time) error { + return nil } -func (c *tlsHandshakeConn) SetReadDeadline(t time.Time) error { - panic("Not implemented") +func (c *tlsHandshakeConn) SetReadDeadline(_ time.Time) error { + return nil } -func (c *tlsHandshakeConn) SetWriteDeadline(t time.Time) error { - panic("Not implemented") +func (c *tlsHandshakeConn) SetWriteDeadline(_ time.Time) error { + return nil } // this connection just delegates all methods to it's wrapped connection @@ -148,21 +148,21 @@ func (c passthroughConn) Close() error { } func (c passthroughConn) LocalAddr() net.Addr { - panic("Not implemented") + return c.c.LocalAddr() } func (c passthroughConn) RemoteAddr() net.Addr { - panic("Not implemented") + return c.c.RemoteAddr() } func (c passthroughConn) SetDeadline(t time.Time) error { - panic("Not implemented") + return c.c.SetDeadline(t) } func (c passthroughConn) SetReadDeadline(t time.Time) error { - panic("Not implemented") + return c.c.SetReadDeadline(t) } func (c passthroughConn) SetWriteDeadline(t time.Time) error { - panic("Not implemented") + return c.c.SetWriteDeadline(t) } diff --git a/vendor/github.com/denisenkom/go-mssqldb/ntlm.go b/vendor/github.com/denisenkom/go-mssqldb/ntlm.go index ea9148aed0..90adb5a026 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/ntlm.go +++ b/vendor/github.com/denisenkom/go-mssqldb/ntlm.go @@ -14,6 +14,7 @@ import ( "time" "unicode/utf16" + //lint:ignore SA1019 MD4 is used by legacy NTLM "golang.org/x/crypto/md4" ) @@ -126,18 +127,6 @@ func createDesKey(bytes, material []byte) { material[7] = (byte)(bytes[6] << 1) } -func oddParity(bytes []byte) { - for i := 0; i < len(bytes); i++ { - b := bytes[i] - needsParity := (((b >> 7) ^ (b >> 6) ^ (b >> 5) ^ (b >> 4) ^ (b >> 3) ^ (b >> 2) ^ (b >> 1)) & 0x01) == 0 - if needsParity { - bytes[i] = bytes[i] | byte(0x01) - } else { - bytes[i] = bytes[i] & byte(0xfe) - } - } -} - func encryptDes(key []byte, cleartext []byte, ciphertext []byte) { var desKey [8]byte createDesKey(key, desKey[:]) diff --git a/vendor/github.com/denisenkom/go-mssqldb/rpc.go b/vendor/github.com/denisenkom/go-mssqldb/rpc.go index 4ca22578fa..f7d4c00efc 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/rpc.go +++ b/vendor/github.com/denisenkom/go-mssqldb/rpc.go @@ -22,12 +22,6 @@ type param struct { buffer []byte } -const ( - fWithRecomp = 1 - fNoMetaData = 2 - fReuseMetaData = 4 -) - var ( sp_Cursor = procId{1, ""} sp_CursorOpen = procId{2, ""} diff --git a/vendor/github.com/denisenkom/go-mssqldb/tds.go b/vendor/github.com/denisenkom/go-mssqldb/tds.go index 67139c6a4a..e1b633007e 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/tds.go +++ b/vendor/github.com/denisenkom/go-mssqldb/tds.go @@ -82,19 +82,20 @@ const ( // https://msdn.microsoft.com/en-us/library/dd304214.aspx const ( packSQLBatch packetType = 1 - packRPCRequest = 3 - packReply = 4 + packRPCRequest packetType = 3 + packReply packetType = 4 // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx - packAttention = 6 - - packBulkLoadBCP = 7 - packTransMgrReq = 14 - packNormal = 15 - packLogin7 = 16 - packSSPIMessage = 17 - packPrelogin = 18 + packAttention packetType = 6 + + packBulkLoadBCP packetType = 7 + packFedAuthToken packetType = 8 + packTransMgrReq packetType = 14 + packNormal packetType = 15 + packLogin7 packetType = 16 + packSSPIMessage packetType = 17 + packPrelogin packetType = 18 ) // prelogin fields @@ -118,6 +119,17 @@ const ( encryptReq = 3 // Encryption is required. ) +const ( + featExtSESSIONRECOVERY byte = 0x01 + featExtFEDAUTH byte = 0x02 + featExtCOLUMNENCRYPTION byte = 0x04 + featExtGLOBALTRANSACTIONS byte = 0x05 + featExtAZURESQLSUPPORT byte = 0x08 + featExtDATACLASSIFICATION byte = 0x09 + featExtUTF8SUPPORT byte = 0x0A + featExtTERMINATOR byte = 0xFF +) + type tdsSession struct { buf *tdsBuffer loginAck loginAckStruct @@ -129,6 +141,7 @@ type tdsSession struct { log optionalLogger routedServer string routedPort uint16 + returnStatus *ReturnStatus } const ( @@ -155,13 +168,13 @@ func (p keySlice) Less(i, j int) bool { return p[i] < p[j] } func (p keySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } // http://msdn.microsoft.com/en-us/library/dd357559.aspx -func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error { +func writePrelogin(packetType packetType, w *tdsBuffer, fields map[uint8][]byte) error { var err error - w.BeginPacket(packPrelogin, false) + w.BeginPacket(packetType, false) offset := uint16(5*len(fields) + 1) keys := make(keySlice, 0, len(fields)) - for k, _ := range fields { + for k := range fields { keys = append(keys, k) } sort.Sort(keys) @@ -210,12 +223,15 @@ func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) { if err != nil { return nil, err } - if packet_type != 4 { - return nil, errors.New("Invalid respones, expected packet type 4, PRELOGIN RESPONSE") + if packet_type != packReply { + return nil, errors.New("invalid respones, expected packet type 4, PRELOGIN RESPONSE") + } + if len(struct_buf) == 0 { + return nil, errors.New("invalid empty PRELOGIN response, it must contain at least one byte") } offset := 0 results := map[uint8][]byte{} - for true { + for { rec_type := struct_buf[offset] if rec_type == preloginTERMINATOR { break @@ -240,6 +256,16 @@ const ( fIntSecurity = 0x80 ) +// OptionFlags3 +// http://msdn.microsoft.com/en-us/library/dd304019.aspx +const ( + fChangePassword = 1 + fSendYukonBinaryXML = 2 + fUserInstance = 4 + fUnknownCollationHandling = 8 + fExtension = 0x10 +) + // TypeFlags const ( // 4 bits for fSQLType @@ -247,12 +273,6 @@ const ( fReadOnlyIntent = 32 ) -// OptionFlags3 -// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/773a62b6-ee89-4c02-9e5e-344882630aac -const ( - fExtension = 0x10 -) - type login struct { TDSVersion uint32 PacketSize uint32 @@ -295,7 +315,7 @@ func (e *featureExts) Add(f featureExt) error { } id := f.featureID() if _, exists := e.features[id]; exists { - f := "Login error: Feature with ID '%v' is already present in FeatureExt block." + f := "login error: Feature with ID '%v' is already present in FeatureExt block" return fmt.Errorf(f, id) } if e.features == nil { @@ -326,37 +346,63 @@ func (e featureExts) toBytes() []byte { return d } -type featureExtFedAuthSTS struct { - FedAuthEcho bool +// featureExtFedAuth tracks federated authentication state before and during login +type featureExtFedAuth struct { + // FedAuthLibrary is populated by the federated authentication provider. + FedAuthLibrary int + + // ADALWorkflow is populated by the federated authentication provider. + ADALWorkflow byte + + // FedAuthEcho is populated from the prelogin response + FedAuthEcho bool + + // FedAuthToken is populated during login with the value from the provider. FedAuthToken string - Nonce []byte + + // Nonce is populated during login with the value from the provider. + Nonce []byte + + // Signature is populated during login with the value from the server. + Signature []byte } -func (e *featureExtFedAuthSTS) featureID() byte { - return 0x02 +func (e *featureExtFedAuth) featureID() byte { + return featExtFEDAUTH } -func (e *featureExtFedAuthSTS) toBytes() []byte { +func (e *featureExtFedAuth) toBytes() []byte { if e == nil { return nil } - options := byte(0x01) << 1 // 0x01 => STS bFedAuthLibrary 7BIT + options := byte(e.FedAuthLibrary) << 1 if e.FedAuthEcho { options |= 1 // fFedAuthEcho } - d := make([]byte, 5) - d[0] = options + // Feature extension format depends on the federated auth library. + // Options are described at + // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/773a62b6-ee89-4c02-9e5e-344882630aac + var d []byte + + switch e.FedAuthLibrary { + case fedAuthLibrarySecurityToken: + d = make([]byte, 5) + d[0] = options - // looks like string in - // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/f88b63bb-b479-49e1-a87b-deda521da508 - tokenBytes := str2ucs2(e.FedAuthToken) - binary.LittleEndian.PutUint32(d[1:], uint32(len(tokenBytes))) // Should be a signed int32, but since the length is relatively small, this should work - d = append(d, tokenBytes...) + // looks like string in + // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/f88b63bb-b479-49e1-a87b-deda521da508 + tokenBytes := str2ucs2(e.FedAuthToken) + binary.LittleEndian.PutUint32(d[1:], uint32(len(tokenBytes))) // Should be a signed int32, but since the length is relatively small, this should work + d = append(d, tokenBytes...) - if len(e.Nonce) == 32 { - d = append(d, e.Nonce...) + if len(e.Nonce) == 32 { + d = append(d, e.Nonce...) + } + + case fedAuthLibraryADAL: + d = []byte{options, e.ADALWorkflow} } return d @@ -418,7 +464,7 @@ func str2ucs2(s string) []byte { func ucs22str(s []byte) (string, error) { if len(s)%2 != 0 { - return "", fmt.Errorf("Illegal UCS2 string length: %d", len(s)) + return "", fmt.Errorf("illegal UCS2 string length: %d", len(s)) } buf := make([]uint16, len(s)/2) for i := 0; i < len(s); i += 2 { @@ -436,7 +482,7 @@ func manglePassword(password string) []byte { } // http://msdn.microsoft.com/en-us/library/dd304019.aspx -func sendLogin(w *tdsBuffer, login login) error { +func sendLogin(w *tdsBuffer, login *login) error { w.BeginPacket(packLogin7, false) hostname := str2ucs2(login.HostName) username := str2ucs2(login.UserName) @@ -572,6 +618,36 @@ func sendLogin(w *tdsBuffer, login login) error { return w.FinishPacket() } +// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/827d9632-2957-4d54-b9ea-384530ae79d0 +func sendFedAuthInfo(w *tdsBuffer, fedAuth *featureExtFedAuth) (err error) { + fedauthtoken := str2ucs2(fedAuth.FedAuthToken) + tokenlen := len(fedauthtoken) + datalen := 4 + tokenlen + len(fedAuth.Nonce) + + w.BeginPacket(packFedAuthToken, false) + err = binary.Write(w, binary.LittleEndian, uint32(datalen)) + if err != nil { + return + } + + err = binary.Write(w, binary.LittleEndian, uint32(tokenlen)) + if err != nil { + return + } + + _, err = w.Write(fedauthtoken) + if err != nil { + return + } + + _, err = w.Write(fedAuth.Nonce) + if err != nil { + return + } + + return w.FinishPacket() +} + func readUcs2(r io.Reader, numchars int) (res string, err error) { buf := make([]byte, numchars*2) _, err = io.ReadFull(r, buf) @@ -770,12 +846,13 @@ type auth interface { // use the first one that allows a connection. func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn net.Conn, err error) { var ips []net.IP - ips, err = net.LookupIP(p.host) - if err != nil { - ip := net.ParseIP(p.host) - if ip == nil { - return nil, err + ip := net.ParseIP(p.host) + if ip == nil { + ips, err = net.LookupIP(p.host) + if err != nil { + return } + } else { ips = []net.IP{ip} } if len(ips) == 1 { @@ -802,7 +879,7 @@ func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn ne } // Wait for either the *first* successful connection, or all the errors wait_loop: - for i, _ := range ips { + for i := range ips { select { case conn = <-connChan: // Got a connection to use, close any others @@ -824,12 +901,123 @@ func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn ne } // Can't do the usual err != nil check, as it is possible to have gotten an error before a successful connection if conn == nil { - f := "Unable to open tcp connection with host '%v:%v': %v" + f := "unable to open tcp connection with host '%v:%v': %v" return nil, fmt.Errorf(f, p.host, resolveServerPort(p.port), err.Error()) } return conn, err } +func preparePreloginFields(p connectParams, fe *featureExtFedAuth) map[uint8][]byte { + instance_buf := []byte(p.instance) + instance_buf = append(instance_buf, 0) // zero terminate instance name + + var encrypt byte + if p.disableEncryption { + encrypt = encryptNotSup + } else if p.encrypt { + encrypt = encryptOn + } else { + encrypt = encryptOff + } + + fields := map[uint8][]byte{ + preloginVERSION: {0, 0, 0, 0, 0, 0}, + preloginENCRYPTION: {encrypt}, + preloginINSTOPT: instance_buf, + preloginTHREADID: {0, 0, 0, 0}, + preloginMARS: {0}, // MARS disabled + } + + if fe.FedAuthLibrary != fedAuthLibraryReserved { + fields[preloginFEDAUTHREQUIRED] = []byte{1} + } + + return fields +} + +func interpretPreloginResponse(p connectParams, fe *featureExtFedAuth, fields map[uint8][]byte) (encrypt byte, err error) { + // If the server returns the preloginFEDAUTHREQUIRED field, then federated authentication + // is supported. The actual value may be 0 or 1, where 0 means either SSPI or federated + // authentication is allowed, while 1 means only federated authentication is allowed. + if fedAuthSupport, ok := fields[preloginFEDAUTHREQUIRED]; ok { + if len(fedAuthSupport) != 1 { + return 0, fmt.Errorf("Federated authentication flag length should be 1: is %d", len(fedAuthSupport)) + } + + // We need to be able to echo the value back to the server + fe.FedAuthEcho = fedAuthSupport[0] != 0 + } else if fe.FedAuthLibrary != fedAuthLibraryReserved { + return 0, fmt.Errorf("Federated authentication is not supported by the server") + } + + encryptBytes, ok := fields[preloginENCRYPTION] + if !ok { + return 0, fmt.Errorf("encrypt negotiation failed") + } + encrypt = encryptBytes[0] + if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) { + return 0, fmt.Errorf("server does not support encryption") + } + + return +} + +func prepareLogin(ctx context.Context, c *Connector, p connectParams, log optionalLogger, auth auth, fe *featureExtFedAuth, packetSize uint32) (l *login, err error) { + l = &login{ + TDSVersion: verTDS74, + PacketSize: packetSize, + Database: p.database, + OptionFlags2: fODBC, // to get unlimited TEXTSIZE + HostName: p.workstation, + ServerName: p.host, + AppName: p.appname, + TypeFlags: p.typeFlags, + } + switch { + case fe.FedAuthLibrary == fedAuthLibrarySecurityToken: + if p.logFlags&logDebug != 0 { + log.Println("Starting federated authentication using security token") + } + + fe.FedAuthToken, err = c.securityTokenProvider(ctx) + if err != nil { + if p.logFlags&logDebug != 0 { + log.Printf("Failed to retrieve service principal token for federated authentication security token library: %v", err) + } + return nil, err + } + + l.FeatureExt.Add(fe) + + case fe.FedAuthLibrary == fedAuthLibraryADAL: + if p.logFlags&logDebug != 0 { + log.Println("Starting federated authentication using ADAL") + } + + l.FeatureExt.Add(fe) + + case auth != nil: + if p.logFlags&logDebug != 0 { + log.Println("Starting SSPI login") + } + + l.SSPI, err = auth.InitialBytes() + if err != nil { + return nil, err + } + + l.OptionFlags2 |= fIntSecurity + return l, nil + + default: + // Default to SQL server authentication with user and password + l.UserName = p.user + l.Password = p.password + } + + return l, nil +} + func connect(ctx context.Context, c *Connector, log optionalLogger, p connectParams) (res *tdsSession, err error) { dialCtx := ctx if p.dial_timeout > 0 { @@ -842,24 +1030,24 @@ func connect(ctx context.Context, c *Connector, log optionalLogger, p connectPar // both instance name and port specified // when port is specified instance name is not used // you should not provide instance name when you provide port - log.Println("WARN: You specified both instance name and port in the connection string, port will be used and instance name will be ignored"); + log.Println("WARN: You specified both instance name and port in the connection string, port will be used and instance name will be ignored") } if p.instance != "" && p.port == 0 { p.instance = strings.ToUpper(p.instance) d := c.getDialer(&p) instances, err := getInstances(dialCtx, d, p.host) if err != nil { - f := "Unable to get instances from Sql Server Browser on host %v: %v" + f := "unable to get instances from Sql Server Browser on host %v: %v" return nil, fmt.Errorf(f, p.host, err.Error()) } strport, ok := instances[p.instance]["tcp"] if !ok { - f := "No instance matching '%v' returned from host '%v'" + f := "no instance matching '%v' returned from host '%v'" return nil, fmt.Errorf(f, p.instance, p.host) } port, err := strconv.ParseUint(strport, 0, 16) if err != nil { - f := "Invalid tcp port returned from Sql Server Browser '%v': %v" + f := "invalid tcp port returned from Sql Server Browser '%v': %v" return nil, fmt.Errorf(f, strport, err.Error()) } p.port = port @@ -880,25 +1068,14 @@ initiate_connection: logFlags: p.logFlags, } - instance_buf := []byte(p.instance) - instance_buf = append(instance_buf, 0) // zero terminate instance name - var encrypt byte - if p.disableEncryption { - encrypt = encryptNotSup - } else if p.encrypt { - encrypt = encryptOn - } else { - encrypt = encryptOff - } - fields := map[uint8][]byte{ - preloginVERSION: {0, 0, 0, 0, 0, 0}, - preloginENCRYPTION: {encrypt}, - preloginINSTOPT: instance_buf, - preloginTHREADID: {0, 0, 0, 0}, - preloginMARS: {0}, // MARS disabled + fedAuth := &featureExtFedAuth{ + FedAuthLibrary: p.fedAuthLibrary, + ADALWorkflow: p.fedAuthADALWorkflow, } - err = writePrelogin(outbuf, fields) + fields := preparePreloginFields(p, fedAuth) + + err = writePrelogin(packPrelogin, outbuf, fields) if err != nil { return nil, err } @@ -908,13 +1085,9 @@ initiate_connection: return nil, err } - encryptBytes, ok := fields[preloginENCRYPTION] - if !ok { - return nil, fmt.Errorf("Encrypt negotiation failed") - } - encrypt = encryptBytes[0] - if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) { - return nil, fmt.Errorf("Server does not support encryption") + encrypt, err := interpretPreloginResponse(p, fedAuth, fields) + if err != nil { + return nil, err } if encrypt != encryptNotSup { @@ -922,7 +1095,7 @@ initiate_connection: if p.certificate != "" { pem, err := ioutil.ReadFile(p.certificate) if err != nil { - return nil, fmt.Errorf("Cannot read certificate %q: %v", p.certificate, err) + return nil, fmt.Errorf("cannot read certificate %q: %v", p.certificate, err) } certs := x509.NewCertPool() certs.AppendCertsFromPEM(pem) @@ -954,54 +1127,46 @@ initiate_connection: } } - login := login{ - TDSVersion: verTDS74, - PacketSize: uint32(outbuf.PackageSize()), - Database: p.database, - OptionFlags2: fODBC, // to get unlimited TEXTSIZE - HostName: p.workstation, - ServerName: p.host, - AppName: p.appname, - TypeFlags: p.typeFlags, - } auth, authOk := getAuth(p.user, p.password, p.serverSPN, p.workstation) - switch { - case p.fedAuthAccessToken != "": // accesstoken ignores user/password - featurext := &featureExtFedAuthSTS{ - FedAuthEcho: len(fields[preloginFEDAUTHREQUIRED]) > 0 && fields[preloginFEDAUTHREQUIRED][0] == 1, - FedAuthToken: p.fedAuthAccessToken, - Nonce: fields[preloginNONCEOPT], - } - login.FeatureExt.Add(featurext) - case authOk: - login.SSPI, err = auth.InitialBytes() - if err != nil { - return nil, err - } - login.OptionFlags2 |= fIntSecurity + if authOk { defer auth.Free() - default: - login.UserName = p.user - login.Password = p.password + } else { + auth = nil } + + login, err := prepareLogin(ctx, c, p, log, auth, fedAuth, uint32(outbuf.PackageSize())) + if err != nil { + return nil, err + } + err = sendLogin(outbuf, login) if err != nil { return nil, err } - // processing login response - success := false - for { - tokchan := make(chan tokenStruct, 5) - go processResponse(context.Background(), &sess, tokchan, nil) - for tok := range tokchan { + // Loop until a packet containing a login acknowledgement is received. + // SSPI and federated authentication scenarios may require multiple + // packet exchanges to complete the login sequence. + for loginAck := false; !loginAck; { + reader := startReading(&sess, ctx, nil) + + for { + tok, err := reader.nextToken() + if err != nil { + return nil, err + } + + if tok == nil { + break + } + switch token := tok.(type) { case sspiMsg: sspi_msg, err := auth.NextBytes(token) if err != nil { return nil, err } - if sspi_msg != nil && len(sspi_msg) > 0 { + if len(sspi_msg) > 0 { outbuf.BeginPacket(packSSPIMessage, false) _, err = outbuf.Write(sspi_msg) if err != nil { @@ -1013,23 +1178,41 @@ initiate_connection: } sspi_msg = nil } + // TODO: for Live ID authentication it may be necessary to + // compare fedAuth.Nonce == token.Nonce and keep track of signature + //case fedAuthAckStruct: + //fedAuth.Signature = token.Signature + case fedAuthInfoStruct: + // For ADAL workflows this contains the STS URL and server SPN. + // If received outside of an ADAL workflow, ignore. + if c == nil || c.adalTokenProvider == nil { + continue + } + + // Request the AD token given the server SPN and STS URL + fedAuth.FedAuthToken, err = c.adalTokenProvider(ctx, token.ServerSPN, token.STSURL) + if err != nil { + return nil, err + } + + // Now need to send the token as a FEDINFO packet + err = sendFedAuthInfo(outbuf, fedAuth) + if err != nil { + return nil, err + } case loginAckStruct: - success = true sess.loginAck = token - case error: - return nil, fmt.Errorf("Login error: %s", token.Error()) + loginAck = true case doneStruct: if token.isError() { - return nil, fmt.Errorf("Login error: %s", token.getError()) + return nil, fmt.Errorf("login error: %s", token.getError()) } - goto loginEnd + case error: + return nil, fmt.Errorf("login error: %s", token.Error()) } } } -loginEnd: - if !success { - return nil, fmt.Errorf("Login failed") - } + if sess.routedServer != "" { toconn.Close() p.host = sess.routedServer @@ -1041,3 +1224,9 @@ loginEnd: } return &sess, nil } + +func (sess *tdsSession) setReturnStatus(status ReturnStatus) { + if sess.returnStatus != nil { + *sess.returnStatus = status + } +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/token.go b/vendor/github.com/denisenkom/go-mssqldb/token.go index 6aa99aa974..c9d452562b 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/token.go +++ b/vendor/github.com/denisenkom/go-mssqldb/token.go @@ -6,12 +6,11 @@ import ( "errors" "fmt" "io" - "net" + "io/ioutil" "strconv" - "strings" ) -//go:generate stringer -type token +//go:generate go run golang.org/x/tools/cmd/stringer -type token type token byte @@ -29,6 +28,7 @@ const ( tokenNbcRow token = 210 // 0xd2 tokenEnvChange token = 227 // 0xE3 tokenSSPI token = 237 // 0xED + tokenFedAuthInfo token = 238 // 0xEE tokenDone token = 253 // 0xFD tokenDoneProc token = 254 tokenDoneInProc token = 255 @@ -70,6 +70,11 @@ const ( envRouting = 20 ) +const ( + fedAuthInfoSTSURL = 0x01 + fedAuthInfoSPN = 0x02 +) + // COLMETADATA flags // https://msdn.microsoft.com/en-us/library/dd357363.aspx const ( @@ -105,26 +110,6 @@ func (d doneStruct) getError() Error { type doneInProcStruct doneStruct -var doneFlags2str = map[uint16]string{ - doneFinal: "final", - doneMore: "more", - doneError: "error", - doneInxact: "inxact", - doneCount: "count", - doneAttn: "attn", - doneSrvError: "srverror", -} - -func doneFlags2Str(flags uint16) string { - strs := make([]string, 0, len(doneFlags2str)) - for flag, tag := range doneFlags2str { - if flags&flag != 0 { - strs = append(strs, tag) - } - } - return strings.Join(strs, "|") -} - // ENVCHANGE stream // http://msdn.microsoft.com/en-us/library/dd303449.aspx func processEnvChg(sess *tdsSession) { @@ -380,9 +365,8 @@ func processEnvChg(sess *tdsSession) { default: // ignore rest of records because we don't know how to skip those sess.log.Printf("WARN: Unknown ENVCHANGE record detected with type id = %d\n", envtype) - break + return } - } } @@ -425,6 +409,78 @@ func parseSSPIMsg(r *tdsBuffer) sspiMsg { return sspiMsg(buf) } +type fedAuthInfoStruct struct { + STSURL string + ServerSPN string +} + +type fedAuthInfoOpt struct { + fedAuthInfoID byte + dataLength, dataOffset uint32 +} + +func parseFedAuthInfo(r *tdsBuffer) fedAuthInfoStruct { + size := r.uint32() + + var STSURL, SPN string + var err error + + // Each fedAuthInfoOpt is one byte to indicate the info ID, + // then a four byte offset and a four byte length. + count := r.uint32() + offset := uint32(4) + opts := make([]fedAuthInfoOpt, count) + + for i := uint32(0); i < count; i++ { + fedAuthInfoID := r.byte() + dataLength := r.uint32() + dataOffset := r.uint32() + offset += 1 + 4 + 4 + + opts[i] = fedAuthInfoOpt{ + fedAuthInfoID: fedAuthInfoID, + dataLength: dataLength, + dataOffset: dataOffset, + } + } + + data := make([]byte, size-offset) + r.ReadFull(data) + + for i := uint32(0); i < count; i++ { + if opts[i].dataOffset < offset { + badStreamPanicf("Fed auth info opt stated data offset %d is before data begins in packet at %d", + opts[i].dataOffset, offset) + // returns via panic + } + + if opts[i].dataOffset+opts[i].dataLength > size { + badStreamPanicf("Fed auth info opt stated data length %d added to stated offset exceeds size of packet %d", + opts[i].dataOffset+opts[i].dataLength, size) + // returns via panic + } + + optData := data[opts[i].dataOffset-offset : opts[i].dataOffset-offset+opts[i].dataLength] + switch opts[i].fedAuthInfoID { + case fedAuthInfoSTSURL: + STSURL, err = ucs22str(optData) + case fedAuthInfoSPN: + SPN, err = ucs22str(optData) + default: + err = fmt.Errorf("Unexpected fed auth info opt ID %d", int(opts[i].fedAuthInfoID)) + } + + if err != nil { + badStreamPanic(err) + } + } + + return fedAuthInfoStruct{ + STSURL: STSURL, + ServerSPN: SPN, + } +} + type loginAckStruct struct { Interface uint8 TDSVersion uint32 @@ -449,19 +505,43 @@ func parseLoginAck(r *tdsBuffer) loginAckStruct { } // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/2eb82f8e-11f0-46dc-b42d-27302fa4701a -func parseFeatureExtAck(r *tdsBuffer) { - // at most 1 featureAck per feature in featureExt - // go-mssqldb will add at most 1 feature, the spec defines 7 different features - for i := 0; i < 8; i++ { - featureID := r.byte() // FeatureID - if featureID == 0xff { - return +type fedAuthAckStruct struct { + Nonce []byte + Signature []byte +} + +func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} { + ack := map[byte]interface{}{} + + for feature := r.byte(); feature != featExtTERMINATOR; feature = r.byte() { + length := r.uint32() + + switch feature { + case featExtFEDAUTH: + // In theory we need to know the federated authentication library to + // know how to parse, but the alternatives provide compatible structures. + fedAuthAck := fedAuthAckStruct{} + if length >= 32 { + fedAuthAck.Nonce = make([]byte, 32) + r.ReadFull(fedAuthAck.Nonce) + length -= 32 + } + if length >= 32 { + fedAuthAck.Signature = make([]byte, 32) + r.ReadFull(fedAuthAck.Signature) + length -= 32 + } + ack[feature] = fedAuthAck + + } + + // Skip unprocessed bytes + if length > 0 { + io.CopyN(ioutil.Discard, r, int64(length)) } - size := r.uint32() // FeatureAckDataLen - d := make([]byte, size) - r.ReadFull(d) } - panic("parsed more than 7 featureAck's, protocol implementation error?") + + return ack } // http://msdn.microsoft.com/en-us/library/dd357363.aspx @@ -579,7 +659,7 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin } var columns []columnStruct errs := make([]Error, 0, 5) - for { + for tokens := 0; ; tokens += 1 { token := token(sess.buf.byte()) if sess.logFlags&logDebug != 0 { sess.log.Printf("got token %v", token) @@ -588,6 +668,9 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin case tokenSSPI: ch <- parseSSPIMsg(sess.buf) return + case tokenFedAuthInfo: + ch <- parseFedAuthInfo(sess.buf) + return case tokenReturnStatus: returnStatus := parseReturnStatus(sess.buf) ch <- returnStatus @@ -595,7 +678,8 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin loginAck := parseLoginAck(sess.buf) ch <- loginAck case tokenFeatureExtAck: - parseFeatureExtAck(sess.buf) + featureExtAck := parseFeatureExtAck(sess.buf) + ch <- featureExtAck case tokenOrder: order := parseOrder(sess.buf) ch <- order @@ -670,158 +754,137 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin } } -type parseRespIter byte - -const ( - parseRespIterContinue parseRespIter = iota // Continue parsing current token. - parseRespIterNext // Fetch the next token. - parseRespIterDone // Done with parsing the response. -) - -type parseRespState byte - -const ( - parseRespStateNormal parseRespState = iota // Normal response state. - parseRespStateCancel // Query is canceled, wait for server to confirm. - parseRespStateClosing // Waiting for tokens to come through. -) - -type parseResp struct { - sess *tdsSession - ctxDone <-chan struct{} - state parseRespState - cancelError error +type tokenProcessor struct { + tokChan chan tokenStruct + ctx context.Context + sess *tdsSession + outs map[string]interface{} + lastRow []interface{} + rowCount int64 + firstError error } -func (ts *parseResp) sendAttention(ch chan tokenStruct) parseRespIter { - if err := sendAttention(ts.sess.buf); err != nil { - ts.dlogf("failed to send attention signal %v", err) - ch <- err - return parseRespIterDone +func startReading(sess *tdsSession, ctx context.Context, outs map[string]interface{}) *tokenProcessor { + tokChan := make(chan tokenStruct, 5) + go processSingleResponse(sess, tokChan, outs) + return &tokenProcessor{ + tokChan: tokChan, + ctx: ctx, + sess: sess, + outs: outs, } - ts.state = parseRespStateCancel - return parseRespIterContinue } -func (ts *parseResp) dlog(msg string) { - // logging from goroutine is disabled to prevent - // data race detection from firing - // The race is probably happening when - // test logger changes between tests. - /*if ts.sess.logFlags&logDebug != 0 { - ts.sess.log.Println(msg) - }*/ -} -func (ts *parseResp) dlogf(f string, v ...interface{}) { - /*if ts.sess.logFlags&logDebug != 0 { - ts.sess.log.Printf(f, v...) - }*/ +func (t *tokenProcessor) iterateResponse() error { + for { + tok, err := t.nextToken() + if err == nil { + if tok == nil { + return t.firstError + } else { + switch token := tok.(type) { + case []columnStruct: + t.sess.columns = token + case []interface{}: + t.lastRow = token + case doneInProcStruct: + if token.Status&doneCount != 0 { + t.rowCount += int64(token.RowCount) + } + case doneStruct: + if token.Status&doneCount != 0 { + t.rowCount += int64(token.RowCount) + } + if token.isError() && t.firstError == nil { + t.firstError = token.getError() + } + case ReturnStatus: + t.sess.setReturnStatus(token) + /*case error: + if resultError == nil { + resultError = token + }*/ + } + } + } else { + return err + } + } } -func (ts *parseResp) iter(ctx context.Context, ch chan tokenStruct, tokChan chan tokenStruct) parseRespIter { - switch ts.state { +func (t tokenProcessor) nextToken() (tokenStruct, error) { + // we do this separate non-blocking check on token channel to + // prioritize it over cancellation channel + select { + case tok, more := <-t.tokChan: + err, more := tok.(error) + if more { + // this is an error and not a token + return nil, err + } else { + return tok, nil + } default: - panic("unknown state") - case parseRespStateNormal: - select { - case tok, ok := <-tokChan: - if !ok { - ts.dlog("response finished") - return parseRespIterDone - } - if err, ok := tok.(net.Error); ok && err.Timeout() { - ts.cancelError = err - ts.dlog("got timeout error, sending attention signal to server") - return ts.sendAttention(ch) - } - // Pass the token along. - ch <- tok - return parseRespIterContinue - - case <-ts.ctxDone: - ts.ctxDone = nil - ts.dlog("got cancel message, sending attention signal to server") - return ts.sendAttention(ch) + // there are no tokens on the channel, will need to wait + } + + select { + case tok, more := <-t.tokChan: + if more { + err, ok := tok.(error) + if ok { + // this is an error and not a token + return nil, err + } else { + return tok, nil + } + } else { + // completed reading response + return nil, nil + } + case <-t.ctx.Done(): + if err := sendAttention(t.sess.buf); err != nil { + // unable to send attention, current connection is bad + // notify caller and close channel + return nil, err } - case parseRespStateCancel: // Read all responses until a DONE or error is received.Auth - select { - case tok, ok := <-tokChan: - if !ok { - ts.dlog("response finished but waiting for attention ack") - return parseRespIterNext - } - switch tok := tok.(type) { - default: - // Ignore all other tokens while waiting. - // The TDS spec says other tokens may arrive after an attention - // signal is sent. Ignore these tokens and continue looking for - // a DONE with attention confirm mark. - case doneStruct: - if tok.Status&doneAttn != 0 { - ts.dlog("got cancellation confirmation from server") - if ts.cancelError != nil { - ch <- ts.cancelError - ts.cancelError = nil - } else { - ch <- ctx.Err() - } - return parseRespIterDone - } - // If an error happens during cancel, pass it along and just stop. - // We are uncertain to receive more tokens. - case error: - ch <- tok - ts.state = parseRespStateClosing - } - return parseRespIterContinue - case <-ts.ctxDone: - ts.ctxDone = nil - ts.state = parseRespStateClosing - return parseRespIterContinue + // now the server should send cancellation confirmation + // it is possible that we already received full response + // just before we sent cancellation request + // in this case current response would not contain confirmation + // and we would need to read one more response + + // first lets finish reading current response and look + // for confirmation in it + if readCancelConfirmation(t.tokChan) { + // we got confirmation in current response + return nil, t.ctx.Err() } - case parseRespStateClosing: // Wait for current token chan to close. - if _, ok := <-tokChan; !ok { - ts.dlog("response finished") - return parseRespIterDone + // we did not get cancellation confirmation in the current response + // read one more response, it must be there + t.tokChan = make(chan tokenStruct, 5) + go processSingleResponse(t.sess, t.tokChan, t.outs) + if readCancelConfirmation(t.tokChan) { + return nil, t.ctx.Err() } - return parseRespIterContinue + // we did not get cancellation confirmation, something is not + // right, this connection is not usable anymore + return nil, errors.New("did not get cancellation confirmation from the server") } } -func processResponse(ctx context.Context, sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) { - ts := &parseResp{ - sess: sess, - ctxDone: ctx.Done(), - } - defer func() { - // Ensure any remaining error is piped through - // or the query may look like it executed when it actually failed. - if ts.cancelError != nil { - ch <- ts.cancelError - ts.cancelError = nil - } - close(ch) - }() - - // Loop over multiple responses. - for { - ts.dlog("initiating response reading") - - tokChan := make(chan tokenStruct) - go processSingleResponse(sess, tokChan, outs) - - // Loop over multiple tokens in response. - tokensLoop: - for { - switch ts.iter(ctx, ch, tokChan) { - case parseRespIterContinue: - // Nothing, continue to next token. - case parseRespIterNext: - break tokensLoop - case parseRespIterDone: - return +func readCancelConfirmation(tokChan chan tokenStruct) bool { + for tok := range tokChan { + switch tok := tok.(type) { + default: + // just skip token + case doneStruct: + if tok.Status&doneAttn != 0 { + // got cancellation confirmation, exit + return true } } } + return false } diff --git a/vendor/github.com/denisenkom/go-mssqldb/token_string.go b/vendor/github.com/denisenkom/go-mssqldb/token_string.go index c075b23be0..a473182cf8 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/token_string.go +++ b/vendor/github.com/denisenkom/go-mssqldb/token_string.go @@ -1,29 +1,24 @@ -// Code generated by "stringer -type token"; DO NOT EDIT +// Code generated by "stringer -type token"; DO NOT EDIT. package mssql -import "fmt" +import "strconv" const ( _token_name_0 = "tokenReturnStatus" _token_name_1 = "tokenColMetadata" - _token_name_2 = "tokenOrdertokenErrortokenInfo" - _token_name_3 = "tokenLoginAck" - _token_name_4 = "tokenRowtokenNbcRow" - _token_name_5 = "tokenEnvChange" - _token_name_6 = "tokenSSPI" - _token_name_7 = "tokenDonetokenDoneProctokenDoneInProc" + _token_name_2 = "tokenOrdertokenErrortokenInfotokenReturnValuetokenLoginAcktokenFeatureExtAck" + _token_name_3 = "tokenRowtokenNbcRow" + _token_name_4 = "tokenEnvChange" + _token_name_5 = "tokenSSPItokenFedAuthInfo" + _token_name_6 = "tokenDonetokenDoneProctokenDoneInProc" ) var ( - _token_index_0 = [...]uint8{0, 17} - _token_index_1 = [...]uint8{0, 16} - _token_index_2 = [...]uint8{0, 10, 20, 29} - _token_index_3 = [...]uint8{0, 13} - _token_index_4 = [...]uint8{0, 8, 19} - _token_index_5 = [...]uint8{0, 14} - _token_index_6 = [...]uint8{0, 9} - _token_index_7 = [...]uint8{0, 9, 22, 37} + _token_index_2 = [...]uint8{0, 10, 20, 29, 45, 58, 76} + _token_index_3 = [...]uint8{0, 8, 19} + _token_index_5 = [...]uint8{0, 9, 25} + _token_index_6 = [...]uint8{0, 9, 22, 37} ) func (i token) String() string { @@ -32,22 +27,21 @@ func (i token) String() string { return _token_name_0 case i == 129: return _token_name_1 - case 169 <= i && i <= 171: + case 169 <= i && i <= 174: i -= 169 return _token_name_2[_token_index_2[i]:_token_index_2[i+1]] - case i == 173: - return _token_name_3 case 209 <= i && i <= 210: i -= 209 - return _token_name_4[_token_index_4[i]:_token_index_4[i+1]] + return _token_name_3[_token_index_3[i]:_token_index_3[i+1]] case i == 227: - return _token_name_5 - case i == 237: - return _token_name_6 + return _token_name_4 + case 237 <= i && i <= 238: + i -= 237 + return _token_name_5[_token_index_5[i]:_token_index_5[i+1]] case 253 <= i && i <= 255: i -= 253 - return _token_name_7[_token_index_7[i]:_token_index_7[i+1]] + return _token_name_6[_token_index_6[i]:_token_index_6[i+1]] default: - return fmt.Sprintf("token(%d)", i) + return "token(" + strconv.FormatInt(int64(i), 10) + ")" } } diff --git a/vendor/github.com/denisenkom/go-mssqldb/tran.go b/vendor/github.com/denisenkom/go-mssqldb/tran.go index cb6436816f..9b21972423 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/tran.go +++ b/vendor/github.com/denisenkom/go-mssqldb/tran.go @@ -21,11 +21,11 @@ type isoLevel uint8 const ( isolationUseCurrent isoLevel = 0 - isolationReadUncommited = 1 - isolationReadCommited = 2 - isolationRepeatableRead = 3 - isolationSerializable = 4 - isolationSnapshot = 5 + isolationReadUncommited isoLevel = 1 + isolationReadCommited isoLevel = 2 + isolationRepeatableRead isoLevel = 3 + isolationSerializable isoLevel = 4 + isolationSnapshot isoLevel = 5 ) func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation isoLevel, name string, resetSession bool) (err error) { diff --git a/vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go b/vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go index 64e5e21fbd..d3890af954 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go +++ b/vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go @@ -4,6 +4,7 @@ package mssql import ( "bytes" + "database/sql" "encoding/binary" "errors" "fmt" @@ -97,6 +98,9 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd for columnStrIdx, fieldIdx := range tvpFieldIndexes { field := refStr.Field(fieldIdx) tvpVal := field.Interface() + if tvp.verifyStandardTypeOnNull(buf, tvpVal) { + continue + } valOf := reflect.ValueOf(tvpVal) elemKind := field.Kind() if elemKind == reflect.Ptr && valOf.IsNil() { @@ -155,7 +159,7 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) { defaultValues = append(defaultValues, v.Interface()) continue } - defaultValues = append(defaultValues, reflect.Zero(field.Type).Interface()) + defaultValues = append(defaultValues, tvp.createZeroType(reflect.Zero(field.Type).Interface())) } if columnCount-len(tvpFieldIndexes) == columnCount { @@ -209,19 +213,23 @@ func getSchemeAndName(tvpName string) (string, string, error) { } splitVal := strings.Split(tvpName, ".") if len(splitVal) > 2 { - return "", "", errors.New("wrong tvp name") + return "", "", ErrorObjectName } + const ( + openSquareBrackets = "[" + closeSquareBrackets = "]" + ) if len(splitVal) == 2 { res := make([]string, 2) for key, value := range splitVal { - tmp := strings.Replace(value, "[", "", -1) - tmp = strings.Replace(tmp, "]", "", -1) + tmp := strings.Replace(value, openSquareBrackets, "", -1) + tmp = strings.Replace(tmp, closeSquareBrackets, "", -1) res[key] = tmp } return res[0], res[1], nil } - tmp := strings.Replace(splitVal[0], "[", "", -1) - tmp = strings.Replace(tmp, "]", "", -1) + tmp := strings.Replace(splitVal[0], openSquareBrackets, "", -1) + tmp = strings.Replace(tmp, closeSquareBrackets, "", -1) return "", tmp, nil } @@ -229,3 +237,56 @@ func getSchemeAndName(tvpName string) (string, string, error) { func getCountSQLSeparators(str string) int { return strings.Count(str, sqlSeparator) } + +// verify types https://golang.org/pkg/database/sql/ +func (tvp TVP) createZeroType(fieldVal interface{}) interface{} { + const ( + defaultBool = false + defaultFloat64 = float64(0) + defaultInt64 = int64(0) + defaultString = "" + ) + + switch fieldVal.(type) { + case sql.NullBool: + return defaultBool + case sql.NullFloat64: + return defaultFloat64 + case sql.NullInt64: + return defaultInt64 + case sql.NullString: + return defaultString + } + return fieldVal +} + +// verify types https://golang.org/pkg/database/sql/ +func (tvp TVP) verifyStandardTypeOnNull(buf *bytes.Buffer, tvpVal interface{}) bool { + const ( + defaultNull = uint8(0) + ) + + switch val := tvpVal.(type) { + case sql.NullBool: + if !val.Valid { + binary.Write(buf, binary.LittleEndian, defaultNull) + return true + } + case sql.NullFloat64: + if !val.Valid { + binary.Write(buf, binary.LittleEndian, defaultNull) + return true + } + case sql.NullInt64: + if !val.Valid { + binary.Write(buf, binary.LittleEndian, defaultNull) + return true + } + case sql.NullString: + if !val.Valid { + binary.Write(buf, binary.LittleEndian, uint64(_PLP_NULL)) + return true + } + } + return false +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/types.go b/vendor/github.com/denisenkom/go-mssqldb/types.go index b6e7fb2b52..cae199244d 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/types.go +++ b/vendor/github.com/denisenkom/go-mssqldb/types.go @@ -665,7 +665,7 @@ func readPLPType(ti *typeInfo, r *tdsBuffer) interface{} { default: buf = bytes.NewBuffer(make([]byte, 0, size)) } - for true { + for { chunksize := r.uint32() if chunksize == 0 { break @@ -690,6 +690,10 @@ func readPLPType(ti *typeInfo, r *tdsBuffer) interface{} { } func writePLPType(w io.Writer, ti typeInfo, buf []byte) (err error) { + if buf == nil { + err = binary.Write(w, binary.LittleEndian, uint64(_PLP_NULL)) + return + } if err = binary.Write(w, binary.LittleEndian, uint64(_UNKNOWN_PLP_LEN)); err != nil { return } @@ -807,7 +811,6 @@ func readVarLen(ti *typeInfo, r *tdsBuffer) { default: badStreamPanicf("Invalid type %d", ti.TypeId) } - return } func decodeMoney(buf []byte) []byte { @@ -834,8 +837,7 @@ func decodeGuid(buf []byte) []byte { } func decodeDecimal(prec uint8, scale uint8, buf []byte) []byte { - var sign uint8 - sign = buf[0] + sign := buf[0] var dec decimal.Decimal dec.SetPositive(sign != 0) dec.SetPrec(prec) @@ -1187,7 +1189,7 @@ func makeDecl(ti typeInfo) string { return fmt.Sprintf("char(%d)", ti.Size) case typeBigVarChar, typeVarChar: if ti.Size > 8000 || ti.Size == 0 { - return fmt.Sprintf("varchar(max)") + return "varchar(max)" } else { return fmt.Sprintf("varchar(%d)", ti.Size) } diff --git a/vendor/modules.txt b/vendor/modules.txt index a6c3192e75..f0e4bcec2c 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -155,7 +155,7 @@ github.com/cloudflare/golz4 github.com/containerd/continuity/pathdriver # github.com/davecgh/go-spew v1.1.1 github.com/davecgh/go-spew/spew -# github.com/denisenkom/go-mssqldb v0.9.0 +# github.com/denisenkom/go-mssqldb v0.10.0 ## explicit github.com/denisenkom/go-mssqldb github.com/denisenkom/go-mssqldb/internal/cp