Skip to content

Commit

Permalink
Merge pull request linkedin-sna#8 from jhartman/master
Browse files Browse the repository at this point in the history
Clean up the entire channel pool every K requests
  • Loading branch information
ruiwang-linkedin committed Sep 1, 2012
2 parents 4a90fd0 + 5c835ae commit 57625a3
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ package object norbertutils {
def atomicCreateIfAbsent[K, V](map: ConcurrentMap[K, V], key: K)(fn: K => V): V = {
val oldValue = map.get(key)
if(oldValue == null) {
val newValue = fn(key)
map.putIfAbsent(key, newValue)
map.get(key)
map.synchronized {
val newValue = fn(key)
map.putIfAbsent(key, newValue)
map.get(key)
}
} else {
oldValue
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ object NetworkDefaults {
*/
val STALE_REQUEST_TIMEOUT_MINS = 1

/**
* How long to keep a channel alive before we'll toss it away
*/
val CLOSE_CHANNEL_TIMEOUT_MILLIS = 30000L

/**
* The amount of time before a request is considered "timed out" by the processing queue. If for some reason (perhaps a GC), when the request
* is pulled from the queue and has been sitting in the queue for longer than this time, a HeavyLoadException is thrown to the client, signalling a throttle.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ class NetworkClientConfig {
var connectTimeoutMillis = NetworkDefaults.CONNECT_TIMEOUT_MILLIS
var writeTimeoutMillis = NetworkDefaults.WRITE_TIMEOUT_MILLIS
var maxConnectionsPerNode = NetworkDefaults.MAX_CONNECTIONS_PER_NODE

var staleRequestTimeoutMins = NetworkDefaults.STALE_REQUEST_TIMEOUT_MINS
var staleRequestCleanupFrequenceMins = NetworkDefaults.STALE_REQUEST_CLEANUP_FREQUENCY_MINS
var closeChannelTimeMillis = NetworkDefaults.CLOSE_CHANNEL_TIMEOUT_MILLIS

var requestStatisticsWindow = NetworkDefaults.REQUEST_STATISTICS_WINDOW

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,25 @@ case class TimeoutIterator[ResponseMsg](inner: ResponseIterator[ResponseMsg], ti

def next: ResponseMsg = {
val before = System.currentTimeMillis
val res = inner.next(timeLeft.get, TimeUnit.MILLISECONDS)
val time = (System.currentTimeMillis - before).asInstanceOf[Int]

timeLeft.addAndGet(-time)
res
try {
return inner.next(timeLeft.get, TimeUnit.MILLISECONDS)
} finally {
val time = (System.currentTimeMillis - before).asInstanceOf[Int]
timeLeft.addAndGet(-time)
}
}

def next(t: Long, unit: TimeUnit): ResponseMsg = {
val before = System.currentTimeMillis
val methodTimeout = unit.toMillis(t)
val res = inner.next(math.min(methodTimeout, timeLeft.get), TimeUnit.MILLISECONDS)
val time = (System.currentTimeMillis - before).asInstanceOf[Int]

timeLeft.addAndGet(-time)
res
try {
return inner.next(math.min(methodTimeout, timeLeft.get), TimeUnit.MILLISECONDS)
} finally {
val time = (System.currentTimeMillis - before).asInstanceOf[Int]
timeLeft.addAndGet(-time)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@ import jmx.JMX
import logging.Logging
import cluster.{Node, ClusterClient}
import java.util.concurrent.atomic.{AtomicLong, AtomicBoolean, AtomicInteger}
import norbertutils.{SystemClock}
import norbertutils.{Clock, SystemClock}
import java.io.IOException
import common.{BackoffStrategy, SimpleBackoffStrategy}
import java.util

class ChannelPoolClosedException extends Exception

class ChannelPoolFactory(maxConnections: Int, openTimeoutMillis: Int, writeTimeoutMillis: Int, bootstrap: ClientBootstrap, errorStrategy: Option[BackoffStrategy]) {
class ChannelPoolFactory(maxConnections: Int, openTimeoutMillis: Int, writeTimeoutMillis: Int,
bootstrap: ClientBootstrap,
errorStrategy: Option[BackoffStrategy],
closeChannelTimeMillis: Long) {

def newChannelPool(address: InetSocketAddress): ChannelPool = {
val group = new DefaultChannelGroup("norbert-client [%s]".format(address))
Expand All @@ -44,23 +48,35 @@ class ChannelPoolFactory(maxConnections: Int, openTimeoutMillis: Int, writeTimeo
writeTimeoutMillis = writeTimeoutMillis,
bootstrap = bootstrap,
channelGroup = group,
errorStrategy = errorStrategy)
closeChannelTimeMillis = closeChannelTimeMillis,
errorStrategy = errorStrategy,
clock = SystemClock)
}

def shutdown: Unit = {
bootstrap.releaseExternalResources
}
}

class ChannelPool(address: InetSocketAddress, maxConnections: Int, openTimeoutMillis: Int, writeTimeoutMillis: Int, bootstrap: ClientBootstrap,
channelGroup: ChannelGroup, val errorStrategy: Option[BackoffStrategy]) extends Logging {
private val pool = new ArrayBlockingQueue[Channel](maxConnections)
class ChannelPool(address: InetSocketAddress, maxConnections: Int, openTimeoutMillis: Int, writeTimeoutMillis: Int,
bootstrap: ClientBootstrap,
channelGroup: ChannelGroup,
closeChannelTimeMillis: Long,
val errorStrategy: Option[BackoffStrategy],
clock: Clock) extends Logging {

case class PoolEntry(channel: Channel, creationTime: Long) {
def age = System.currentTimeMillis() - creationTime

def isFresh(closeChannelTimeMillis: Long) = closeChannelTimeMillis > 0 && age < closeChannelTimeMillis
}

private val pool = new ArrayBlockingQueue[PoolEntry](maxConnections)
private val waitingWrites = new LinkedBlockingQueue[Request[_, _]]
private val poolSize = new AtomicInteger(0)
private val closed = new AtomicBoolean
private val requestsSent = new AtomicInteger(0)
private var channelBufferRecycleFrequence = 1000
private val channelBufferRecycleCounter = new AtomicInteger(channelBufferRecycleFrequence)
private val lock = new java.util.concurrent.locks.ReentrantReadWriteLock(true)

private val jmxHandle = JMX.register(new MBean(classOf[ChannelPoolMBean], "address=%s,port=%d".format(address.getHostName, address.getPort)) with ChannelPoolMBean {
import scala.math._
Expand All @@ -71,19 +87,15 @@ class ChannelPool(address: InetSocketAddress, maxConnections: Int, openTimeoutMi
def getMaxChannels = maxConnections

def getNumberRequestsSent = requestsSent.get.abs

def getChannelBufferRecycleFrequence = channelBufferRecycleFrequence

def setChannelBufferRecycleFrequence(noReqsPerRecycle: Int) {channelBufferRecycleFrequence = max(noReqsPerRecycle, 10) }
})

def sendRequest[RequestMsg, ResponseMsg](request: Request[RequestMsg, ResponseMsg]): Unit = if (closed.get) {
throw new ChannelPoolClosedException
} else {
checkoutChannel match {
case Some(channel) =>
writeRequestToChannel(request, channel)
checkinChannel(channel)
case Some(poolEntry) =>
writeRequestToChannel(request, poolEntry.channel)
checkinChannel(poolEntry)

case None =>
waitingWrites.offer(request)
Expand All @@ -98,55 +110,48 @@ class ChannelPool(address: InetSocketAddress, maxConnections: Int, openTimeoutMi
}
}

private def checkinChannel(channel: Channel, isFirstWriteToChannel: Boolean = false) {
private def checkinChannel(poolEntry: PoolEntry, isFirstWriteToChannel: Boolean = false) {
while (!waitingWrites.isEmpty) {
waitingWrites.poll match {
case null => // do nothing

case request =>
val timeout = if (isFirstWriteToChannel) writeTimeoutMillis + openTimeoutMillis else writeTimeoutMillis
if((System.currentTimeMillis - request.timestamp) < timeout)
writeRequestToChannel(request, channel)
writeRequestToChannel(request, poolEntry.channel)
else
request.onFailure(new TimeoutException("Timed out while waiting to write"))
}
}

if(!isFirstWriteToChannel && ((channelBufferRecycleCounter.incrementAndGet % channelBufferRecycleFrequence) == 0))
{
val pipeline = channel.getPipeline
try {
pipeline.remove("frameDecoder")
pipeline.addBefore("protobufDecoder", "frameDecoder", new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 4, 0, 4))
pool.offer(channel)
} catch {
case e: Exception => log.warn("error while replacing frameDecoder, discarding channel")
}
} else
{
pool.offer(channel)
}
if(poolEntry.isFresh(closeChannelTimeMillis))
pool.offer(poolEntry)
else
poolEntry.channel.close()
}

private def checkoutChannel: Option[Channel] = {
private def checkoutChannel: Option[PoolEntry] = {
var poolEntry: PoolEntry = null
var found = false
var channel: Channel = null

while (!pool.isEmpty && !found) {
pool.poll match {
case null => // do nothing

case c =>
if (c.isConnected) {
channel = c
found = true
case pe =>
if (pe.channel.isConnected) {
if(pe.isFresh(closeChannelTimeMillis)) {
poolEntry = pe
found = true
} else {
pe.channel.close()
}
} else {
poolSize.decrementAndGet
}
}
}

Option(channel)
Option(poolEntry)
}

private def openChannel(request: Request[_, _]) {
Expand All @@ -163,7 +168,9 @@ class ChannelPool(address: InetSocketAddress, maxConnections: Int, openTimeoutMi
log.debug("Opened a channel to: %s".format(address))

channelGroup.add(channel)
checkinChannel(channel, isFirstWriteToChannel = true)

val poolEntry = PoolEntry(channel, System.currentTimeMillis())
checkinChannel(poolEntry, isFirstWriteToChannel = true)
} else {
log.error(openFuture.getCause, "Error when opening channel to: %s, marking offline".format(address))
errorStrategy.foreach(_.notifyFailure(request.node))
Expand Down Expand Up @@ -199,6 +206,4 @@ trait ChannelPoolMBean {
def getMaxChannels: Int
def getWriteQueueSize: Int
def getNumberRequestsSent: Int
def getChannelBufferRecycleFrequence: Int
def setChannelBufferRecycleFrequence(noReqsPerRecycle: Int): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class ClientChannelHandler(clientName: Option[String],

def shutdown: Unit = {
responseHandler.shutdown
cleanupExecutor.shutdownNow
statsJMX.foreach { JMX.unregister(_) }
serverErrorStrategyJMX.foreach { JMX.unregister(_) }
clientStatsStrategyJMX.foreach { JMX.unregister(_) }
Expand Down Expand Up @@ -237,4 +238,4 @@ class ClientStatisticsRequestStrategyMBeanImpl(clientName: Option[String], servi
def setOutlierConstant(c: Double) = { strategy.outlierConstant = c}

def getTotalNodesMarkedDown = strategy.totalNodesMarkedDown.get.abs
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ abstract class BaseNettyNetworkClient(clientConfig: NetworkClientConfig) extends
openTimeoutMillis = clientConfig.connectTimeoutMillis,
writeTimeoutMillis = clientConfig.writeTimeoutMillis,
bootstrap = bootstrap,
closeChannelTimeMillis = clientConfig.closeChannelTimeMillis,
errorStrategy = Some(channelPoolStrategy))

val clusterIoClient = new NettyClusterIoClient(channelPoolFactory, strategy)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@ import org.jboss.netty.channel.{Channel, ChannelFutureListener, ChannelFuture}
import com.google.protobuf.Message
import java.util.concurrent.{TimeoutException, TimeUnit}
import java.net.InetSocketAddress
import norbertutils.MockClock

class ChannelPoolSpec extends Specification with Mockito {
val channelGroup = mock[ChannelGroup]
val bootstrap = mock[ClientBootstrap]
val address = new InetSocketAddress("localhost", 31313)
val channelPool = new ChannelPool(address, 1, 100, 100, bootstrap, channelGroup, None)

val channelPool = new ChannelPool(address, 1, 100, 100, bootstrap, channelGroup,
closeChannelTimeMillis = 10000, errorStrategy = None, clock = MockClock)

"ChannelPool" should {
"close the ChannelGroup when close is called" in {
Expand Down Expand Up @@ -122,6 +125,29 @@ class ChannelPoolSpec extends Specification with Mockito {
}
}

"open a new channel if a channel has expired" in {
val channel = mock[Channel]
val future = new TestChannelFuture(channel, true)
bootstrap.connect(address) returns future
channelGroup.add(channel) returns true
channel.write(any[Request[_, _]]) returns future

val request = mock[Request[_, _]]
channelPool.sendRequest(request)
future.listener.operationComplete(future)

MockClock.currentTime = 20000L

channelPool.sendRequest(request)
future.listener.operationComplete(future)

got {
two(channelGroup).add(channel)
two(bootstrap).connect(address)
}
}


"write all queued requests" in {
val channel = mock[Channel]
val request = mock[Request[_, _]]
Expand Down

0 comments on commit 57625a3

Please sign in to comment.