Skip to content

Commit

Permalink
kafka: implement basic SASL auth (#1677)
Browse files Browse the repository at this point in the history
  • Loading branch information
Commelina authored Nov 13, 2023
1 parent ff7b2b6 commit be42734
Show file tree
Hide file tree
Showing 16 changed files with 515 additions and 32 deletions.
87 changes: 64 additions & 23 deletions hstream-kafka/HStream/Kafka/Network.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
module HStream.Kafka.Network
( -- * Server
ServerOptions (..)
, SaslOptions (..)
, defaultServerOpts
, runServer
-- * Client
Expand All @@ -18,24 +19,26 @@ module HStream.Kafka.Network
) where

import Control.Concurrent
import qualified Control.Exception as E
import qualified Control.Exception as E
import Control.Monad
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL
import Data.Int
import Data.List (find, intersperse)
import Data.Maybe (fromMaybe, isJust,
isNothing)
import qualified Network.Socket as N
import qualified Network.Socket.ByteString as N
import qualified Network.Socket.ByteString.Lazy as NL
import Numeric (showHex, showInt)
import Data.List (find, intersperse)
import Data.Maybe (fromMaybe, isJust,
isNothing)
import qualified Network.Socket as N
import qualified Network.Socket.ByteString as N
import qualified Network.Socket.ByteString.Lazy as NL
import Numeric (showHex, showInt)

import HStream.Kafka.Common.OffsetManager (initOffsetReader)
import HStream.Kafka.Server.Types (ServerContext (..))
import qualified HStream.Logger as Log
import HStream.Kafka.Common.KafkaException (ErrorCodeException (..))
import HStream.Kafka.Common.OffsetManager (initOffsetReader)
import HStream.Kafka.Server.Types (ServerContext (..))
import qualified HStream.Logger as Log
import Kafka.Protocol.Encoding
import qualified Kafka.Protocol.Error as K
import Kafka.Protocol.Message
import Kafka.Protocol.Service

Expand All @@ -45,18 +48,23 @@ import Kafka.Protocol.Service
-- TODO
data SslOptions

-- TODO
data SaslOptions = SaslOptions

data ServerOptions = ServerOptions
{ serverHost :: !String
, serverPort :: !Int
, serverSslOptions :: !(Maybe SslOptions)
, serverOnStarted :: !(Maybe (IO ()))
{ serverHost :: !String
, serverPort :: !Int
, serverSslOptions :: !(Maybe SslOptions)
, serverSaslOptions :: !(Maybe SaslOptions)
, serverOnStarted :: !(Maybe (IO ()))
}

defaultServerOpts :: ServerOptions
defaultServerOpts = ServerOptions
{ serverHost = "0.0.0.0"
, serverPort = 9092
, serverSslOptions = Nothing
, serverSaslOptions = Nothing
, serverOnStarted = Nothing
}

Expand All @@ -66,34 +74,66 @@ runServer
:: ServerOptions
-> ServerContext
-> (ServerContext -> [ServiceHandler])
-> (ServerContext -> [ServiceHandler])
-> IO ()
runServer opts sc mkHandlers =
runServer opts sc mkPreAuthedHandlers mkAuthedHandlers =
startTCPServer opts $ \(s, peer) -> do
-- Since the Reader is thread-unsafe, for each connection we create a new
-- Reader.
om <- initOffsetReader $ scOffsetManager sc
let sc' = sc{scOffsetManager = om}
i <- N.recv s 1024
talk (peer, (mkHandlers sc')) i Nothing s
-- Decide if we require SASL authentication
case (serverSaslOptions opts) of
Nothing -> talk (peer, mkAuthedHandlers sc') i Nothing s
_ -> do i' <- authTalk (peer, (mkPreAuthedHandlers sc')) i Nothing s
talk (peer, mkAuthedHandlers sc') i' Nothing s
where
authTalk _ "" _ _ = pure mempty -- client exit
authTalk !(peer, hds) i m_more s = do
reqBsResult <- case m_more of
Nothing -> runParser @ByteString get i
Just mf -> mf i
case reqBsResult of
Done "" reqBs -> do authed <- newEmptyMVar
respBs <- runHandler peer hds reqBs authed
NL.sendAll s respBs
tryTakeMVar authed >>= \case
Just True -> N.recv s 1024
Just False -> E.throwIO $ ErrorCodeException K.SASL_AUTHENTICATION_FAILED
Nothing -> do msg <- N.recv s 1024
authTalk (peer, hds) msg Nothing s
Done l reqBs -> do authed <- newEmptyMVar
respBs <- runHandler peer hds reqBs authed
NL.sendAll s respBs
tryTakeMVar authed >>= \case
Just True -> return l
Just False -> E.throwIO $ ErrorCodeException K.SASL_AUTHENTICATION_FAILED
Nothing -> authTalk (peer, hds) l Nothing s
More f -> do msg <- N.recv s 1024
authTalk (peer, hds) msg (Just f) s
Fail _ err -> E.throwIO $ DecodeError $ "Fail, " <> err

talk _ "" _ _ = pure () -- client exit
talk !(peer, hds) i m_more s = do
reqBsResult <- case m_more of
Nothing -> runParser @ByteString get i
Just mf -> mf i
case reqBsResult of
Done "" reqBs -> do respBs <- runHandler peer hds reqBs
Done "" reqBs -> do respBs <- runHandler peer hds reqBs undefined -- FIXME: unused 'authed' as 'undefined'. Better way?
NL.sendAll s respBs
msg <- N.recv s 1024
talk (peer, hds) msg Nothing s
Done l reqBs -> do respBs <- runHandler peer hds reqBs
Done l reqBs -> do respBs <- runHandler peer hds reqBs undefined -- FIXME: unused 'authed' as 'undefined'. Better way?
NL.sendAll s respBs
talk (peer, hds) l Nothing s
More f -> do msg <- N.recv s 1024
talk (peer, hds) msg (Just f) s
Fail _ err -> E.throwIO $ DecodeError $ "Fail, " <> err

runHandler peer handlers reqBs = do
-- 'authed :: MVar Bool'. Empty at the beginning and is only used by SASL auth.
-- FIXME: better way?
runHandler peer handlers reqBs authed = do
headerResult <- runParser @RequestHeader get reqBs
case headerResult of
Done l RequestHeader{..} -> do
Expand All @@ -112,6 +152,7 @@ runServer opts sc mkHandlers =
RequestContext
{ clientId = requestClientId
, clientHost = showSockAddrHost peer
, clientAuthDone = authed
}
resp <- rpcHandler' reqContext req
Log.debug $ "Server response: " <> Log.buildString' resp
Expand Down
6 changes: 6 additions & 0 deletions hstream-kafka/HStream/Kafka/Server/Config/FromCli.hs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ cliOptionsParser = do

cliStoreCompression <- optional storeCompressionParser

cliEnableSaslAuth <- enableSaslAuthParser

return CliOptions{..}

-------------------------------------------------------------------------------
Expand Down Expand Up @@ -249,6 +251,10 @@ storeConfigPathParser = strOption
<> metavar "PATH" <> value "/data/store/logdevice.conf"
<> help "Storage config path"

enableSaslAuthParser :: O.Parser Bool
enableSaslAuthParser = flag False True
$ long "enable-sasl"
<> help "Enable SASL authentication"
-------------------------------------------------------------------------------

parserOpt :: (Text -> Either String a) -> O.Mod O.OptionFields a -> O.Parser a
Expand Down
4 changes: 4 additions & 0 deletions hstream-kafka/HStream/Kafka/Server/Config/FromJson.hs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ parseJSONToOptions CliOptions{..} obj = do
let !_securityProtocolMap = defaultProtocolMap tlsConfig
let !_listenersSecurityProtocolMap = Map.union cliListenersSecurityProtocolMap nodeListenersSecurityProtocolMap

-- SASL config
nodeEnableSaslAuth <- nodeCfgObj .:? "enable-sasl" .!= False
let !_enableSaslAuth = cliEnableSaslAuth || nodeEnableSaslAuth

return ServerOpts {..}

-------------------------------------------------------------------------------
Expand Down
5 changes: 5 additions & 0 deletions hstream-kafka/HStream/Kafka/Server/Config/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ data ServerOpts = ServerOpts
, _ldConfigPath :: !CBytes

, _compression :: !Compression

, _enableSaslAuth :: !Bool
} deriving (Show, Eq)

-------------------------------------------------------------------------------
Expand Down Expand Up @@ -117,6 +119,9 @@ data CliOptions = CliOptions

-- Internal options
, cliStoreCompression :: !(Maybe Compression)

-- SASL Authentication
, cliEnableSaslAuth :: !Bool
} deriving Show

-------------------------------------------------------------------------------
Expand Down
28 changes: 24 additions & 4 deletions hstream-kafka/HStream/Kafka/Server/Handler.hsc
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}

module HStream.Kafka.Server.Handler (handlers) where
module HStream.Kafka.Server.Handler
( handlers
, unAuthedHandlers
) where

import HStream.Kafka.Server.Handler.Basic
import HStream.Kafka.Server.Handler.Consume
import HStream.Kafka.Server.Handler.Group
import HStream.Kafka.Server.Handler.Offset
import HStream.Kafka.Server.Handler.Produce
import HStream.Kafka.Server.Handler.Security
import HStream.Kafka.Server.Handler.Topic
import HStream.Kafka.Server.Types (ServerContext (..))
import qualified Kafka.Protocol.Message as K
import qualified Kafka.Protocol.Service as K
import HStream.Kafka.Server.Types (ServerContext (..))
import qualified Kafka.Protocol.Message as K
import qualified Kafka.Protocol.Service as K

-------------------------------------------------------------------------------

Expand Down Expand Up @@ -54,6 +58,9 @@ import qualified Kafka.Protocol.Service as K
#cv_handler ApiVersions, 0, 3
#cv_handler Fetch, 0, 2

#cv_handler SaslHandshake, 0, 1
#cv_handler SaslAuthenticate, 0, 0

handlers :: ServerContext -> [K.ServiceHandler]
handlers sc =
[ #mk_handler ApiVersions, 0, 3
Expand Down Expand Up @@ -91,4 +98,17 @@ handlers sc =
, K.hd (K.RPC :: K.RPC K.HStreamKafkaV0 "heartbeat") (handleHeartbeatV0 sc)
, K.hd (K.RPC :: K.RPC K.HStreamKafkaV0 "listGroups") (handleListGroupsV0 sc)
, K.hd (K.RPC :: K.RPC K.HStreamKafkaV0 "describeGroups") (handleDescribeGroupsV0 sc)

, K.hd (K.RPC :: K.RPC K.HStreamKafkaV0 "saslHandshake") (handleAfterAuthSaslHandshakeV0 sc)
, K.hd (K.RPC :: K.RPC K.HStreamKafkaV1 "saslHandshake") (handleAfterAuthSaslHandshakeV1 sc)

, K.hd (K.RPC :: K.RPC K.HStreamKafkaV0 "saslAuthenticate") (handleAfterAuthSaslAuthenticateV0 sc)
]

unAuthedHandlers :: ServerContext -> [K.ServiceHandler]
unAuthedHandlers sc =
[ #mk_handler ApiVersions, 0, 3

, #mk_handler SaslHandshake, 0, 1
, #mk_handler SaslAuthenticate, 0, 0
]
104 changes: 104 additions & 0 deletions hstream-kafka/HStream/Kafka/Server/Handler/Security.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
module HStream.Kafka.Server.Handler.Security
( handleSaslHandshake
, handleSaslAuthenticate

, handleAfterAuthSaslHandshakeV0
, handleAfterAuthSaslHandshakeV1
, handleAfterAuthSaslAuthenticateV0
) where

import Control.Concurrent.MVar
import Control.Monad.IO.Class (liftIO)
import Data.ByteString (ByteString)
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Vector as V
import Network.Protocol.SASL.GNU

import HStream.Kafka.Server.Types (ServerContext (..))
import qualified HStream.Logger as Log
import qualified Kafka.Protocol.Encoding as K
import qualified Kafka.Protocol.Error as K
import qualified Kafka.Protocol.Message as K
import qualified Kafka.Protocol.Service as K

-------------------------------------------------------------------------------
saslPlainCallback :: Property -> Session Progress
saslPlainCallback PropertyAuthID = do
authID <- getProperty PropertyAuthID
liftIO . Log.debug $ "SASL PLAIN invoke callback with PropertyAuthID. I got " <> Log.buildString' authID
return Complete
saslPlainCallback PropertyPassword = do
password <- getProperty PropertyPassword
liftIO . Log.debug $ "SASL PLAIN invoke callback with PropertyPassword. I got" <> Log.buildString' password
return Complete
saslPlainCallback PropertyAuthzID = do
authzID <- getProperty PropertyAuthzID
liftIO . Log.debug $ "SASL PLAIN invoke callback with PropertyAuthzID. I got" <> Log.buildString' authzID
return Complete
saslPlainCallback ValidateSimple = do
liftIO . Log.debug $ "SASL PLAIN invoke callback with ValidateSimple..."
authID <- getProperty PropertyAuthID
password <- getProperty PropertyPassword
if authID == Just "admin" && password == Just "passwd" -- FIXME: do actual check
then return Complete
else throw AuthenticationError
saslPlainCallback prop = do
liftIO . Log.warning $ "SASL PLAIN invoke callback with " <> Log.buildString' prop <> ". But I do not know how to handle it..."
return Complete

saslPlainSession :: ByteString -> Session ByteString
saslPlainSession input = do
mechanism <- mechanismName
liftIO . Log.debug $ "SASL: I am using " <> Log.buildString' mechanism
liftIO . Log.debug $ "SASL PLAIN: I got C: " <> Log.build input
(serverMsg, prog) <- step input
case prog of
Complete -> do
liftIO . Log.debug $ "SASL PLAIN: Complete. S: " <> Log.build serverMsg
return serverMsg
NeedsMore -> do
liftIO . Log.warning $ "SASL PLAIN: I need more... But why? S: " <> Log.build serverMsg
throw AuthenticationError

saslPlain :: ByteString -> SASL (Either Error ByteString)
saslPlain input = do
setCallback saslPlainCallback
runServer (Mechanism "PLAIN") (saslPlainSession input)

-------------------------------------------------------------------------------
handleSaslHandshake :: ServerContext -> K.RequestContext -> K.SaslHandshakeRequest -> IO K.SaslHandshakeResponse
handleSaslHandshake _ _ K.SaslHandshakeRequest{..} = do
let reqMechanism = Mechanism (T.encodeUtf8 mechanism)
isMechSupported <- runSASL (serverSupports reqMechanism)
if isMechSupported then do
Log.debug $ "SASL: client requests " <> Log.buildString' mechanism
return $ K.SaslHandshakeResponse K.NONE (K.KaArray $ Just (V.singleton "PLAIN"))
else do
Log.warning $ "SASL: client requests " <> Log.buildString' mechanism <> ", but I do not support it..."
return $ K.SaslHandshakeResponse K.UNSUPPORTED_SASL_MECHANISM (K.KaArray $ Just (V.singleton "PLAIN"))

handleSaslAuthenticate :: ServerContext -> K.RequestContext -> K.SaslAuthenticateRequest -> IO K.SaslAuthenticateResponse
handleSaslAuthenticate _ reqCtx K.SaslAuthenticateRequest{..} = do
respBytes_e <- runSASL (saslPlain authBytes)
case respBytes_e of
Left err -> do
Log.warning $ "SASL: auth failed, " <> Log.buildString' err
putMVar (K.clientAuthDone reqCtx) False
return $ K.SaslAuthenticateResponse K.SASL_AUTHENTICATION_FAILED (Just . T.pack $ show err) mempty
Right respBytes -> do
putMVar (K.clientAuthDone reqCtx) True
return $ K.SaslAuthenticateResponse K.NONE mempty respBytes

-------------------------------------------------------------------------------
handleAfterAuthSaslHandshakeV0 :: ServerContext -> K.RequestContext -> K.SaslHandshakeRequestV0 -> IO K.SaslHandshakeResponseV0
handleAfterAuthSaslHandshakeV0 _ _ _ = return $ K.SaslHandshakeResponseV0 K.ILLEGAL_SASL_STATE (K.KaArray Nothing)

handleAfterAuthSaslHandshakeV1 :: ServerContext -> K.RequestContext -> K.SaslHandshakeRequestV1 -> IO K.SaslHandshakeResponseV1
handleAfterAuthSaslHandshakeV1 = handleAfterAuthSaslHandshakeV0

handleAfterAuthSaslAuthenticateV0 :: ServerContext -> K.RequestContext -> K.SaslAuthenticateRequestV0 -> IO K.SaslAuthenticateResponseV0
handleAfterAuthSaslAuthenticateV0 _ _ _ =
return $ K.SaslAuthenticateResponseV0 K.ILLEGAL_SASL_STATE
(Just "SaslAuthenticate request received after successful authentication")
mempty
2 changes: 2 additions & 0 deletions hstream-kafka/hstream-kafka.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ library
HStream.Kafka.Server.Handler.Group
HStream.Kafka.Server.Handler.Offset
HStream.Kafka.Server.Handler.Produce
HStream.Kafka.Server.Handler.Security
HStream.Kafka.Server.Handler.Topic

cxx-sources: cbits/hs_kafka_client.cpp
Expand All @@ -145,6 +146,7 @@ library
, containers
, directory
, foreign
, gsasl >=0.3.0
, hashable
, hashtables
, haskeline
Expand Down
Loading

0 comments on commit be42734

Please sign in to comment.