diff --git a/source/snapshot/fetch_worker.go b/source/snapshot/fetch_worker.go index 18cb904..07b90b2 100644 --- a/source/snapshot/fetch_worker.go +++ b/source/snapshot/fetch_worker.go @@ -257,11 +257,9 @@ func (f *FetchWorker) updateSnapshotEnd(ctx context.Context, tx pgx.Tx) error { return nil } - if err := tx.QueryRow( - ctx, - fmt.Sprintf("SELECT max(%s) FROM %s", f.conf.Key, f.conf.Table), - ).Scan(&f.snapshotEnd); err != nil { - return fmt.Errorf("failed to query max on %q.%q: %w", f.conf.Table, f.conf.Key, err) + query := fmt.Sprintf("SELECT COALESCE(max(%s), 0) FROM %s", f.conf.Key, f.conf.Table) + if err := tx.QueryRow(ctx, query).Scan(&f.snapshotEnd); err != nil { + return fmt.Errorf("failed to get snapshot end with query %q: %w", query, err) } return nil diff --git a/source/snapshot/fetch_worker_test.go b/source/snapshot/fetch_worker_test.go index e9a7595..a29fcd7 100644 --- a/source/snapshot/fetch_worker_test.go +++ b/source/snapshot/fetch_worker_test.go @@ -201,6 +201,40 @@ func Test_FetcherValidate(t *testing.T) { }) } +func Test_FetcherRun_EmptySnapshot(t *testing.T) { + var ( + is = is.New(t) + ctx = test.Context(t) + pool = test.ConnectPool(context.Background(), t, test.RegularConnString) + table = test.SetupEmptyTestTable(context.Background(), t, pool) + out = make(chan FetchData) + testTomb = &tomb.Tomb{} + ) + + f := NewFetchWorker(pool, out, FetchConfig{ + Table: table, + Key: "id", + }) + + testTomb.Go(func() error { + ctx = testTomb.Context(ctx) + defer close(out) + + if err := f.Validate(ctx); err != nil { + return err + } + return f.Run(ctx) + }) + + var gotFetchData []FetchData + for data := range out { + gotFetchData = append(gotFetchData, data) + } + + is.NoErr(testTomb.Err()) + is.True(len(gotFetchData) == 0) +} + func Test_FetcherRun_Initial(t *testing.T) { var ( pool = test.ConnectPool(context.Background(), t, test.RegularConnString) @@ -226,13 +260,13 @@ func Test_FetcherRun_Initial(t *testing.T) { return f.Run(ctx) }) - var dd []FetchData + var gotFetchData []FetchData for data := range out { - dd = append(dd, data) + gotFetchData = append(gotFetchData, data) } is.NoErr(tt.Err()) - is.True(len(dd) == 4) + is.True(len(gotFetchData) == 4) expectedMatch := []opencdc.StructuredData{ {"id": int64(1), "key": []uint8{49}, "column1": "foo", "column2": int32(123), "column3": false, "column4": 12.2, "column5": int64(4)}, @@ -241,17 +275,17 @@ func Test_FetcherRun_Initial(t *testing.T) { {"id": int64(4), "key": []uint8{52}, "column1": nil, "column2": nil, "column3": nil, "column4": 91.1, "column5": nil}, } - for i, d := range dd { + for i, got := range gotFetchData { t.Run(fmt.Sprintf("payload_%d", i+1), func(t *testing.T) { is := is.New(t) - is.Equal(d.Key, opencdc.StructuredData{"id": int64(i + 1)}) - is.Equal("", cmp.Diff(expectedMatch[i], d.Payload)) + is.Equal(got.Key, opencdc.StructuredData{"id": int64(i + 1)}) + is.Equal("", cmp.Diff(expectedMatch[i], got.Payload)) - is.Equal(d.Position, position.SnapshotPosition{ + is.Equal(got.Position, position.SnapshotPosition{ LastRead: int64(i + 1), SnapshotEnd: 4, }) - is.Equal(d.Table, table) + is.Equal(got.Table, table) }) } } diff --git a/test/helper.go b/test/helper.go index 9ebc8e1..41d9385 100644 --- a/test/helper.go +++ b/test/helper.go @@ -173,7 +173,7 @@ func ConnectSimple(ctx context.Context, t *testing.T, connString string) *pgx.Co } // SetupTestTable creates a new table and returns its name. -func SetupTestTable(ctx context.Context, t *testing.T, conn Querier) string { +func SetupEmptyTestTable(ctx context.Context, t *testing.T, conn Querier) string { is := is.New(t) table := RandomIdentifier(t) @@ -189,14 +189,22 @@ func SetupTestTable(ctx context.Context, t *testing.T, conn Querier) string { is.NoErr(err) }) - query = ` + return table +} + +// SetupTestTable creates a new table and returns its name. +func SetupTestTable(ctx context.Context, t *testing.T, conn Querier) string { + is := is.New(t) + table := SetupEmptyTestTable(ctx, t, conn) + + query := ` INSERT INTO %s (key, column1, column2, column3, column4, column5) VALUES ('1', 'foo', 123, false, 12.2, 4), ('2', 'bar', 456, true, 13.42, 8), ('3', 'baz', 789, false, null, 9), ('4', null, null, null, 91.1, null)` query = fmt.Sprintf(query, table) - _, err = conn.Exec(ctx, query) + _, err := conn.Exec(ctx, query) is.NoErr(err) return table