Skip to content

Commit

Permalink
Add option for setting additional HTTP headers in WebSocket handshake
Browse files Browse the repository at this point in the history
Certain services (such as AWS IOT when using a custom authorizer) require
additional headers to be sent during the WebSocket opening handshake.  This
adds a function SetHTTPHeaders() to the ClientOptions that enables setting
additional headers.

Signed-off-by: Scott Talbert <[email protected]>
  • Loading branch information
Scott Talbert committed Mar 23, 2018
1 parent 9ec68b7 commit 4c524ad
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 3 deletions.
4 changes: 2 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func (c *client) Connect() Token {
c.options.ProtocolVersion = protocolVersion
CONN:
DEBUG.Println(CLI, "about to write new connect msg")
c.conn, err = openConnection(broker, &c.options.TLSConfig, c.options.ConnectTimeout)
c.conn, err = openConnection(broker, &c.options.TLSConfig, c.options.ConnectTimeout, c.options.HTTPHeaders)
if err == nil {
DEBUG.Println(CLI, "socket connected to broker")
switch c.options.ProtocolVersion {
Expand Down Expand Up @@ -320,7 +320,7 @@ func (c *client) reconnect() {

for _, broker := range c.options.Servers {
DEBUG.Println(CLI, "about to write new connect msg")
c.conn, err = openConnection(broker, &c.options.TLSConfig, c.options.ConnectTimeout)
c.conn, err = openConnection(broker, &c.options.TLSConfig, c.options.ConnectTimeout, c.options.HTTPHeaders)
if err == nil {
DEBUG.Println(CLI, "socket connected to broker")
switch c.options.ProtocolVersion {
Expand Down
4 changes: 3 additions & 1 deletion net.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"reflect"
Expand All @@ -37,7 +38,7 @@ func signalError(c chan<- error, err error) {
}
}

func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration) (net.Conn, error) {
func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, headers http.Header) (net.Conn, error) {
switch uri.Scheme {
case "ws":
conn, err := websocket.Dial(uri.String(), "mqtt", fmt.Sprintf("http://%s", uri.Host))
Expand All @@ -50,6 +51,7 @@ func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration) (net.
config, _ := websocket.NewConfig(uri.String(), fmt.Sprintf("https://%s", uri.Host))
config.Protocol = []string{"mqtt"}
config.TlsConfig = tlsc
config.Header = headers
conn, err := websocket.DialConfig(config)
if err != nil {
return nil, err
Expand Down
10 changes: 10 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package mqtt

import (
"crypto/tls"
"net/http"
"net/url"
"strings"
"time"
Expand Down Expand Up @@ -69,6 +70,7 @@ type ClientOptions struct {
OnConnectionLost ConnectionLostHandler
WriteTimeout time.Duration
MessageChannelDepth uint
HTTPHeaders http.Header
}

// NewClientOptions will create a new ClientClientOptions type with some
Expand Down Expand Up @@ -106,6 +108,7 @@ func NewClientOptions() *ClientOptions {
OnConnectionLost: DefaultConnectionLostHandler,
WriteTimeout: 0, // 0 represents timeout disabled
MessageChannelDepth: 100,
HTTPHeaders: make(map[string][]string),
}
return o
}
Expand Down Expand Up @@ -318,3 +321,10 @@ func (o *ClientOptions) SetMessageChannelDepth(s uint) *ClientOptions {
o.MessageChannelDepth = s
return o
}

// SetHTTPHeaders sets the additional HTTP headers that will be sent in the WebSocket
// opening handshake.
func (o *ClientOptions) SetHTTPHeaders(h http.Header) *ClientOptions {
o.HTTPHeaders = h
return o
}
6 changes: 6 additions & 0 deletions options_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package mqtt

import (
"crypto/tls"
"net/http"
"net/url"
"time"
)
Expand Down Expand Up @@ -135,3 +136,8 @@ func (r *ClientOptionsReader) MessageChannelDepth() uint {
s := r.options.MessageChannelDepth
return s
}

func (r *ClientOptionsReader) HTTPHeaders() http.Header {
h := r.options.HTTPHeaders
return h
}

0 comments on commit 4c524ad

Please sign in to comment.