Skip to content

Commit

Permalink
feat: add GetOrInjectBaggage (#217)
Browse files Browse the repository at this point in the history
* feat: add GetOrInjectBaggage

* fix: improve test codecov

* fix: TestServeIn_signalWatch is blocked on windows
  • Loading branch information
GGXXLL authored Dec 30, 2021
1 parent 6c60918 commit cd2a071
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 3 deletions.
17 changes: 17 additions & 0 deletions ctxmeta/ctxmeta.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,16 @@ func (m *MetadataSet) GetBaggage(ctx context.Context) *Baggage {
return nil
}

// GetOrInjectBaggage creates and returns Baggage. using Baggage found within the context.
// If that doesn't exist it creates new Baggage. It also returns a context.Context
// object built around the returned Baggage.
func (m *MetadataSet) GetOrInjectBaggage(ctx context.Context) (*Baggage, context.Context) {
if baggage := m.GetBaggage(ctx); baggage != nil {
return baggage, ctx
}
return m.Inject(ctx)
}

// Inject constructs a Baggage object and injects it into the provided context
// under the default context key. Use the returned context for all further
// operations. The returned Data can be queried at any point for metadata
Expand All @@ -233,3 +243,10 @@ func Inject(ctx context.Context) (*Baggage, context.Context) {
func GetBaggage(ctx context.Context) *Baggage {
return DefaultMetadata.GetBaggage(ctx)
}

// GetOrInjectBaggage creates and returns Baggage. using Baggage found within the context.
// If that doesn't exist it creates new Baggage. It also returns a context.Context
// object built around the returned Baggage.
func GetOrInjectBaggage(ctx context.Context) (*Baggage, context.Context) {
return DefaultMetadata.GetOrInjectBaggage(ctx)
}
22 changes: 21 additions & 1 deletion ctxmeta/ctxmeta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestContextMeta_crud(t *testing.T) {
assert.ErrorIs(t, err, ErrNotFound)
}

func TestContextMeta_ErrNoBaggae(t *testing.T) {
func TestContextMeta_ErrNoBaggage(t *testing.T) {
t.Parallel()

ctx := context.Background()
Expand Down Expand Up @@ -157,3 +157,23 @@ func TestMetadata_global(t *testing.T) {
world, _ = baggage3.Get("hello")
assert.Equal(t, "world", world)
}

func TestMetadata_GetOrInjectBaggage(t *testing.T) {
t.Parallel()

ctx := context.Background()
baggage1, ctx := GetOrInjectBaggage(ctx)
baggage1.Set("hello", "world")

baggage2 := GetBaggage(ctx)
world, _ := baggage2.Get("hello")
assert.Equal(t, "world", world)

baggage3 := DefaultMetadata.GetBaggage(ctx)
world, _ = baggage3.Get("hello")
assert.Equal(t, "world", world)

baggage4, _ := GetOrInjectBaggage(ctx)
world, _ = baggage4.Get("hello")
assert.Equal(t, "world", world)
}
8 changes: 6 additions & 2 deletions serve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"errors"
"os"
"runtime"
"testing"
"time"

Expand All @@ -21,16 +22,19 @@ func TestServeIn_signalWatch(t *testing.T) {
assert.NoError(t, err)

t.Run("stop when signal received", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("TestServeIn_signalWatch/stop_when_signal_received only works on unix")
}
var group run.Group
group.Add(do, cancel)
group.Add(func() error {
time.Sleep(time.Second)
p, err := os.FindProcess(os.Getpid())
if err != nil {
t.Skip("TestServeIn_signalWatch only works on unix")
return err
}
if err := p.Signal(os.Interrupt); err != nil {
t.Skip("TestServeIn_signalWatch only works on unix")
return err
}
// trigger the signal twice should be ok.
p.Signal(os.Interrupt)
Expand Down

0 comments on commit cd2a071

Please sign in to comment.