Skip to content
This repository has been archived by the owner on Feb 9, 2024. It is now read-only.

ordered pipeline exec #64

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func benchmarkWASM(wasmFile, testPayloadFile string, step *protos.PipelineStep,

s := &Streamdal{
pipelinesMtx: &sync.RWMutex{},
pipelines: map[string]map[string]*protos.Command{},
pipelines: map[string][]*protos.Command{},
audiencesMtx: &sync.RWMutex{},
audiences: map[string]struct{}{},
}
Expand Down Expand Up @@ -204,7 +204,7 @@ func inferSchema(fileName string) (*protos.WASMResponse, error) {

s := &Streamdal{
pipelinesMtx: &sync.RWMutex{},
pipelines: map[string]map[string]*protos.Command{},
pipelines: map[string][]*protos.Command{},
audiencesMtx: &sync.RWMutex{},
audiences: map[string]struct{}{},
}
Expand Down
34 changes: 16 additions & 18 deletions go_sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ type IStreamdal interface {
type Streamdal struct {
config *Config
functions map[string]*function
pipelines map[string]map[string]*protos.Command // k1: audienceStr k2: pipelineID
pipelinesPaused map[string]map[string]*protos.Command // k1: audienceStr k2: pipelineID
pipelines map[string][]*protos.Command // k1: audienceStr k2: pipelineID
pipelinesPaused map[string][]*protos.Command // k1: audienceStr k2: pipelineID
functionsMtx *sync.RWMutex
pipelinesMtx *sync.RWMutex
pipelinesPausedMtx *sync.RWMutex
Expand Down Expand Up @@ -233,9 +233,9 @@ func New(cfg *Config) (*Streamdal, error) {
functions: make(map[string]*function),
functionsMtx: &sync.RWMutex{},
serverClient: serverClient,
pipelines: make(map[string]map[string]*protos.Command),
pipelines: make(map[string][]*protos.Command),
pipelinesMtx: &sync.RWMutex{},
pipelinesPaused: make(map[string]map[string]*protos.Command),
pipelinesPaused: make(map[string][]*protos.Command),
pipelinesPausedMtx: &sync.RWMutex{},
audiences: map[string]struct{}{},
audiencesMtx: &sync.RWMutex{},
Expand Down Expand Up @@ -416,12 +416,13 @@ func (s *Streamdal) pullInitialPipelines(ctx context.Context) error {
}

for _, cmd := range cmds.Paused {
s.config.Logger.Debugf("Pipeline '%s' is paused", cmd.GetAttachPipeline().Pipeline.Name)
s.config.Logger.Debugf("Pipeline '%s' is paused - skipping initial attach", cmd.GetAttachPipeline().Pipeline.Name)

if _, ok := s.pipelinesPaused[audToStr(cmd.Audience)]; !ok {
s.pipelinesPaused[audToStr(cmd.Audience)] = make(map[string]*protos.Command)
s.pipelinesPaused[audToStr(cmd.Audience)] = make([]*protos.Command, 0)
}

s.pipelinesPaused[audToStr(cmd.Audience)][cmd.GetAttachPipeline().Pipeline.Id] = cmd
s.pipelinesPaused[audToStr(cmd.Audience)] = append(s.pipelinesPaused[audToStr(cmd.Audience)], cmd)
}

return nil
Expand Down Expand Up @@ -512,15 +513,15 @@ func (s *Streamdal) runStep(ctx context.Context, aud *protos.Audience, step *pro
return resp, nil
}

func (s *Streamdal) getPipelines(ctx context.Context, aud *protos.Audience) map[string]*protos.Command {
func (s *Streamdal) getPipelines(ctx context.Context, aud *protos.Audience) []*protos.Command {
s.pipelinesMtx.RLock()
defer s.pipelinesMtx.RUnlock()

s.addAudience(ctx, aud)

pipelines, ok := s.pipelines[audToStr(aud)]
if !ok {
return make(map[string]*protos.Command)
return make([]*protos.Command, 0)
}

return pipelines
Expand Down Expand Up @@ -595,6 +596,10 @@ func (s *Streamdal) Process(ctx context.Context, req *ProcessRequest) *ProcessRe

pipelines := s.getPipelines(ctx, aud)

for pIndex, p := range pipelines {
s.config.Logger.Warnf("pIndex %d, pipeline name '%s'", pIndex, p.GetAttachPipeline().GetPipeline().Name)
}

// WARNING: This case will (usually) only "hit" for the first <100ms of
// running the SDK - after that, the server will have sent us at least one,
// "hidden" pipeline - "infer schema". All of this happens asynchronously
Expand Down Expand Up @@ -624,15 +629,10 @@ func (s *Streamdal) Process(ctx context.Context, req *ProcessRequest) *ProcessRe
}

totalPipelines := len(pipelines)
var (
pIndex int
sIndex int
)

PIPELINE:
for _, p := range pipelines {
for pIndex, p := range pipelines {
var isr *protos.InterStepResult
pIndex += 1

pipelineTimeoutCtx, pipelineTimeoutCxl := context.WithTimeout(ctx, s.config.PipelineTimeout)

Expand All @@ -649,9 +649,7 @@ PIPELINE:

totalSteps := len(pipeline.Steps)

for _, step := range pipeline.Steps {
sIndex += 1

for sIndex, step := range pipeline.Steps {
stepTimeoutCtx, stepTimeoutCxl := context.WithTimeout(ctx, s.config.StepTimeout)

stepStatus := &protos.StepStatus{
Expand Down
20 changes: 14 additions & 6 deletions go_sdk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ var _ = Describe("Streamdal", func() {

s := &Streamdal{
pipelinesMtx: &sync.RWMutex{},
pipelines: map[string]map[string]*protos.Command{},
pipelines: make(map[string][]*protos.Command),
serverClient: fakeClient,
audiencesMtx: &sync.RWMutex{},
audiences: map[string]struct{}{},
Expand All @@ -182,8 +182,16 @@ var _ = Describe("Streamdal", func() {
})

It("returns a single pipeline", func() {
s.pipelines[audToStr(aud)] = map[string]*protos.Command{
uuid.New().String(): {},
s.pipelines[audToStr(aud)] = []*protos.Command{
&protos.Command{
Command: &protos.Command_AttachPipeline{
AttachPipeline: &protos.AttachPipelineCommand{
Pipeline: &protos.Pipeline{
Id: uuid.New().String(),
},
},
},
},
}
Expect(len(s.getPipelines(ctx, aud))).To(Equal(1))
})
Expand Down Expand Up @@ -536,9 +544,9 @@ func createStreamdalClientFull(serviceName string, aud *protos.Audience, pipelin
tails: map[string]map[string]*Tail{},
tailsMtx: &sync.RWMutex{},
pipelinesMtx: &sync.RWMutex{},
pipelines: map[string]map[string]*protos.Command{
pipelines: map[string][]*protos.Command{
audToStr(aud): {
pipeline.Id: {
{
Audience: aud,
Command: &protos.Command_AttachPipeline{
AttachPipeline: &protos.AttachPipelineCommand{
Expand All @@ -564,7 +572,7 @@ func createStreamdalClient() (*Streamdal, *kv.KV, error) {

return &Streamdal{
pipelinesMtx: &sync.RWMutex{},
pipelines: map[string]map[string]*protos.Command{},
pipelines: map[string][]*protos.Command{},
audiencesMtx: &sync.RWMutex{},
audiences: map[string]struct{}{},
kv: kvClient,
Expand Down
146 changes: 101 additions & 45 deletions register.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,20 +286,59 @@ func (s *Streamdal) attachPipeline(_ context.Context, cmd *protos.Command) error
return ErrEmptyCommand
}

s.config.Logger.Warnf("Received attach pipeline command for audience '%s', pipeline name '%s'", audToStr(cmd.Audience), cmd.GetAttachPipeline().Pipeline.Name)

s.pipelinesMtx.Lock()
defer s.pipelinesMtx.Unlock()

// If first time seeing audience, create pipeline (command) slice
if _, ok := s.pipelines[audToStr(cmd.Audience)]; !ok {
s.pipelines[audToStr(cmd.Audience)] = make(map[string]*protos.Command)
s.pipelines[audToStr(cmd.Audience)] = make([]*protos.Command, 0)
}

s.pipelines[audToStr(cmd.Audience)][cmd.GetAttachPipeline().Pipeline.Id] = cmd
// Only append pipeline if it doesn't already exist
pipelineIndex := getPipelineIndex(s.pipelines[audToStr(cmd.Audience)], cmd.GetAttachPipeline().Pipeline.Id)

// Debugging indexes
//s.config.Logger.Warnf("pipelineIndex is %d for pipeline ID %s\n", pipelineIndex, cmd.GetAttachPipeline().Pipeline.Id)
//s.config.Logger.Warnf("known pipeline length: %d", len(s.pipelines[audToStr(cmd.Audience)]))
//
//for pIndex, p := range s.pipelines[audToStr(cmd.Audience)] {
// s.config.Logger.Warnf("pIndex %d for pipeline %s\n", pIndex, p.GetAttachPipeline().Pipeline.Id)
//}

if pipelineIndex == -1 {
// Pipeline does not exist, append it
s.config.Logger.Debugf("Attached new pipeline %s", cmd.GetAttachPipeline().Pipeline.Id)
s.pipelines[audToStr(cmd.Audience)] = append(s.pipelines[audToStr(cmd.Audience)], cmd)
} else {
// Avoid potential panic
if pipelineIndex > len(s.pipelines[audToStr(cmd.Audience)])-1 { // len-1 because of 0-indexing
errMsg := fmt.Errorf("bug? invalid pipeline index: %d", pipelineIndex)
s.config.Logger.Error(errMsg)
return errors.New(errMsg.Error())
}

// Pipeline already exists, update it
s.config.Logger.Debugf("Updated attached pipeline %s (index %d)", cmd.GetAttachPipeline().Pipeline.Id, pipelineIndex)

s.config.Logger.Debugf("Attached pipeline %s", cmd.GetAttachPipeline().Pipeline.Id)
s.pipelines[audToStr(cmd.Audience)][pipelineIndex] = cmd
}

return nil
}

// Looks for pipelineID in pipeline slice and returns index if found, -1 otherwise
func getPipelineIndex(pipelines []*protos.Command, pipelineID string) int {
for i, p := range pipelines {
if p.GetAttachPipeline().Pipeline.Id == pipelineID {
return i
}
}

return -1
}

func (s *Streamdal) detachPipeline(_ context.Context, cmd *protos.Command) error {
if cmd == nil {
return ErrEmptyCommand
Expand All @@ -311,21 +350,34 @@ func (s *Streamdal) detachPipeline(_ context.Context, cmd *protos.Command) error
audStr := audToStr(cmd.Audience)

if _, ok := s.pipelines[audStr]; !ok {
s.config.Logger.Debugf("Attempted to detach pipeline %s, but no pipelines exist for audience %s", cmd.GetDetachPipeline().PipelineId, audStr)

return nil
}

delete(s.pipelines[audStr], cmd.GetDetachPipeline().PipelineId)
if index := getPipelineIndex(s.pipelines[audStr], cmd.GetDetachPipeline().PipelineId); index != -1 {
s.config.Logger.Debugf("Detaching pipeline %s (index %d)", cmd.GetDetachPipeline().PipelineId, index)
s.pipelines[audStr] = append(s.pipelines[audStr][:index], s.pipelines[audStr][index+1:]...)

if len(s.pipelines[audStr]) == 0 {
delete(s.pipelines, audStr)
return nil
}

s.config.Logger.Debugf("Detached pipeline %s", cmd.GetDetachPipeline().PipelineId)
s.config.Logger.Debugf("Pipeline '%s' not attached for audience '%s' - nothing to do", cmd.GetDetachPipeline().PipelineId, audStr)

return nil
}

// TODO: Refactor to pause/unpause
func (s *Streamdal) pausePipeline(_ context.Context, cmd *protos.Command) error {
return s.pauseResumePipeline(nil, cmd, true)
}

func (s *Streamdal) resumePipeline(_ context.Context, cmd *protos.Command) error {
return s.pauseResumePipeline(nil, cmd, false)
}

// Helper method that handles pause/unpause logic. Used by pausePipeline and resumePipeline
func (s *Streamdal) pauseResumePipeline(_ context.Context, cmd *protos.Command, pause bool) error {
if cmd == nil {
return ErrEmptyCommand
}
Expand All @@ -335,64 +387,68 @@ func (s *Streamdal) pausePipeline(_ context.Context, cmd *protos.Command) error
s.pipelinesPausedMtx.Lock()
defer s.pipelinesPausedMtx.Unlock()

var (
action string
src map[string][]*protos.Command
dst map[string][]*protos.Command
)

if pause {
action = "pause"
src = s.pipelines
dst = s.pipelinesPaused
} else {
action = "resume"
src = s.pipelinesPaused
dst = s.pipelines
}

audStr := audToStr(cmd.Audience)

if _, ok := s.pipelines[audStr]; !ok {
return ErrPipelineNotActive
}

pipeline, ok := s.pipelines[audStr][cmd.GetPausePipeline().PipelineId]
if !ok {
// Is this audience known?
if _, ok := src[audStr]; !ok {
s.config.Logger.Debugf("Attempted to %s pipeline %s for audience %s but no such audience known", action, cmd.GetPausePipeline().PipelineId, audStr)
return ErrPipelineNotActive
}

if _, ok := s.pipelinesPaused[audStr]; !ok {
s.pipelinesPaused[audStr] = make(map[string]*protos.Command)
}

s.pipelinesPaused[audStr][cmd.GetPausePipeline().PipelineId] = pipeline

delete(s.pipelines[audStr], cmd.GetPausePipeline().PipelineId)

if len(s.pipelines[audStr]) == 0 {
delete(s.pipelines, audStr)
}
// Audience is known; is pipeline known?
srcPipelineIndex := getPipelineIndex(src[audStr], cmd.GetPausePipeline().PipelineId)

return nil
}

func (s *Streamdal) resumePipeline(_ context.Context, cmd *protos.Command) error {
if cmd == nil {
return ErrEmptyCommand
if srcPipelineIndex == -1 {
s.config.Logger.Debugf("Attempted to %s pipeline %s for audience %s but no such pipeline known", action, cmd.GetPausePipeline().PipelineId, audStr)
return ErrPipelineNotActive
}

s.pipelinesMtx.Lock()
defer s.pipelinesMtx.Unlock()
s.pipelinesPausedMtx.Lock()
defer s.pipelinesPausedMtx.Unlock()

audStr := audToStr(cmd.Audience)

if _, ok := s.pipelinesPaused[audStr]; !ok {
return ErrPipelineNotPaused
// Audience and pipeline exist - if dst map does not contain audience, create pipeline slice
if _, ok := dst[audStr]; !ok {
dst[audStr] = make([]*protos.Command, 0)
}

pipeline, ok := s.pipelinesPaused[audStr][cmd.GetResumePipeline().PipelineId]
if !ok {
return ErrPipelineNotPaused
}
dstPipelineIndex := getPipelineIndex(s.pipelinesPaused[audStr], cmd.GetPausePipeline().PipelineId)

if _, ok := s.pipelines[audStr]; !ok {
s.pipelines[audStr] = make(map[string]*protos.Command)
if dstPipelineIndex != -1 {
// Pipeline already paused, nothing to do
s.config.Logger.Debugf("Attempted to %s pipeline %s for audience %s but pipeline already paused", action, cmd.GetPausePipeline().PipelineId, audStr)
return nil
}

s.pipelines[audStr][cmd.GetResumePipeline().PipelineId] = pipeline
// Pipeline not in dst map, add it
dst[audStr] = append(dst[audStr], src[audStr][srcPipelineIndex])

delete(s.pipelinesPaused[audStr], cmd.GetResumePipeline().PipelineId)
// Remove pipeline from src pipelines map
src[audStr] = append(src[audStr][:srcPipelineIndex], src[audStr][srcPipelineIndex+1:]...)

if len(s.pipelinesPaused[audStr]) == 0 {
delete(s.pipelinesPaused, audStr)
// If src has no pipelines for this audience, remove the audience
if len(src[audStr]) == 0 {
s.config.Logger.Debugf("No active pipelines left for audience %s during %s, removing audience", audStr, action)
delete(src, audStr)
}

s.config.Logger.Debugf("Successful %s for pipeline %s for audience %s", action, cmd.GetPausePipeline().PipelineId, audStr)

return nil
}
Loading
Loading