Skip to content

Commit

Permalink
rewrite StreamPredictionFiles
Browse files Browse the repository at this point in the history
Again, no goroutines! Only Readers!
  • Loading branch information
philandstuff committed Oct 11, 2024
1 parent b5b7785 commit 38b8355
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 159 deletions.
166 changes: 49 additions & 117 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"io"
"net/http"
"strings"
"time"
"unicode/utf8"

"github.com/vincent-petithory/dataurl"
Expand Down Expand Up @@ -146,18 +145,6 @@ func (r *Client) StreamPrediction(ctx context.Context, prediction *Prediction) (
return sseChan, errChan
}

func (r *Client) StreamPredictionFiles(ctx context.Context, prediction *Prediction) (<-chan streaming.File, error) {
url := prediction.URLs["stream"]
if url == "" {
return nil, errors.New("streaming not supported or not enabled for this prediction")
}

ch := make(chan streaming.File)

go r.streamFilesTo(ctx, ch, url, "")
return ch, nil
}

type textStreamer struct {
s *sse.Streamer
ctx context.Context
Expand Down Expand Up @@ -264,125 +251,70 @@ func (h *httpURL) Body(ctx context.Context) (io.ReadCloser, error) {
return resp.Body, nil
}

type errWrapper struct {
err error
type fileStreamer struct {
s *sse.Streamer
c *http.Client
done bool
}

var _ streaming.File = &errWrapper{}

func fileError(err error) streaming.File {
return &errWrapper{err: err}
}

func (e *errWrapper) Body(_ context.Context) (io.ReadCloser, error) {
return nil, e.err
}

func (e *errWrapper) Close() error {
return nil
}

func (r *Client) streamFilesTo(ctx context.Context, out chan<- streaming.File, url string, lastEventID string) {
defer close(out)
ctx, cancel := context.WithCancel(ctx)
defer cancel()

ch := make(chan event)
go r.streamEventsTo(ctx, ch, url, lastEventID)
func (f *fileStreamer) NextFile(ctx context.Context) (streaming.File, error) {
if f.done {
return nil, io.EOF
}
for {
var url string
e, err := f.s.NextEvent(ctx)
if err != nil {
return nil, err
}
switch e.Type {
case "":
// empty message, ignore
// nchan starts streams with a blank `: hi` message
continue
case SSETypeDone:
f.done = true
return nil, io.EOF
case SSETypeError:
return nil, fmt.Errorf("Error event: %s", e.Data)
case SSETypeOutput:
url = strings.TrimSuffix(e.Data, "\n")
default:
return nil, fmt.Errorf("unexpected type %s, %+v", e.Type, e)
}

for e := range ch {
url := strings.TrimSuffix(e.rawData, "\n")
switch {
case strings.HasPrefix(url, "data:"):
select {
case <-ctx.Done():
case out <- &dataURL{url: url}:
}
return &dataURL{url: url}, nil
case strings.HasPrefix(url, "http"):
select {
case <-ctx.Done():
case out <- &httpURL{c: r.c, url: url}:
}
return &httpURL{c: f.c, url: url}, nil
default:
select {
case <-ctx.Done():
case out <- fileError(fmt.Errorf("Could not parse URL: %s", url)):
}
return
return nil, fmt.Errorf("Could not parse URL: %s", url)
}
}
}

type event struct {
rawData string
err error
func (f *fileStreamer) Close() error {
return f.s.Close()
}

func (r *Client) streamEventsTo(ctx context.Context, out chan<- event, url string, lastEventID string) {
defer close(out)
ATTEMPT:
for attempt := 0; attempt <= r.options.retryPolicy.maxRetries; attempt++ {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
out <- event{err: err}
return
}
req.Header.Set("Accept", "text/event-stream")

if lastEventID != "" {
req.Header.Set("Last-Event-ID", lastEventID)
}

resp, err := r.c.Do(req)
if err != nil {
select {
case <-ctx.Done():
case out <- event{err: fmt.Errorf("failed to send request: %w", err)}:
}
return
}
type FileStreamer interface {
io.Closer
NextFile(ctx context.Context) (streaming.File, error)
}

if resp.StatusCode != http.StatusOK {
out <- event{err: fmt.Errorf("received invalid status code: %d", resp.StatusCode)}
return
}
defer resp.Body.Close()
decoder := sse.NewDecoder(resp.Body)
for {
e, err := decoder.Next()
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
// retry
delay := r.options.retryPolicy.backoff.NextDelay(attempt)
time.Sleep(delay)
continue ATTEMPT
}
select {
case <-ctx.Done():
case out <- event{err: fmt.Errorf("failed to get token: %w", err)}:
}
return
}
lastEventID = e.ID
switch e.Type {
case SSETypeOutput:
select {
case <-ctx.Done():
case out <- event{rawData: e.Data}:
}
case SSETypeDone:
return
case SSETypeLogs:
// TODO
default:
select {
case <-ctx.Done():
case out <- event{err: fmt.Errorf("unknown event type %s", e.Type)}:
}
return
}
}
// StreamPredictionFiles streams prediction file output via the replicate
// streaming api. It is the caller's responsibility to close the returned
// FileStreamer to ensure connections and associated resources are cleaned up
// appropriately.
func (r *Client) StreamPredictionFiles(ctx context.Context, prediction *Prediction) (FileStreamer, error) {

Check failure on line 310 in stream.go

View workflow job for this annotation

GitHub Actions / build

unused-parameter: parameter 'ctx' seems to be unused, consider removing or renaming it as _ (revive)
url := prediction.URLs["stream"]
if url == "" {
return nil, errors.New("streaming not supported or not enabled for this prediction")
}

s := sse.NewStreamer(r.c, url, r.options.retryPolicy.maxRetries, r.options.retryPolicy.backoff)
return &fileStreamer{s: s, c: r.c}, nil
}

func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, lastEvent *SSEEvent, sseChan chan SSEEvent, errChan chan error) {
Expand Down
70 changes: 28 additions & 42 deletions stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ event: done
`)
}))
defer ts.Close()
t.Cleanup(ts.Close)

p := &replicate.Prediction{
URLs: map[string]string{
Expand All @@ -33,12 +33,13 @@ event: done
}

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
t.Cleanup(cancel)

c, err := replicate.NewClient(replicate.WithToken("test-token"))
require.NoError(t, err)

r, err := c.StreamPredictionText(ctx, p)
t.Cleanup(func() { r.Close() })

require.NoError(t, err)

Expand All @@ -59,7 +60,7 @@ event: done
`)
}))
defer ts.Close()
t.Cleanup(ts.Close)

