diff --git a/client/eth/connection_pool.go b/client/eth/connection_pool.go index 3881a2ec..268c4005 100644 --- a/client/eth/connection_pool.go +++ b/client/eth/connection_pool.go @@ -36,17 +36,18 @@ func NewConnectionPoolImpl(cfg ConnectionPoolConfig, logger log.Logger) (Connect cfg.HealthCheckInterval = defaultHealthCheckInterval } - cache, err := lru.NewWithEvict( - len(cfg.EthHTTPURLs), func(_ string, v *HealthCheckedClient) { - defer v.Close() - // The timeout is added so that any in progress - // requests have a chance to complete before we close. - time.Sleep(cfg.DefaultTimeout) - }) - if err != nil { - return nil, err + var ( + cache *lru.Cache[string, *HealthCheckedClient] + wsCache *lru.Cache[string, *HealthCheckedClient] + err error + ) + + // The LRU cache needs at least one URL provided for HTTP. + if len(cfg.EthHTTPURLs) == 0 { + return nil, fmt.Errorf("ConnectionPool: missing URL for HTTP clients") } - wsCache, err := lru.NewWithEvict( + + cache, err = lru.NewWithEvict( len(cfg.EthHTTPURLs), func(_ string, v *HealthCheckedClient) { defer v.Close() // The timeout is added so that any in progress @@ -57,6 +58,21 @@ func NewConnectionPoolImpl(cfg ConnectionPoolConfig, logger log.Logger) (Connect return nil, err } + if len(cfg.EthWSURLs) == 0 { + logger.Warn("ConnectionPool: missing URL for WS clients") + } else { + wsCache, err = lru.NewWithEvict( + len(cfg.EthWSURLs), func(_ string, v *HealthCheckedClient) { + defer v.Close() + // The timeout is added so that any in progress + // requests have a chance to complete before we close. + time.Sleep(cfg.DefaultTimeout) + }) + if err != nil { + return nil, err + } + } + return &ConnectionPoolImpl{ cache: cache, wsCache: wsCache, @@ -68,6 +84,11 @@ func NewConnectionPoolImpl(cfg ConnectionPoolConfig, logger log.Logger) (Connect func (c *ConnectionPoolImpl) Close() error { c.mutex.Lock() defer c.mutex.Unlock() + + if c.cache == nil { + return nil + } + for _, client := range c.cache.Keys() { if err := c.removeClient(client); err != nil { return err @@ -81,6 +102,8 @@ func (c *ConnectionPoolImpl) Dial(string) error { } func (c *ConnectionPoolImpl) DialContext(ctx context.Context, _ string) error { + // NOTE: Check the cache for the HTTP URL is not needed because it + // is guaranteed to be non-nil when a ConnectionPoolImpl is created. for _, url := range c.config.EthHTTPURLs { client := NewHealthCheckedClient(c.config.HealthCheckInterval, c.logger) if err := client.DialContextWithTimeout(ctx, url, c.config.DefaultTimeout); err != nil { @@ -88,6 +111,12 @@ func (c *ConnectionPoolImpl) DialContext(ctx context.Context, _ string) error { } c.cache.Add(url, client) } + + // Check is needed because the WS URL is optional. + if c.wsCache == nil { + return nil + } + for _, url := range c.config.EthWSURLs { client := NewHealthCheckedClient(c.config.HealthCheckInterval, c.logger) if err := client.DialContextWithTimeout(ctx, url, c.config.DefaultTimeout); err != nil { @@ -98,22 +127,32 @@ func (c *ConnectionPoolImpl) DialContext(ctx context.Context, _ string) error { return nil } +// NOTE: this function assumes the cache is non-nil +// because it is guaranteed to be non-nil when a ConnectionPoolImpl is created. func (c *ConnectionPoolImpl) GetHTTP() (Client, bool) { c.mutex.Lock() defer c.mutex.Unlock() -retry: - _, client, ok := c.cache.GetOldest() - if !client.Health() { - goto retry - } - return client, ok + + return c.getClientFrom(c.cache) } func (c *ConnectionPoolImpl) GetWS() (Client, bool) { c.mutex.Lock() defer c.mutex.Unlock() + + // Because the WS URL is optional, we need to check if it's nil. + if c.wsCache == nil { + return nil, false + } + return c.getClientFrom(c.wsCache) +} + +// NOTE: this function assumes the lock is held and cache is non-nil. +func (c *ConnectionPoolImpl) getClientFrom( + cache *lru.Cache[string, *HealthCheckedClient], +) (Client, bool) { retry: - _, client, ok := c.wsCache.GetOldest() + _, client, ok := cache.GetOldest() if !client.Health() { goto retry } diff --git a/client/eth/connection_pool_test.go b/client/eth/connection_pool_test.go new file mode 100644 index 00000000..1921bea3 --- /dev/null +++ b/client/eth/connection_pool_test.go @@ -0,0 +1,134 @@ +package eth_test + +import ( + "bytes" + "io" + "os" + "testing" + + "github.com/berachain/offchain-sdk/client/eth" + "github.com/berachain/offchain-sdk/log" + "github.com/stretchr/testify/require" +) + +var ( + HTTPURL = os.Getenv("ETH_HTTP_URL") + WSURL = os.Getenv("ETH_WS_URL") +) + +/******************************* HELPER FUNCTIONS ***************************************/ + +// NOTE: requires chain rpc url at env var `ETH_HTTP_URL` and `ETH_WS_URL`. +func checkEnv(t *testing.T) { + ethHTTPRPC := os.Getenv("ETH_HTTP_URL") + ethWSRPC := os.Getenv("ETH_WS_URL") + if ethHTTPRPC == "" || ethWSRPC == "" { + t.Skipf("Skipping test: no eth rpc url provided") + } +} + +// initConnectionPool initializes a new connection pool. +func initConnectionPool( + cfg eth.ConnectionPoolConfig, writer io.Writer, +) (eth.ConnectionPool, error) { + logger := log.NewLogger(writer, "test-runner") + return eth.NewConnectionPoolImpl(cfg, logger) +} + +// Use Init function as a setup function for the tests. +// It means each test will have to call Init function to set up the test. +func Init( + cfg eth.ConnectionPoolConfig, writer io.Writer, t *testing.T, +) (eth.ConnectionPool, error) { + checkEnv(t) + return initConnectionPool(cfg, writer) +} + +/******************************* TEST CASES ***************************************/ + +// TestNewConnectionPoolImpl_MissingURLs tests the case when the URLs are missing. +func TestNewConnectionPoolImpl_MissingURLs(t *testing.T) { + cfg := eth.ConnectionPoolConfig{} + var logBuffer bytes.Buffer + + _, err := Init(cfg, &logBuffer, t) + require.ErrorContains(t, err, "ConnectionPool: missing URL for HTTP clients") +} + +// TestNewConnectionPoolImpl_MissingWSURLs tests the case when the WS URLs are missing. +func TestNewConnectionPoolImpl_MissingWSURLs(t *testing.T) { + cfg := eth.ConnectionPoolConfig{ + EthHTTPURLs: []string{HTTPURL}, + } + var logBuffer bytes.Buffer + pool, err := Init(cfg, &logBuffer, t) + + require.NoError(t, err) + require.NotNil(t, pool) + require.Contains(t, logBuffer.String(), "ConnectionPool: missing URL for WS clients") +} + +// TestNewConnectionPoolImpl tests the case when the URLs are provided. +// It should the expected behavior. +func TestNewConnectionPoolImpl(t *testing.T) { + cfg := eth.ConnectionPoolConfig{ + EthHTTPURLs: []string{HTTPURL}, + EthWSURLs: []string{WSURL}, + } + var logBuffer bytes.Buffer + pool, err := Init(cfg, &logBuffer, t) + + require.NoError(t, err) + require.NotNil(t, pool) + require.Empty(t, logBuffer.String()) +} + +// TestGetHTTP tests the retrieval of the HTTP client when it +// has been set and the connection has been established. +func TestGetHTTP(t *testing.T) { + cfg := eth.ConnectionPoolConfig{ + EthHTTPURLs: []string{HTTPURL}, + } + var logBuffer bytes.Buffer + pool, _ := Init(cfg, &logBuffer, t) + err := pool.Dial("") + require.NoError(t, err) + + client, ok := pool.GetHTTP() + require.True(t, ok) + require.NotNil(t, client) +} + +// TestGetWS tests the retrieval of the HTTP client when it +// has been set and the connection has been established. +func TestGetWS(t *testing.T) { + cfg := eth.ConnectionPoolConfig{ + EthHTTPURLs: []string{HTTPURL}, + EthWSURLs: []string{WSURL}, + } + var logBuffer bytes.Buffer + pool, _ := Init(cfg, &logBuffer, t) + err := pool.Dial("") + + require.NoError(t, err) + + client, ok := pool.GetWS() + require.True(t, ok) + require.NotNil(t, client) +} + +// TestGetWS_WhenItIsNotSet tests the retrieval of the WS client when +// no WS URLs have been provided. +func TestGetWS_WhenItIsNotSet(t *testing.T) { + cfg := eth.ConnectionPoolConfig{ + EthHTTPURLs: []string{HTTPURL}, + } + var logBuffer bytes.Buffer + pool, _ := Init(cfg, &logBuffer, t) + err := pool.Dial("") + require.NoError(t, err) + + client, ok := pool.GetWS() + require.False(t, ok) + require.Nil(t, client) +} diff --git a/examples/listener/config.toml b/examples/listener/config.toml index 26573255..642db720 100644 --- a/examples/listener/config.toml +++ b/examples/listener/config.toml @@ -19,7 +19,7 @@ Enabled = true Namespace = "example" Subsystem = "listener_app" -[App.ConnectionPool] +[ConnectionPool] EthHTTPURLs = ["http://localhost:10545"] EthWSURLs = ["ws://localhost:10546"] DefaultTimeout = "5s"