diff --git a/.golangci.yml b/.golangci.yml index 78effdf9..5f112400 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -16,12 +16,18 @@ linters-settings: values: regexp: copyright-year: 20[2-9]\d - + forbidigo: + forbid: + - p: ^pgxpool\.New.*$ + msg: "Use github.com/conduitio/conduit-connector-postgres/source/cpool.New instead." issues: exclude-rules: - path: test/helper\.go linters: - gosec + - path: source/cpool/cpool\.go + linters: + - forbidigo linters: # please, do not use `enable-all`: it's deprecated and will be removed soon. @@ -38,7 +44,7 @@ linters: # - exhaustive # - exhaustivestruct - exportloopref - # - forbidigo + - forbidigo # - forcetypeassert # - funlen # - gochecknoinits diff --git a/source/cpool/cpool.go b/source/cpool/cpool.go index 224f4df3..ed077411 100644 --- a/source/cpool/cpool.go +++ b/source/cpool/cpool.go @@ -59,6 +59,10 @@ func beforeAcquireHook(ctx context.Context, conn *pgx.Conn) bool { // beforeConnectHook customizes the configuration of the new connection. func beforeConnectHook(ctx context.Context, config *pgx.ConnConfig) error { + if config.RuntimeParams["application_name"] == "" { + config.RuntimeParams["application_name"] = "conduit-connector-postgres" + } + if v := ctx.Value(replicationCtxKey{}); v != nil { config.RuntimeParams["replication"] = "database" } diff --git a/source/logrepl/cleaner.go b/source/logrepl/cleaner.go index befa20fa..db52c9fc 100644 --- a/source/logrepl/cleaner.go +++ b/source/logrepl/cleaner.go @@ -19,9 +19,9 @@ import ( "errors" "fmt" + "github.com/conduitio/conduit-connector-postgres/source/cpool" "github.com/conduitio/conduit-connector-postgres/source/logrepl/internal" sdk "github.com/conduitio/conduit-connector-sdk" - "github.com/jackc/pgx/v5/pgxpool" ) type CleanupConfig struct { @@ -35,7 +35,7 @@ type CleanupConfig struct { func Cleanup(ctx context.Context, c CleanupConfig) error { logger := sdk.Logger(ctx) - pool, err := pgxpool.New(ctx, c.URL) + pool, err := cpool.New(ctx, c.URL) if err != nil { return fmt.Errorf("failed to connect to database: %w", err) } diff --git a/source/logrepl/combined_test.go b/source/logrepl/combined_test.go index 592efebf..4abca0da 100644 --- a/source/logrepl/combined_test.go +++ b/source/logrepl/combined_test.go @@ -165,6 +165,16 @@ func TestCombinedIterator_Next(t *testing.T) { expectedRecords := testRecords() + // interrupt repl connection + var terminated bool + is.NoErr(pool.QueryRow(ctx, fmt.Sprintf( + `SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE + query ILIKE '%%CREATE_REPLICATION_SLOT %s%%' and pid <> pg_backend_pid() + `, + table, + )).Scan(&terminated)) + is.True(terminated) + // compare snapshot for id := 1; id < 5; id++ { t.Run(fmt.Sprint("next_snapshot", id), func(t *testing.T) { diff --git a/source/logrepl/internal/subscription.go b/source/logrepl/internal/subscription.go index a7e0d0da..193356ce 100644 --- a/source/logrepl/internal/subscription.go +++ b/source/logrepl/internal/subscription.go @@ -44,6 +44,7 @@ type Subscription struct { TXSnapshotID string conn *pgxpool.Conn + pool *pgxpool.Pool stop context.CancelFunc @@ -133,6 +134,7 @@ func CreateSubscription( TXSnapshotID: result.SnapshotName, conn: conn, + pool: pool, ready: make(chan struct{}), done: make(chan struct{}), @@ -144,6 +146,7 @@ func (s *Subscription) Run(ctx context.Context) error { defer s.doneReplication() if err := s.startReplication(ctx); err != nil { + close(s.ready) // ready to fail. return err } @@ -330,6 +333,18 @@ func (s *Subscription) Err() error { // startReplication starts replication with a specific start LSN. func (s *Subscription) startReplication(ctx context.Context) error { + // N.B. Snapshots may take long time and connection may timeout. + // Safer to refresh the connection before replication begins. + + s.conn.Release() + + conn, err := s.pool.Acquire(cpool.WithReplication(ctx)) + if err != nil { + return fmt.Errorf("could not establish replication connection: %w", err) + } + + s.conn = conn + pluginArgs := []string{ `"proto_version" '1'`, fmt.Sprintf(`"publication_names" '%s'`, s.Publication), diff --git a/source/schema/avro.go b/source/schema/avro.go index ac119d0f..112b853b 100644 --- a/source/schema/avro.go +++ b/source/schema/avro.go @@ -17,7 +17,6 @@ package schema import ( "cmp" "fmt" - "math" "slices" "github.com/hamba/avro/v2" @@ -27,15 +26,8 @@ import ( ) const ( - avroNS = "conduit.postgres" - // The default decimal precision is pretty generous, but it is in excess of what - // pgx provides by default. All numeric values by default are coded to float64/int64. - // Ideally in the future the decimal precision can be adjusted to fit the definition in postgres. - avroDecimalPrecision = 38 - // The size of the storage in which a decimal may be encoded depends on the underlying numeric definition. - // Unfortunately similarly to the decimal precision, this is dependent on the size of the numeric, which - // by default is constraint to 8 bytes. This default is generously allocating four times larger width. - avroDecimalFixedSize = 8 * 4 + avroNS = "conduit.postgres" + avroDecimalPadding = 8 ) var Avro = &avroExtractor{ @@ -75,38 +67,29 @@ type avroExtractor struct { } func (a avroExtractor) ExtractLogrepl(rel *pglogrepl.RelationMessage, row *pglogrepl.TupleData) (avro.Schema, error) { - var ( - fields []pgconn.FieldDescription - values []any - ) + var fields []pgconn.FieldDescription - for i, tuple := range row.Columns { + for i := range row.Columns { fields = append(fields, pgconn.FieldDescription{ - Name: rel.Columns[i].Name, - DataTypeOID: rel.Columns[i].DataType, + Name: rel.Columns[i].Name, + DataTypeOID: rel.Columns[i].DataType, + TypeModifier: rel.Columns[i].TypeModifier, }) - - v, err := a.decodeColumnValue(rel.Columns[i], tuple.Data) - if err != nil { - return nil, err - } - - values = append(values, v) } - return a.Extract(rel.RelationName, fields, values) + return a.Extract(rel.RelationName, fields) } -func (a *avroExtractor) Extract(name string, fields []pgconn.FieldDescription, values []any) (avro.Schema, error) { +func (a *avroExtractor) Extract(name string, fields []pgconn.FieldDescription) (avro.Schema, error) { var avroFields []*avro.Field - for i, f := range fields { + for _, f := range fields { t, ok := a.pgMap.TypeForOID(f.DataTypeOID) if !ok { return nil, fmt.Errorf("field %q with OID %d cannot be resolved", f.Name, f.DataTypeOID) } - s, err := a.extractType(t, values[i]) + s, err := a.extractType(t, f.TypeModifier) if err != nil { return nil, err } @@ -131,41 +114,26 @@ func (a *avroExtractor) Extract(name string, fields []pgconn.FieldDescription, v return sch, nil } -func (a *avroExtractor) extractType(t *pgtype.Type, val any) (avro.Schema, error) { +func (a *avroExtractor) extractType(t *pgtype.Type, typeMod int32) (avro.Schema, error) { if ps, ok := a.avroMap[t.Name]; ok { return ps, nil } - switch tt := val.(type) { - case pgtype.Numeric: - // N.B.: Default to 38 positions and pick the exponent as the scale. + switch t.OID { + case pgtype.NumericOID: + scale := int((typeMod - 4) & 65535) + precision := int(((typeMod - 4) >> 16) & 65535) fs, err := avro.NewFixedSchema( string(avro.Decimal), avroNS, - avroDecimalFixedSize, - avro.NewDecimalLogicalSchema(avroDecimalPrecision, int(math.Abs(float64(tt.Exp)))), + precision+scale+avroDecimalPadding, + avro.NewDecimalLogicalSchema(precision, scale), ) if err != nil { return nil, fmt.Errorf("failed to create avro.FixedSchema: %w", err) } return fs, nil default: - return nil, fmt.Errorf("cannot resolve field %q of type %T", t.Name, tt) + return nil, fmt.Errorf("cannot resolve field type %q ", t.Name) } } - -func (a *avroExtractor) decodeColumnValue(col *pglogrepl.RelationMessageColumn, data []byte) (any, error) { - var t *pgtype.Type - - t, ok := a.pgMap.TypeForOID(col.DataType) - if !ok { - t, _ = a.pgMap.TypeForOID(pgtype.UnknownOID) - } - - v, err := t.Codec.DecodeValue(a.pgMap, col.DataType, pgtype.TextFormatCode, data) - if err != nil { - return nil, fmt.Errorf("failed to decode %q tuple: %w", col.Name, err) - } - - return v, nil -} diff --git a/source/schema/avro_test.go b/source/schema/avro_test.go index 679597e0..c5d2176c 100644 --- a/source/schema/avro_test.go +++ b/source/schema/avro_test.go @@ -50,7 +50,7 @@ func Test_AvroExtract(t *testing.T) { fields := rows.FieldDescriptions() - sch, err := Avro.Extract(table, fields, values) + sch, err := Avro.Extract(table, fields) t.Run("schema is parsable", func(t *testing.T) { is := is.New(t) @@ -191,8 +191,8 @@ func avroTestSchema(t *testing.T, table string) avro.Schema { assert(avro.NewField("col_numeric", assert(avro.NewFixedSchema(string(avro.Decimal), avroNS, - avroDecimalFixedSize, - avro.NewDecimalLogicalSchema(avroDecimalPrecision, 2), + 18, + avro.NewDecimalLogicalSchema(8, 2), )))), assert(avro.NewField("col_date", avro.NewPrimitiveSchema( avro.Int, diff --git a/source/snapshot/fetch_worker.go b/source/snapshot/fetch_worker.go index 6a64a72b..b57e9696 100644 --- a/source/snapshot/fetch_worker.go +++ b/source/snapshot/fetch_worker.go @@ -286,7 +286,7 @@ func (f *FetchWorker) fetch(ctx context.Context, tx pgx.Tx) (int, error) { } if f.conf.WithAvroSchema && f.avroSchema == nil { - sch, err := schema.Avro.Extract(f.conf.Table, fields, values) + sch, err := schema.Avro.Extract(f.conf.Table, fields) if err != nil { return 0, fmt.Errorf("failed to extract schema: %w", err) }