From 302242849ba09dbb4f4b6d95155421dffafb6105 Mon Sep 17 00:00:00 2001 From: Joel Lubinitsky <33523178+joellubi@users.noreply.github.com> Date: Tue, 19 Mar 2024 11:43:42 -0400 Subject: [PATCH] refactor(go/adbc/driver): driverbase implementation for connection (#1590) Implementation of Connection driver base, along with a refactor of Driver and Database bases. The bases have been refactored in the following way: - The `*Impl` interface (e.g. `DatabaseImpl`) now explicitly implements the corresponding `adbc` interface (e.g. `adbc.Database`). - We now check to guarantee the `DatabaseImplBase` implements the entire `DatabaseImpl` interface with stub methods or default implementations. - A new interface has been added (e.g. `driverbase.Database`) which contains all methods the _output_ of driverbase constructor `NewDatabase()` should be. This helps document and guarantee the "extra" behavior provided by using the driverbase. This interface should be internal to the library. - By embedding `DatabaseImpl` in the `database` struct (and similarly for the other bases) it automatically inherits implementations coming from the `DatabaseImpl`. This way we don't need to write out all the implementations a second time, hence the deletes. - The Connection base uses a builder for its constructor to register any helper methods (see discussion in comments). The Driver and Database bases use simple function constructors because they don't have any helpers to register. This felt simpler but I can make those into trivial builders as well if we prefer to have consistency between them. A new `DriverInfo` type has been introduced to help consolidate the collection and validation of metadata for `GetInfo()`. There are more small changes such as refactors of the flightsql and snowflake drivers to make use of the added functionality, as well as a new set of tests for the driverbase. Please let me know if anything else could use clarification. Resolves #1105. --- go/adbc/adbc.go | 11 + go/adbc/driver/driverbase/driver.go | 66 -- .../driver/flightsql/flightsql_connection.go | 578 ++++++----------- .../driver/flightsql/flightsql_database.go | 25 +- go/adbc/driver/flightsql/flightsql_driver.go | 45 +- .../driver/flightsql/flightsql_statement.go | 12 +- .../driver/internal/driverbase/connection.go | 497 +++++++++++++++ .../{ => internal}/driverbase/database.go | 111 ++-- go/adbc/driver/internal/driverbase/driver.go | 116 ++++ .../driver/internal/driverbase/driver_info.go | 176 ++++++ .../internal/driverbase/driver_info_test.go | 88 +++ .../driver/internal/driverbase/driver_test.go | 595 ++++++++++++++++++ .../driver/{ => internal}/driverbase/error.go | 0 .../{ => internal}/driverbase/logging.go | 0 go/adbc/driver/snowflake/connection.go | 293 +++------ go/adbc/driver/snowflake/driver.go | 45 +- go/adbc/driver/snowflake/driver_test.go | 4 + .../driver/snowflake/snowflake_database.go | 41 +- go/adbc/driver/snowflake/statement.go | 2 +- go/adbc/go.mod | 1 + go/adbc/go.sum | 1 + 21 files changed, 1861 insertions(+), 846 deletions(-) delete mode 100644 go/adbc/driver/driverbase/driver.go create mode 100644 go/adbc/driver/internal/driverbase/connection.go rename go/adbc/driver/{ => internal}/driverbase/database.go (52%) create mode 100644 go/adbc/driver/internal/driverbase/driver.go create mode 100644 go/adbc/driver/internal/driverbase/driver_info.go create mode 100644 go/adbc/driver/internal/driverbase/driver_info_test.go create mode 100644 go/adbc/driver/internal/driverbase/driver_test.go rename go/adbc/driver/{ => internal}/driverbase/error.go (100%) rename go/adbc/driver/{ => internal}/driverbase/logging.go (100%) diff --git a/go/adbc/adbc.go b/go/adbc/adbc.go index f5514626ad..6968faacf5 100644 --- a/go/adbc/adbc.go +++ b/go/adbc/adbc.go @@ -355,6 +355,17 @@ const ( InfoDriverADBCVersion InfoCode = 103 // DriverADBCVersion ) +type InfoValueTypeCode = arrow.UnionTypeCode + +const ( + InfoValueStringType InfoValueTypeCode = 0 + InfoValueBooleanType InfoValueTypeCode = 1 + InfoValueInt64Type InfoValueTypeCode = 2 + InfoValueInt32BitmaskType InfoValueTypeCode = 3 + InfoValueStringListType InfoValueTypeCode = 4 + InfoValueInt32ToInt32ListMapType InfoValueTypeCode = 5 +) + type ObjectDepth int const ( diff --git a/go/adbc/driver/driverbase/driver.go b/go/adbc/driver/driverbase/driver.go deleted file mode 100644 index e4cfb99602..0000000000 --- a/go/adbc/driver/driverbase/driver.go +++ /dev/null @@ -1,66 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// Package driverbase provides a framework for implementing ADBC drivers in -// Go. It intends to reduce boilerplate for common functionality and managing -// state transitions. -package driverbase - -import ( - "github.com/apache/arrow-adbc/go/adbc" - "github.com/apache/arrow/go/v16/arrow/memory" -) - -// DriverImpl is an interface that drivers implement to provide -// vendor-specific functionality. -type DriverImpl interface { - Base() *DriverImplBase - NewDatabase(opts map[string]string) (adbc.Database, error) -} - -// DriverImplBase is a struct that provides default implementations of the -// DriverImpl interface. It is meant to be used as a composite struct for a -// driver's DriverImpl implementation. -type DriverImplBase struct { - Alloc memory.Allocator - ErrorHelper ErrorHelper -} - -func NewDriverImplBase(name string, alloc memory.Allocator) DriverImplBase { - if alloc == nil { - alloc = memory.DefaultAllocator - } - return DriverImplBase{Alloc: alloc, ErrorHelper: ErrorHelper{DriverName: name}} -} - -func (base *DriverImplBase) Base() *DriverImplBase { - return base -} - -// driver is the actual implementation of adbc.Driver. -type driver struct { - impl DriverImpl -} - -// NewDriver wraps a DriverImpl to create an adbc.Driver. -func NewDriver(impl DriverImpl) adbc.Driver { - return &driver{impl} -} - -func (drv *driver) NewDatabase(opts map[string]string) (adbc.Database, error) { - return drv.impl.NewDatabase(opts) -} diff --git a/go/adbc/driver/flightsql/flightsql_connection.go b/go/adbc/driver/flightsql/flightsql_connection.go index e71ac308df..83807856ec 100644 --- a/go/adbc/driver/flightsql/flightsql_connection.go +++ b/go/adbc/driver/flightsql/flightsql_connection.go @@ -28,6 +28,7 @@ import ( "github.com/apache/arrow-adbc/go/adbc" "github.com/apache/arrow-adbc/go/adbc/driver/internal" + "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" "github.com/apache/arrow/go/v16/arrow" "github.com/apache/arrow/go/v16/arrow/array" "github.com/apache/arrow/go/v16/arrow/flight" @@ -43,7 +44,9 @@ import ( "google.golang.org/protobuf/proto" ) -type cnxn struct { +type connectionImpl struct { + driverbase.ConnectionImplBase + cl *flightsql.Client db *databaseImpl @@ -54,6 +57,82 @@ type cnxn struct { supportInfo support } +// GetCurrentCatalog implements driverbase.CurrentNamespacer. +func (c *connectionImpl) GetCurrentCatalog() (string, error) { + options, err := c.getSessionOptions(context.Background()) + if err != nil { + return "", err + } + if catalog, ok := options["catalog"]; ok { + if val, ok := catalog.(string); ok { + return val, nil + } + return "", c.Base().ErrorHelper.Errorf(adbc.StatusInternal, "server returned non-string catalog %#v", catalog) + } + return "", c.Base().ErrorHelper.Errorf(adbc.StatusNotFound, "current catalog not supported") +} + +// GetCurrentDbSchema implements driverbase.CurrentNamespacer. +func (c *connectionImpl) GetCurrentDbSchema() (string, error) { + options, err := c.getSessionOptions(context.Background()) + if err != nil { + return "", err + } + if schema, ok := options["schema"]; ok { + if val, ok := schema.(string); ok { + return val, nil + } + return "", c.Base().ErrorHelper.Errorf(adbc.StatusInternal, "server returned non-string schema %#v", schema) + } + return "", c.Base().ErrorHelper.Errorf(adbc.StatusNotFound, "current schema not supported") +} + +// SetCurrentCatalog implements driverbase.CurrentNamespacer. +func (c *connectionImpl) SetCurrentCatalog(value string) error { + return c.setSessionOptions(context.Background(), "catalog", value) +} + +// SetCurrentDbSchema implements driverbase.CurrentNamespacer. +func (c *connectionImpl) SetCurrentDbSchema(value string) error { + return c.setSessionOptions(context.Background(), "schema", value) +} + +func (c *connectionImpl) SetAutocommit(enabled bool) error { + if enabled && c.txn == nil { + // no-op don't even error if the server didn't support transactions + return nil + } + + if !c.supportInfo.transactions { + return errNoTransactionSupport + } + + ctx := metadata.NewOutgoingContext(context.Background(), c.hdrs) + var err error + if c.txn != nil { + if err = c.txn.Commit(ctx, c.timeouts); err != nil { + return adbc.Error{ + Msg: "[Flight SQL] failed to update autocommit: " + err.Error(), + Code: adbc.StatusIO, + } + } + } + + if enabled { + c.txn = nil + return nil + } + + if c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts); err != nil { + return adbc.Error{ + Msg: "[Flight SQL] failed to update autocommit: " + err.Error(), + Code: adbc.StatusIO, + } + } + + return nil +} + var adbcToFlightSQLInfo = map[adbc.InfoCode]flightsql.SqlInfo{ adbc.InfoVendorName: flightsql.SqlInfoFlightSqlServerName, adbc.InfoVendorVersion: flightsql.SqlInfoFlightSqlServerVersion, @@ -97,7 +176,7 @@ func doGet(ctx context.Context, cl *flightsql.Client, endpoint *flight.FlightEnd return nil, err } -func (c *cnxn) getSessionOptions(ctx context.Context) (map[string]interface{}, error) { +func (c *connectionImpl) getSessionOptions(ctx context.Context) (map[string]interface{}, error) { ctx = metadata.NewOutgoingContext(ctx, c.hdrs) var header, trailer metadata.MD rawOptions, err := c.cl.GetSessionOptions(ctx, &flight.GetSessionOptionsRequest{}, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) @@ -140,7 +219,7 @@ func (c *cnxn) getSessionOptions(ctx context.Context) (map[string]interface{}, e return options, nil } -func (c *cnxn) setSessionOptions(ctx context.Context, key string, val interface{}) error { +func (c *connectionImpl) setSessionOptions(ctx context.Context, key string, val interface{}) error { req := flight.SetSessionOptionsRequest{} var err error @@ -206,7 +285,7 @@ func getSessionOption[T any](options map[string]interface{}, key string, default return value, nil } -func (c *cnxn) GetOption(key string) (string, error) { +func (c *connectionImpl) GetOption(key string) (string, error) { if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix) headers := c.hdrs.Get(name) @@ -226,51 +305,6 @@ func (c *cnxn) GetOption(key string) (string, error) { return c.timeouts.queryTimeout.String(), nil case OptionTimeoutUpdate: return c.timeouts.updateTimeout.String(), nil - case adbc.OptionKeyAutoCommit: - if c.txn != nil { - // No autocommit - return adbc.OptionValueDisabled, nil - } else { - // Autocommit - return adbc.OptionValueEnabled, nil - } - case adbc.OptionKeyCurrentCatalog: - options, err := c.getSessionOptions(context.Background()) - if err != nil { - return "", err - } - if catalog, ok := options["catalog"]; ok { - if val, ok := catalog.(string); ok { - return val, nil - } - return "", adbc.Error{ - Msg: fmt.Sprintf("[FlightSQL] Server returned non-string catalog %#v", catalog), - Code: adbc.StatusInternal, - } - } - return "", adbc.Error{ - Msg: "[FlightSQL] current catalog not supported", - Code: adbc.StatusNotFound, - } - - case adbc.OptionKeyCurrentDbSchema: - options, err := c.getSessionOptions(context.Background()) - if err != nil { - return "", err - } - if schema, ok := options["schema"]; ok { - if val, ok := schema.(string); ok { - return val, nil - } - return "", adbc.Error{ - Msg: fmt.Sprintf("[FlightSQL] Server returned non-string schema %#v", schema), - Code: adbc.StatusInternal, - } - } - return "", adbc.Error{ - Msg: "[FlightSQL] current schema not supported", - Code: adbc.StatusNotFound, - } case OptionSessionOptions: options, err := c.getSessionOptions(context.Background()) if err != nil { @@ -333,7 +367,7 @@ func (c *cnxn) GetOption(key string) (string, error) { } } -func (c *cnxn) GetOptionBytes(key string) ([]byte, error) { +func (c *connectionImpl) GetOptionBytes(key string) ([]byte, error) { switch key { case OptionSessionOptions: options, err := c.getSessionOptions(context.Background()) @@ -356,7 +390,7 @@ func (c *cnxn) GetOptionBytes(key string) ([]byte, error) { } } -func (c *cnxn) GetOptionInt(key string) (int64, error) { +func (c *connectionImpl) GetOptionInt(key string) (int64, error) { switch key { case OptionTimeoutFetch: fallthrough @@ -378,13 +412,10 @@ func (c *cnxn) GetOptionInt(key string) (int64, error) { return getSessionOption(options, name, int64(0), "an integer") } - return 0, adbc.Error{ - Msg: "[Flight SQL] unknown connection option", - Code: adbc.StatusNotFound, - } + return c.ConnectionImplBase.GetOptionInt(key) } -func (c *cnxn) GetOptionDouble(key string) (float64, error) { +func (c *connectionImpl) GetOptionDouble(key string) (float64, error) { switch key { case OptionTimeoutFetch: return c.timeouts.fetchTimeout.Seconds(), nil @@ -402,13 +433,10 @@ func (c *cnxn) GetOptionDouble(key string) (float64, error) { return getSessionOption(options, name, float64(0.0), "a floating-point") } - return 0.0, adbc.Error{ - Msg: "[Flight SQL] unknown connection option", - Code: adbc.StatusNotFound, - } + return c.ConnectionImplBase.GetOptionDouble(key) } -func (c *cnxn) SetOption(key, value string) error { +func (c *connectionImpl) SetOption(key, value string) error { if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix) if value == "" { @@ -422,56 +450,6 @@ func (c *cnxn) SetOption(key, value string) error { switch key { case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate: return c.timeouts.setTimeoutString(key, value) - case adbc.OptionKeyAutoCommit: - autocommit := true - switch value { - case adbc.OptionValueEnabled: - // Do nothing - case adbc.OptionValueDisabled: - autocommit = false - default: - return adbc.Error{ - Msg: "[Flight SQL] invalid value for option " + key + ": " + value, - Code: adbc.StatusInvalidArgument, - } - } - - if autocommit && c.txn == nil { - // no-op don't even error if the server didn't support transactions - return nil - } - - if !c.supportInfo.transactions { - return errNoTransactionSupport - } - - ctx := metadata.NewOutgoingContext(context.Background(), c.hdrs) - var err error - if c.txn != nil { - if err = c.txn.Commit(ctx, c.timeouts); err != nil { - return adbc.Error{ - Msg: "[Flight SQL] failed to update autocommit: " + err.Error(), - Code: adbc.StatusIO, - } - } - } - - if autocommit { - c.txn = nil - return nil - } - - if c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts); err != nil { - return adbc.Error{ - Msg: "[Flight SQL] failed to update autocommit: " + err.Error(), - Code: adbc.StatusIO, - } - } - return nil - case adbc.OptionKeyCurrentCatalog: - return c.setSessionOptions(context.Background(), "catalog", value) - case adbc.OptionKeyCurrentDbSchema: - return c.setSessionOptions(context.Background(), "schema", value) } switch { @@ -506,20 +484,10 @@ func (c *cnxn) SetOption(key, value string) error { return c.setSessionOptions(context.Background(), name, nil) } - return adbc.Error{ - Msg: "[Flight SQL] unknown connection option", - Code: adbc.StatusNotImplemented, - } + return c.ConnectionImplBase.SetOption(key, value) } -func (c *cnxn) SetOptionBytes(key string, value []byte) error { - return adbc.Error{ - Msg: "[Flight SQL] unknown connection option", - Code: adbc.StatusNotImplemented, - } -} - -func (c *cnxn) SetOptionInt(key string, value int64) error { +func (c *connectionImpl) SetOptionInt(key string, value int64) error { switch key { case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate: return c.timeouts.setTimeout(key, float64(value)) @@ -529,13 +497,10 @@ func (c *cnxn) SetOptionInt(key string, value int64) error { return c.setSessionOptions(context.Background(), name, value) } - return adbc.Error{ - Msg: "[Flight SQL] unknown connection option", - Code: adbc.StatusNotImplemented, - } + return c.ConnectionImplBase.SetOptionInt(key, value) } -func (c *cnxn) SetOptionDouble(key string, value float64) error { +func (c *connectionImpl) SetOptionDouble(key string, value float64) error { switch key { case OptionTimeoutFetch: fallthrough @@ -549,231 +514,117 @@ func (c *cnxn) SetOptionDouble(key string, value float64) error { return c.setSessionOptions(context.Background(), name, value) } - return adbc.Error{ - Msg: "[Flight SQL] unknown connection option", - Code: adbc.StatusNotImplemented, - } + return c.ConnectionImplBase.SetOptionDouble(key, value) } -// GetInfo returns metadata about the database/driver. -// -// The result is an Arrow dataset with the following schema: -// -// Field Name | Field Type -// ----------------------------|----------------------------- -// info_name | uint32 not null -// info_value | INFO_SCHEMA -// -// INFO_SCHEMA is a dense union with members: -// -// Field Name (Type Code) | Field Type -// ----------------------------|----------------------------- -// string_value (0) | utf8 -// bool_value (1) | bool -// int64_value (2) | int64 -// int32_bitmask (3) | int32 -// string_list (4) | list -// int32_to_int32_list_map (5) | map> -// -// Each metadatum is identified by an integer code. The recognized -// codes are defined as constants. Codes [0, 10_000) are reserved -// for ADBC usage. Drivers/vendors will ignore requests for unrecognized -// codes (the row will be omitted from the result). -func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.RecordReader, error) { - const strValTypeID arrow.UnionTypeCode = 0 - const intValTypeID arrow.UnionTypeCode = 2 +func (c *connectionImpl) PrepareDriverInfo(ctx context.Context, infoCodes []adbc.InfoCode) error { + driverInfo := c.ConnectionImplBase.DriverInfo if len(infoCodes) == 0 { - infoCodes = infoSupportedCodes + infoCodes = driverInfo.InfoSupportedCodes() } - bldr := array.NewRecordBuilder(c.cl.Alloc, adbc.GetInfoSchema) - defer bldr.Release() - bldr.Reserve(len(infoCodes)) - - infoNameBldr := bldr.Field(0).(*array.Uint32Builder) - infoValueBldr := bldr.Field(1).(*array.DenseUnionBuilder) - strInfoBldr := infoValueBldr.Child(int(strValTypeID)).(*array.StringBuilder) - intInfoBldr := infoValueBldr.Child(int(intValTypeID)).(*array.Int64Builder) - translated := make([]flightsql.SqlInfo, 0, len(infoCodes)) for _, code := range infoCodes { if t, ok := adbcToFlightSQLInfo[code]; ok { translated = append(translated, t) - continue } + } - switch code { - case adbc.InfoDriverName: - infoNameBldr.Append(uint32(code)) - infoValueBldr.Append(strValTypeID) - strInfoBldr.Append(infoDriverName) - case adbc.InfoDriverVersion: - infoNameBldr.Append(uint32(code)) - infoValueBldr.Append(strValTypeID) - strInfoBldr.Append(infoDriverVersion) - case adbc.InfoDriverArrowVersion: - infoNameBldr.Append(uint32(code)) - infoValueBldr.Append(strValTypeID) - strInfoBldr.Append(infoDriverArrowVersion) - case adbc.InfoDriverADBCVersion: - infoNameBldr.Append(uint32(code)) - infoValueBldr.Append(intValTypeID) - intInfoBldr.Append(adbc.AdbcVersion1_1_0) - } + // None of the requested info codes are available on the server, so just return the local info + if len(translated) == 0 { + return nil } ctx = metadata.NewOutgoingContext(ctx, c.hdrs) var header, trailer metadata.MD info, err := c.cl.GetSqlInfo(ctx, translated, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) - if err == nil { - for i, endpoint := range info.Endpoint { - var header, trailer metadata.MD - rdr, err := doGet(ctx, c.cl, endpoint, c.clientCache, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) - if err != nil { - return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location) - } - for rdr.Next() { - rec := rdr.Record() - field := rec.Column(0).(*array.Uint32) - info := rec.Column(1).(*array.DenseUnion) - - for i := 0; i < int(rec.NumRows()); i++ { - switch flightsql.SqlInfo(field.Value(i)) { - case flightsql.SqlInfoFlightSqlServerName: - infoNameBldr.Append(uint32(adbc.InfoVendorName)) - case flightsql.SqlInfoFlightSqlServerVersion: - infoNameBldr.Append(uint32(adbc.InfoVendorVersion)) - case flightsql.SqlInfoFlightSqlServerArrowVersion: - infoNameBldr.Append(uint32(adbc.InfoVendorArrowVersion)) - default: - continue - } + // Just return local driver info if GetSqlInfo hasn't been implemented on the server + if grpcstatus.Code(err) == grpccodes.Unimplemented { + return nil + } + + if err != nil { + return adbcFromFlightStatus(err, "GetInfo(GetSqlInfo)") + } + + // No error, go get the SqlInfo from the server + for i, endpoint := range info.Endpoint { + var header, trailer metadata.MD + rdr, err := doGet(ctx, c.cl, endpoint, c.clientCache, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) + if err != nil { + return adbcFromFlightStatusWithDetails(err, header, trailer, "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location) + } - infoValueBldr.Append(info.TypeCode(i)) - // we know we're only doing string fields here right now - v := info.Field(info.ChildID(i)).(*array.String). - Value(int(info.ValueOffset(i))) - strInfoBldr.Append(v) + for rdr.Next() { + rec := rdr.Record() + field := rec.Column(0).(*array.Uint32) + info := rec.Column(1).(*array.DenseUnion) + + var adbcInfoCode adbc.InfoCode + for i := 0; i < int(rec.NumRows()); i++ { + switch flightsql.SqlInfo(field.Value(i)) { + case flightsql.SqlInfoFlightSqlServerName: + adbcInfoCode = adbc.InfoVendorName + case flightsql.SqlInfoFlightSqlServerVersion: + adbcInfoCode = adbc.InfoVendorVersion + case flightsql.SqlInfoFlightSqlServerArrowVersion: + adbcInfoCode = adbc.InfoVendorArrowVersion + default: + continue } - } - if err := checkContext(rdr.Err(), ctx); err != nil { - return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location) + // we know we're only doing string fields here right now + v := info.Field(info.ChildID(i)).(*array.String). + Value(int(info.ValueOffset(i))) + if err := driverInfo.RegisterInfoCode(adbcInfoCode, strings.Clone(v)); err != nil { + return err + } } } - } else if grpcstatus.Code(err) != grpccodes.Unimplemented { - return nil, adbcFromFlightStatus(err, "GetInfo(GetSqlInfo)") + + if err := checkContext(rdr.Err(), ctx); err != nil { + return adbcFromFlightStatusWithDetails(err, header, trailer, "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location) + } } - final := bldr.NewRecord() - defer final.Release() - return array.NewRecordReader(adbc.GetInfoSchema, []arrow.Record{final}) + return nil } -// GetObjects gets a hierarchical view of all catalogs, database schemas, -// tables, and columns. -// -// The result is an Arrow Dataset with the following schema: -// -// Field Name | Field Type -// ----------------------------|---------------------------- -// catalog_name | utf8 -// catalog_db_schemas | list -// -// DB_SCHEMA_SCHEMA is a Struct with the fields: -// -// Field Name | Field Type -// ----------------------------|---------------------------- -// db_schema_name | utf8 -// db_schema_tables | list -// -// TABLE_SCHEMA is a Struct with the fields: -// -// Field Name | Field Type -// ----------------------------|---------------------------- -// table_name | utf8 not null -// table_type | utf8 not null -// table_columns | list -// table_constraints | list -// -// COLUMN_SCHEMA is a Struct with the fields: -// -// Field Name | Field Type | Comments -// ----------------------------|---------------------|--------- -// column_name | utf8 not null | -// ordinal_position | int32 | (1) -// remarks | utf8 | (2) -// xdbc_data_type | int16 | (3) -// xdbc_type_name | utf8 | (3) -// xdbc_column_size | int32 | (3) -// xdbc_decimal_digits | int16 | (3) -// xdbc_num_prec_radix | int16 | (3) -// xdbc_nullable | int16 | (3) -// xdbc_column_def | utf8 | (3) -// xdbc_sql_data_type | int16 | (3) -// xdbc_datetime_sub | int16 | (3) -// xdbc_char_octet_length | int32 | (3) -// xdbc_is_nullable | utf8 | (3) -// xdbc_scope_catalog | utf8 | (3) -// xdbc_scope_schema | utf8 | (3) -// xdbc_scope_table | utf8 | (3) -// xdbc_is_autoincrement | bool | (3) -// xdbc_is_generatedcolumn | bool | (3) -// -// 1. The column's ordinal position in the table (starting from 1). -// 2. Database-specific description of the column. -// 3. Optional Value. Should be null if not supported by the driver. -// xdbc_values are meant to provide JDBC/ODBC-compatible metadata -// in an agnostic manner. -// -// CONSTRAINT_SCHEMA is a Struct with the fields: -// -// Field Name | Field Type | Comments -// ----------------------------|---------------------|--------- -// constraint_name | utf8 | -// constraint_type | utf8 not null | (1) -// constraint_column_names | list not null | (2) -// constraint_column_usage | list | (3) -// -// 1. One of 'CHECK', 'FOREIGN KEY', 'PRIMARY KEY', or 'UNIQUE'. -// 2. The columns on the current table that are constrained, in order. -// 3. For FOREIGN KEY only, the referenced table and columns. -// -// USAGE_SCHEMA is a Struct with fields: -// -// Field Name | Field Type -// ----------------------------|---------------------------- -// fk_catalog | utf8 -// fk_db_schema | utf8 -// fk_table | utf8 not null -// fk_column_name | utf8 not null -// -// For the parameters: If nil is passed, then that parameter will not -// be filtered by at all. If an empty string, then only objects without -// that property (ie: catalog or db schema) will be returned. -// -// tableName and columnName must be either nil (do not filter by -// table name or column name) or non-empty. -// -// All non-empty, non-nil strings should be a search pattern (as described -// earlier). -func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (array.RecordReader, error) { - ctx = metadata.NewOutgoingContext(ctx, c.hdrs) - g := internal.GetObjects{Ctx: ctx, Depth: depth, Catalog: catalog, DbSchema: dbSchema, TableName: tableName, ColumnName: columnName, TableType: tableType} - if err := g.Init(c.db.Alloc, c.getObjectsDbSchemas, c.getObjectsTables); err != nil { - return nil, err +// Helper function to read and validate a metadata stream +func (c *connectionImpl) readInfo(ctx context.Context, expectedSchema *arrow.Schema, info *flight.FlightInfo, opts ...grpc.CallOption) (array.RecordReader, error) { + // use a default queueSize for the reader + rdr, err := newRecordReader(ctx, c.db.Alloc, c.cl, info, c.clientCache, 5, opts...) + if err != nil { + return nil, adbcFromFlightStatus(err, "DoGet") } - defer g.Release() - var header, trailer metadata.MD + if !rdr.Schema().Equal(expectedSchema) { + rdr.Release() + return nil, adbc.Error{ + Msg: fmt.Sprintf("Invalid schema returned for: expected %s, got %s", expectedSchema.String(), rdr.Schema().String()), + Code: adbc.StatusInternal, + } + } + return rdr, nil +} + +func (c *connectionImpl) GetObjectsCatalogs(ctx context.Context, catalog *string) ([]string, error) { + var ( + header, trailer metadata.MD + numCatalogs int64 + ) // To avoid an N+1 query problem, we assume result sets here will fit in memory and build up a single response. info, err := c.cl.GetCatalogs(ctx, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) if err != nil { return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetCatalogs)") } + if info.TotalRecords > 0 { + numCatalogs = info.TotalRecords + } + header = metadata.MD{} trailer = metadata.MD{} rdr, err := c.readInfo(ctx, schema_ref.Catalogs, info, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) @@ -782,48 +633,25 @@ func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog * } defer rdr.Release() - foundCatalog := false + catalogs := make([]string, 0, numCatalogs) for rdr.Next() { arr := rdr.Record().Column(0).(*array.String) for i := 0; i < arr.Len(); i++ { // XXX: force copy since accessor is unsafe catalogName := string([]byte(arr.Value(i))) - g.AppendCatalog(catalogName) - foundCatalog = true + catalogs = append(catalogs, catalogName) } } - // Implementations like Dremio report no catalogs, but still have schemas - if !foundCatalog && depth != adbc.ObjectDepthCatalogs { - g.AppendCatalog("") - } - if err := checkContext(rdr.Err(), ctx); err != nil { return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetObjects(GetCatalogs)") } - return g.Finish() -} - -// Helper function to read and validate a metadata stream -func (c *cnxn) readInfo(ctx context.Context, expectedSchema *arrow.Schema, info *flight.FlightInfo, opts ...grpc.CallOption) (array.RecordReader, error) { - // use a default queueSize for the reader - rdr, err := newRecordReader(ctx, c.db.Alloc, c.cl, info, c.clientCache, 5, opts...) - if err != nil { - return nil, adbcFromFlightStatus(err, "DoGet") - } - if !rdr.Schema().Equal(expectedSchema) { - rdr.Release() - return nil, adbc.Error{ - Msg: fmt.Sprintf("Invalid schema returned for: expected %s, got %s", expectedSchema.String(), rdr.Schema().String()), - Code: adbc.StatusInternal, - } - } - return rdr, nil + return catalogs, nil } // Helper function to build up a map of catalogs to DB schemas -func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, metadataRecords []internal.Metadata) (result map[string][]string, err error) { +func (c *connectionImpl) GetObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, metadataRecords []internal.Metadata) (result map[string][]string, err error) { if depth == adbc.ObjectDepthCatalogs { return } @@ -864,7 +692,7 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, return } -func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string, metadataRecords []internal.Metadata) (result internal.SchemaToTableInfo, err error) { +func (c *connectionImpl) GetObjectsTables(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string, metadataRecords []internal.Metadata) (result internal.SchemaToTableInfo, err error) { if depth == adbc.ObjectDepthCatalogs || depth == adbc.ObjectDepthDBSchemas { return } @@ -944,7 +772,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat return } -func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *string, tableName string) (*arrow.Schema, error) { +func (c *connectionImpl) GetTableSchema(ctx context.Context, catalog *string, dbSchema *string, tableName string) (*arrow.Schema, error) { opts := &flightsql.GetTablesOpts{ Catalog: catalog, DbSchemaFilterPattern: dbSchema, @@ -1023,7 +851,7 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st // Field Name | Field Type // ----------------|-------------- // table_type | utf8 not null -func (c *cnxn) GetTableTypes(ctx context.Context) (array.RecordReader, error) { +func (c *connectionImpl) GetTableTypes(ctx context.Context) (array.RecordReader, error) { ctx = metadata.NewOutgoingContext(ctx, c.hdrs) var header, trailer metadata.MD info, err := c.cl.GetTableTypes(ctx, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) @@ -1040,18 +868,7 @@ func (c *cnxn) GetTableTypes(ctx context.Context) (array.RecordReader, error) { // Behavior is undefined if this is mixed with SQL transaction statements. // When not supported, the convention is that it should act as if autocommit // is enabled and return INVALID_STATE errors. -func (c *cnxn) Commit(ctx context.Context) error { - if c.txn == nil { - return adbc.Error{ - Msg: "[Flight SQL] Cannot commit when autocommit is enabled", - Code: adbc.StatusInvalidState, - } - } - - if !c.supportInfo.transactions { - return errNoTransactionSupport - } - +func (c *connectionImpl) Commit(ctx context.Context) error { ctx = metadata.NewOutgoingContext(ctx, c.hdrs) var header, trailer metadata.MD err := c.txn.Commit(ctx, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) @@ -1074,18 +891,7 @@ func (c *cnxn) Commit(ctx context.Context) error { // Behavior is undefined if this is mixed with SQL transaction statements. // When not supported, the convention is that it should act as if autocommit // is enabled and return INVALID_STATE errors. -func (c *cnxn) Rollback(ctx context.Context) error { - if c.txn == nil { - return adbc.Error{ - Msg: "[Flight SQL] Cannot rollback when autocommit is enabled", - Code: adbc.StatusInvalidState, - } - } - - if !c.supportInfo.transactions { - return errNoTransactionSupport - } - +func (c *connectionImpl) Rollback(ctx context.Context) error { ctx = metadata.NewOutgoingContext(ctx, c.hdrs) var header, trailer metadata.MD err := c.txn.Rollback(ctx, c.timeouts, grpc.Header(&header), grpc.Trailer(&trailer)) @@ -1103,7 +909,7 @@ func (c *cnxn) Rollback(ctx context.Context) error { } // NewStatement initializes a new statement object tied to this connection -func (c *cnxn) NewStatement() (adbc.Statement, error) { +func (c *connectionImpl) NewStatement() (adbc.Statement, error) { return &statement{ alloc: c.db.Alloc, clientCache: c.clientCache, @@ -1114,7 +920,7 @@ func (c *cnxn) NewStatement() (adbc.Statement, error) { }, nil } -func (c *cnxn) execute(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.FlightInfo, error) { +func (c *connectionImpl) execute(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.FlightInfo, error) { if c.txn != nil { return c.txn.Execute(ctx, query, opts...) } @@ -1122,7 +928,7 @@ func (c *cnxn) execute(ctx context.Context, query string, opts ...grpc.CallOptio return c.cl.Execute(ctx, query, opts...) } -func (c *cnxn) executeSchema(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.SchemaResult, error) { +func (c *connectionImpl) executeSchema(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.SchemaResult, error) { if c.txn != nil { return c.txn.GetExecuteSchema(ctx, query, opts...) } @@ -1130,7 +936,7 @@ func (c *cnxn) executeSchema(ctx context.Context, query string, opts ...grpc.Cal return c.cl.GetExecuteSchema(ctx, query, opts...) } -func (c *cnxn) executeSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.FlightInfo, error) { +func (c *connectionImpl) executeSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.FlightInfo, error) { if c.txn != nil { return c.txn.ExecuteSubstrait(ctx, plan, opts...) } @@ -1138,7 +944,7 @@ func (c *cnxn) executeSubstrait(ctx context.Context, plan flightsql.SubstraitPla return c.cl.ExecuteSubstrait(ctx, plan, opts...) } -func (c *cnxn) executeSubstraitSchema(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.SchemaResult, error) { +func (c *connectionImpl) executeSubstraitSchema(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.SchemaResult, error) { if c.txn != nil { return c.txn.GetExecuteSubstraitSchema(ctx, plan, opts...) } @@ -1146,7 +952,7 @@ func (c *cnxn) executeSubstraitSchema(ctx context.Context, plan flightsql.Substr return c.cl.GetExecuteSubstraitSchema(ctx, plan, opts...) } -func (c *cnxn) executeUpdate(ctx context.Context, query string, opts ...grpc.CallOption) (n int64, err error) { +func (c *connectionImpl) executeUpdate(ctx context.Context, query string, opts ...grpc.CallOption) (n int64, err error) { if c.txn != nil { return c.txn.ExecuteUpdate(ctx, query, opts...) } @@ -1154,7 +960,7 @@ func (c *cnxn) executeUpdate(ctx context.Context, query string, opts ...grpc.Cal return c.cl.ExecuteUpdate(ctx, query, opts...) } -func (c *cnxn) executeSubstraitUpdate(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (n int64, err error) { +func (c *connectionImpl) executeSubstraitUpdate(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (n int64, err error) { if c.txn != nil { return c.txn.ExecuteSubstraitUpdate(ctx, plan, opts...) } @@ -1162,7 +968,7 @@ func (c *cnxn) executeSubstraitUpdate(ctx context.Context, plan flightsql.Substr return c.cl.ExecuteSubstraitUpdate(ctx, plan, opts...) } -func (c *cnxn) poll(ctx context.Context, query string, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) { +func (c *connectionImpl) poll(ctx context.Context, query string, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) { if c.txn != nil { return c.txn.ExecutePoll(ctx, query, retryDescriptor, opts...) } @@ -1170,7 +976,7 @@ func (c *cnxn) poll(ctx context.Context, query string, retryDescriptor *flight.F return c.cl.ExecutePoll(ctx, query, retryDescriptor, opts...) } -func (c *cnxn) pollSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) { +func (c *connectionImpl) pollSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) { if c.txn != nil { return c.txn.ExecuteSubstraitPoll(ctx, plan, retryDescriptor, opts...) } @@ -1178,7 +984,7 @@ func (c *cnxn) pollSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, return c.cl.ExecuteSubstraitPoll(ctx, plan, retryDescriptor, opts...) } -func (c *cnxn) prepare(ctx context.Context, query string, opts ...grpc.CallOption) (*flightsql.PreparedStatement, error) { +func (c *connectionImpl) prepare(ctx context.Context, query string, opts ...grpc.CallOption) (*flightsql.PreparedStatement, error) { if c.txn != nil { return c.txn.Prepare(ctx, query, opts...) } @@ -1186,7 +992,7 @@ func (c *cnxn) prepare(ctx context.Context, query string, opts ...grpc.CallOptio return c.cl.Prepare(ctx, query, opts...) } -func (c *cnxn) prepareSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flightsql.PreparedStatement, error) { +func (c *connectionImpl) prepareSubstrait(ctx context.Context, plan flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flightsql.PreparedStatement, error) { if c.txn != nil { return c.txn.PrepareSubstrait(ctx, plan, opts...) } @@ -1195,7 +1001,7 @@ func (c *cnxn) prepareSubstrait(ctx context.Context, plan flightsql.SubstraitPla } // Close closes this connection and releases any associated resources. -func (c *cnxn) Close() error { +func (c *connectionImpl) Close() error { if c.cl == nil { return adbc.Error{ Msg: "[Flight SQL Connection] trying to close already closed connection", @@ -1225,7 +1031,7 @@ func (c *cnxn) Close() error { // results can then be read independently using the returned RecordReader. // // A partition can be retrieved by using ExecutePartitions on a statement. -func (c *cnxn) ReadPartition(ctx context.Context, serializedPartition []byte) (rdr array.RecordReader, err error) { +func (c *connectionImpl) ReadPartition(ctx context.Context, serializedPartition []byte) (rdr array.RecordReader, err error) { var info flight.FlightInfo if err := proto.Unmarshal(serializedPartition, &info); err != nil { return nil, adbc.Error{ @@ -1251,5 +1057,5 @@ func (c *cnxn) ReadPartition(ctx context.Context, serializedPartition []byte) (r } var ( - _ adbc.PostInitOptions = (*cnxn)(nil) + _ adbc.PostInitOptions = (*connectionImpl)(nil) ) diff --git a/go/adbc/driver/flightsql/flightsql_database.go b/go/adbc/driver/flightsql/flightsql_database.go index 5e5e3af978..9f0848c3f9 100644 --- a/go/adbc/driver/flightsql/flightsql_database.go +++ b/go/adbc/driver/flightsql/flightsql_database.go @@ -29,7 +29,7 @@ import ( "time" "github.com/apache/arrow-adbc/go/adbc" - "github.com/apache/arrow-adbc/go/adbc/driver/driverbase" + "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" "github.com/apache/arrow/go/v16/arrow/array" "github.com/apache/arrow/go/v16/arrow/flight" "github.com/apache/arrow/go/v16/arrow/flight/flightsql" @@ -51,7 +51,6 @@ func (d *dbDialOpts) rebuild() { d.opts = []grpc.DialOption{ grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(d.maxMsgSize), grpc.MaxCallSendMsgSize(d.maxMsgSize)), - grpc.WithUserAgent("ADBC Flight SQL Driver " + infoDriverVersion), } if d.block { d.opts = append(d.opts, grpc.WithBlock()) @@ -383,7 +382,12 @@ func getFlightClient(ctx context.Context, loc string, d *databaseImpl, authMiddl creds = insecure.NewCredentials() target = "unix:" + uri.Path } - dialOpts := append(d.dialOpts.opts, grpc.WithConnectParams(d.timeout.connectParams()), grpc.WithTransportCredentials(creds)) + + driverVersion, ok := d.DatabaseImplBase.DriverInfo.GetInfoDriverVersion() + if !ok { + driverVersion = driverbase.UnknownVersion + } + dialOpts := append(d.dialOpts.opts, grpc.WithConnectParams(d.timeout.connectParams()), grpc.WithTransportCredentials(creds), grpc.WithUserAgent("ADBC Flight SQL Driver "+driverVersion)) d.Logger.DebugContext(ctx, "new client", "location", loc) cl, err := flightsql.NewClient(target, nil, middleware, dialOpts...) @@ -503,9 +507,18 @@ func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) { } } - return &cnxn{cl: cl, db: d, clientCache: cache, - hdrs: make(metadata.MD), timeouts: d.timeout, - supportInfo: cnxnSupport}, nil + conn := &connectionImpl{ + cl: cl, db: d, clientCache: cache, + hdrs: make(metadata.MD), timeouts: d.timeout, supportInfo: cnxnSupport, + ConnectionImplBase: driverbase.NewConnectionImplBase(&d.DatabaseImplBase), + } + + return driverbase.NewConnectionBuilder(conn). + WithDriverInfoPreparer(conn). + WithAutocommitSetter(conn). + WithDbObjectsEnumerator(conn). + WithCurrentNamespacer(conn). + Connection(), nil } type bearerAuthMiddleware struct { diff --git a/go/adbc/driver/flightsql/flightsql_driver.go b/go/adbc/driver/flightsql/flightsql_driver.go index d437f0829b..441370a9e2 100644 --- a/go/adbc/driver/flightsql/flightsql_driver.go +++ b/go/adbc/driver/flightsql/flightsql_driver.go @@ -33,12 +33,10 @@ package flightsql import ( "net/url" - "runtime/debug" - "strings" "time" "github.com/apache/arrow-adbc/go/adbc" - "github.com/apache/arrow-adbc/go/adbc/driver/driverbase" + "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" "github.com/apache/arrow/go/v16/arrow/memory" "golang.org/x/exp/maps" "google.golang.org/grpc/metadata" @@ -69,56 +67,19 @@ const ( infoDriverName = "ADBC Flight SQL Driver - Go" ) -var ( - infoDriverVersion string - infoDriverArrowVersion string - infoSupportedCodes []adbc.InfoCode -) - var errNoTransactionSupport = adbc.Error{ Msg: "[Flight SQL] server does not report transaction support", Code: adbc.StatusNotImplemented, } -func init() { - if info, ok := debug.ReadBuildInfo(); ok { - for _, dep := range info.Deps { - switch { - case dep.Path == "github.com/apache/arrow-adbc/go/adbc/driver/flightsql": - infoDriverVersion = dep.Version - case strings.HasPrefix(dep.Path, "github.com/apache/arrow/go/"): - infoDriverArrowVersion = dep.Version - } - } - } - // XXX: Deps not populated in tests - // https://github.com/golang/go/issues/33976 - if infoDriverVersion == "" { - infoDriverVersion = "(unknown or development build)" - } - if infoDriverArrowVersion == "" { - infoDriverArrowVersion = "(unknown or development build)" - } - - infoSupportedCodes = []adbc.InfoCode{ - adbc.InfoDriverName, - adbc.InfoDriverVersion, - adbc.InfoDriverArrowVersion, - adbc.InfoDriverADBCVersion, - adbc.InfoVendorName, - adbc.InfoVendorVersion, - adbc.InfoVendorArrowVersion, - } -} - type driverImpl struct { driverbase.DriverImplBase } // NewDriver creates a new Flight SQL driver using the given Arrow allocator. func NewDriver(alloc memory.Allocator) adbc.Driver { - impl := driverImpl{DriverImplBase: driverbase.NewDriverImplBase("Flight SQL", alloc)} - return driverbase.NewDriver(&impl) + info := driverbase.DefaultDriverInfo("Flight SQL") + return driverbase.NewDriver(&driverImpl{DriverImplBase: driverbase.NewDriverImplBase(info, alloc)}) } func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error) { diff --git a/go/adbc/driver/flightsql/flightsql_statement.go b/go/adbc/driver/flightsql/flightsql_statement.go index d78b653c81..c68eba8cdb 100644 --- a/go/adbc/driver/flightsql/flightsql_statement.go +++ b/go/adbc/driver/flightsql/flightsql_statement.go @@ -72,7 +72,7 @@ func (s *sqlOrSubstrait) setSubstraitPlan(plan []byte) { s.substraitPlan = plan } -func (s *sqlOrSubstrait) execute(ctx context.Context, cnxn *cnxn, opts ...grpc.CallOption) (*flight.FlightInfo, error) { +func (s *sqlOrSubstrait) execute(ctx context.Context, cnxn *connectionImpl, opts ...grpc.CallOption) (*flight.FlightInfo, error) { if s.sqlQuery != "" { return cnxn.execute(ctx, s.sqlQuery, opts...) } else if s.substraitPlan != nil { @@ -85,7 +85,7 @@ func (s *sqlOrSubstrait) execute(ctx context.Context, cnxn *cnxn, opts ...grpc.C } } -func (s *sqlOrSubstrait) executeSchema(ctx context.Context, cnxn *cnxn, opts ...grpc.CallOption) (*arrow.Schema, error) { +func (s *sqlOrSubstrait) executeSchema(ctx context.Context, cnxn *connectionImpl, opts ...grpc.CallOption) (*arrow.Schema, error) { var ( res *flight.SchemaResult err error @@ -108,7 +108,7 @@ func (s *sqlOrSubstrait) executeSchema(ctx context.Context, cnxn *cnxn, opts ... return flight.DeserializeSchema(res.Schema, cnxn.cl.Alloc) } -func (s *sqlOrSubstrait) executeUpdate(ctx context.Context, cnxn *cnxn, opts ...grpc.CallOption) (int64, error) { +func (s *sqlOrSubstrait) executeUpdate(ctx context.Context, cnxn *connectionImpl, opts ...grpc.CallOption) (int64, error) { if s.sqlQuery != "" { return cnxn.executeUpdate(ctx, s.sqlQuery, opts...) } else if s.substraitPlan != nil { @@ -121,7 +121,7 @@ func (s *sqlOrSubstrait) executeUpdate(ctx context.Context, cnxn *cnxn, opts ... } } -func (s *sqlOrSubstrait) poll(ctx context.Context, cnxn *cnxn, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) { +func (s *sqlOrSubstrait) poll(ctx context.Context, cnxn *connectionImpl, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) { if s.sqlQuery != "" { return cnxn.poll(ctx, s.sqlQuery, retryDescriptor, opts...) } else if s.substraitPlan != nil { @@ -134,7 +134,7 @@ func (s *sqlOrSubstrait) poll(ctx context.Context, cnxn *cnxn, retryDescriptor * } } -func (s *sqlOrSubstrait) prepare(ctx context.Context, cnxn *cnxn, opts ...grpc.CallOption) (*flightsql.PreparedStatement, error) { +func (s *sqlOrSubstrait) prepare(ctx context.Context, cnxn *connectionImpl, opts ...grpc.CallOption) (*flightsql.PreparedStatement, error) { if s.sqlQuery != "" { return cnxn.prepare(ctx, s.sqlQuery, opts...) } else if s.substraitPlan != nil { @@ -156,7 +156,7 @@ type incrementalState struct { type statement struct { alloc memory.Allocator - cnxn *cnxn + cnxn *connectionImpl clientCache gcache.Cache hdrs metadata.MD diff --git a/go/adbc/driver/internal/driverbase/connection.go b/go/adbc/driver/internal/driverbase/connection.go new file mode 100644 index 0000000000..68b0a9bc69 --- /dev/null +++ b/go/adbc/driver/internal/driverbase/connection.go @@ -0,0 +1,497 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package driverbase + +import ( + "context" + "fmt" + + "github.com/apache/arrow-adbc/go/adbc" + "github.com/apache/arrow-adbc/go/adbc/driver/internal" + "github.com/apache/arrow/go/v16/arrow" + "github.com/apache/arrow/go/v16/arrow/array" + "github.com/apache/arrow/go/v16/arrow/memory" + "golang.org/x/exp/slog" +) + +const ( + ConnectionMessageOptionUnknown = "Unknown connection option" + ConnectionMessageOptionUnsupported = "Unsupported connection option" + ConnectionMessageCannotCommit = "Cannot commit when autocommit is enabled" + ConnectionMessageCannotRollback = "Cannot rollback when autocommit is enabled" +) + +// ConnectionImpl is an interface that drivers implement to provide +// vendor-specific functionality. +type ConnectionImpl interface { + adbc.Connection + adbc.GetSetOptions + Base() *ConnectionImplBase +} + +// CurrentNamespacer is an interface that drivers may implement to delegate +// stateful namespacing with DB catalogs and schemas. The appropriate (Get/Set)Options +// implementations will be provided using the results of these methods. +type CurrentNamespacer interface { + GetCurrentCatalog() (string, error) + GetCurrentDbSchema() (string, error) + SetCurrentCatalog(string) error + SetCurrentDbSchema(string) error +} + +// DriverInfoPreparer is an interface that drivers may implement to add/update +// DriverInfo values whenever adbc.Connection.GetInfo() is called. +type DriverInfoPreparer interface { + PrepareDriverInfo(ctx context.Context, infoCodes []adbc.InfoCode) error +} + +// TableTypeLister is an interface that drivers may implement to simplify the +// implementation of adbc.Connection.GetTableTypes() for backends that do not natively +// send these values as arrow records. The conversion of the result to a RecordReader +// is handled automatically. +type TableTypeLister interface { + ListTableTypes(ctx context.Context) ([]string, error) +} + +// AutocommitSetter is an interface that drivers may implement to simplify the +// implementation of autocommit state management. There is no need to implement +// this for backends that do not support autocommit, as this is already the default +// behavior. SetAutocommit should only attempt to update the autocommit state in the +// backend. Local driver state is automatically updated if the result of this call +// does not produce an error. (Get/Set)Options implementations are provided automatically +// as well/ +type AutocommitSetter interface { + SetAutocommit(enabled bool) error +} + +// DbObjectsEnumerator is an interface that drivers may implement to simplify the +// implementation of adbc.Connection.GetObjects(). By independently implementing lookup +// for catalogs, dbSchemas and tables, the driverbase is able to provide the full +// GetObjects functionality for arbitrary search patterns and lookup depth. +type DbObjectsEnumerator interface { + GetObjectsCatalogs(ctx context.Context, catalog *string) ([]string, error) + GetObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, catalog *string, schema *string, metadataRecords []internal.Metadata) (map[string][]string, error) + GetObjectsTables(ctx context.Context, depth adbc.ObjectDepth, catalog *string, schema *string, tableName *string, columnName *string, tableType []string, metadataRecords []internal.Metadata) (map[internal.CatalogAndSchema][]internal.TableInfo, error) +} + +// Connection is the interface satisfied by the result of the NewConnection constructor, +// given that an input is provided satisfying the ConnectionImpl interface. +type Connection interface { + adbc.Connection + adbc.GetSetOptions +} + +// ConnectionImplBase is a struct that provides default implementations of the +// ConnectionImpl interface. It is meant to be used as a composite struct for a +// driver's ConnectionImpl implementation. +type ConnectionImplBase struct { + Alloc memory.Allocator + ErrorHelper ErrorHelper + DriverInfo *DriverInfo + Logger *slog.Logger + + Autocommit bool + Closed bool +} + +// NewConnectionImplBase instantiates ConnectionImplBase. +// +// - database is a DatabaseImplBase containing the common resources from the parent +// database, allowing the Arrow allocator, error handler, and logger to be reused. +func NewConnectionImplBase(database *DatabaseImplBase) ConnectionImplBase { + return ConnectionImplBase{ + Alloc: database.Alloc, + ErrorHelper: database.ErrorHelper, + DriverInfo: database.DriverInfo, + Logger: database.Logger, + Autocommit: true, + Closed: false, + } +} + +func (base *ConnectionImplBase) Base() *ConnectionImplBase { + return base +} + +func (base *ConnectionImplBase) Commit(ctx context.Context) error { + return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Commit") +} + +func (base *ConnectionImplBase) Rollback(context.Context) error { + return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Rollback") +} + +func (base *ConnectionImplBase) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.RecordReader, error) { + + if len(infoCodes) == 0 { + infoCodes = base.DriverInfo.InfoSupportedCodes() + } + + bldr := array.NewRecordBuilder(base.Alloc, adbc.GetInfoSchema) + defer bldr.Release() + bldr.Reserve(len(infoCodes)) + + infoNameBldr := bldr.Field(0).(*array.Uint32Builder) + infoValueBldr := bldr.Field(1).(*array.DenseUnionBuilder) + strInfoBldr := infoValueBldr.Child(int(adbc.InfoValueStringType)).(*array.StringBuilder) + intInfoBldr := infoValueBldr.Child(int(adbc.InfoValueInt64Type)).(*array.Int64Builder) + + for _, code := range infoCodes { + switch code { + case adbc.InfoDriverName: + name, ok := base.DriverInfo.GetInfoDriverName() + if !ok { + continue + } + + infoNameBldr.Append(uint32(code)) + infoValueBldr.Append(adbc.InfoValueStringType) + strInfoBldr.Append(name) + case adbc.InfoDriverVersion: + version, ok := base.DriverInfo.GetInfoDriverVersion() + if !ok { + continue + } + + infoNameBldr.Append(uint32(code)) + infoValueBldr.Append(adbc.InfoValueStringType) + strInfoBldr.Append(version) + case adbc.InfoDriverArrowVersion: + arrowVersion, ok := base.DriverInfo.GetInfoDriverArrowVersion() + if !ok { + continue + } + + infoNameBldr.Append(uint32(code)) + infoValueBldr.Append(adbc.InfoValueStringType) + strInfoBldr.Append(arrowVersion) + case adbc.InfoDriverADBCVersion: + adbcVersion, ok := base.DriverInfo.GetInfoDriverADBCVersion() + if !ok { + continue + } + + infoNameBldr.Append(uint32(code)) + infoValueBldr.Append(adbc.InfoValueInt64Type) + intInfoBldr.Append(adbcVersion) + case adbc.InfoVendorName: + name, ok := base.DriverInfo.GetInfoVendorName() + if !ok { + continue + } + + infoNameBldr.Append(uint32(code)) + infoValueBldr.Append(adbc.InfoValueStringType) + strInfoBldr.Append(name) + default: + infoNameBldr.Append(uint32(code)) + value, ok := base.DriverInfo.GetInfoForInfoCode(code) + if !ok { + infoValueBldr.AppendNull() + continue + } + + // TODO: Handle other custom info types + infoValueBldr.Append(adbc.InfoValueStringType) + strInfoBldr.Append(fmt.Sprint(value)) + } + } + + final := bldr.NewRecord() + defer final.Release() + return array.NewRecordReader(adbc.GetInfoSchema, []arrow.Record{final}) +} + +func (base *ConnectionImplBase) Close() error { + return nil +} + +func (base *ConnectionImplBase) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (array.RecordReader, error) { + return nil, base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "GetObjects") +} + +func (base *ConnectionImplBase) GetTableSchema(ctx context.Context, catalog *string, dbSchema *string, tableName string) (*arrow.Schema, error) { + return nil, base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "GetTableSchema") +} + +func (base *ConnectionImplBase) GetTableTypes(context.Context) (array.RecordReader, error) { + return nil, base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "GetTableTypes") +} + +func (base *ConnectionImplBase) NewStatement() (adbc.Statement, error) { + return nil, base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "NewStatement") +} + +func (base *ConnectionImplBase) ReadPartition(ctx context.Context, serializedPartition []byte) (array.RecordReader, error) { + return nil, base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "ReadPartition") +} + +func (base *ConnectionImplBase) GetOption(key string) (string, error) { + return "", base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'", ConnectionMessageOptionUnknown, key) +} + +func (base *ConnectionImplBase) GetOptionBytes(key string) ([]byte, error) { + return nil, base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'", ConnectionMessageOptionUnknown, key) +} + +func (base *ConnectionImplBase) GetOptionDouble(key string) (float64, error) { + return 0, base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'", ConnectionMessageOptionUnknown, key) +} + +func (base *ConnectionImplBase) GetOptionInt(key string) (int64, error) { + return 0, base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'", ConnectionMessageOptionUnknown, key) +} + +func (base *ConnectionImplBase) SetOption(key string, val string) error { + switch key { + case adbc.OptionKeyAutoCommit: + return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'", ConnectionMessageOptionUnsupported, key) + } + return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'", ConnectionMessageOptionUnknown, key) +} + +func (base *ConnectionImplBase) SetOptionBytes(key string, val []byte) error { + return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'", ConnectionMessageOptionUnknown, key) +} + +func (base *ConnectionImplBase) SetOptionDouble(key string, val float64) error { + return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'", ConnectionMessageOptionUnknown, key) +} + +func (base *ConnectionImplBase) SetOptionInt(key string, val int64) error { + return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'", ConnectionMessageOptionUnknown, key) +} + +type connection struct { + ConnectionImpl + + dbObjectsEnumerator DbObjectsEnumerator + currentNamespacer CurrentNamespacer + driverInfoPreparer DriverInfoPreparer + tableTypeLister TableTypeLister + autocommitSetter AutocommitSetter +} + +type ConnectionBuilder struct { + connection *connection +} + +func NewConnectionBuilder(impl ConnectionImpl) *ConnectionBuilder { + return &ConnectionBuilder{connection: &connection{ConnectionImpl: impl}} +} + +func (b *ConnectionBuilder) WithDbObjectsEnumerator(helper DbObjectsEnumerator) *ConnectionBuilder { + if b == nil { + panic("nil ConnectionBuilder: cannot reuse after calling Connection()") + } + b.connection.dbObjectsEnumerator = helper + return b +} + +func (b *ConnectionBuilder) WithCurrentNamespacer(helper CurrentNamespacer) *ConnectionBuilder { + if b == nil { + panic("nil ConnectionBuilder: cannot reuse after calling Connection()") + } + b.connection.currentNamespacer = helper + return b +} + +func (b *ConnectionBuilder) WithDriverInfoPreparer(helper DriverInfoPreparer) *ConnectionBuilder { + if b == nil { + panic("nil ConnectionBuilder: cannot reuse after calling Connection()") + } + b.connection.driverInfoPreparer = helper + return b +} + +func (b *ConnectionBuilder) WithAutocommitSetter(helper AutocommitSetter) *ConnectionBuilder { + if b == nil { + panic("nil ConnectionBuilder: cannot reuse after calling Connection()") + } + b.connection.autocommitSetter = helper + return b +} + +func (b *ConnectionBuilder) WithTableTypeLister(helper TableTypeLister) *ConnectionBuilder { + if b == nil { + panic("nil ConnectionBuilder: cannot reuse after calling Connection()") + } + b.connection.tableTypeLister = helper + return b +} + +func (b *ConnectionBuilder) Connection() Connection { + conn := b.connection + b.connection = nil + return conn +} + +// GetObjects implements Connection. +func (cnxn *connection) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (array.RecordReader, error) { + helper := cnxn.dbObjectsEnumerator + + // If the dbObjectsEnumerator has not been set, then the driver implementor has elected to provide their own GetObjects implementation + if helper == nil { + return cnxn.ConnectionImpl.GetObjects(ctx, depth, catalog, dbSchema, tableName, columnName, tableType) + } + + // To avoid an N+1 query problem, we assume result sets here will fit in memory and build up a single response. + g := internal.GetObjects{Ctx: ctx, Depth: depth, Catalog: catalog, DbSchema: dbSchema, TableName: tableName, ColumnName: columnName, TableType: tableType} + if err := g.Init(cnxn.Base().Alloc, helper.GetObjectsDbSchemas, helper.GetObjectsTables); err != nil { + return nil, err + } + defer g.Release() + + catalogs, err := helper.GetObjectsCatalogs(ctx, catalog) + if err != nil { + return nil, err + } + + foundCatalog := false + for _, catalog := range catalogs { + g.AppendCatalog(catalog) + foundCatalog = true + } + + // Implementations like Dremio report no catalogs, but still have schemas + if !foundCatalog && depth != adbc.ObjectDepthCatalogs { + g.AppendCatalog("") + } + return g.Finish() +} + +func (cnxn *connection) GetOption(key string) (string, error) { + switch key { + case adbc.OptionKeyAutoCommit: + if cnxn.Base().Autocommit { + return adbc.OptionValueEnabled, nil + } else { + return adbc.OptionValueDisabled, nil + } + case adbc.OptionKeyCurrentCatalog: + if cnxn.currentNamespacer != nil { + val, err := cnxn.currentNamespacer.GetCurrentCatalog() + if err != nil { + return "", cnxn.Base().ErrorHelper.Errorf(adbc.StatusNotFound, "failed to get current catalog: %s", err) + } + return val, nil + } + case adbc.OptionKeyCurrentDbSchema: + if cnxn.currentNamespacer != nil { + val, err := cnxn.currentNamespacer.GetCurrentDbSchema() + if err != nil { + return "", cnxn.Base().ErrorHelper.Errorf(adbc.StatusNotFound, "failed to get current db schema: %s", err) + } + return val, nil + } + } + return cnxn.ConnectionImpl.GetOption(key) +} + +func (cnxn *connection) SetOption(key string, val string) error { + switch key { + case adbc.OptionKeyAutoCommit: + if cnxn.autocommitSetter != nil { + + var autocommit bool + switch val { + case adbc.OptionValueEnabled: + autocommit = true + case adbc.OptionValueDisabled: + autocommit = false + default: + return cnxn.Base().ErrorHelper.Errorf(adbc.StatusInvalidArgument, "cannot set value %s for key %s", val, key) + } + + err := cnxn.autocommitSetter.SetAutocommit(autocommit) + if err == nil { + // Only update the driver state if the action was successful + cnxn.Base().Autocommit = autocommit + } + + return err + } + case adbc.OptionKeyCurrentCatalog: + if cnxn.currentNamespacer != nil { + return cnxn.currentNamespacer.SetCurrentCatalog(val) + } + case adbc.OptionKeyCurrentDbSchema: + if cnxn.currentNamespacer != nil { + return cnxn.currentNamespacer.SetCurrentDbSchema(val) + } + } + return cnxn.ConnectionImpl.SetOption(key, val) +} + +func (cnxn *connection) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.RecordReader, error) { + if cnxn.driverInfoPreparer != nil { + if err := cnxn.driverInfoPreparer.PrepareDriverInfo(ctx, infoCodes); err != nil { + return nil, err + } + } + + return cnxn.Base().GetInfo(ctx, infoCodes) +} + +func (cnxn *connection) GetTableTypes(ctx context.Context) (array.RecordReader, error) { + if cnxn.tableTypeLister == nil { + return cnxn.ConnectionImpl.GetTableTypes(ctx) + } + + tableTypes, err := cnxn.tableTypeLister.ListTableTypes(ctx) + if err != nil { + return nil, err + } + + bldr := array.NewRecordBuilder(cnxn.Base().Alloc, adbc.TableTypesSchema) + defer bldr.Release() + + bldr.Field(0).(*array.StringBuilder).AppendValues(tableTypes, nil) + final := bldr.NewRecord() + defer final.Release() + return array.NewRecordReader(adbc.TableTypesSchema, []arrow.Record{final}) +} + +func (cnxn *connection) Commit(ctx context.Context) error { + if cnxn.Base().Autocommit { + return cnxn.Base().ErrorHelper.Errorf(adbc.StatusInvalidState, ConnectionMessageCannotCommit) + } + return cnxn.ConnectionImpl.Commit(ctx) +} + +func (cnxn *connection) Rollback(ctx context.Context) error { + if cnxn.Base().Autocommit { + return cnxn.Base().ErrorHelper.Errorf(adbc.StatusInvalidState, ConnectionMessageCannotRollback) + } + return cnxn.ConnectionImpl.Rollback(ctx) +} + +func (cnxn *connection) Close() error { + if cnxn.Base().Closed { + return cnxn.Base().ErrorHelper.Errorf(adbc.StatusInvalidState, "Trying to close already closed connection") + } + + err := cnxn.ConnectionImpl.Close() + if err == nil { + cnxn.Base().Closed = true + } + + return err +} + +var _ ConnectionImpl = (*ConnectionImplBase)(nil) diff --git a/go/adbc/driver/driverbase/database.go b/go/adbc/driver/internal/driverbase/database.go similarity index 52% rename from go/adbc/driver/driverbase/database.go rename to go/adbc/driver/internal/driverbase/database.go index b08b77fcaa..9ab00967a5 100644 --- a/go/adbc/driver/driverbase/database.go +++ b/go/adbc/driver/internal/driverbase/database.go @@ -25,14 +25,24 @@ import ( "golang.org/x/exp/slog" ) +const ( + DatabaseMessageOptionUnknown = "Unknown database option" +) + // DatabaseImpl is an interface that drivers implement to provide // vendor-specific functionality. type DatabaseImpl interface { + adbc.Database adbc.GetSetOptions Base() *DatabaseImplBase - Open(context.Context) (adbc.Connection, error) - Close() error - SetOptions(map[string]string) error +} + +// Database is the interface satisfied by the result of the NewDatabase constructor, +// given an input is provided satisfying the DatabaseImpl interface. +type Database interface { + adbc.Database + adbc.GetSetOptions + adbc.DatabaseLogging } // DatabaseImplBase is a struct that provides default implementations of the @@ -41,14 +51,16 @@ type DatabaseImpl interface { type DatabaseImplBase struct { Alloc memory.Allocator ErrorHelper ErrorHelper + DriverInfo *DriverInfo Logger *slog.Logger } -// NewDatabaseImplBase instantiates DatabaseImplBase. name is the driver's -// name and is used to construct error messages. alloc is an Arrow allocator -// to use. +// NewDatabaseImplBase instantiates DatabaseImplBase. +// +// - driver is a DriverImplBase containing the common resources from the parent +// driver, allowing the Arrow allocator and error handler to be reused. func NewDatabaseImplBase(driver *DriverImplBase) DatabaseImplBase { - return DatabaseImplBase{Alloc: driver.Alloc, ErrorHelper: driver.ErrorHelper, Logger: nilLogger()} + return DatabaseImplBase{Alloc: driver.Alloc, ErrorHelper: driver.ErrorHelper, DriverInfo: driver.DriverInfo, Logger: nilLogger()} } func (base *DatabaseImplBase) Base() *DatabaseImplBase { @@ -56,97 +68,72 @@ func (base *DatabaseImplBase) Base() *DatabaseImplBase { } func (base *DatabaseImplBase) GetOption(key string) (string, error) { - return "", base.ErrorHelper.Errorf(adbc.StatusNotFound, "Unknown database option '%s'", key) + return "", base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'", DatabaseMessageOptionUnknown, key) } func (base *DatabaseImplBase) GetOptionBytes(key string) ([]byte, error) { - return nil, base.ErrorHelper.Errorf(adbc.StatusNotFound, "Unknown database option '%s'", key) + return nil, base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'", DatabaseMessageOptionUnknown, key) } func (base *DatabaseImplBase) GetOptionDouble(key string) (float64, error) { - return 0, base.ErrorHelper.Errorf(adbc.StatusNotFound, "Unknown database option '%s'", key) + return 0, base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'", DatabaseMessageOptionUnknown, key) } func (base *DatabaseImplBase) GetOptionInt(key string) (int64, error) { - return 0, base.ErrorHelper.Errorf(adbc.StatusNotFound, "Unknown database option '%s'", key) + return 0, base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'", DatabaseMessageOptionUnknown, key) } func (base *DatabaseImplBase) SetOption(key string, val string) error { - return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Unknown database option '%s'", key) + return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'", DatabaseMessageOptionUnknown, key) } func (base *DatabaseImplBase) SetOptionBytes(key string, val []byte) error { - return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Unknown database option '%s'", key) + return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'", DatabaseMessageOptionUnknown, key) } func (base *DatabaseImplBase) SetOptionDouble(key string, val float64) error { - return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Unknown database option '%s'", key) + return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'", DatabaseMessageOptionUnknown, key) } func (base *DatabaseImplBase) SetOptionInt(key string, val int64) error { - return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Unknown database option '%s'", key) -} - -// database is the implementation of adbc.Database. -type database struct { - impl DatabaseImpl -} - -// NewDatabase wraps a DatabaseImpl to create an adbc.Database. -func NewDatabase(impl DatabaseImpl) adbc.Database { - return &database{ - impl: impl, - } -} - -func (db *database) GetOption(key string) (string, error) { - return db.impl.GetOption(key) -} - -func (db *database) GetOptionBytes(key string) ([]byte, error) { - return db.impl.GetOptionBytes(key) -} - -func (db *database) GetOptionDouble(key string) (float64, error) { - return db.impl.GetOptionDouble(key) -} - -func (db *database) GetOptionInt(key string) (int64, error) { - return db.impl.GetOptionInt(key) -} - -func (db *database) SetOption(key string, val string) error { - return db.impl.SetOption(key, val) + return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'", DatabaseMessageOptionUnknown, key) } -func (db *database) SetOptionBytes(key string, val []byte) error { - return db.impl.SetOptionBytes(key, val) +func (base *DatabaseImplBase) Close() error { + return nil } -func (db *database) SetOptionDouble(key string, val float64) error { - return db.impl.SetOptionDouble(key, val) +func (base *DatabaseImplBase) Open(ctx context.Context) (adbc.Connection, error) { + return nil, base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Open") } -func (db *database) SetOptionInt(key string, val int64) error { - return db.impl.SetOptionInt(key, val) +func (base *DatabaseImplBase) SetOptions(options map[string]string) error { + for key, val := range options { + if err := base.SetOption(key, val); err != nil { + return err + } + } + return nil } -func (db *database) Open(ctx context.Context) (adbc.Connection, error) { - return db.impl.Open(ctx) +// database is the implementation of adbc.Database. +type database struct { + DatabaseImpl } -func (db *database) Close() error { - return db.impl.Close() +// NewDatabase wraps a DatabaseImpl to create an adbc.Database. +func NewDatabase(impl DatabaseImpl) Database { + return &database{ + DatabaseImpl: impl, + } } func (db *database) SetLogger(logger *slog.Logger) { if logger != nil { - db.impl.Base().Logger = logger + db.Base().Logger = logger } else { - db.impl.Base().Logger = nilLogger() + db.Base().Logger = nilLogger() } } -func (db *database) SetOptions(opts map[string]string) error { - return db.impl.SetOptions(opts) -} +var _ DatabaseImpl = (*DatabaseImplBase)(nil) diff --git a/go/adbc/driver/internal/driverbase/driver.go b/go/adbc/driver/internal/driverbase/driver.go new file mode 100644 index 0000000000..bd3e11c086 --- /dev/null +++ b/go/adbc/driver/internal/driverbase/driver.go @@ -0,0 +1,116 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Package driverbase provides a framework for implementing ADBC drivers in +// Go. It intends to reduce boilerplate for common functionality and managing +// state transitions. +package driverbase + +import ( + "runtime/debug" + "strings" + + "github.com/apache/arrow-adbc/go/adbc" + "github.com/apache/arrow/go/v16/arrow/memory" +) + +var ( + infoDriverVersion string + infoDriverArrowVersion string +) + +func init() { + if info, ok := debug.ReadBuildInfo(); ok { + for _, dep := range info.Deps { + switch { + case dep.Path == "github.com/apache/arrow-adbc/go/adbc": + infoDriverVersion = dep.Version + case strings.HasPrefix(dep.Path, "github.com/apache/arrow/go/"): + infoDriverArrowVersion = dep.Version + } + } + } +} + +// DriverImpl is an interface that drivers implement to provide +// vendor-specific functionality. +type DriverImpl interface { + adbc.Driver + Base() *DriverImplBase +} + +// Driver is the interface satisfied by the result of the NewDriver constructor, +// given an input is provided satisfying the DriverImpl interface. +type Driver interface { + adbc.Driver +} + +// DriverImplBase is a struct that provides default implementations of the +// DriverImpl interface. It is meant to be used as a composite struct for a +// driver's DriverImpl implementation. +type DriverImplBase struct { + Alloc memory.Allocator + ErrorHelper ErrorHelper + DriverInfo *DriverInfo +} + +func (base *DriverImplBase) NewDatabase(opts map[string]string) (adbc.Database, error) { + return nil, base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "NewDatabase") +} + +// NewDriverImplBase instantiates DriverImplBase. +// +// - info contains build and vendor info, as well as the name to construct error messages. +// - alloc is an Arrow allocator to use. +func NewDriverImplBase(info *DriverInfo, alloc memory.Allocator) DriverImplBase { + if alloc == nil { + alloc = memory.DefaultAllocator + } + + if infoDriverVersion != "" { + if err := info.RegisterInfoCode(adbc.InfoDriverVersion, infoDriverVersion); err != nil { + panic(err) + } + } + + if infoDriverArrowVersion != "" { + if err := info.RegisterInfoCode(adbc.InfoDriverArrowVersion, infoDriverArrowVersion); err != nil { + panic(err) + } + } + + return DriverImplBase{ + Alloc: alloc, + ErrorHelper: ErrorHelper{DriverName: info.GetName()}, + DriverInfo: info, + } +} + +func (base *DriverImplBase) Base() *DriverImplBase { + return base +} + +type driver struct { + DriverImpl +} + +// NewDriver wraps a DriverImpl to create a Driver. +func NewDriver(impl DriverImpl) Driver { + return &driver{DriverImpl: impl} +} + +var _ DriverImpl = (*DriverImplBase)(nil) diff --git a/go/adbc/driver/internal/driverbase/driver_info.go b/go/adbc/driver/internal/driverbase/driver_info.go new file mode 100644 index 0000000000..e68aa16c2c --- /dev/null +++ b/go/adbc/driver/internal/driverbase/driver_info.go @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package driverbase + +import ( + "fmt" + "sort" + + "github.com/apache/arrow-adbc/go/adbc" +) + +const ( + UnknownVersion = "(unknown or development build)" + DefaultInfoDriverADBCVersion = adbc.AdbcVersion1_1_0 +) + +func DefaultDriverInfo(name string) *DriverInfo { + defaultInfoVendorName := name + defaultInfoDriverName := fmt.Sprintf("ADBC %s Driver - Go", name) + + return &DriverInfo{ + name: name, + info: map[adbc.InfoCode]any{ + adbc.InfoVendorName: defaultInfoVendorName, + adbc.InfoDriverName: defaultInfoDriverName, + adbc.InfoDriverVersion: UnknownVersion, + adbc.InfoDriverArrowVersion: UnknownVersion, + adbc.InfoVendorVersion: UnknownVersion, + adbc.InfoVendorArrowVersion: UnknownVersion, + adbc.InfoDriverADBCVersion: DefaultInfoDriverADBCVersion, + }, + } +} + +type DriverInfo struct { + name string + info map[adbc.InfoCode]any +} + +func (di *DriverInfo) GetName() string { return di.name } + +func (di *DriverInfo) InfoSupportedCodes() []adbc.InfoCode { + // The keys of the info map are used to determine which info codes are supported. + // This means that any info codes the driver knows about should be set to some default + // at init, even if we don't know the value yet. + codes := make([]adbc.InfoCode, 0, len(di.info)) + for code := range di.info { + codes = append(codes, code) + } + + // Sorting info codes helps present them to the client in a consistent way. + // It also helps add some determinism to internal tests. + // The ordering is in no way part of the API contract and should not be relied upon. + sort.SliceStable(codes, func(i, j int) bool { + return codes[i] < codes[j] + }) + return codes +} + +func (di *DriverInfo) RegisterInfoCode(code adbc.InfoCode, value any) error { + switch code { + case adbc.InfoVendorName: + if err := ensureType[string](value); err != nil { + return fmt.Errorf("info_code %d: %w", code, err) + } + case adbc.InfoVendorVersion: + if err := ensureType[string](value); err != nil { + return fmt.Errorf("info_code %d: %w", code, err) + } + case adbc.InfoVendorArrowVersion: + if err := ensureType[string](value); err != nil { + return fmt.Errorf("info_code %d: %w", code, err) + } + case adbc.InfoDriverName: + if err := ensureType[string](value); err != nil { + return fmt.Errorf("info_code %d: %w", code, err) + } + case adbc.InfoDriverVersion: + if err := ensureType[string](value); err != nil { + return fmt.Errorf("info_code %d: %w", code, err) + } + case adbc.InfoDriverArrowVersion: + if err := ensureType[string](value); err != nil { + return fmt.Errorf("info_code %d: %w", code, err) + } + case adbc.InfoDriverADBCVersion: + if err := ensureType[int64](value); err != nil { + return fmt.Errorf("info_code %d: %w", code, err) + } + } + + di.info[code] = value + return nil +} + +func (di *DriverInfo) GetInfoForInfoCode(code adbc.InfoCode) (any, bool) { + val, ok := di.info[code] + return val, ok +} + +func (di *DriverInfo) GetInfoVendorName() (string, bool) { + return di.getStringInfoCode(adbc.InfoVendorName) +} + +func (di *DriverInfo) GetInfoVendorVersion() (string, bool) { + return di.getStringInfoCode(adbc.InfoVendorVersion) +} + +func (di *DriverInfo) GetInfoVendorArrowVersion() (string, bool) { + return di.getStringInfoCode(adbc.InfoVendorArrowVersion) +} + +func (di *DriverInfo) GetInfoDriverName() (string, bool) { + return di.getStringInfoCode(adbc.InfoDriverName) +} + +func (di *DriverInfo) GetInfoDriverVersion() (string, bool) { + return di.getStringInfoCode(adbc.InfoDriverVersion) +} + +func (di *DriverInfo) GetInfoDriverArrowVersion() (string, bool) { + return di.getStringInfoCode(adbc.InfoDriverArrowVersion) +} + +func (di *DriverInfo) GetInfoDriverADBCVersion() (int64, bool) { + return di.getInt64InfoCode(adbc.InfoDriverADBCVersion) +} + +func (di *DriverInfo) getStringInfoCode(code adbc.InfoCode) (string, bool) { + val, ok := di.GetInfoForInfoCode(code) + if !ok { + return "", false + } + + if err := ensureType[string](val); err != nil { + panic(err) + } + + return val.(string), true +} + +func (di *DriverInfo) getInt64InfoCode(code adbc.InfoCode) (int64, bool) { + val, ok := di.GetInfoForInfoCode(code) + if !ok { + return int64(0), false + } + + if err := ensureType[int64](val); err != nil { + panic(err) + } + + return val.(int64), true +} + +func ensureType[T any](value any) error { + typedVal, ok := value.(T) + if !ok { + return fmt.Errorf("expected info_value %v to be of type %T but found %T", value, typedVal, value) + } + return nil +} diff --git a/go/adbc/driver/internal/driverbase/driver_info_test.go b/go/adbc/driver/internal/driverbase/driver_info_test.go new file mode 100644 index 0000000000..2bad25d056 --- /dev/null +++ b/go/adbc/driver/internal/driverbase/driver_info_test.go @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package driverbase_test + +import ( + "strings" + "testing" + + "github.com/apache/arrow-adbc/go/adbc" + "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" + "github.com/stretchr/testify/require" +) + +func TestDriverInfo(t *testing.T) { + driverInfo := driverbase.DefaultDriverInfo("test") + + // The provided name is used for ErrorHelper, certain info code values, etc + require.Equal(t, "test", driverInfo.GetName()) + + // These are the info codes that are set for every driver + expectedDefaultInfoCodes := []adbc.InfoCode{ + adbc.InfoVendorName, + adbc.InfoVendorVersion, + adbc.InfoVendorArrowVersion, + adbc.InfoDriverName, + adbc.InfoDriverVersion, + adbc.InfoDriverArrowVersion, + adbc.InfoDriverADBCVersion, + } + require.ElementsMatch(t, expectedDefaultInfoCodes, driverInfo.InfoSupportedCodes()) + + // We get some formatted default values out of the box + vendorName, ok := driverInfo.GetInfoVendorName() + require.True(t, ok) + require.Equal(t, "test", vendorName) + + driverName, ok := driverInfo.GetInfoDriverName() + require.True(t, ok) + require.Equal(t, "ADBC test Driver - Go", driverName) + + // We can register a string value to an info code that expects a string + require.NoError(t, driverInfo.RegisterInfoCode(adbc.InfoDriverVersion, "string_value")) + + // We cannot register a non-string value to that same info code + err := driverInfo.RegisterInfoCode(adbc.InfoDriverVersion, 123) + require.Error(t, err) + require.Equal(t, "info_code 101: expected info_value 123 to be of type string but found int", err.Error()) + + // We can also set vendor-specific info codes but they won't get type checked + require.NoError(t, driverInfo.RegisterInfoCode(adbc.InfoCode(10_001), "string_value")) + require.NoError(t, driverInfo.RegisterInfoCode(adbc.InfoCode(10_001), 123)) + + // Retrieving known info codes is type-safe + driverVersion, ok := driverInfo.GetInfoDriverName() + require.True(t, ok) + require.NotEmpty(t, strings.Clone(driverVersion)) // do string stuff + + adbcVersion, ok := driverInfo.GetInfoDriverADBCVersion() + require.True(t, ok) + require.NotEmpty(t, adbcVersion+int64(123)) // do int64 stuff + + // We can also retrieve arbitrary info codes, but the result's type must be asserted + arrowVersion, ok := driverInfo.GetInfoForInfoCode(adbc.InfoDriverArrowVersion) + require.True(t, ok) + _, ok = arrowVersion.(string) + require.True(t, ok) + + // We can check if info codes have been set or not + _, ok = driverInfo.GetInfoForInfoCode(adbc.InfoCode(10_001)) + require.True(t, ok) + _, ok = driverInfo.GetInfoForInfoCode(adbc.InfoCode(10_002)) + require.False(t, ok) +} diff --git a/go/adbc/driver/internal/driverbase/driver_test.go b/go/adbc/driver/internal/driverbase/driver_test.go new file mode 100644 index 0000000000..f43a049bbe --- /dev/null +++ b/go/adbc/driver/internal/driverbase/driver_test.go @@ -0,0 +1,595 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package driverbase_test + +import ( + "context" + "fmt" + "testing" + + "golang.org/x/exp/slog" + + "github.com/apache/arrow-adbc/go/adbc" + "github.com/apache/arrow-adbc/go/adbc/driver/internal" + "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" + "github.com/apache/arrow/go/v16/arrow" + "github.com/apache/arrow/go/v16/arrow/array" + "github.com/apache/arrow/go/v16/arrow/memory" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +const ( + OptionKeyRecognized = "recognized" + OptionKeyUnrecognized = "unrecognized" +) + +// NewDriver creates a new adbc.Driver for testing. In addition to a memory.Allocator, it takes +// a slog.Handler to use for all structured logging as well as a useHelpers flag to determine whether +// the test should register helper methods or use the default driverbase implementation. +func NewDriver(alloc memory.Allocator, handler slog.Handler, useHelpers bool) adbc.Driver { + info := driverbase.DefaultDriverInfo("MockDriver") + _ = info.RegisterInfoCode(adbc.InfoCode(10_001), "my custom info") + return driverbase.NewDriver(&driverImpl{DriverImplBase: driverbase.NewDriverImplBase(info, alloc), handler: handler, useHelpers: useHelpers}) +} + +func TestDefaultDriver(t *testing.T) { + var handler MockedHandler + handler.On("Handle", mock.Anything, mock.Anything).Return(nil) + + ctx := context.TODO() + alloc := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer alloc.AssertSize(t, 0) + + drv := NewDriver(alloc, &handler, false) // Do not use helper implementations; only default behavior + + db, err := drv.NewDatabase(nil) + require.NoError(t, err) + defer db.Close() + + require.NoError(t, db.SetOptions(map[string]string{OptionKeyRecognized: "should-pass"})) + + err = db.SetOptions(map[string]string{OptionKeyUnrecognized: "should-fail"}) + require.Error(t, err) + require.Equal(t, "Not Implemented: [MockDriver] Unknown database option 'unrecognized'", err.Error()) + + cnxn, err := db.Open(ctx) + require.NoError(t, err) + defer func() { + // Cannot close more than once + require.NoError(t, cnxn.Close()) + require.Error(t, cnxn.Close()) + }() + + err = cnxn.Commit(ctx) + require.Error(t, err) + require.Equal(t, "Invalid State: [MockDriver] Cannot commit when autocommit is enabled", err.Error()) + + err = cnxn.Rollback(ctx) + require.Error(t, err) + require.Equal(t, "Invalid State: [MockDriver] Cannot rollback when autocommit is enabled", err.Error()) + + info, err := cnxn.GetInfo(ctx, nil) + require.NoError(t, err) + getInfoTable := tableFromRecordReader(info) + defer getInfoTable.Release() + + // This is what the driverbase provided GetInfo result should look like out of the box, + // with one custom setting registered at initialization + expectedGetInfoTable, err := array.TableFromJSON(alloc, adbc.GetInfoSchema, []string{`[ + { + "info_name": 0, + "info_value": [0, "MockDriver"] + }, + { + "info_name": 1, + "info_value": [0, "(unknown or development build)"] + }, + { + "info_name": 2, + "info_value": [0, "(unknown or development build)"] + }, + { + "info_name": 100, + "info_value": [0, "ADBC MockDriver Driver - Go"] + }, + { + "info_name": 101, + "info_value": [0, "(unknown or development build)"] + }, + { + "info_name": 102, + "info_value": [0, "(unknown or development build)"] + }, + { + "info_name": 103, + "info_value": [2, 1001000] + }, + { + "info_name": 10001, + "info_value": [0, "my custom info"] + } + ]`}) + require.NoError(t, err) + defer expectedGetInfoTable.Release() + + require.Truef(t, array.TableEqual(expectedGetInfoTable, getInfoTable), "expected: %s\ngot: %s", expectedGetInfoTable, getInfoTable) + + _, err = cnxn.GetObjects(ctx, adbc.ObjectDepthAll, nil, nil, nil, nil, nil) + require.Error(t, err) + require.Equal(t, "Not Implemented: [MockDriver] GetObjects", err.Error()) + + _, err = cnxn.GetTableTypes(ctx) + require.Error(t, err) + require.Equal(t, "Not Implemented: [MockDriver] GetTableTypes", err.Error()) + + autocommit, err := cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyAutoCommit) + require.NoError(t, err) + require.Equal(t, adbc.OptionValueEnabled, autocommit) + + err = cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyAutoCommit, "false") + require.Error(t, err) + require.Equal(t, "Not Implemented: [MockDriver] Unsupported connection option 'adbc.connection.autocommit'", err.Error()) + + _, err = cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentCatalog) + require.Error(t, err) + require.Equal(t, "Not Found: [MockDriver] Unknown connection option 'adbc.connection.catalog'", err.Error()) + + err = cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyCurrentCatalog, "test_catalog") + require.Error(t, err) + require.Equal(t, "Not Implemented: [MockDriver] Unknown connection option 'adbc.connection.catalog'", err.Error()) + + // We passed a mock handler into the driver to use for logs, so we can check actual messages logged + expectedLogMessages := []logMessage{ + {Message: "Opening a new connection", Level: "INFO", Attrs: map[string]string{"withHelpers": "false"}}, + } + + logMessages := make([]logMessage, 0, len(handler.Calls)) + for _, call := range handler.Calls { + sr, ok := call.Arguments.Get(1).(slog.Record) + require.True(t, ok) + logMessages = append(logMessages, newLogMessage(sr)) + } + + for _, expected := range expectedLogMessages { + var found bool + for _, message := range logMessages { + if messagesEqual(message, expected) { + found = true + break + } + } + require.Truef(t, found, "expected message was never logged: %v", expected) + } + +} + +func TestCustomizedDriver(t *testing.T) { + var handler MockedHandler + handler.On("Handle", mock.Anything, mock.Anything).Return(nil) + + ctx := context.TODO() + alloc := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer alloc.AssertSize(t, 0) + + drv := NewDriver(alloc, &handler, true) // Use helper implementations + + db, err := drv.NewDatabase(nil) + require.NoError(t, err) + defer db.Close() + + require.NoError(t, db.SetOptions(map[string]string{OptionKeyRecognized: "should-pass"})) + + err = db.SetOptions(map[string]string{OptionKeyUnrecognized: "should-fail"}) + require.Error(t, err) + require.Equal(t, "Not Implemented: [MockDriver] Unknown database option 'unrecognized'", err.Error()) + + cnxn, err := db.Open(ctx) + require.NoError(t, err) + defer cnxn.Close() + + err = cnxn.Commit(ctx) + require.Error(t, err) + require.Equal(t, "Invalid State: [MockDriver] Cannot commit when autocommit is enabled", err.Error()) + + err = cnxn.Rollback(ctx) + require.Error(t, err) + require.Equal(t, "Invalid State: [MockDriver] Cannot rollback when autocommit is enabled", err.Error()) + + info, err := cnxn.GetInfo(ctx, nil) + require.NoError(t, err) + getInfoTable := tableFromRecordReader(info) + defer getInfoTable.Release() + + // This is the arrow table representation of GetInfo produced by merging: + // - the default DriverInfo set at initialization + // - the DriverInfo set once in the NewDriver constructor + // - the DriverInfo set dynamically when GetInfo is called by implementing DriverInfoPreparer interface + expectedGetInfoTable, err := array.TableFromJSON(alloc, adbc.GetInfoSchema, []string{`[ + { + "info_name": 0, + "info_value": [0, "MockDriver"] + }, + { + "info_name": 1, + "info_value": [0, "(unknown or development build)"] + }, + { + "info_name": 2, + "info_value": [0, "(unknown or development build)"] + }, + { + "info_name": 100, + "info_value": [0, "ADBC MockDriver Driver - Go"] + }, + { + "info_name": 101, + "info_value": [0, "(unknown or development build)"] + }, + { + "info_name": 102, + "info_value": [0, "(unknown or development build)"] + }, + { + "info_name": 103, + "info_value": [2, 1001000] + }, + { + "info_name": 10001, + "info_value": [0, "my custom info"] + }, + { + "info_name": 10002, + "info_value": [0, "this was fetched dynamically"] + } + ]`}) + require.NoError(t, err) + defer expectedGetInfoTable.Release() + + require.Truef(t, array.TableEqual(expectedGetInfoTable, getInfoTable), "expected: %s\ngot: %s", expectedGetInfoTable, getInfoTable) + + dbObjects, err := cnxn.GetObjects(ctx, adbc.ObjectDepthAll, nil, nil, nil, nil, nil) + require.NoError(t, err) + dbObjectsTable := tableFromRecordReader(dbObjects) + defer dbObjectsTable.Release() + + // This is the arrow table representation of the GetObjects output we get by implementing + // the simplified TableTypeLister interface + expectedDbObjectsTable, err := array.TableFromJSON(alloc, adbc.GetObjectsSchema, []string{`[ + { + "catalog_name": "default", + "catalog_db_schemas": [ + { + "db_schema_name": "public", + "db_schema_tables": [ + { + "table_name": "foo", + "table_type": "TABLE", + "table_columns": [], + "table_constraints": [] + } + ] + }, + { + "db_schema_name": "test", + "db_schema_tables": [ + { + "table_name": "bar", + "table_type": "TABLE", + "table_columns": [], + "table_constraints": [] + } + ] + } + ] + }, + { + "catalog_name": "my_db", + "catalog_db_schemas": [ + { + "db_schema_name": "public", + "db_schema_tables": [ + { + "table_name": "baz", + "table_type": "TABLE", + "table_columns": [], + "table_constraints": [] + } + ] + } + ] + } + ]`}) + require.NoError(t, err) + defer expectedDbObjectsTable.Release() + + require.Truef(t, array.TableEqual(expectedDbObjectsTable, dbObjectsTable), "expected: %s\ngot: %s", expectedDbObjectsTable, dbObjectsTable) + + tableTypes, err := cnxn.GetTableTypes(ctx) + require.NoError(t, err) + tableTypeTable := tableFromRecordReader(tableTypes) + defer tableTypeTable.Release() + + // This is the arrow table representation of the GetTableTypes output we get by implementing + // the simplified TableTypeLister interface + expectedTableTypesTable, err := array.TableFromJSON(alloc, adbc.TableTypesSchema, []string{`[ + { "table_type": "TABLE" }, + { "table_type": "VIEW" } + ]`}) + require.NoError(t, err) + defer expectedTableTypesTable.Release() + + require.Truef(t, array.TableEqual(expectedTableTypesTable, tableTypeTable), "expected: %s\ngot: %s", expectedTableTypesTable, tableTypeTable) + + autocommit, err := cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyAutoCommit) + require.NoError(t, err) + require.Equal(t, adbc.OptionValueEnabled, autocommit) + + // By implementing AutocommitSetter, we are able to successfully toggle autocommit + err = cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyAutoCommit, "false") + require.NoError(t, err) + + // We haven't implemented Commit, but we get NotImplemented instead of InvalidState because + // Autocommit has been explicitly disabled + err = cnxn.Commit(ctx) + require.Error(t, err) + require.Equal(t, "Not Implemented: [MockDriver] Commit", err.Error()) + + // By implementing CurrentNamespacer, we can now get/set the current catalog/dbschema + // Default current(catalog|dbSchema) is driver-specific, but the stub implementation falls back + // to a 'not found' error instead of 'not implemented' + _, err = cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentCatalog) + require.Error(t, err) + require.Equal(t, "Not Found: [MockDriver] failed to get current catalog: current catalog is not set", err.Error()) + + err = cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyCurrentCatalog, "test_catalog") + require.NoError(t, err) + + currentCatalog, err := cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentCatalog) + require.NoError(t, err) + require.Equal(t, "test_catalog", currentCatalog) + + _, err = cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentDbSchema) + require.Error(t, err) + require.Equal(t, "Not Found: [MockDriver] failed to get current db schema: current db schema is not set", err.Error()) + + err = cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyCurrentDbSchema, "test_schema") + require.NoError(t, err) + + currentDbSchema, err := cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentDbSchema) + require.NoError(t, err) + require.Equal(t, "test_schema", currentDbSchema) + + // We passed a mock handler into the driver to use for logs, so we can check actual messages logged + expectedLogMessages := []logMessage{ + {Message: "Opening a new connection", Level: "INFO", Attrs: map[string]string{"withHelpers": "true"}}, + {Message: "SetAutocommit", Level: "DEBUG", Attrs: map[string]string{"enabled": "false"}}, + {Message: "SetCurrentCatalog", Level: "DEBUG", Attrs: map[string]string{"val": "test_catalog"}}, + {Message: "SetCurrentDbSchema", Level: "DEBUG", Attrs: map[string]string{"val": "test_schema"}}, + } + + logMessages := make([]logMessage, 0, len(handler.Calls)) + for _, call := range handler.Calls { + sr, ok := call.Arguments.Get(1).(slog.Record) + require.True(t, ok) + logMessages = append(logMessages, newLogMessage(sr)) + } + + for _, expected := range expectedLogMessages { + var found bool + for _, message := range logMessages { + if messagesEqual(message, expected) { + found = true + break + } + } + require.Truef(t, found, "expected message was never logged: %v", expected) + } +} + +type driverImpl struct { + driverbase.DriverImplBase + + handler slog.Handler + useHelpers bool +} + +func (drv *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error) { + db := driverbase.NewDatabase( + &databaseImpl{DatabaseImplBase: driverbase.NewDatabaseImplBase(&drv.DriverImplBase), + drv: drv, + useHelpers: drv.useHelpers, + }) + db.SetLogger(slog.New(drv.handler)) + return db, nil +} + +type databaseImpl struct { + driverbase.DatabaseImplBase + drv *driverImpl + + useHelpers bool +} + +// SetOptions implements adbc.Database. +func (d *databaseImpl) SetOptions(options map[string]string) error { + for k, v := range options { + if err := d.SetOption(k, v); err != nil { + return err + } + } + return nil +} + +// Only need to implement keys we recognize. +// Any other values will fallthrough to default failure message. +func (d *databaseImpl) SetOption(key, value string) error { + switch key { + case OptionKeyRecognized: + _ = value // pretend to recognize the setting + return nil + } + return d.DatabaseImplBase.SetOption(key, value) +} + +func (db *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) { + db.DatabaseImplBase.Logger.Info("Opening a new connection", "withHelpers", db.useHelpers) + cnxn := &connectionImpl{ConnectionImplBase: driverbase.NewConnectionImplBase(&db.DatabaseImplBase), db: db} + bldr := driverbase.NewConnectionBuilder(cnxn) + if db.useHelpers { // this toggles between the NewDefaultDriver and NewCustomizedDriver scenarios + return bldr. + WithAutocommitSetter(cnxn). + WithCurrentNamespacer(cnxn). + WithTableTypeLister(cnxn). + WithDriverInfoPreparer(cnxn). + WithDbObjectsEnumerator(cnxn). + Connection(), nil + } + return bldr.Connection(), nil +} + +type connectionImpl struct { + driverbase.ConnectionImplBase + db *databaseImpl + + currentCatalog string + currentDbSchema string +} + +func (c *connectionImpl) SetAutocommit(enabled bool) error { + c.Base().Logger.Debug("SetAutocommit", "enabled", enabled) + return nil +} + +func (c *connectionImpl) GetCurrentCatalog() (string, error) { + if c.currentCatalog == "" { + return "", fmt.Errorf("current catalog is not set") + } + return c.currentCatalog, nil +} + +func (c *connectionImpl) GetCurrentDbSchema() (string, error) { + if c.currentDbSchema == "" { + return "", fmt.Errorf("current db schema is not set") + } + return c.currentDbSchema, nil +} + +func (c *connectionImpl) SetCurrentCatalog(val string) error { + c.Base().Logger.Debug("SetCurrentCatalog", "val", val) + c.currentCatalog = val + return nil +} + +func (c *connectionImpl) SetCurrentDbSchema(val string) error { + c.Base().Logger.Debug("SetCurrentDbSchema", "val", val) + c.currentDbSchema = val + return nil +} + +func (c *connectionImpl) ListTableTypes(ctx context.Context) ([]string, error) { + return []string{"TABLE", "VIEW"}, nil +} + +func (c *connectionImpl) PrepareDriverInfo(ctx context.Context, infoCodes []adbc.InfoCode) error { + return c.ConnectionImplBase.DriverInfo.RegisterInfoCode(adbc.InfoCode(10_002), "this was fetched dynamically") +} + +func (c *connectionImpl) GetObjectsCatalogs(ctx context.Context, catalog *string) ([]string, error) { + return []string{"default", "my_db"}, nil +} + +func (c *connectionImpl) GetObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, catalog *string, schema *string, metadataRecords []internal.Metadata) (map[string][]string, error) { + return map[string][]string{ + "default": {"public", "test"}, + "my_db": {"public"}, + }, nil +} + +func (c *connectionImpl) GetObjectsTables(ctx context.Context, depth adbc.ObjectDepth, catalog *string, schema *string, tableName *string, columnName *string, tableType []string, metadataRecords []internal.Metadata) (map[internal.CatalogAndSchema][]internal.TableInfo, error) { + return map[internal.CatalogAndSchema][]internal.TableInfo{ + {Catalog: "default", Schema: "public"}: {internal.TableInfo{Name: "foo", TableType: "TABLE"}}, + {Catalog: "default", Schema: "test"}: {internal.TableInfo{Name: "bar", TableType: "TABLE"}}, + {Catalog: "my_db", Schema: "public"}: {internal.TableInfo{Name: "baz", TableType: "TABLE"}}, + }, nil +} + +// MockedHandler is a mock.Mock that implements the slog.Handler interface. +// It is used to assert specific behavior for loggers it is injected into. +type MockedHandler struct { + mock.Mock +} + +func (h *MockedHandler) Enabled(ctx context.Context, level slog.Level) bool { return true } +func (h *MockedHandler) WithAttrs(attrs []slog.Attr) slog.Handler { return h } +func (h *MockedHandler) WithGroup(name string) slog.Handler { return h } +func (h *MockedHandler) Handle(ctx context.Context, r slog.Record) error { + // We only care to assert the message value, and want to isolate nondetermistic behavior (e.g. timestamp) + args := h.Called(ctx, r) + return args.Error(0) +} + +// logMessage is a container for log attributes we would like to compare for equality during tests. +// It intentionally omits timestamps and other sources of nondeterminism. +type logMessage struct { + Message string + Level string + Attrs map[string]string +} + +// newLogMessage constructs a logMessage from a slog.Record, containing only deterministic fields. +func newLogMessage(r slog.Record) logMessage { + message := logMessage{Message: r.Message, Level: r.Level.String(), Attrs: make(map[string]string)} + r.Attrs(func(a slog.Attr) bool { + message.Attrs[a.Key] = a.Value.String() + return true + }) + return message +} + +// messagesEqual compares two logMessages and returns whether they are equal. +func messagesEqual(expected, actual logMessage) bool { + if expected.Message != actual.Message { + return false + } + if expected.Level != actual.Level { + return false + } + if len(expected.Attrs) != len(actual.Attrs) { + return false + } + for k, v := range expected.Attrs { + if actual.Attrs[k] != v { + return false + } + } + return true +} + +func tableFromRecordReader(rdr array.RecordReader) arrow.Table { + defer rdr.Release() + + recs := make([]arrow.Record, 0) + for rdr.Next() { + rec := rdr.Record() + rec.Retain() + defer rec.Release() + recs = append(recs, rec) + } + return array.NewTableFromRecords(rdr.Schema(), recs) +} diff --git a/go/adbc/driver/driverbase/error.go b/go/adbc/driver/internal/driverbase/error.go similarity index 100% rename from go/adbc/driver/driverbase/error.go rename to go/adbc/driver/internal/driverbase/error.go diff --git a/go/adbc/driver/driverbase/logging.go b/go/adbc/driver/internal/driverbase/logging.go similarity index 100% rename from go/adbc/driver/driverbase/logging.go rename to go/adbc/driver/internal/driverbase/logging.go diff --git a/go/adbc/driver/snowflake/connection.go b/go/adbc/driver/snowflake/connection.go index 8252665070..4b023b0505 100644 --- a/go/adbc/driver/snowflake/connection.go +++ b/go/adbc/driver/snowflake/connection.go @@ -30,6 +30,7 @@ import ( "github.com/apache/arrow-adbc/go/adbc" "github.com/apache/arrow-adbc/go/adbc/driver/internal" + "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" "github.com/apache/arrow/go/v16/arrow" "github.com/apache/arrow/go/v16/arrow/array" "github.com/snowflakedb/gosnowflake" @@ -50,7 +51,9 @@ type snowflakeConn interface { QueryArrowStream(context.Context, string, ...driver.NamedValue) (gosnowflake.ArrowStreamLoader, error) } -type cnxn struct { +type connectionImpl struct { + driverbase.ConnectionImplBase + cn snowflakeConn db *databaseImpl ctor gosnowflake.Connector @@ -60,6 +63,59 @@ type cnxn struct { useHighPrecision bool } +// ListTableTypes implements driverbase.TableTypeLister. +func (*connectionImpl) ListTableTypes(ctx context.Context) ([]string, error) { + return []string{"BASE TABLE", "TEMPORARY TABLE", "VIEW"}, nil +} + +// GetCurrentCatalog implements driverbase.CurrentNamespacer. +func (c *connectionImpl) GetCurrentCatalog() (string, error) { + return c.getStringQuery("SELECT CURRENT_DATABASE()") +} + +// GetCurrentDbSchema implements driverbase.CurrentNamespacer. +func (c *connectionImpl) GetCurrentDbSchema() (string, error) { + return c.getStringQuery("SELECT CURRENT_SCHEMA()") +} + +// SetCurrentCatalog implements driverbase.CurrentNamespacer. +func (c *connectionImpl) SetCurrentCatalog(value string) error { + _, err := c.cn.ExecContext(context.Background(), "USE DATABASE ?", []driver.NamedValue{{Value: value}}) + return err +} + +// SetCurrentDbSchema implements driverbase.CurrentNamespacer. +func (c *connectionImpl) SetCurrentDbSchema(value string) error { + _, err := c.cn.ExecContext(context.Background(), "USE SCHEMA ?", []driver.NamedValue{{Value: value}}) + return err +} + +// SetAutocommit implements driverbase.AutocommitSetter. +func (c *connectionImpl) SetAutocommit(enabled bool) error { + if enabled { + if c.activeTransaction { + _, err := c.cn.ExecContext(context.Background(), "COMMIT", nil) + if err != nil { + return errToAdbcErr(adbc.StatusInternal, err) + } + c.activeTransaction = false + } + _, err := c.cn.ExecContext(context.Background(), "ALTER SESSION SET AUTOCOMMIT = true", nil) + return err + } + + if !c.activeTransaction { + _, err := c.cn.ExecContext(context.Background(), "BEGIN", nil) + if err != nil { + return errToAdbcErr(adbc.StatusInternal, err) + } + c.activeTransaction = true + } + _, err := c.cn.ExecContext(context.Background(), "ALTER SESSION SET AUTOCOMMIT = false", nil) + return err + +} + // Metadata methods // Generally these methods return an array.RecordReader that // can be consumed to retrieve metadata about the database as Arrow @@ -77,80 +133,6 @@ type cnxn struct { // characters, or "_" to match exactly one character. (See the // documentation of DatabaseMetaData in JDBC or "Pattern Value Arguments" // in the ODBC documentation.) Escaping is not currently supported. -// GetInfo returns metadata about the database/driver. -// -// The result is an Arrow dataset with the following schema: -// -// Field Name | Field Type -// ----------------------------|----------------------------- -// info_name | uint32 not null -// info_value | INFO_SCHEMA -// -// INFO_SCHEMA is a dense union with members: -// -// Field Name (Type Code) | Field Type -// ----------------------------|----------------------------- -// string_value (0) | utf8 -// bool_value (1) | bool -// int64_value (2) | int64 -// int32_bitmask (3) | int32 -// string_list (4) | list -// int32_to_int32_list_map (5) | map> -// -// Each metadatum is identified by an integer code. The recognized -// codes are defined as constants. Codes [0, 10_000) are reserved -// for ADBC usage. Drivers/vendors will ignore requests for unrecognized -// codes (the row will be omitted from the result). -func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.RecordReader, error) { - const strValTypeID arrow.UnionTypeCode = 0 - const intValTypeID arrow.UnionTypeCode = 2 - - if len(infoCodes) == 0 { - infoCodes = infoSupportedCodes - } - - bldr := array.NewRecordBuilder(c.db.Alloc, adbc.GetInfoSchema) - defer bldr.Release() - bldr.Reserve(len(infoCodes)) - - infoNameBldr := bldr.Field(0).(*array.Uint32Builder) - infoValueBldr := bldr.Field(1).(*array.DenseUnionBuilder) - strInfoBldr := infoValueBldr.Child(int(strValTypeID)).(*array.StringBuilder) - intInfoBldr := infoValueBldr.Child(int(intValTypeID)).(*array.Int64Builder) - - for _, code := range infoCodes { - switch code { - case adbc.InfoDriverName: - infoNameBldr.Append(uint32(code)) - infoValueBldr.Append(strValTypeID) - strInfoBldr.Append(infoDriverName) - case adbc.InfoDriverVersion: - infoNameBldr.Append(uint32(code)) - infoValueBldr.Append(strValTypeID) - strInfoBldr.Append(infoDriverVersion) - case adbc.InfoDriverArrowVersion: - infoNameBldr.Append(uint32(code)) - infoValueBldr.Append(strValTypeID) - strInfoBldr.Append(infoDriverArrowVersion) - case adbc.InfoDriverADBCVersion: - infoNameBldr.Append(uint32(code)) - infoValueBldr.Append(intValTypeID) - intInfoBldr.Append(adbc.AdbcVersion1_1_0) - case adbc.InfoVendorName: - infoNameBldr.Append(uint32(code)) - infoValueBldr.Append(strValTypeID) - strInfoBldr.Append(infoVendorName) - default: - infoNameBldr.Append(uint32(code)) - infoValueBldr.AppendNull() - } - } - - final := bldr.NewRecord() - defer final.Release() - return array.NewRecordReader(adbc.GetInfoSchema, []arrow.Record{final}) -} - // GetObjects gets a hierarchical view of all catalogs, database schemas, // tables, and columns. // @@ -238,7 +220,7 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re // // All non-empty, non-nil strings should be a search pattern (as described // earlier). -func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (array.RecordReader, error) { +func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (array.RecordReader, error) { metadataRecords, err := c.populateMetadata(ctx, depth, catalog, dbSchema, tableName, columnName, tableType) if err != nil { return nil, err @@ -266,7 +248,7 @@ func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog * return g.Finish() } -func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, metadataRecords []internal.Metadata) (result map[string][]string, err error) { +func (c *connectionImpl) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, metadataRecords []internal.Metadata) (result map[string][]string, err error) { if depth == adbc.ObjectDepthCatalogs { return } @@ -452,7 +434,7 @@ func toXdbcDataType(dt arrow.DataType) (xdbcType internal.XdbcDataType) { } } -func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string, metadataRecords []internal.Metadata) (result internal.SchemaToTableInfo, err error) { +func (c *connectionImpl) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string, metadataRecords []internal.Metadata) (result internal.SchemaToTableInfo, err error) { if depth == adbc.ObjectDepthCatalogs || depth == adbc.ObjectDepthDBSchemas { return } @@ -524,7 +506,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat return } -func (c *cnxn) populateMetadata(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) ([]internal.Metadata, error) { +func (c *connectionImpl) populateMetadata(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) ([]internal.Metadata, error) { var metadataRecords []internal.Metadata catalogMetadataRecords, err := c.getCatalogsMetadata(ctx) if err != nil { @@ -556,7 +538,7 @@ func (c *cnxn) populateMetadata(ctx context.Context, depth adbc.ObjectDepth, cat return metadataRecords, nil } -func (c *cnxn) getCatalogsMetadata(ctx context.Context) ([]internal.Metadata, error) { +func (c *connectionImpl) getCatalogsMetadata(ctx context.Context) ([]internal.Metadata, error) { metadataRecords := make([]internal.Metadata, 0) rows, err := c.sqldb.QueryContext(ctx, prepareCatalogsSQL(), nil) @@ -585,7 +567,7 @@ func (c *cnxn) getCatalogsMetadata(ctx context.Context) ([]internal.Metadata, er return metadataRecords, nil } -func (c *cnxn) getDbSchemasMetadata(ctx context.Context, matchingCatalogNames []string, catalog *string, dbSchema *string) ([]internal.Metadata, error) { +func (c *connectionImpl) getDbSchemasMetadata(ctx context.Context, matchingCatalogNames []string, catalog *string, dbSchema *string) ([]internal.Metadata, error) { var metadataRecords []internal.Metadata query, queryArgs := prepareDbSchemasSQL(matchingCatalogNames, catalog, dbSchema) rows, err := c.sqldb.QueryContext(ctx, query, queryArgs...) @@ -604,7 +586,7 @@ func (c *cnxn) getDbSchemasMetadata(ctx context.Context, matchingCatalogNames [] return metadataRecords, nil } -func (c *cnxn) getTablesMetadata(ctx context.Context, matchingCatalogNames []string, catalog *string, dbSchema *string, tableName *string, tableType []string) ([]internal.Metadata, error) { +func (c *connectionImpl) getTablesMetadata(ctx context.Context, matchingCatalogNames []string, catalog *string, dbSchema *string, tableName *string, tableType []string) ([]internal.Metadata, error) { metadataRecords := make([]internal.Metadata, 0) query, queryArgs := prepareTablesSQL(matchingCatalogNames, catalog, dbSchema, tableName, tableType) rows, err := c.sqldb.QueryContext(ctx, query, queryArgs...) @@ -623,7 +605,7 @@ func (c *cnxn) getTablesMetadata(ctx context.Context, matchingCatalogNames []str return metadataRecords, nil } -func (c *cnxn) getColumnsMetadata(ctx context.Context, matchingCatalogNames []string, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) ([]internal.Metadata, error) { +func (c *connectionImpl) getColumnsMetadata(ctx context.Context, matchingCatalogNames []string, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) ([]internal.Metadata, error) { metadataRecords := make([]internal.Metadata, 0) query, queryArgs := prepareColumnsSQL(matchingCatalogNames, catalog, dbSchema, tableName, columnName, tableType) rows, err := c.sqldb.QueryContext(ctx, query, queryArgs...) @@ -870,29 +852,7 @@ func descToField(name, typ, isnull, primary string, comment sql.NullString) (fie return } -func (c *cnxn) GetOption(key string) (string, error) { - switch key { - case adbc.OptionKeyAutoCommit: - if c.activeTransaction { - // No autocommit - return adbc.OptionValueDisabled, nil - } else { - // Autocommit - return adbc.OptionValueEnabled, nil - } - case adbc.OptionKeyCurrentCatalog: - return c.getStringQuery("SELECT CURRENT_DATABASE()") - case adbc.OptionKeyCurrentDbSchema: - return c.getStringQuery("SELECT CURRENT_SCHEMA()") - } - - return "", adbc.Error{ - Msg: "[Snowflake] unknown connection option", - Code: adbc.StatusNotFound, - } -} - -func (c *cnxn) getStringQuery(query string) (string, error) { +func (c *connectionImpl) getStringQuery(query string) (string, error) { result, err := c.cn.QueryContext(context.Background(), query, nil) if err != nil { return "", errToAdbcErr(adbc.StatusInternal, err) @@ -928,28 +888,7 @@ func (c *cnxn) getStringQuery(query string) (string, error) { return value, nil } -func (c *cnxn) GetOptionBytes(key string) ([]byte, error) { - return nil, adbc.Error{ - Msg: "[Snowflake] unknown connection option", - Code: adbc.StatusNotFound, - } -} - -func (c *cnxn) GetOptionInt(key string) (int64, error) { - return 0, adbc.Error{ - Msg: "[Snowflake] unknown connection option", - Code: adbc.StatusNotFound, - } -} - -func (c *cnxn) GetOptionDouble(key string) (float64, error) { - return 0.0, adbc.Error{ - Msg: "[Snowflake] unknown connection option", - Code: adbc.StatusNotFound, - } -} - -func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *string, tableName string) (*arrow.Schema, error) { +func (c *connectionImpl) GetTableSchema(ctx context.Context, catalog *string, dbSchema *string, tableName string) (*arrow.Schema, error) { tblParts := make([]string, 0, 3) if catalog != nil { tblParts = append(tblParts, strconv.Quote(*catalog)) @@ -990,35 +929,11 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st return sc, nil } -// GetTableTypes returns a list of the table types in the database. -// -// The result is an arrow dataset with the following schema: -// -// Field Name | Field Type -// ----------------|-------------- -// table_type | utf8 not null -func (c *cnxn) GetTableTypes(_ context.Context) (array.RecordReader, error) { - bldr := array.NewRecordBuilder(c.db.Alloc, adbc.TableTypesSchema) - defer bldr.Release() - - bldr.Field(0).(*array.StringBuilder).AppendValues([]string{"BASE TABLE", "TEMPORARY TABLE", "VIEW"}, nil) - final := bldr.NewRecord() - defer final.Release() - return array.NewRecordReader(adbc.TableTypesSchema, []arrow.Record{final}) -} - // Commit commits any pending transactions on this connection, it should // only be used if autocommit is disabled. // // Behavior is undefined if this is mixed with SQL transaction statements. -func (c *cnxn) Commit(_ context.Context) error { - if !c.activeTransaction { - return adbc.Error{ - Msg: "no active transaction, cannot commit", - Code: adbc.StatusInvalidState, - } - } - +func (c *connectionImpl) Commit(_ context.Context) error { _, err := c.cn.ExecContext(context.Background(), "COMMIT", nil) if err != nil { return errToAdbcErr(adbc.StatusInternal, err) @@ -1032,14 +947,7 @@ func (c *cnxn) Commit(_ context.Context) error { // is disabled. // // Behavior is undefined if this is mixed with SQL transaction statements. -func (c *cnxn) Rollback(_ context.Context) error { - if !c.activeTransaction { - return adbc.Error{ - Msg: "no active transaction, cannot rollback", - Code: adbc.StatusInvalidState, - } - } - +func (c *connectionImpl) Rollback(_ context.Context) error { _, err := c.cn.ExecContext(context.Background(), "ROLLBACK", nil) if err != nil { return errToAdbcErr(adbc.StatusInternal, err) @@ -1050,7 +958,7 @@ func (c *cnxn) Rollback(_ context.Context) error { } // NewStatement initializes a new statement object tied to this connection -func (c *cnxn) NewStatement() (adbc.Statement, error) { +func (c *connectionImpl) NewStatement() (adbc.Statement, error) { defaultIngestOptions := DefaultIngestOptions() return &statement{ alloc: c.db.Alloc, @@ -1063,7 +971,7 @@ func (c *cnxn) NewStatement() (adbc.Statement, error) { } // Close closes this connection and releases any associated resources. -func (c *cnxn) Close() error { +func (c *connectionImpl) Close() error { if c.sqldb == nil || c.cn == nil { return adbc.Error{Code: adbc.StatusInvalidState} } @@ -1083,49 +991,15 @@ func (c *cnxn) Close() error { // results can then be read independently using the returned RecordReader. // // A partition can be retrieved by using ExecutePartitions on a statement. -func (c *cnxn) ReadPartition(ctx context.Context, serializedPartition []byte) (array.RecordReader, error) { +func (c *connectionImpl) ReadPartition(ctx context.Context, serializedPartition []byte) (array.RecordReader, error) { return nil, adbc.Error{ Code: adbc.StatusNotImplemented, Msg: "ReadPartition not yet implemented for snowflake driver", } } -func (c *cnxn) SetOption(key, value string) error { +func (c *connectionImpl) SetOption(key, value string) error { switch key { - case adbc.OptionKeyAutoCommit: - switch value { - case adbc.OptionValueEnabled: - if c.activeTransaction { - _, err := c.cn.ExecContext(context.Background(), "COMMIT", nil) - if err != nil { - return errToAdbcErr(adbc.StatusInternal, err) - } - c.activeTransaction = false - } - _, err := c.cn.ExecContext(context.Background(), "ALTER SESSION SET AUTOCOMMIT = true", nil) - return err - case adbc.OptionValueDisabled: - if !c.activeTransaction { - _, err := c.cn.ExecContext(context.Background(), "BEGIN", nil) - if err != nil { - return errToAdbcErr(adbc.StatusInternal, err) - } - c.activeTransaction = true - } - _, err := c.cn.ExecContext(context.Background(), "ALTER SESSION SET AUTOCOMMIT = false", nil) - return err - default: - return adbc.Error{ - Msg: "[Snowflake] invalid value for option " + key + ": " + value, - Code: adbc.StatusInvalidArgument, - } - } - case adbc.OptionKeyCurrentCatalog: - _, err := c.cn.ExecContext(context.Background(), "USE DATABASE ?", []driver.NamedValue{{Value: value}}) - return err - case adbc.OptionKeyCurrentDbSchema: - _, err := c.cn.ExecContext(context.Background(), "USE SCHEMA ?", []driver.NamedValue{{Value: value}}) - return err case OptionUseHighPrecision: // statements will inherit the value of the OptionUseHighPrecision // from the connection, but the option can be overridden at the @@ -1149,24 +1023,3 @@ func (c *cnxn) SetOption(key, value string) error { } } } - -func (c *cnxn) SetOptionBytes(key string, value []byte) error { - return adbc.Error{ - Msg: "[Snowflake] unknown connection option", - Code: adbc.StatusNotImplemented, - } -} - -func (c *cnxn) SetOptionInt(key string, value int64) error { - return adbc.Error{ - Msg: "[Snowflake] unknown connection option", - Code: adbc.StatusNotImplemented, - } -} - -func (c *cnxn) SetOptionDouble(key string, value float64) error { - return adbc.Error{ - Msg: "[Snowflake] unknown connection option", - Code: adbc.StatusNotImplemented, - } -} diff --git a/go/adbc/driver/snowflake/driver.go b/go/adbc/driver/snowflake/driver.go index 77bdcdaac9..124c4d3887 100644 --- a/go/adbc/driver/snowflake/driver.go +++ b/go/adbc/driver/snowflake/driver.go @@ -20,19 +20,15 @@ package snowflake import ( "errors" "runtime/debug" - "strings" "github.com/apache/arrow-adbc/go/adbc" - "github.com/apache/arrow-adbc/go/adbc/driver/driverbase" + "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" "github.com/apache/arrow/go/v16/arrow/memory" "github.com/snowflakedb/gosnowflake" "golang.org/x/exp/maps" ) const ( - infoDriverName = "ADBC Snowflake Driver - Go" - infoVendorName = "Snowflake" - OptionDatabase = "adbc.snowflake.sql.db" OptionSchema = "adbc.snowflake.sql.schema" OptionWarehouse = "adbc.snowflake.sql.warehouse" @@ -119,37 +115,18 @@ const ( ) var ( - infoDriverVersion string - infoDriverArrowVersion string - infoSupportedCodes []adbc.InfoCode + infoVendorVersion string ) func init() { if info, ok := debug.ReadBuildInfo(); ok { for _, dep := range info.Deps { switch { - case dep.Path == "github.com/apache/arrow-adbc/go/adbc/driver/snowflake": - infoDriverVersion = dep.Version - case strings.HasPrefix(dep.Path, "github.com/apache/arrow/go/"): - infoDriverArrowVersion = dep.Version + case dep.Path == "github.com/snowflakedb/gosnowflake": + infoVendorVersion = dep.Version } } } - // XXX: Deps not populated in tests - // https://github.com/golang/go/issues/33976 - if infoDriverVersion == "" { - infoDriverVersion = "(unknown or development build)" - } - if infoDriverArrowVersion == "" { - infoDriverArrowVersion = "(unknown or development build)" - } - - infoSupportedCodes = []adbc.InfoCode{ - adbc.InfoDriverName, - adbc.InfoDriverVersion, - adbc.InfoDriverArrowVersion, - adbc.InfoVendorName, - } } func errToAdbcErr(code adbc.Status, err error) error { @@ -192,13 +169,21 @@ type driverImpl struct { // NewDriver creates a new Snowflake driver using the given Arrow allocator. func NewDriver(alloc memory.Allocator) adbc.Driver { - return driverbase.NewDriver(&driverImpl{DriverImplBase: driverbase.NewDriverImplBase("Snowflake", alloc)}) + info := driverbase.DefaultDriverInfo("Snowflake") + if infoVendorVersion != "" { + if err := info.RegisterInfoCode(adbc.InfoVendorVersion, infoVendorVersion); err != nil { + panic(err) + } + } + return driverbase.NewDriver(&driverImpl{DriverImplBase: driverbase.NewDriverImplBase(info, alloc)}) } func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error) { opts = maps.Clone(opts) - db := &databaseImpl{DatabaseImplBase: driverbase.NewDatabaseImplBase(&d.DriverImplBase), - useHighPrecision: true} + db := &databaseImpl{ + DatabaseImplBase: driverbase.NewDatabaseImplBase(&d.DriverImplBase), + useHighPrecision: true, + } if err := db.SetOptions(opts); err != nil { return nil, err } diff --git a/go/adbc/driver/snowflake/driver_test.go b/go/adbc/driver/snowflake/driver_test.go index 5752ae5eec..3f93dbdb58 100644 --- a/go/adbc/driver/snowflake/driver_test.go +++ b/go/adbc/driver/snowflake/driver_test.go @@ -217,6 +217,10 @@ func (s *SnowflakeQuirks) GetMetadata(code adbc.InfoCode) interface{} { return "(unknown or development build)" case adbc.InfoDriverArrowVersion: return "(unknown or development build)" + case adbc.InfoVendorVersion: + return "(unknown or development build)" + case adbc.InfoVendorArrowVersion: + return "(unknown or development build)" case adbc.InfoDriverADBCVersion: return adbc.AdbcVersion1_1_0 case adbc.InfoVendorName: diff --git a/go/adbc/driver/snowflake/snowflake_database.go b/go/adbc/driver/snowflake/snowflake_database.go index 76ab4684bf..5c5f32b690 100644 --- a/go/adbc/driver/snowflake/snowflake_database.go +++ b/go/adbc/driver/snowflake/snowflake_database.go @@ -32,7 +32,7 @@ import ( "time" "github.com/apache/arrow-adbc/go/adbc" - "github.com/apache/arrow-adbc/go/adbc/driver/driverbase" + "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" "github.com/snowflakedb/gosnowflake" "github.com/youmark/pkcs8" ) @@ -136,28 +136,7 @@ func (d *databaseImpl) GetOption(key string) (string, error) { return *val, nil } } - return "", adbc.Error{ - Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), - Code: adbc.StatusNotFound, - } -} -func (d *databaseImpl) GetOptionBytes(key string) ([]byte, error) { - return nil, adbc.Error{ - Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), - Code: adbc.StatusNotFound, - } -} -func (d *databaseImpl) GetOptionInt(key string) (int64, error) { - return 0, adbc.Error{ - Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), - Code: adbc.StatusNotFound, - } -} -func (d *databaseImpl) GetOptionDouble(key string) (float64, error) { - return 0, adbc.Error{ - Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'", key), - Code: adbc.StatusNotFound, - } + return d.DatabaseImplBase.GetOption(key) } func (d *databaseImpl) SetOptions(cnOptions map[string]string) error { @@ -176,7 +155,8 @@ func (d *databaseImpl) SetOptions(cnOptions map[string]string) error { } } - defaultAppName := "[ADBC][Go-" + infoDriverVersion + "]" + driverVersion, _ := d.DatabaseImplBase.DriverInfo.GetInfoDriverVersion() + defaultAppName := "[ADBC][Go-" + driverVersion + "]" // set default application name to track // unless user overrides it d.cfg.Application = defaultAppName @@ -464,15 +444,22 @@ func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) { return nil, errToAdbcErr(adbc.StatusIO, err) } - return &cnxn{ + conn := &connectionImpl{ cn: cn.(snowflakeConn), db: d, ctor: connector, sqldb: sql.OpenDB(connector), // default enable high precision // SetOption(OptionUseHighPrecision, adbc.OptionValueDisabled) to // get Int64/Float64 instead - useHighPrecision: d.useHighPrecision, - }, nil + useHighPrecision: d.useHighPrecision, + ConnectionImplBase: driverbase.NewConnectionImplBase(&d.DatabaseImplBase), + } + + return driverbase.NewConnectionBuilder(conn). + WithAutocommitSetter(conn). + WithCurrentNamespacer(conn). + WithTableTypeLister(conn). + Connection(), nil } func (d *databaseImpl) Close() error { diff --git a/go/adbc/driver/snowflake/statement.go b/go/adbc/driver/snowflake/statement.go index 8439ddfcd4..3f446662ea 100644 --- a/go/adbc/driver/snowflake/statement.go +++ b/go/adbc/driver/snowflake/statement.go @@ -42,7 +42,7 @@ const ( ) type statement struct { - cnxn *cnxn + cnxn *connectionImpl alloc memory.Allocator queueSize int prefetchConcurrency int diff --git a/go/adbc/go.mod b/go/adbc/go.mod index fe30fd7cb3..35bdc70817 100644 --- a/go/adbc/go.mod +++ b/go/adbc/go.mod @@ -83,6 +83,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/sirupsen/logrus v1.9.3 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect golang.org/x/crypto v0.21.0 // indirect golang.org/x/mod v0.16.0 // indirect diff --git a/go/adbc/go.sum b/go/adbc/go.sum index a5bea3b7f5..cf74b7baa8 100644 --- a/go/adbc/go.sum +++ b/go/adbc/go.sum @@ -130,6 +130,7 @@ github.com/snowflakedb/gosnowflake v1.8.0 h1:4bQj8eAYGMkou/nICiIEb9jSbBLDDp5cB6J github.com/snowflakedb/gosnowflake v1.8.0/go.mod h1:7yyY2MxtDti2eXgtvlZ8QxzCN6KV2B4qb1HuygMI+0U= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=