diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelMonitor.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelMonitor.java index 0878ce9107..8ccb73c158 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelMonitor.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelMonitor.java @@ -4,6 +4,8 @@ import io.grpc.ConnectivityState; import io.grpc.ManagedChannel; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import lombok.extern.slf4j.Slf4j; @@ -32,10 +34,18 @@ public static void monitorChannelState( ConnectivityState currentState = channel.getState(true); log.info("Channel state changed to: {}", currentState); if (currentState == ConnectivityState.READY) { - onConnectionReady.run(); + if (onConnectionReady != null) { + onConnectionReady.run(); + } else { + log.debug("onConnectionReady is null"); + } } else if (currentState == ConnectivityState.TRANSIENT_FAILURE || currentState == ConnectivityState.SHUTDOWN) { - onConnectionLost.run(); + if (onConnectionLost != null) { + onConnectionLost.run(); + } else { + log.debug("onConnectionLost is null"); + } } // Re-register the state monitor to watch for the next state transition. monitorChannelState(currentState, channel, onConnectionReady, onConnectionLost); @@ -43,54 +53,39 @@ public static void monitorChannelState( } /** - * Waits for the channel to reach a desired state within a specified timeout period. + * Waits for the channel to reach the desired connectivity state within the specified timeout. * - * @param channel the ManagedChannel to monitor. - * @param desiredState the ConnectivityState to wait for. - * @param connectCallback callback invoked when the desired state is reached. - * @param timeout the maximum amount of time to wait. - * @param unit the time unit of the timeout. - * @throws InterruptedException if the current thread is interrupted while waiting. + * @param desiredState the desired {@link ConnectivityState} to wait for + * @param channel the {@link ManagedChannel} to monitor + * @param connectCallback the {@link Runnable} to execute when the desired state is reached + * @param timeout the maximum time to wait + * @param unit the time unit of the timeout argument + * @throws InterruptedException if the current thread is interrupted while waiting + * @throws GeneralError if the desired state is not reached within the timeout */ public static void waitForDesiredState( - ManagedChannel channel, ConnectivityState desiredState, - Runnable connectCallback, - long timeout, - TimeUnit unit) - throws InterruptedException { - waitForDesiredState(channel, desiredState, connectCallback, new CountDownLatch(1), timeout, unit); - } - - private static void waitForDesiredState( ManagedChannel channel, - ConnectivityState desiredState, Runnable connectCallback, - CountDownLatch latch, long timeout, TimeUnit unit) throws InterruptedException { - channel.notifyWhenStateChanged(ConnectivityState.SHUTDOWN, () -> { - try { - ConnectivityState state = channel.getState(true); - log.debug("Channel state changed to: {}", state); + CountDownLatch latch = new CountDownLatch(1); - if (state == desiredState) { - connectCallback.run(); - latch.countDown(); - return; - } - waitForDesiredState(channel, desiredState, connectCallback, latch, timeout, unit); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - log.error("Thread interrupted while waiting for desired state", e); - } catch (Exception e) { - log.error("Error occurred while waiting for desired state", e); + Runnable waitForStateTask = () -> { + ConnectivityState currentState = channel.getState(true); + if (currentState == desiredState) { + connectCallback.run(); + latch.countDown(); } - }); + }; + + ScheduledFuture scheduledFuture = Executors.newSingleThreadScheduledExecutor() + .scheduleWithFixedDelay(waitForStateTask, 0, 100, TimeUnit.MILLISECONDS); - // Await the latch or timeout for the state change - if (!latch.await(timeout, unit)) { + boolean success = latch.await(timeout, unit); + scheduledFuture.cancel(true); + if (!success) { throw new GeneralError(String.format( "Deadline exceeded. Condition did not complete within the %d " + "deadline", timeout)); } diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/GrpcConnector.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/GrpcConnector.java index 9f1c2fe6d3..d5ca69aff3 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/GrpcConnector.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/GrpcConnector.java @@ -137,7 +137,7 @@ public GrpcConnector( public void initialize() throws Exception { log.info("Initializing GRPC connection..."); ChannelMonitor.waitForDesiredState( - channel, ConnectivityState.READY, this::onInitialConnect, deadline, TimeUnit.MILLISECONDS); + ConnectivityState.READY, channel, this::onInitialConnect, deadline, TimeUnit.MILLISECONDS); ChannelMonitor.monitorChannelState(ConnectivityState.READY, channel, this::onReady, this::onConnectionLost); } diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelMonitorTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelMonitorTest.java new file mode 100644 index 0000000000..539ef8d861 --- /dev/null +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelMonitorTest.java @@ -0,0 +1,90 @@ +package dev.openfeature.contrib.providers.flagd.resolver.common; + +import static dev.openfeature.contrib.providers.flagd.resolver.common.ChannelMonitor.monitorChannelState; +import static dev.openfeature.contrib.providers.flagd.resolver.common.ChannelMonitor.waitForDesiredState; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import dev.openfeature.sdk.exceptions.GeneralError; +import io.grpc.ConnectivityState; +import io.grpc.ManagedChannel; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +class ChannelMonitorTest { + @Test + void testWaitForDesiredState() throws InterruptedException { + ManagedChannel channel = mock(ManagedChannel.class); + Runnable connectCallback = mock(Runnable.class); + + // Set up the desired state + ConnectivityState desiredState = ConnectivityState.READY; + when(channel.getState(anyBoolean())).thenReturn(desiredState); + + // Call the method + waitForDesiredState(desiredState, channel, connectCallback, 1, TimeUnit.SECONDS); + + // Verify that the callback was run + verify(connectCallback, times(1)).run(); + } + + @Test + void testWaitForDesiredStateTimeout() { + ManagedChannel channel = Mockito.mock(ManagedChannel.class); + Runnable connectCallback = mock(Runnable.class); + + // Set up the desired state + ConnectivityState desiredState = ConnectivityState.READY; + when(channel.getState(anyBoolean())).thenReturn(ConnectivityState.IDLE); + + // Call the method and expect a timeout + assertThrows(GeneralError.class, () -> { + waitForDesiredState(desiredState, channel, connectCallback, 1, TimeUnit.SECONDS); + }); + } + + @ParameterizedTest + @EnumSource(ConnectivityState.class) + void testMonitorChannelState(ConnectivityState state) { + ManagedChannel channel = Mockito.mock(ManagedChannel.class); + Runnable onConnectionReady = mock(Runnable.class); + Runnable onConnectionLost = mock(Runnable.class); + + // Set up the expected state + ConnectivityState expectedState = ConnectivityState.IDLE; + when(channel.getState(anyBoolean())).thenReturn(state); + + // Capture the callback + ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(Runnable.class); + doNothing().when(channel).notifyWhenStateChanged(eq(expectedState), callbackCaptor.capture()); + + // Call the method + monitorChannelState(expectedState, channel, onConnectionReady, onConnectionLost); + + // Simulate state change + callbackCaptor.getValue().run(); + + // Verify the callbacks based on the state + if (state == ConnectivityState.READY) { + verify(onConnectionReady, times(1)).run(); + verify(onConnectionLost, never()).run(); + } else if (state == ConnectivityState.TRANSIENT_FAILURE || state == ConnectivityState.SHUTDOWN) { + verify(onConnectionReady, never()).run(); + verify(onConnectionLost, times(1)).run(); + } else { + verify(onConnectionReady, never()).run(); + verify(onConnectionLost, never()).run(); + } + } +}