Skip to content

Commit

Permalink
Merge pull request #52 from input-output-hk/coot/typed-protocols-new-api
Browse files Browse the repository at this point in the history
New API for typed-protocols
  • Loading branch information
coot authored Sep 16, 2024
2 parents f3277f6 + f451040 commit 52a4afd
Show file tree
Hide file tree
Showing 59 changed files with 4,571 additions and 1,286 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/haskell.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ jobs:
- name: typed-protocols-examples [test]
run: cabal run typed-protocols-examples:test

- name: typed-protocols-doc [test]
run: cabal test typed-protocols-doc
# - name: typed-protocols-doc [test]
# run: cabal test typed-protocols-doc

stylish-haskell:
runs-on: ubuntu-22.04
Expand Down
6 changes: 4 additions & 2 deletions cabal.project
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ repository cardano-haskell-packages

index-state:
hackage.haskell.org 2024-08-27T18:06:30Z
, cardano-haskell-packages 2024-06-27T10:53:24Z
, cardano-haskell-packages 2024-07-24T14:16:32Z

packages: ./typed-protocols
./typed-protocols-cborg
./typed-protocols-stateful
./typed-protocols-stateful-cborg
./typed-protocols-examples
./typed-protocols-doc
-- ./typed-protocols-doc

test-show-details: direct
4 changes: 2 additions & 2 deletions scripts/check-stylish.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ export LC_ALL=C.UTF-8

[[ -x '/usr/bin/fd' ]] && FD="fd" || FD="fdfind"

$FD . './typed-protocols' -e hs -E Setup.hs -X stylish-haskell -c .stylish-haskell.yaml -i
$FD . './typed-protocols' -e hs -E Setup.hs -E Core.hs -X stylish-haskell -c .stylish-haskell.yaml -i
$FD . './typed-protocols-cborg' -e hs -E Setup.hs -X stylish-haskell -c .stylish-haskell.yaml -i
$FD . './typed-protocols-examples' -e hs -E Setup.hs -X stylish-haskell -c .stylish-haskell.yaml -i
$FD . './typed-protocols-examples' -e hs -E Setup.hs -E Channel.hs -X stylish-haskell -c .stylish-haskell.yaml -i
64 changes: 40 additions & 24 deletions typed-protocols-cborg/src/Network/TypedProtocol/Codec/CBOR.hs
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MonoLocalBinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.TypedProtocol.Codec.CBOR
( module Network.TypedProtocol.Codec
, DeserialiseFailure
, mkCodecCborLazyBS
, mkCodecCborStrictBS
, convertCborDecoderBS
, convertCborDecoderLBS
-- * Re-exports
, CBOR.DeserialiseFailure (..)
) where

import Control.Monad.Class.MonadST (MonadST (..))
Expand All @@ -27,8 +33,6 @@ import Network.TypedProtocol.Codec
import Network.TypedProtocol.Core


type DeserialiseFailure = CBOR.DeserialiseFailure

-- | Construct a 'Codec' for a CBOR based serialisation format, using strict
-- 'BS.ByteString's.
--
Expand All @@ -44,19 +48,23 @@ type DeserialiseFailure = CBOR.DeserialiseFailure
mkCodecCborStrictBS
:: forall ps m. MonadST m

