Skip to content

Commit

Permalink
Avoid allocations in TLSEngine when logging is disabled (#2462)
Browse files Browse the repository at this point in the history
* Avoid allocations in TLSEngine when logging is disabled

* Scalafmt

* Scalafmt

* Fix 2.12 build

* Mima exclusions

* Scalafmt
  • Loading branch information
mpilquist authored Jul 3, 2021
1 parent 516f33d commit a502dd1
Show file tree
Hide file tree
Showing 9 changed files with 191 additions and 89 deletions.
11 changes: 10 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,16 @@ ThisBuild / mimaBinaryIssueFilters ++= Seq(
ProblemFilters.exclude[MissingClassProblem]("fs2.Pull$CloseScope$"),
ProblemFilters.exclude[ReversedAbstractMethodProblem]("fs2.Pull#CloseScope.*"),
ProblemFilters.exclude[Problem]("fs2.io.Watcher#Registration.*"),
ProblemFilters.exclude[Problem]("fs2.io.Watcher#DefaultWatcher.*")
ProblemFilters.exclude[Problem]("fs2.io.Watcher#DefaultWatcher.*"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("fs2.io.net.tls.TLSContext.clientBuilder"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("fs2.io.net.tls.TLSContext.serverBuilder"),
ProblemFilters.exclude[ReversedMissingMethodProblem](
"fs2.io.net.tls.TLSContext.dtlsClientBuilder"
),
ProblemFilters.exclude[ReversedMissingMethodProblem](
"fs2.io.net.tls.TLSContext.dtlsServerBuilder"
),
ProblemFilters.exclude[Problem]("fs2.io.net.tls.TLSEngine*")
)

lazy val root = project
Expand Down
3 changes: 1 addition & 2 deletions core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import scala.concurrent.duration._
import scala.concurrent.TimeoutException
import cats.effect.{IO, SyncIO}
import cats.effect.kernel.Ref
import cats.effect.std.Queue
import cats.effect.std.Semaphore
import cats.syntax.all._
import org.scalacheck.Gen
Expand Down Expand Up @@ -1316,7 +1315,7 @@ class StreamCombinatorsSuite extends Fs2Suite {
val action =
Vector.fill(streamSize)(Deferred[IO, Unit]).sequence.map { seenArr =>
def peek(ind: Int)(f: Option[Unit] => Boolean) =
seenArr.get(ind).fold(true.pure[IO])(_.tryGet.map(f))
seenArr.get(ind.toLong).fold(true.pure[IO])(_.tryGet.map(f))

Stream
.emits(0 until streamSize)
Expand Down
187 changes: 118 additions & 69 deletions io/src/main/scala/fs2/io/net/tls/TLSContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ import javax.net.ssl.{
SSLContext,
SSLEngine,
TrustManagerFactory,
X509ExtendedTrustManager,
X509TrustManager
X509ExtendedTrustManager
}
import cats.Applicative
import cats.effect.kernel.{Async, Resource}
Expand All @@ -47,43 +46,85 @@ import java.util.function.BiFunction
*/
sealed trait TLSContext[F[_]] {

/** Creates a `TLSSocket` in client mode, using the supplied parameters.
* Internal debug logging of the session can be enabled by passing a logger.
*/
/** Creates a `TLSSocket` builder in client mode. */
def client(socket: Socket[F]): Resource[F, TLSSocket[F]] =
clientBuilder(socket).build

/** Creates a `TLSSocket` builder in client mode, allowing optional parameters to be configured. */
def clientBuilder(socket: Socket[F]): TLSContext.SocketBuilder[F, TLSSocket]

@deprecated("Use client(socket) or clientBuilder(socket).with(...).build", "3.0.6")
def client(
socket: Socket[F],
params: TLSParameters = TLSParameters.Default,
logger: Option[String => F[Unit]] = None
): Resource[F, TLSSocket[F]]
): Resource[F, TLSSocket[F]] =
clientBuilder(socket).withParameters(params).withOldLogging(logger).build

/** Creates a `TLSSocket` builder in server mode. */
def server(socket: Socket[F]): Resource[F, TLSSocket[F]] =
serverBuilder(socket).build

/** Creates a `TLSSocket` in server mode, using the supplied parameters.
* Internal debug logging of the session can be enabled by passing a logger.
*/
/** Creates a `TLSSocket` builder in server mode, allowing optional parameters to be configured. */
def serverBuilder(socket: Socket[F]): TLSContext.SocketBuilder[F, TLSSocket]

@deprecated("Use server(socket) or serverBuilder(socket).with(...).build", "3.0.6")
def server(
socket: Socket[F],
params: TLSParameters = TLSParameters.Default,
logger: Option[String => F[Unit]] = None
): Resource[F, TLSSocket[F]]
): Resource[F, TLSSocket[F]] =
serverBuilder(socket).withParameters(params).withOldLogging(logger).build

/** Creates a `DTLSSocket` builder in client mode. */
def dtlsClient(
socket: DatagramSocket[F],
remoteAddress: SocketAddress[IpAddress]
): Resource[F, DTLSSocket[F]] =
dtlsClientBuilder(socket, remoteAddress).build

/** Creates a `DTLSSocket` builder in client mode, allowing optional parameters to be configured. */
def dtlsClientBuilder(
socket: DatagramSocket[F],
remoteAddress: SocketAddress[IpAddress]
): TLSContext.SocketBuilder[F, DTLSSocket]

/** Creates a `DTLSSocket` in client mode, using the supplied parameters.
* Internal debug logging of the session can be enabled by passing a logger.
*/
@deprecated(
"Use dtlsClient(socket, remoteAddress) or dtlsClientBuilder(socket, remoteAddress).with(...).build",
"3.0.6"
)
def dtlsClient(
socket: DatagramSocket[F],
remoteAddress: SocketAddress[IpAddress],
params: TLSParameters = TLSParameters.Default,
logger: Option[String => F[Unit]] = None
): Resource[F, DTLSSocket[F]]
): Resource[F, DTLSSocket[F]] =
dtlsClientBuilder(socket, remoteAddress).withParameters(params).withOldLogging(logger).build

/** Creates a `DTLSSocket` in server mode, using the supplied parameters.
* Internal debug logging of the session can be enabled by passing a logger.
*/
/** Creates a `DTLSSocket` builder in server mode. */
def dtlsServer(
socket: DatagramSocket[F],
remoteAddress: SocketAddress[IpAddress]
): Resource[F, DTLSSocket[F]] =
dtlsServerBuilder(socket, remoteAddress).build

/** Creates a `DTLSSocket` builder in client mode, allowing optional parameters to be configured. */
def dtlsServerBuilder(
socket: DatagramSocket[F],
remoteAddress: SocketAddress[IpAddress]
): TLSContext.SocketBuilder[F, DTLSSocket]

@deprecated(
"Use dtlsServer(socket, remoteAddress) or dtlsClientBuilder(socket, remoteAddress).with(...).build",
"3.0.6"
)
def dtlsServer(
socket: DatagramSocket[F],
remoteAddress: SocketAddress[IpAddress],
params: TLSParameters = TLSParameters.Default,
logger: Option[String => F[Unit]] = None
): Resource[F, DTLSSocket[F]]
): Resource[F, DTLSSocket[F]] =
dtlsServerBuilder(socket, remoteAddress).withParameters(params).withOldLogging(logger).build
}

object TLSContext {
Expand Down Expand Up @@ -128,35 +169,17 @@ object TLSContext {
ctx: SSLContext
): TLSContext[F] =
new TLSContext[F] {
def client(
socket: Socket[F],
params: TLSParameters,
logger: Option[String => F[Unit]]
): Resource[F, TLSSocket[F]] =
mkSocket(
socket,
true,
params,
logger
)

def server(
socket: Socket[F],
params: TLSParameters,
logger: Option[String => F[Unit]]
): Resource[F, TLSSocket[F]] =
mkSocket(
socket,
false,
params,
logger
)
def clientBuilder(socket: Socket[F]) =
SocketBuilder((p, l) => mkSocket(socket, true, p, l))

def serverBuilder(socket: Socket[F]) =
SocketBuilder((p, l) => mkSocket(socket, false, p, l))

private def mkSocket(
socket: Socket[F],
clientMode: Boolean,
params: TLSParameters,
logger: Option[String => F[Unit]]
logger: TLSLogger[F]
): Resource[F, TLSSocket[F]] =
Resource
.eval(
Expand All @@ -174,40 +197,24 @@ object TLSContext {
)
.flatMap(engine => TLSSocket(socket, engine))

def dtlsClient(
def dtlsClientBuilder(
socket: DatagramSocket[F],
remoteAddress: SocketAddress[IpAddress],
params: TLSParameters,
logger: Option[String => F[Unit]]
): Resource[F, DTLSSocket[F]] =
mkDtlsSocket(
socket,
remoteAddress,
true,
params,
logger
)

def dtlsServer(
remoteAddress: SocketAddress[IpAddress]
) =
SocketBuilder((p, l) => mkDtlsSocket(socket, remoteAddress, true, p, l))

def dtlsServerBuilder(
socket: DatagramSocket[F],
remoteAddress: SocketAddress[IpAddress],
params: TLSParameters,
logger: Option[String => F[Unit]]
): Resource[F, DTLSSocket[F]] =
mkDtlsSocket(
socket,
remoteAddress,
false,
params,
logger
)
remoteAddress: SocketAddress[IpAddress]
) =
SocketBuilder((p, l) => mkDtlsSocket(socket, remoteAddress, false, p, l))

private def mkDtlsSocket(
socket: DatagramSocket[F],
remoteAddress: SocketAddress[IpAddress],
clientMode: Boolean,
params: TLSParameters,
logger: Option[String => F[Unit]]
logger: TLSLogger[F]
): Resource[F, DTLSSocket[F]] =
Resource
.eval(
Expand All @@ -230,7 +237,7 @@ object TLSContext {
binding: TLSEngine.Binding[F],
clientMode: Boolean,
params: TLSParameters,
logger: Option[String => F[Unit]]
logger: TLSLogger[F]
): F[TLSEngine[F]] = {
val sslEngine = Async[F].blocking {
val engine = ctx.createSSLEngine()
Expand Down Expand Up @@ -345,4 +352,46 @@ object TLSContext {
.map(fromSSLContext(_))
}
}

sealed trait SocketBuilder[F[_], S[_[_]]] {
def withParameters(params: TLSParameters): SocketBuilder[F, S]
def withLogging(log: (=> String) => F[Unit]): SocketBuilder[F, S]
def withoutLogging: SocketBuilder[F, S]
def withLogger(logger: TLSLogger[F]): SocketBuilder[F, S]
private[TLSContext] def withOldLogging(log: Option[String => F[Unit]]): SocketBuilder[F, S]
def build: Resource[F, S[F]]
}

object SocketBuilder {
private[TLSContext] type Build[F[_], S[_[_]]] =
(TLSParameters, TLSLogger[F]) => Resource[F, S[F]]

private[TLSContext] def apply[F[_], S[_[_]]](
mkSocket: Build[F, S]
): SocketBuilder[F, S] =
instance(mkSocket, TLSParameters.Default, TLSLogger.Disabled)

private def instance[F[_], S[_[_]]](
mkSocket: Build[F, S],
params: TLSParameters,
logger: TLSLogger[F]
): SocketBuilder[F, S] =
new SocketBuilder[F, S] {
def withParameters(params: TLSParameters): SocketBuilder[F, S] =
instance(mkSocket, params, logger)
def withLogging(log: (=> String) => F[Unit]): SocketBuilder[F, S] =
withLogger(TLSLogger.Enabled(log))
def withoutLogging: SocketBuilder[F, S] =
withLogger(TLSLogger.Disabled)
def withLogger(logger: TLSLogger[F]): SocketBuilder[F, S] =
instance(mkSocket, params, logger)
private[TLSContext] def withOldLogging(
log: Option[String => F[Unit]]
): SocketBuilder[F, S] =
log.map(f => withLogging(m => f(m))).getOrElse(withoutLogging)
def build: Resource[F, S[F]] =
mkSocket(params, logger)
}
}

}
11 changes: 8 additions & 3 deletions io/src/main/scala/fs2/io/net/tls/TLSEngine.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ private[tls] object TLSEngine {
def apply[F[_]: Async](
engine: SSLEngine,
binding: Binding[F],
logger: Option[String => F[Unit]] = None
logger: TLSLogger[F]
): F[TLSEngine[F]] =
for {
wrapBuffer <- InputOutputBuffer[F](
Expand All @@ -70,8 +70,13 @@ private[tls] object TLSEngine {
handshakeSemaphore <- Semaphore[F](1)
sslEngineTaskRunner = SSLEngineTaskRunner[F](engine)
} yield new TLSEngine[F] {
private def log(msg: String): F[Unit] =
logger.map(_(msg)).getOrElse(Applicative[F].unit)
private val doLog: (() => String) => F[Unit] =
logger match {
case e: TLSLogger.Enabled[_] => msg => e.log(msg())
case TLSLogger.Disabled => _ => Applicative[F].unit
}

private def log(msg: => String): F[Unit] = doLog(() => msg)

def beginHandshake = Sync[F].delay(engine.beginHandshake())
def session = Sync[F].delay(engine.getSession())
Expand Down
29 changes: 29 additions & 0 deletions io/src/main/scala/fs2/io/net/tls/TLSLogger.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) 2013 Functional Streams for Scala
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of
* this software and associated documentation files (the "Software"), to deal in
* the Software without restriction, including without limitation the rights to
* use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
* the Software, and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/

package fs2.io.net.tls

sealed trait TLSLogger[+F[_]]

object TLSLogger {
case object Disabled extends TLSLogger[Nothing]
case class Enabled[F[_]](log: (=> String) => F[Unit]) extends TLSLogger[F]
}
10 changes: 8 additions & 2 deletions io/src/test/scala/fs2/io/net/tls/DTLSSocketSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,14 @@ class DTLSSocketSuite extends TLSSuite {
serverAddress <- address(serverSocket)
clientSocket <- Network[IO].openDatagramSocket()
clientAddress <- address(clientSocket)
tlsServerSocket <- tlsContext.dtlsServer(serverSocket, clientAddress, logger = logger)
tlsClientSocket <- tlsContext.dtlsClient(clientSocket, serverAddress, logger = logger)
tlsServerSocket <- tlsContext
.dtlsServerBuilder(serverSocket, clientAddress)
.withLogger(logger)
.build
tlsClientSocket <- tlsContext
.dtlsClientBuilder(clientSocket, serverAddress)
.withLogger(logger)
.build
} yield (tlsServerSocket, tlsClientSocket, serverAddress)

Stream
Expand Down
5 changes: 3 additions & 2 deletions io/src/test/scala/fs2/io/net/tls/TLSDebugExample.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ object TLSDebug {
host.resolve.flatMap { socketAddress =>
Network[F].client(socketAddress).use { rawSocket =>
tlsContext
.client(
rawSocket,
.clientBuilder(rawSocket)
.withParameters(
TLSParameters(serverNames = Some(List(new SNIHostName(host.host.toString))))
)
.build
.use { tlsSocket =>
tlsSocket.write(Chunk.empty) >>
tlsSocket.session.map { session =>
Expand Down
Loading

0 comments on commit a502dd1

Please sign in to comment.