Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wip Pr for thread safety #276

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ type chain struct {
context AssertionContext
handler AssertionHandler
severity AssertionSeverity
failure *AssertionFailure
}

// If enabled, chain will panic if used incorrectly or gets illformed AssertionFailure.
Expand Down Expand Up @@ -340,12 +341,15 @@ func (c *chain) replace(name string, args ...interface{}) *chain {
// Must be called after enter().
// Chain can't be used after this call.
func (c *chain) leave() {
println("leaving now")
var (
context AssertionContext
handler AssertionHandler
parent *chain
reportSuccess bool
reportFailure bool
failure *AssertionFailure
flags chainFlags
)

func() {
Expand All @@ -366,11 +370,16 @@ func (c *chain) leave() {
reportFailure = true
}
}()

context = c.context
handler = c.handler
failure = c.failure
flags = c.flags
if reportSuccess {
handler.Success(&context)
}

if flags&flagFailed == 1 && failure != nil {
handler.Failure(&context, failure)
}
if reportFailure {
parent.mu.Lock()
parent.flags |= flagFailed
Expand All @@ -389,10 +398,8 @@ func (c *chain) leave() {

// Report assertion failure and mark chain as failed.
// Must be called between enter() and leave().
func (c *chain) fail(failure AssertionFailure) {
func (c *chain) fail(AssertFailure AssertionFailure) {
var (
context AssertionContext
handler AssertionHandler
reportFailure bool
)

Expand All @@ -409,21 +416,19 @@ func (c *chain) fail(failure AssertionFailure) {
}
c.flags |= flagFailed

failure.Severity = c.severity
AssertFailure.Severity = c.severity
if c.severity == SeverityError {
failure.IsFatal = true
AssertFailure.IsFatal = true
}

context = c.context
handler = c.handler
reportFailure = true
}()

if reportFailure {
handler.Failure(&context, &failure)
c.failure = &AssertFailure

if chainValidation {
if err := validateAssertion(&failure); err != nil {
if err := validateAssertion(&AssertFailure); err != nil {
panic(err)
}
}
Expand Down
40 changes: 37 additions & 3 deletions environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package httpexpect

import (
"errors"
"sync"
"time"
)

Expand All @@ -13,9 +14,9 @@ import (
// env.Put("key", "value")
// value := env.GetString("key")
type Environment struct {
noCopy noCopy
chain *chain
data map[string]interface{}
mu sync.RWMutex
chain *chain
data map[string]interface{}
}

// NewEnvironment returns a new Environment.
Expand Down Expand Up @@ -58,6 +59,9 @@ func (e *Environment) Put(key string, value interface{}) {
opChain := e.chain.enter("Put(%q)", key)
defer opChain.leave()

e.mu.Lock()
defer e.mu.Unlock()

e.data[key] = value
}

Expand All @@ -72,6 +76,9 @@ func (e *Environment) Delete(key string) {
opChain := e.chain.enter("Delete(%q)", key)
defer opChain.leave()

e.mu.Lock()
defer e.mu.Unlock()

delete(e.data, key)
}

Expand All @@ -86,6 +93,9 @@ func (e *Environment) Has(key string) bool {
opChain := e.chain.enter("Has(%q)", key)
defer opChain.leave()

e.mu.RLock()
defer e.mu.RUnlock()

_, ok := e.data[key]
return ok
}
Expand All @@ -102,6 +112,9 @@ func (e *Environment) Get(key string) interface{} {
opChain := e.chain.enter("Get(%q)", key)
defer opChain.leave()

e.mu.RLock()
defer e.mu.RUnlock()

value, _ := envValue(opChain, e.data, key)

return value
Expand All @@ -118,6 +131,9 @@ func (e *Environment) GetBool(key string) bool {
opChain := e.chain.enter("GetBool(%q)", key)
defer opChain.leave()

e.mu.RLock()
defer e.mu.RUnlock()

value, ok := envValue(opChain, e.data, key)
if !ok {
return false
Expand Down Expand Up @@ -150,6 +166,9 @@ func (e *Environment) GetInt(key string) int {
opChain := e.chain.enter("GetInt(%q)", key)
defer opChain.leave()

e.mu.RLock()
defer e.mu.RUnlock()

value, ok := envValue(opChain, e.data, key)
if !ok {
return 0
Expand Down Expand Up @@ -235,6 +254,9 @@ func (e *Environment) GetFloat(key string) float64 {
opChain := e.chain.enter("GetFloat(%q)", key)
defer opChain.leave()

e.mu.RLock()
defer e.mu.RUnlock()

value, ok := envValue(opChain, e.data, key)
if !ok {
return 0
Expand Down Expand Up @@ -275,6 +297,9 @@ func (e *Environment) GetString(key string) string {
opChain := e.chain.enter("GetString(%q)", key)
defer opChain.leave()

e.mu.RLock()
defer e.mu.RUnlock()

value, ok := envValue(opChain, e.data, key)
if !ok {
return ""
Expand Down Expand Up @@ -306,6 +331,9 @@ func (e *Environment) GetBytes(key string) []byte {
opChain := e.chain.enter("GetBytes(%q)", key)
defer opChain.leave()

e.mu.RLock()
defer e.mu.RUnlock()

value, ok := envValue(opChain, e.data, key)
if !ok {
return nil
Expand Down Expand Up @@ -338,6 +366,9 @@ func (e *Environment) GetDuration(key string) time.Duration {
opChain := e.chain.enter("GetDuration(%q)", key)
defer opChain.leave()

e.mu.RLock()
defer e.mu.RUnlock()

value, ok := envValue(opChain, e.data, key)
if !ok {
return time.Duration(0)
Expand Down Expand Up @@ -370,6 +401,9 @@ func (e *Environment) GetTime(key string) time.Time {
opChain := e.chain.enter("GetTime(%q)", key)
defer opChain.leave()

e.mu.RLock()
defer e.mu.RUnlock()

value, ok := envValue(opChain, e.data, key)
if !ok {
return time.Unix(0, 0)
Expand Down
13 changes: 13 additions & 0 deletions environment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,19 @@ func TestEnvironment_Basic(t *testing.T) {
env.chain.assertFailed(t)
}

func TestEnvironment_Reentrant(t *testing.T) {
reporter := newMockReporter(t)

env := NewEnvironment(reporter)

reporter.reportCb = func() {
env.Put("good_key", 123)
}

env.Get("bad_key")
env.chain.assertFailed(t)
}

func TestEnvironment_Delete(t *testing.T) {
env := newEnvironment(newMockChain(t))

Expand Down
5 changes: 5 additions & 0 deletions mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ func (l *mockLogger) Logf(message string, args ...interface{}) {
type mockReporter struct {
testing *testing.T
reported bool
reportCb func()
}

func newMockReporter(t *testing.T) *mockReporter {
Expand All @@ -188,6 +189,10 @@ func newMockReporter(t *testing.T) *mockReporter {
func (r *mockReporter) Errorf(message string, args ...interface{}) {
r.testing.Logf("Fail: "+message, args...)
r.reported = true

if r.reportCb != nil {
r.reportCb()
}
}

type mockFormatter struct {
Expand Down