Skip to content

Commit

Permalink
feat: add DBReplica and use it in ReadWriteConnResolver
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Feb 10, 2025
1 parent 41abcd3 commit 95c825e
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package bun

import (
"context"
"crypto/rand"
cryptorand "crypto/rand"
"database/sql"
"encoding/hex"
"fmt"
"math/rand/v2"
"reflect"
"strings"
"sync/atomic"
Expand Down Expand Up @@ -633,7 +634,7 @@ func (tx Tx) Begin() (Tx, error) {
func (tx Tx) BeginTx(ctx context.Context, _ *sql.TxOptions) (Tx, error) {
// mssql savepoint names are limited to 32 characters
sp := make([]byte, 14)
_, err := rand.Read(sp)
_, err := cryptorand.Read(sp)
if err != nil {
return Tx{}, err
}
Expand Down Expand Up @@ -753,13 +754,19 @@ type ConnResolver interface {
Close() error
}

type DBReplica interface {
IConn
PingContext(context.Context) error
Close() error
}

// TODO:
// - make monitoring interval configurable
// - make ping timeout configutable
// - allow adding read/write replicas for multi-master replication
type ReadWriteConnResolver struct {
replicas []*sql.DB // read-only replicas
healthyReplicas atomic.Pointer[[]*sql.DB]
replicas []DBReplica // read-only replicas
healthyReplicas atomic.Pointer[[]DBReplica]
nextReplica atomic.Int64
closed atomic.Bool
}
Expand All @@ -774,14 +781,18 @@ func NewReadWriteConnResolver(opts ...ReadWriteConnResolverOption) *ReadWriteCon
if len(r.replicas) > 0 {
r.healthyReplicas.Store(&r.replicas)
go r.monitor()

// Start with a random replica.
rnd := rand.IntN(len(r.replicas))
r.nextReplica.Store(int64(rnd))
}

return r
}

type ReadWriteConnResolverOption func(r *ReadWriteConnResolver)

func WithReadOnlyReplica(dbs ...*sql.DB) ReadWriteConnResolverOption {
func WithReadOnlyReplica(dbs ...DBReplica) ReadWriteConnResolverOption {
return func(r *ReadWriteConnResolver) {
r.replicas = append(r.replicas, dbs...)
}
Expand Down Expand Up @@ -831,7 +842,7 @@ func isReadOnlyQuery(query Query) bool {
return true
}

func (r *ReadWriteConnResolver) loadHealthyReplicas() []*sql.DB {
func (r *ReadWriteConnResolver) loadHealthyReplicas() []DBReplica {
if ptr := r.healthyReplicas.Load(); ptr != nil {
return *ptr
}
Expand All @@ -841,7 +852,7 @@ func (r *ReadWriteConnResolver) loadHealthyReplicas() []*sql.DB {
func (r *ReadWriteConnResolver) monitor() {
const interval = 5 * time.Second
for !r.closed.Load() {
healthy := make([]*sql.DB, 0, len(r.replicas))
healthy := make([]DBReplica, 0, len(r.replicas))

for _, replica := range r.replicas {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
Expand Down

0 comments on commit 95c825e

Please sign in to comment.