=> (forall (pr :: PeerRole) (st :: ps) (st' :: ps).
PeerHasAgency pr st
-> Message ps st st' -> CBOR.Encoding)
=> (forall (st :: ps) (st' :: ps).
StateTokenI st
=> ActiveState st
=> Message ps st st' -> CBOR.Encoding)
-- ^ cbor encoder

-> (forall (pr :: PeerRole) (st :: ps) s.
PeerHasAgency pr st
-> (forall (st :: ps) s.
ActiveState st
=> StateToken st
-> CBOR.Decoder s (SomeMessage st))
-- ^ cbor decoder

-> Codec ps DeserialiseFailure m BS.ByteString
-> Codec ps CBOR.DeserialiseFailure m BS.ByteString
mkCodecCborStrictBS cborMsgEncode cborMsgDecode =
Codec {
encode = \stok msg -> convertCborEncoder (cborMsgEncode stok) msg,
decode = \stok -> convertCborDecoder (cborMsgDecode stok)
encode = \msg -> convertCborEncoder cborMsgEncode msg,
decode = \stok -> convertCborDecoder (cborMsgDecode stok)
}
where
convertCborEncoder :: (a -> CBOR.Encoding) -> a -> BS.ByteString
Expand All @@ -66,20 +74,22 @@ mkCodecCborStrictBS cborMsgEncode cborMsgDecode =

convertCborDecoder
:: (forall s. CBOR.Decoder s a)
-> m (DecodeStep BS.ByteString DeserialiseFailure m a)
-> m (DecodeStep BS.ByteString CBOR.DeserialiseFailure m a)
convertCborDecoder cborDecode =
convertCborDecoderBS cborDecode stToIO

convertCborDecoderBS
:: forall s m a. Functor m
=> (CBOR.Decoder s a)
=> CBOR.Decoder s a
-- ^ cbor decoder
-> (forall b. ST s b -> m b)
-> m (DecodeStep BS.ByteString DeserialiseFailure m a)
-- ^ lift ST computation (e.g. 'Control.Monad.ST.stToIO', 'stToPrim', etc)
-> m (DecodeStep BS.ByteString CBOR.DeserialiseFailure m a)
convertCborDecoderBS cborDecode liftST =
go <$> liftST (CBOR.deserialiseIncremental cborDecode)
where
go :: CBOR.IDecode s a
-> DecodeStep BS.ByteString DeserialiseFailure m a
-> DecodeStep BS.ByteString CBOR.DeserialiseFailure m a
go (CBOR.Done trailing _ x)
| BS.null trailing = DecodeDone x Nothing
| otherwise = DecodeDone x (Just trailing)
Expand All @@ -98,19 +108,23 @@ convertCborDecoderBS cborDecode liftST =
mkCodecCborLazyBS
:: forall ps m. MonadST m

=> (forall (pr :: PeerRole) (st :: ps) (st' :: ps).
PeerHasAgency pr st
-> Message ps st st' -> CBOR.Encoding)
=> (forall (st :: ps) (st' :: ps).
StateTokenI st
=> ActiveState st
=> Message ps st st' -> CBOR.Encoding)
-- ^ cbor encoder

-> (forall (pr :: PeerRole) (st :: ps) s.
PeerHasAgency pr st
-> (forall (st :: ps) s.
ActiveState st
=> StateToken st
-> CBOR.Decoder s (SomeMessage st))
-- ^ cbor decoder

-> Codec ps CBOR.DeserialiseFailure m LBS.ByteString
mkCodecCborLazyBS cborMsgEncode cborMsgDecode =
Codec {
encode = \stok msg -> convertCborEncoder (cborMsgEncode stok) msg,
decode = \stok -> convertCborDecoder (cborMsgDecode stok)
encode = \msg -> convertCborEncoder cborMsgEncode msg,
decode = \stok -> convertCborDecoder (cborMsgDecode stok)
}
where
convertCborEncoder :: (a -> CBOR.Encoding) -> a -> LBS.ByteString
Expand All @@ -127,8 +141,10 @@ mkCodecCborLazyBS cborMsgEncode cborMsgDecode =

convertCborDecoderLBS
:: forall s m a. Monad m
=> (CBOR.Decoder s a)
=> CBOR.Decoder s a
-- ^ cbor decoder
-> (forall b. ST s b -> m b)
-- ^ lift ST computation (e.g. 'Control.Monad.ST.stToIO', 'stToPrim', etc)
-> m (DecodeStep LBS.ByteString CBOR.DeserialiseFailure m a)
convertCborDecoderLBS cborDecode liftST =
go [] =<< liftST (CBOR.deserialiseIncremental cborDecode)
Expand All @@ -148,7 +164,7 @@ convertCborDecoderLBS cborDecode liftST =
-- We keep a bunch of chunks and supply the CBOR decoder with them
-- until we run out, when we go get another bunch.
go (c:cs) (CBOR.Partial k) = go cs =<< liftST (k (Just c))
go [] (CBOR.Partial k) = return $ DecodePartial $ \mbs -> case mbs of
go [] (CBOR.Partial k) = return $ DecodePartial $ \case
Nothing -> go [] =<< liftST (k Nothing)
Just bs -> go cs (CBOR.Partial k)
where cs = LBS.toChunks bs
Expand Down
3 changes: 2 additions & 1 deletion typed-protocols-cborg/typed-protocols-cborg.cabal
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cabal-version: 3.0
name: typed-protocols-cborg
version: 0.1.0.4
version: 0.2.0.0
synopsis: CBOR codecs for typed-protocols
-- description:
license: Apache-2.0
Expand All @@ -21,6 +21,7 @@ library
build-depends: base >=4.12 && <4.21,
bytestring >=0.10 && <0.13,
cborg >=0.2.1 && <0.3,
singletons,

io-classes ^>=1.5,
typed-protocols
Expand Down
71 changes: 65 additions & 6 deletions typed-protocols-examples/src/Network/TypedProtocol/Channel.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module Network.TypedProtocol.Channel
( Channel (..)
Expand All @@ -10,8 +12,12 @@ module Network.TypedProtocol.Channel
, fixedInputChannel
, mvarsAsChannel
, handlesAsChannel
#if !defined(mingw32_HOST_OS)
, socketAsChannel
#endif
, createConnectedChannels
, createConnectedBufferedChannels
, createConnectedBufferedChannelsUnbounded
, createPipelineTestChannels
, channelEffect
, delayChannel
Expand All @@ -25,8 +31,14 @@ import Control.Monad.Class.MonadTimer.SI
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import Data.ByteString.Lazy.Internal (smallChunkSize)
import Data.Proxy
import Numeric.Natural

#if !defined(mingw32_HOST_OS)
import Network.Socket (Socket)
import qualified Network.Socket.ByteString.Lazy as Socket
#endif

import qualified System.IO as IO (Handle, hFlush, hIsEOF)


Expand Down Expand Up @@ -119,12 +131,20 @@ mvarsAsChannel bufferRead bufferWrite =
--
-- This is primarily useful for testing protocols.
--
createConnectedChannels :: MonadSTM m => m (Channel m a, Channel m a)
createConnectedChannels :: forall m a. (MonadLabelledSTM m, MonadTraceSTM m, Show a) => m (Channel m a, Channel m a)
createConnectedChannels = do
-- Create two TMVars to act as the channel buffer (one for each direction)
-- and use them to make both ends of a bidirectional channel
bufferA <- atomically $ newEmptyTMVar
bufferB <- atomically $ newEmptyTMVar
bufferA <- atomically $ do
v <- newEmptyTMVar
labelTMVar v "buffer-a"
traceTMVar (Proxy @m) v $ \_ a -> pure $ TraceString ("buffer-a: " ++ show a)
return v
bufferB <- atomically $ do
v <- newEmptyTMVar
traceTMVar (Proxy @m) v $ \_ a -> pure $ TraceString ("buffer-b: " ++ show a)
labelTMVar v "buffer-b"
return v

return (mvarsAsChannel bufferB bufferA,
mvarsAsChannel bufferA bufferB)
Expand Down Expand Up @@ -156,11 +176,32 @@ createConnectedBufferedChannels sz = do
recv = atomically (Just <$> readTBQueue bufferRead)


-- | Create a pair of channels that are connected via two unbounded buffers.
--
-- This is primarily useful for testing protocols.
--
createConnectedBufferedChannelsUnbounded :: forall m a. MonadSTM m
=> m (Channel m a, Channel m a)
createConnectedBufferedChannelsUnbounded = do
-- Create two TQueues to act as the channel buffers (one for each
-- direction) and use them to make both ends of a bidirectional channel
bufferA <- newTQueueIO
bufferB <- newTQueueIO

return (queuesAsChannel bufferB bufferA,
queuesAsChannel bufferA bufferB)
where
queuesAsChannel bufferRead bufferWrite =
Channel{send, recv}
where
send x = atomically (writeTQueue bufferWrite x)
recv = atomically ( Just <$> readTQueue bufferRead)

-- | Create a pair of channels that are connected via N-place buffers.
--
-- This variant /fails/ when 'send' would exceed the maximum buffer size.
-- Use this variant when you want the 'PeerPipelined' to limit the pipelining
-- itself, and you want to check that it does not exceed the expected level of
-- Use this variant when you want the 'Peer' to limit the pipelining itself,
-- and you want to check that it does not exceed the expected level of
-- pipelining.
--
-- This is primarily useful for testing protocols.
Expand Down Expand Up @@ -194,7 +235,8 @@ createPipelineTestChannels sz = do
--
-- The Handles should be open in the appropriate read or write mode, and in
-- binary mode. Writes are flushed after each write, so it is safe to use
-- a buffering mode.
-- a buffering mode. On unix named pipes can be used, see
-- 'Network.TypedProtocol.ReqResp.Test.prop_namedPipePipelined_IO'
--
-- For bidirectional handles it is safe to pass the same handle for both.
--
Expand Down Expand Up @@ -251,6 +293,23 @@ delayChannel delay = channelEffect (\_ -> return ())
(\_ -> threadDelay delay)


#if !defined(mingw32_HOST_OS)
socketAsChannel :: Socket
-> Channel IO LBS.ByteString
socketAsChannel sock =
Channel{send, recv}
where
send :: LBS.ByteString -> IO ()
send = Socket.sendAll sock

recv :: IO (Maybe LBS.ByteString)
recv = do
bs <- Socket.recv sock (fromIntegral smallChunkSize)
if LBS.null bs
then return Nothing
else return (Just bs)
#endif

-- | Channel which logs sent and received messages.
--
loggingChannel :: ( MonadSay m
Expand Down
Loading

0 comments on commit 52a4afd

Please sign in to comment.