diff --git a/v2/distributed_gobreaker.go b/v2/distributed_gobreaker.go index ce61d11..221d4ac 100644 --- a/v2/distributed_gobreaker.go +++ b/v2/distributed_gobreaker.go @@ -96,178 +96,60 @@ func (dcb *DistributedCircuitBreaker[T]) setSharedState(ctx context.Context, sta return dcb.store.SetData(ctx, dcb.sharedStateKey(), data) } -// State returns the State of DistributedCircuitBreaker. -func (dcb *DistributedCircuitBreaker[T]) State(ctx context.Context) (State, error) { - state, err := dcb.getSharedState(ctx) - if err != nil { - return state.State, err - } - - now := time.Now() - currentState, _ := dcb.currentState(state, now) - - // update the state if it has changed - if currentState != state.State { - state.State = currentState - if err := dcb.setSharedState(ctx, state); err != nil { - return state.State, err - } - } +func (dcb *DistributedCircuitBreaker[T]) inject(shared SharedState) { + dcb.mutex.Lock() + defer dcb.mutex.Unlock() - return state.State, nil + dcb.state = shared.State + dcb.generation = shared.Generation + dcb.counts = shared.Counts + dcb.expiry = shared.Expiry } -// Execute runs the given request if the DistributedCircuitBreaker accepts it. -func (dcb *DistributedCircuitBreaker[T]) Execute(ctx context.Context, req func() (T, error)) (t T, err error) { - generation, err := dcb.beforeRequest(ctx) - if err != nil { - var defaultValue T - return defaultValue, err - } - - defer func() { - e := recover() - if e != nil { - ae := dcb.afterRequest(ctx, generation, false) - if err == nil { - err = ae - } - panic(e) - } - }() +func (dcb *DistributedCircuitBreaker[T]) extract() SharedState { + dcb.mutex.Lock() + defer dcb.mutex.Unlock() - result, err := req() - ae := dcb.afterRequest(ctx, generation, dcb.isSuccessful(err)) - if err == nil { - err = ae + return SharedState{ + State: dcb.state, + Generation: dcb.generation, + Counts: dcb.counts, + Expiry: dcb.expiry, } - return result, err } -func (dcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uint64, error) { - state, err := dcb.getSharedState(ctx) +// State returns the State of DistributedCircuitBreaker. +func (dcb *DistributedCircuitBreaker[T]) State(ctx context.Context) (State, error) { + shared, err := dcb.getSharedState(ctx) if err != nil { - return 0, err - } - - now := time.Now() - currentState, generation := dcb.currentState(state, now) - - if currentState != state.State { - dcb.setState(&state, currentState, now) - err = dcb.setSharedState(ctx, state) - if err != nil { - return 0, err - } + return shared.State, err } - if currentState == StateOpen { - return generation, ErrOpenState - } else if currentState == StateHalfOpen && state.Counts.Requests >= dcb.maxRequests { - return generation, ErrTooManyRequests - } + dcb.inject(shared) + state := dcb.CircuitBreaker.State() + shared = dcb.extract() - state.Counts.onRequest() - err = dcb.setSharedState(ctx, state) - if err != nil { - return 0, err - } - - return generation, nil + err = dcb.setSharedState(ctx, shared) + return state, err } -func (dcb *DistributedCircuitBreaker[T]) afterRequest(ctx context.Context, before uint64, success bool) error { - state, err := dcb.getSharedState(ctx) +// Execute runs the given request if the DistributedCircuitBreaker accepts it. +func (dcb *DistributedCircuitBreaker[T]) Execute(ctx context.Context, req func() (T, error)) (T, error) { + shared, err := dcb.getSharedState(ctx) if err != nil { - return err - } - - now := time.Now() - currentState, generation := dcb.currentState(state, now) - if generation != before { - return nil - } - - if success { - dcb.onSuccess(&state, currentState, now) - } else { - dcb.onFailure(&state, currentState, now) - } - return dcb.setSharedState(ctx, state) -} - -func (dcb *DistributedCircuitBreaker[T]) onSuccess(state *SharedState, currentState State, now time.Time) { - if state.State == StateOpen { - state.State = currentState - } - - switch currentState { - case StateClosed: - state.Counts.onSuccess() - case StateHalfOpen: - state.Counts.onSuccess() - if state.Counts.ConsecutiveSuccesses >= dcb.maxRequests { - dcb.setState(state, StateClosed, now) - } - } -} - -func (dcb *DistributedCircuitBreaker[T]) onFailure(state *SharedState, currentState State, now time.Time) { - switch currentState { - case StateClosed: - state.Counts.onFailure() - if dcb.readyToTrip(state.Counts) { - dcb.setState(state, StateOpen, now) - } - case StateHalfOpen: - dcb.setState(state, StateOpen, now) - } -} - -func (dcb *DistributedCircuitBreaker[T]) currentState(state SharedState, now time.Time) (State, uint64) { - switch state.State { - case StateClosed: - if !state.Expiry.IsZero() && state.Expiry.Before(now) { - dcb.toNewGeneration(&state, now) - } - case StateOpen: - if state.Expiry.Before(now) { - dcb.setState(&state, StateHalfOpen, now) - } - } - return state.State, state.Generation -} - -func (dcb *DistributedCircuitBreaker[T]) setState(state *SharedState, newState State, now time.Time) { - if state.State == newState { - return + var defaultValue T + return defaultValue, err } - prev := state.State - state.State = newState - - dcb.toNewGeneration(state, now) + dcb.inject(shared) + t, e := dcb.CircuitBreaker.Execute(req) + shared = dcb.extract() - if dcb.onStateChange != nil { - dcb.onStateChange(dcb.name, prev, newState) + err = dcb.setSharedState(ctx, shared) + if err != nil { + var defaultValue T + return defaultValue, err } -} -func (dcb *DistributedCircuitBreaker[T]) toNewGeneration(state *SharedState, now time.Time) { - state.Generation++ - state.Counts.clear() - - var zero time.Time - switch state.State { - case StateClosed: - if dcb.interval == 0 { - state.Expiry = zero - } else { - state.Expiry = now.Add(dcb.interval) - } - case StateOpen: - state.Expiry = now.Add(dcb.timeout) - default: // StateHalfOpen - state.Expiry = zero - } + return t, e }