Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add pipeline route for mqttproxy #453

Merged
merged 15 commits into from
Jan 12, 2022
25 changes: 10 additions & 15 deletions pkg/context/mqttcontext.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type (
MQTTContext interface {
Context
Client() MQTTClient
Backend() MQTTBackend
// Backend() MQTTBackend
Cancel(error)
Canceled() bool
Duration() time.Duration
Expand All @@ -52,10 +52,10 @@ type (
EarlyStop() bool // if early stop is true, pipeline will skip following filters and return
}

// MQTTBackend is backend of MQTT proxy
MQTTBackend interface {
Publish(target string, data []byte, headers map[string]string) error
}
// // MQTTBackend is backend of MQTT proxy
// MQTTBackend interface {
// Publish(target string, data []byte, headers map[string]string) error
// }

// MQTTClient contains client info that send this packet
MQTTClient interface {
Expand All @@ -74,10 +74,10 @@ type (
ctx stdcontext.Context
cancelFunc stdcontext.CancelFunc

startTime time.Time
endTime time.Time
client MQTTClient
backend MQTTBackend
startTime time.Time
endTime time.Time
client MQTTClient
// backend MQTTBackend
packet packets.ControlPacket
packetType MQTTPacketType

Expand Down Expand Up @@ -116,15 +116,14 @@ const (
var _ MQTTContext = (*mqttContext)(nil)

// NewMQTTContext create new MQTTContext
func NewMQTTContext(ctx stdcontext.Context, backend MQTTBackend, client MQTTClient, packet packets.ControlPacket) MQTTContext {
func NewMQTTContext(ctx stdcontext.Context, client MQTTClient, packet packets.ControlPacket) MQTTContext {
stdctx, cancelFunc := stdcontext.WithCancel(ctx)
startTime := time.Now()
mqttCtx := &mqttContext{
ctx: stdctx,
cancelFunc: cancelFunc,
startTime: startTime,
client: client,
backend: backend,
}

switch packet.(type) {
Expand Down Expand Up @@ -265,7 +264,3 @@ func (ctx *mqttContext) SetEarlyStop() {
func (ctx *mqttContext) EarlyStop() bool {
return atomic.LoadInt32(&ctx.earlyStop) == 1
}

func (ctx *mqttContext) Backend() MQTTBackend {
return ctx.backend
}
24 changes: 0 additions & 24 deletions pkg/context/mqttmock.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,27 +52,3 @@ func (m *MockMQTTClient) Store(key interface{}, value interface{}) {
func (m *MockMQTTClient) Delete(key interface{}) {
m.MockKVMap.Delete(key)
}

// MockMQTTMsg is message send to MockMQTTBackend
type MockMQTTMsg struct {
Target string
Data []byte
Headers map[string]string
}

// MockMQTTBackend is mocked MQTT backend
type MockMQTTBackend struct {
Messages map[string]MockMQTTMsg
}

var _ MQTTBackend = (*MockMQTTBackend)(nil)

// Publish publish msg to MockMQTTBackend
func (m *MockMQTTBackend) Publish(target string, data []byte, headers map[string]string) error {
m.Messages[target] = MockMQTTMsg{
Target: target,
Data: data,
Headers: headers,
}
return nil
}
155 changes: 155 additions & 0 deletions pkg/filter/authentication/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Copyright (c) 2017, MegaEase
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package authentication

import (
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"

"github.com/eclipse/paho.mqtt.golang/packets"
"github.com/megaease/easegress/pkg/context"
"github.com/megaease/easegress/pkg/logger"
"github.com/megaease/easegress/pkg/object/pipeline"
)

const (
// Kind is the kind of Authentication
Kind = "Authentication"

resultAuthFail = "AuthFail"
)

var errAuthFail = errors.New(resultAuthFail)

func init() {
pipeline.Register(&Authentication{})
}

type (
// Authentication is used to check authentication for MQTT client
Authentication struct {
filterSpec *pipeline.FilterSpec
spec *Spec
authMap map[string]string
}

// Spec is spec for Authentication
Spec struct {
Auth []Auth `yaml:"auth" jsonschema:"required"`
}

// Auth describes username and password for MQTTProxy
// passSha256 make sure customer's password is safe.
Auth struct {
UserName string `yaml:"userName" jsonschema:"required"`
PassBase64 string `yaml:"passBase64" jsonschema:"required"`
}
)

var _ pipeline.Filter = (*Authentication)(nil)
var _ pipeline.MQTTFilter = (*Authentication)(nil)

// Kind return kind of Authentication
func (a *Authentication) Kind() string {
return Kind
}

// DefaultSpec return default spec of Authentication
func (a *Authentication) DefaultSpec() interface{} {
return &Spec{}
}

// Description return description of Authentication
func (a *Authentication) Description() string {
return "Authentication can check MQTT client's username and password"
}

// Results return possible results of Authentication
func (a *Authentication) Results() []string {
return []string{resultAuthFail}
}

// Init init Authentication
func (a *Authentication) Init(filterSpec *pipeline.FilterSpec) {
if filterSpec.Protocol() != context.MQTT {
panic("filter ConnectControl only support MQTT protocol for now")
}
a.filterSpec, a.spec = filterSpec, filterSpec.FilterSpec().(*Spec)
a.authMap = make(map[string]string)

for _, auth := range a.spec.Auth {
passwd, err := base64.StdEncoding.DecodeString(auth.PassBase64)
if err != nil {
logger.Errorf("auth with name %v, base64 password %v decode failed: %v", auth.UserName, auth.PassBase64, err)
continue
}
a.authMap[auth.UserName] = sha256Sum(passwd)
}
if len(a.authMap) == 0 {
logger.Errorf("empty valid authentication for MQTT filter %v", filterSpec.Name())
}
}

// Inherit init Authentication based on previous generation
func (k *Authentication) Inherit(filterSpec *pipeline.FilterSpec, previousGeneration pipeline.Filter) {
previousGeneration.Close()
k.Init(filterSpec)
}

// Close close Authentication
func (a *Authentication) Close() {
}

// Status return status of Authentication
func (a *Authentication) Status() interface{} {
return nil
}

func sha256Sum(data []byte) string {
sha256Bytes := sha256.Sum256(data)
return hex.EncodeToString(sha256Bytes[:])
}

func (a *Authentication) checkAuth(connect *packets.ConnectPacket) error {
if connect.ClientIdentifier == "" {
return errAuthFail
}
pass, ok := a.authMap[connect.Username]
if !ok {
return errAuthFail
}
if pass != sha256Sum(connect.Password) {
return errAuthFail
}
return nil
}

// HandleMQTT handle MQTT context
func (a *Authentication) HandleMQTT(ctx context.MQTTContext) *context.MQTTResult {
if ctx.PacketType() != context.MQTTConnect {
return &context.MQTTResult{}
}
err := a.checkAuth(ctx.ConnectPacket())
if err != nil {
ctx.SetDisconnect()
return &context.MQTTResult{Err: errAuthFail}
}
return &context.MQTTResult{}
}
118 changes: 118 additions & 0 deletions pkg/filter/authentication/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Copyright (c) 2017, MegaEase
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package authentication

import (
stdcontext "context"
"encoding/base64"
"fmt"
"testing"

"github.com/eclipse/paho.mqtt.golang/packets"
"github.com/megaease/easegress/pkg/context"
"github.com/megaease/easegress/pkg/logger"
"github.com/megaease/easegress/pkg/object/pipeline"
"github.com/stretchr/testify/assert"
)

func init() {
logger.InitNop()
}

func newContext(cid, username, password string) context.MQTTContext {
client := &context.MockMQTTClient{
MockClientID: cid,
}
packet := packets.NewControlPacket(packets.Connect).(*packets.ConnectPacket)
packet.ClientIdentifier = cid
packet.Username = username
packet.Password = []byte(password)
return context.NewMQTTContext(stdcontext.Background(), client, packet)
}

func defaultFilterSpec(spec *Spec) *pipeline.FilterSpec {
meta := &pipeline.FilterMetaSpec{
Name: "connect-demo",
Kind: Kind,
Pipeline: "pipeline-demo",
Protocol: context.MQTT,
}
filterSpec := pipeline.MockFilterSpec(nil, nil, "", meta, spec)
return filterSpec
}

func base64Encode(text string) string {
return base64.StdEncoding.EncodeToString([]byte(text))
}

func TestAuth(t *testing.T) {
assert := assert.New(t)
spec := &Spec{
Auth: []Auth{
{UserName: "test", PassBase64: base64Encode("test")},
{UserName: "admin", PassBase64: base64Encode("admin")},
},
}
fmt.Printf("auth %+v", spec.Auth)
filterSpec := defaultFilterSpec(spec)
auth := &Authentication{}
auth.Init(filterSpec)

assert.Equal(Kind, auth.Kind())
assert.Equal(&Spec{}, auth.DefaultSpec())
assert.NotEmpty(auth.Description())
assert.Equal(1, len(auth.Results()), "please update this case if add more results")
assert.Nil(auth.Status(), "please update this case if return status")

newAuth := &Authentication{}
newAuth.Inherit(filterSpec, auth)
newAuth.Close()
}

func TestAuthMQTTClient(t *testing.T) {
assert := assert.New(t)
spec := &Spec{
Auth: []Auth{
{UserName: "test", PassBase64: base64Encode("test")},
{UserName: "admin", PassBase64: base64Encode("admin")},
},
}

filterSpec := defaultFilterSpec(spec)
auth := &Authentication{}
auth.Init(filterSpec)

tests := []struct {
cid string
name string
pass string
disconnect bool
}{
{"client1", "test", "test", false},
{"client2", "admin", "admin", false},
{"client3", "fake", "test", true},
{"client4", "test", "wrongPass", true},
{"", "test", "test", true},
}

for _, test := range tests {
ctx := newContext(test.cid, test.name, test.pass)
auth.HandleMQTT(ctx)
assert.Equal(test.disconnect, ctx.Disconnect(), fmt.Errorf("test case %+v got wrong result", test))
}
}
5 changes: 1 addition & 4 deletions pkg/filter/connectcontrol/connectcontrol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,12 @@ func init() {
}

func newContext(cid string, topic string) context.MQTTContext {
backend := &context.MockMQTTBackend{
Messages: make(map[string]context.MockMQTTMsg),
}
client := &context.MockMQTTClient{
MockClientID: cid,
}
packet := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket)
packet.TopicName = topic
return context.NewMQTTContext(stdcontext.Background(), backend, client, packet)
return context.NewMQTTContext(stdcontext.Background(), client, packet)
}

func defaultFilterSpec(spec *Spec) *pipeline.FilterSpec {
Expand Down
Loading