p := &replicate.Prediction{
URLs: map[string]string{
Expand All @@ -68,12 +69,13 @@ event: done
}

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
t.Cleanup(cancel)

c, err := replicate.NewClient(replicate.WithToken("test-token"))
require.NoError(t, err)

r, err := c.StreamPredictionText(ctx, p)
t.Cleanup(func() { r.Close() })

require.NoError(t, err)

Expand Down Expand Up @@ -114,7 +116,7 @@ id: 3
`)
}))
defer ts.Close()
t.Cleanup(ts.Close)

p := &replicate.Prediction{
URLs: map[string]string{
Expand All @@ -123,14 +125,14 @@ id: 3
}

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
t.Cleanup(cancel)

c, err := replicate.NewClient(replicate.WithToken("test-token"))
require.NoError(t, err)

r, err := c.StreamPredictionText(ctx, p)

assert.NoError(t, err)
require.NoError(t, err)
t.Cleanup(func() { r.Close() })

text, err := io.ReadAll(r)
assert.NoError(t, err)
Expand All @@ -157,7 +159,7 @@ event: done
`)
}))
defer ts.Close()
t.Cleanup(ts.Close)
baseURL = ts.URL

p := &replicate.Prediction{
Expand All @@ -167,55 +169,39 @@ event: done
}

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
t.Cleanup(cancel)

c, err := replicate.NewClient(replicate.WithToken("test-token"))
require.NoError(t, err)

files, err := c.StreamPredictionFiles(ctx, p)

assert.NoError(t, err)
require.NoError(t, err)

var body io.Reader
// first file is a data URI
select {
case file := <-files:
require.NotNil(t, file)
body, err = file.Body(ctx)
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "Timed out waiting for file")
return
}
file, err := files.NextFile(ctx)
require.NoError(t, err)
body, err := file.Body(ctx)
require.NoError(t, err)
content1, err := io.ReadAll(body)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, "banana", string(content1))

// second file is a base64'd data URI
select {
case file := <-files:
require.NotNil(t, file)
body, err = file.Body(ctx)
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "Timed out waiting for file")
return
}
file, err = files.NextFile(ctx)
require.NoError(t, err)
body, err = file.Body(ctx)
require.NoError(t, err)
content2, err := io.ReadAll(body)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, "apple", string(content2))

// third file is an http URI
select {
case file := <-files:
require.NotNil(t, file)
body, err = file.Body(ctx)
require.NoError(t, err)
case <-time.After(time.Second):
assert.Fail(t, "Timed out waiting for file")
return
}
file, err = files.NextFile(ctx)
require.NoError(t, err)
body, err = file.Body(ctx)
require.NoError(t, err)
content3, err := io.ReadAll(body)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, "mango\n", string(content3))
}

0 comments on commit 38b8355

Please sign in to comment.