Skip to content

Commit

Permalink
feat: upgrade to use gorilla/websocket
Browse files Browse the repository at this point in the history
Closes #12
  • Loading branch information
dignifiedquire committed Jul 1, 2017
1 parent fa8d0d2 commit 398fcd5
Showing 1 changed file with 89 additions and 17 deletions.
106 changes: 89 additions & 17 deletions p2p/transport/websocket/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (
"net"
"net/http"
"net/url"
"time"

wsGorilla "github.com/gorilla/websocket"
tpt "github.com/libp2p/go-libp2p-transport"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr-net"
Expand All @@ -29,6 +31,9 @@ var WsCodec = &manet.NetCodec{
ParseNetAddr: ParseWebsocketNetAddr,
}

// Default gorilla upgrader
var upgrader = wsGorilla.Upgrader{}

func init() {
err := ma.AddProtocol(WsProtocol)
if err != nil {
Expand Down Expand Up @@ -107,12 +112,15 @@ func (d *dialer) DialContext(ctx context.Context, raddr ma.Multiaddr) (tpt.Conn,
return nil, err
}

wscon, err := ws.Dial(wsurl, "", "http://127.0.0.1:0/")
// TODO: figure out origins, probably don't work for us
// header := http.Header{}
// header.Set("Origin", "http://127.0.0.1:0/")
wscon, _, err := wsGorilla.DefaultDialer.Dial(wsurl, nil)
if err != nil {
return nil, err
}

mnc, err := manet.WrapNetConn(wscon)
mnc, err := manet.WrapNetConn(NewGorillaNetConn(wscon))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -141,17 +149,19 @@ type listener struct {
incoming chan *conn

tpt tpt.Transport

origin *url.URL
}

type conn struct {
*ws.Conn
*GorillaNetConn

done func()
}

func (c *conn) Close() error {
c.done()
return c.Conn.Close()
return c.GorillaNetConn.Close()
}

func (t *WebsocketTransport) Listen(a ma.Multiaddr) (tpt.Listener, error) {
Expand All @@ -160,38 +170,41 @@ func (t *WebsocketTransport) Listen(a ma.Multiaddr) (tpt.Listener, error) {
return nil, err
}

tlist := t.wrapListener(list)

u, err := url.Parse("ws://" + list.Addr().String())
u, err := url.Parse("http://" + list.Addr().String())
if err != nil {
return nil, err
}

s := &ws.Server{
Handler: tlist.handleWsConn,
Config: ws.Config{Origin: u},
}
tlist := t.wrapListener(list, u)

go http.Serve(list.NetListener(), s)
http.HandleFunc("/", tlist.handleWsConn)
go http.Serve(list.NetListener(), nil)

return tlist, nil
}

func (t *WebsocketTransport) wrapListener(l manet.Listener) *listener {
func (t *WebsocketTransport) wrapListener(l manet.Listener, origin *url.URL) *listener {
return &listener{
Listener: l,
incoming: make(chan *conn),
tpt: t,
origin: origin,
}
}

func (l *listener) handleWsConn(s *ws.Conn) {
func (l *listener) handleWsConn(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
http.Error(w, "Failed to upgrade websocket", 400)
return
}

ctx, cancel := context.WithCancel(context.Background())
s.PayloadType = ws.BinaryFrame

wrapped := NewGorillaNetConn(c)
l.incoming <- &conn{
Conn: s,
done: cancel,
GorillaNetConn: &wrapped,
done: cancel,
}

// wait until conn gets closed, otherwise the handler closes it early
Expand Down Expand Up @@ -225,3 +238,62 @@ func (l *listener) Multiaddr() ma.Multiaddr {
}

var _ tpt.Transport = (*WebsocketTransport)(nil)

type GorillaNetConn struct {
Inner *wsGorilla.Conn
DefaultMessageType int
}

func (c GorillaNetConn) Read(b []byte) (n int, err error) {
fmt.Println("reading")
_, r, err := c.Inner.NextReader()
if err != nil {
return 0, err
}

return r.Read(b)
}

func (c GorillaNetConn) Write(b []byte) (n int, err error) {
fmt.Printf("write %s\n", string(b))
if err := c.Inner.WriteMessage(c.DefaultMessageType, b); err != nil {
return 0, err
}

return len(b), nil
}

func (c GorillaNetConn) Close() error {
return c.Inner.Close()
}

func (c GorillaNetConn) LocalAddr() net.Addr {
return c.Inner.LocalAddr()
}

func (c GorillaNetConn) RemoteAddr() net.Addr {
return c.Inner.RemoteAddr()
}

func (c GorillaNetConn) SetDeadline(t time.Time) error {
if err := c.SetReadDeadline(t); err != nil {
return err
}

return c.SetReadDeadline(t)
}

func (c GorillaNetConn) SetReadDeadline(t time.Time) error {
return c.Inner.SetReadDeadline(t)
}

func (c GorillaNetConn) SetWriteDeadline(t time.Time) error {
return c.Inner.SetWriteDeadline(t)
}

func NewGorillaNetConn(raw *wsGorilla.Conn) GorillaNetConn {
return GorillaNetConn{
Inner: raw,
DefaultMessageType: wsGorilla.BinaryMessage,
}
}

0 comments on commit 398fcd5

Please sign in to comment.