Skip to content

Commit

Permalink
Simplify logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Jan 4, 2024
1 parent 97d4d33 commit 10bd0ef
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 121 deletions.
71 changes: 68 additions & 3 deletions internal/integration/unified/entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,70 @@ func newCollectionEntityOptions(id string, databaseID string, collectionName str
return options
}

type task struct {
name string
execute func() error
}

type backgroundRoutine struct {
tasks chan *task
wg sync.WaitGroup
err error
}

func (b *backgroundRoutine) start() {
b.wg.Add(1)

go func() {
defer b.wg.Done()

for t := range b.tasks {
if b.err != nil {
continue
}

ch := make(chan error)
go func(task *task) {
ch <- task.execute()
}(t)
select {
case err := <-ch:
if err != nil {
b.err = fmt.Errorf("error running operation %s: %v", t.name, err)
}
case <-time.After(10 * time.Second):
b.err = fmt.Errorf("timed out after 10 seconds")
}
}
}()
}

func (b *backgroundRoutine) stop() error {
close(b.tasks)
b.wg.Wait()
return b.err
}

func (b *backgroundRoutine) addTask(name string, execute func() error) bool {
select {
case b.tasks <- &task{
name: name,
execute: execute,
}:
return true
default:
return false
}
}

func newBackgroundRoutine() *backgroundRoutine {
routine := &backgroundRoutine{
tasks: make(chan *task, 10),
}

return routine
}

