Skip to content

Commit

Permalink
Propagate version header through federator
Browse files Browse the repository at this point in the history
  • Loading branch information
pcapriotti committed Dec 11, 2023
1 parent da88d00 commit f4e06f9
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ type BrigApi =
:<|> FedEndpoint "search-users" SearchRequest SearchResponse
:<|> FedEndpoint "get-user-clients" GetUserClients (UserMap (Set PubClient))
:<|> FedEndpointWithMods '[Until V1] (Versioned 'V0 "get-mls-clients") MLSClientsRequestV0 (Set ClientInfo)
:<|> FedEndpoint "get-mls-clients" MLSClientsRequest (Set ClientInfo)
:<|> FedEndpointWithMods '[From V1] "get-mls-clients" MLSClientsRequest (Set ClientInfo)
:<|> FedEndpoint "send-connection-action" NewConnectionRequest NewConnectionResponse
:<|> FedEndpoint "claim-key-packages" ClaimKeyPackageRequest (Maybe KeyPackageBundle)
:<|> FedEndpoint "get-not-fully-connected-backends" DomainSet NonConnectedBackends
Expand Down
4 changes: 3 additions & 1 deletion services/federator/src/Federator/ExternalServer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ import System.Logger.Message qualified as Log
import Wire.API.Federation.Component
import Wire.API.Federation.Domain
import Wire.API.Routes.FederationDomainConfig
import Wire.API.VersionInfo

-- | Used to get PEM encoded certificate out of an HTTP header
newtype CertHeader = CertHeader X509.Certificate
Expand Down Expand Up @@ -157,7 +158,8 @@ callInward component (RPC rpc) originDomain (CertHeader cert) wreq = do
let path = LBS.toStrict (toLazyByteString (HTTP.encodePathSegments ["federation", rpc]))

body <- embed $ Wai.lazyRequestBody wreq
resp <- serviceCall component path body validatedDomain
let headers = filter ((== versionHeader) . fst) (Wai.requestHeaders wreq)
resp <- serviceCall component path headers body validatedDomain
Log.debug $
Log.msg ("Inward Request response" :: ByteString)
. Log.field "status" (show (responseStatusCode resp))
Expand Down
6 changes: 4 additions & 2 deletions services/federator/src/Federator/Service.hs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import Federator.Env
import Imports
import Network.HTTP.Client
import Network.HTTP.Types qualified as HTTP
import Network.HTTP.Types.Header
import Polysemy
import Polysemy.Input
import Servant.Client.Core qualified as Servant
Expand All @@ -51,7 +52,7 @@ type ServiceStreaming = Service (SourceT IO ByteString)

data Service body m a where
-- | Returns status, headers and body, 'HTTP.Response' is not nice to work with in tests
ServiceCall :: Component -> ByteString -> LByteString -> Domain -> Service body m (Servant.ResponseF body)
ServiceCall :: Component -> ByteString -> RequestHeaders -> LByteString -> Domain -> Service body m (Servant.ResponseF body)

makeSem ''Service

Expand Down Expand Up @@ -80,7 +81,7 @@ interpretServiceHTTP ::
Sem (ServiceStreaming ': r) a ->
Sem r a
interpretServiceHTTP = interpret $ \case
ServiceCall component rpcPath body domain -> do
ServiceCall component rpcPath headers body domain -> do
Endpoint serviceHost servicePort <- inputs (view service) <*> pure component
manager <- inputs (view httpManager)
reqId <- inputs (view requestId)
Expand All @@ -96,6 +97,7 @@ interpretServiceHTTP = interpret $ \case
(originDomainHeaderName, cs (domainText domain)),
(RPC.requestIdName, RPC.unRequestId reqId)
]
<> headers
}

embed $
Expand Down
14 changes: 8 additions & 6 deletions services/federator/test/unit/Test/Federator/ExternalServer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import Control.Monad.Codensity
import Data.ByteString qualified as BS
import Data.Default
import Data.Domain
import Data.Sequence as Seq
import Data.Text.Encoding qualified as Text
import Federator.Discovery
import Federator.Error.ServerError (ServerError (..))
Expand Down Expand Up @@ -90,6 +91,7 @@ exampleRequest certFile path = do
data Call = Call
{ cComponent :: Component,
cPath :: ByteString,
cHeaders :: RequestHeaders,
cBody :: LByteString,
cDomain :: Domain
}
Expand All @@ -101,12 +103,12 @@ mockService ::
Sem (ServiceStreaming ': r) a ->
Sem r a
mockService status = interpret $ \case
ServiceCall comp path body domain -> do
output (Call comp path body domain)
ServiceCall comp path headers body domain -> do
output (Call comp path headers body domain)
pure
Servant.Response
{ Servant.responseStatusCode = status,
Servant.responseHeaders = mempty,
Servant.responseHeaders = Seq.fromList headers,
Servant.responseHttpVersion = HTTP.http11,
Servant.responseBody = source ["\"bar\""]
}
Expand Down Expand Up @@ -138,7 +140,7 @@ requestBrigSuccess =
. runInputConst noClientCertSettings
. runInputConst scaffoldingFederationDomainConfigs
$ callInward Brig (RPC "get-user-by-handle") aValidDomain (CertHeader cert) request
let expectedCall = Call Brig "/federation/get-user-by-handle" "\"foo\"" aValidDomain
let expectedCall = Call Brig "/federation/get-user-by-handle" [] "\"foo\"" aValidDomain
assertEqual "one call to brig should be made" [expectedCall] actualCalls
Wai.responseStatus res @?= HTTP.status200
body <- Wai.lazyResponseBody res
Expand Down Expand Up @@ -167,7 +169,7 @@ requestBrigFailure =
. runInputConst scaffoldingFederationDomainConfigs
$ callInward Brig (RPC "get-user-by-handle") aValidDomain (CertHeader cert) request

let expectedCall = Call Brig "/federation/get-user-by-handle" "\"foo\"" aValidDomain
let expectedCall = Call Brig "/federation/get-user-by-handle" [] "\"foo\"" aValidDomain
assertEqual "one call to brig should be made" [expectedCall] actualCalls
Wai.responseStatus res @?= HTTP.notFound404
body <- Wai.lazyResponseBody res
Expand Down Expand Up @@ -196,7 +198,7 @@ requestGalleySuccess =
. runInputConst noClientCertSettings
. runInputConst scaffoldingFederationDomainConfigs
$ callInward Galley (RPC "get-conversations") aValidDomain (CertHeader cert) request
let expectedCall = Call Galley "/federation/get-conversations" "\"foo\"" aValidDomain
let expectedCall = Call Galley "/federation/get-conversations" [] "\"foo\"" aValidDomain
embed $ assertEqual "one call to galley should be made" [expectedCall] actualCalls
embed $ Wai.responseStatus res @?= HTTP.status200
body <- embed $ Wai.lazyResponseBody res
Expand Down

0 comments on commit f4e06f9

Please sign in to comment.