type clientEncryptionOpts struct {
KeyVaultClient string `bson:"keyVaultClient"`
KeyVaultNamespace string `bson:"keyVaultNamespace"`
Expand All @@ -136,7 +200,7 @@ type EntityMap struct {
successValues map[string]int32
iterationValues map[string]int32
clientEncryptionEntities map[string]*mongo.ClientEncryption
waitChans map[string]chan error
routinesMap sync.Map // maps thread name to *backgroundRoutine
evtLock sync.Mutex
closed atomic.Value
// keyVaultClientIDs tracks IDs of clients used as a keyVaultClient in ClientEncryption objects.
Expand Down Expand Up @@ -168,7 +232,6 @@ func newEntityMap() *EntityMap {
successValues: make(map[string]int32),
iterationValues: make(map[string]int32),
clientEncryptionEntities: make(map[string]*mongo.ClientEncryption),
waitChans: make(map[string]chan error),
keyVaultClientIDs: make(map[string]bool),
}
em.setClosed(false)
Expand Down Expand Up @@ -286,7 +349,9 @@ func (em *EntityMap) addEntity(ctx context.Context, entityType string, entityOpt
case "session":
err = em.addSessionEntity(entityOptions)
case "thread":
em.waitChans[entityOptions.ID] = make(chan error)
routine := newBackgroundRoutine()
em.routinesMap.Store(entityOptions.ID, routine)
routine.start()
case "bucket":
err = em.addGridFSBucketEntity(entityOptions)
case "clientEncryption":
Expand Down
26 changes: 13 additions & 13 deletions internal/integration/unified/testrunner_operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,22 +196,22 @@ func executeTestRunnerOperation(ctx context.Context, op *operation, loopDone <-c
if err := operationRaw.Unmarshal(threadOp); err != nil {
return fmt.Errorf("error unmarshaling 'operation' argument: %v", err)
}
ch := entities(ctx).waitChans[lookupString(args, "thread")]
go func(op *operation) {
err := op.execute(ctx, loopDone)
ch <- err
}(threadOp)
thread := lookupString(args, "thread")
routine, ok := entities(ctx).routinesMap.Load(thread)
if !ok {
return fmt.Errorf("run on unknown thread: %s", thread)
}
routine.(*backgroundRoutine).addTask(threadOp.Name, func() error {
return threadOp.execute(ctx, loopDone)
})
return nil
case "waitForThread":
if ch, ok := entities(ctx).waitChans[lookupString(args, "thread")]; ok {
select {
case err := <-ch:
return err
case <-time.After(10 * time.Second):
return fmt.Errorf("timed out after 10 seconds")
}
thread := lookupString(args, "thread")
routine, ok := entities(ctx).routinesMap.Load(thread)
if !ok {
return fmt.Errorf("wait for unknown thread: %s", thread)
}
return nil
return routine.(*backgroundRoutine).stop()
case "waitForEvent":
var wfeArgs waitForEventArguments
if err := bson.Unmarshal(op.Arguments, &wfeArgs); err != nil {
Expand Down
10 changes: 0 additions & 10 deletions x/mongo/driver/topology/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ type connection struct {
// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
// - suggested layout: https://go101.org/article/memory-layout.html
state int64
err error

id string
nc net.Conn // When nil, the connection is closed.
Expand Down Expand Up @@ -338,7 +337,6 @@ func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
if atomic.LoadInt64(&c.state) != connConnected {
return ConnectionError{
ConnectionID: c.id,
Wrapped: c.err,
message: "connection is closed",
}
}
Expand Down Expand Up @@ -393,7 +391,6 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
if atomic.LoadInt64(&c.state) != connConnected {
return nil, ConnectionError{
ConnectionID: c.id,
Wrapped: c.err,
message: "connection is closed",
}
}
Expand Down Expand Up @@ -494,13 +491,6 @@ func (c *connection) close() error {
return err
}

func (c *connection) closeWithErr(err error) error {
c.err = err
c.closeConnectContext()
c.wait() // Make sure that the connection has finished connecting.
return c.close()
}

func (c *connection) closed() bool {
return atomic.LoadInt64(&c.state) == connDisconnected
}
Expand Down
121 changes: 39 additions & 82 deletions x/mongo/driver/topology/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,6 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) *pool {
}
pool.connOpts = append(pool.connOpts, withGenerationNumberFn(func(_ generationNumberFn) generationNumberFn { return pool.getGenerationForNewConnection }))

pool.generation.connect()

// Create a Context with cancellation that's used to signal the createConnections() and
// maintain() background goroutines to stop. Also create a "backgroundDone" WaitGroup that is
// used to wait for the background goroutines to return.
Expand Down Expand Up @@ -278,7 +276,9 @@ func (p *pool) stale(conn *connection) bool {
if conn == nil {
return true
}
if atomic.LoadInt64(&p.generation.state) == generationDisconnected {
p.stateMu.RLock()
defer p.stateMu.RUnlock()
if p.state == poolClosed {
return true
}
if generation, ok := p.generation.getGeneration(conn.desc.ServiceID); ok {
Expand Down Expand Up @@ -353,8 +353,6 @@ func (p *pool) close(ctx context.Context) {
// Wait for all background goroutines to exit.
p.backgroundDone.Wait()

p.generation.disconnect()

if ctx == nil {
ctx = context.Background()
}
Expand Down Expand Up @@ -751,59 +749,38 @@ func (p *pool) removeConnection(conn *connection, reason reason, err error) erro
return nil
}

// checkIn returns an idle connection to the pool. It calls checkInWithCallback internally.
// checkIn returns an idle connection to the pool. If the connection is perished or the pool is
// closed, it is removed from the connection pool and closed.
func (p *pool) checkIn(conn *connection) error {
return p.checkInWithCallback(conn, func() (reason, bool) {
if mustLogPoolMessage(p) {
keysAndValues := logger.KeyValues{
logger.KeyDriverConnectionID, conn.driverConnectionID,
}
if conn == nil {
return nil
}
if conn.pool != p {
return ErrWrongPool
}

logPoolMessage(p, logger.ConnectionCheckedIn, keysAndValues...)
if mustLogPoolMessage(p) {
keysAndValues := logger.KeyValues{
logger.KeyDriverConnectionID, conn.driverConnectionID,
}

if p.monitor != nil {
p.monitor.Event(&event.PoolEvent{
Type: event.ConnectionCheckedIn,
ConnectionID: conn.driverConnectionID,
Address: conn.addr.String(),
})
}
logPoolMessage(p, logger.ConnectionCheckedIn, keysAndValues...)
}

r, perished := connectionPerished(conn)
if !perished && conn.pool.getState() == poolClosed {
perished = true
r = reason{
loggerConn: logger.ReasonConnClosedPoolClosed,
event: event.ReasonPoolClosed,
}
}
return r, perished
})
if p.monitor != nil {
p.monitor.Event(&event.PoolEvent{
Type: event.ConnectionCheckedIn,
ConnectionID: conn.driverConnectionID,
Address: conn.addr.String(),
})
}

return p.checkInNoEvent(conn)
}

// checkInNoEvent returns a connection to the pool. It behaves identically to checkIn except it does
// not publish events. It is only intended for use by pool-internal functions.
func (p *pool) checkInNoEvent(conn *connection) error {
return p.checkInWithCallback(conn, func() (reason, bool) {
r, perished := connectionPerished(conn)
if !perished && conn.pool.getState() == poolClosed {
perished = true
r = reason{
loggerConn: logger.ReasonConnClosedPoolClosed,
event: event.ReasonPoolClosed,
}
}
return r, perished
})
}

// checkInWithCallback returns a connection to the pool. If the connection is perished or the pool is
// closed, it is removed from the connection pool and closed.
// The callback parameter is expected to returns a reason of the check-in and a boolean value to
// indicate whether the connection is perished.
// Events and logs can also be added in the callback function.
func (p *pool) checkInWithCallback(conn *connection, callback func() (reason, bool)) error {
if conn == nil {
return nil
}
Expand All @@ -819,10 +796,13 @@ func (p *pool) checkInWithCallback(conn *connection, callback func() (reason, bo
// connection should never be perished due to max idle time.
conn.bumpIdleDeadline()

var r reason
var perished bool
if callback != nil {
r, perished = callback()
r, perished := connectionPerished(conn)
if !perished && conn.pool.getState() == poolClosed {
perished = true
r = reason{
loggerConn: logger.ReasonConnClosedPoolClosed,
event: event.ReasonPoolClosed,
}
}
if perished {
_ = p.removeConnection(conn, r, nil)
Expand Down Expand Up @@ -868,36 +848,13 @@ func (p *pool) clearAll(err error, serviceID *primitive.ObjectID) {
// interruptConnections interrupts the input connections.
func (p *pool) interruptConnections(conns []*connection) {
for _, conn := range conns {
_ = conn.closeWithErr(poolClearedError{
err: fmt.Errorf("interrupted"),
address: p.address,
})
_ = p.checkInWithCallback(conn, func() (reason, bool) {
if mustLogPoolMessage(p) {
keysAndValues := logger.KeyValues{
logger.KeyDriverConnectionID, conn.driverConnectionID,
}

logPoolMessage(p, logger.ConnectionCheckedIn, keysAndValues...)
}

if p.monitor != nil {
p.monitor.Event(&event.PoolEvent{
Type: event.ConnectionCheckedIn,
ConnectionID: conn.driverConnectionID,
Address: conn.addr.String(),
})
}

r, ok := connectionPerished(conn)
if ok {
r = reason{
loggerConn: logger.ReasonConnClosedStale,
event: event.ReasonStale,
}
}
return r, ok
})
_ = p.removeConnection(conn, reason{
loggerConn: logger.ReasonConnClosedStale,
event: event.ReasonStale,
}, nil)
go func(c *connection) {
_ = p.closeConnection(c)
}(conn)
}
}

Expand Down
13 changes: 0 additions & 13 deletions x/mongo/driver/topology/pool_generation_counter.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package topology

import (
"sync"
"sync/atomic"

"go.mongodb.org/mongo-driver/bson/primitive"
)
Expand All @@ -30,10 +29,6 @@ type generationStats struct {
// load balancer, there is only one service ID: primitive.NilObjectID. For load-balanced deployments, each server behind
// the load balancer will have a unique service ID.
type poolGenerationMap struct {
// state must be accessed using the atomic package and should be at the beginning of the struct.
// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
// - suggested layout: https://go101.org/article/memory-layout.html
state int64
generationMap map[primitive.ObjectID]*generationStats

sync.Mutex
Expand All @@ -47,14 +42,6 @@ func newPoolGenerationMap() *poolGenerationMap {
return pgm
}

func (p *poolGenerationMap) connect() {
atomic.StoreInt64(&p.state, generationConnected)
}

func (p *poolGenerationMap) disconnect() {
atomic.StoreInt64(&p.state, generationDisconnected)
}

// addConnection increments the connection count for the generation associated with the given service ID and returns the
// generation number for the connection.
func (p *poolGenerationMap) addConnection(serviceIDPtr *primitive.ObjectID) uint64 {
Expand Down

0 comments on commit 10bd0ef

Please sign in to comment.