From 6c2e058b05ec70899c16ef1548b52d773cea5003 Mon Sep 17 00:00:00 2001 From: Bernd Warmuth Date: Wed, 18 Dec 2024 12:49:34 +0100 Subject: [PATCH] feat: refactor GrpcConnector to use grpc builtin reconnection Signed-off-by: Bernd Warmuth --- providers/flagd/pom.xml | 1 - .../contrib/providers/flagd/Config.java | 1 - .../contrib/providers/flagd/FlagdOptions.java | 15 +- .../providers/flagd/FlagdProvider.java | 63 +- .../flagd/resolver/common/ChannelMonitor.java | 169 ++++++ .../resolver/common/ConnectionEvent.java | 113 +++- .../resolver/common/ConnectionState.java | 27 + .../providers/flagd/resolver/common/Util.java | 31 +- .../resolver/grpc/EventStreamObserver.java | 92 +-- .../flagd/resolver/grpc/GrpcConnector.java | 298 +++++---- .../flagd/resolver/grpc/GrpcResolver.java | 77 ++- .../resolver/process/InProcessResolver.java | 24 +- .../providers/flagd/FlagdOptionsTest.java | 4 - .../providers/flagd/FlagdProviderTest.java | 394 ++---------- .../e2e/RunFlagdRpcReconnectCucumberTest.java | 3 +- .../e2e/reconnect/rpc/FlagdRpcSetup.java | 1 + .../resolver/common/ChannelBuilderTest.java | 159 +++++ .../grpc/EventStreamObserverTest.java | 66 +- .../resolver/grpc/GrpcConnectorTest.java | 566 ++++-------------- 19 files changed, 959 insertions(+), 1145 deletions(-) create mode 100644 providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelMonitor.java create mode 100644 providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ConnectionState.java create mode 100644 providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelBuilderTest.java diff --git a/providers/flagd/pom.xml b/providers/flagd/pom.xml index 33cd9327e1..b51581fffc 100644 --- a/providers/flagd/pom.xml +++ b/providers/flagd/pom.xml @@ -399,5 +399,4 @@ - diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/Config.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/Config.java index 139403b685..94b28c260e 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/Config.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/Config.java @@ -57,7 +57,6 @@ public final class Config { public static final String LRU_CACHE = CacheType.LRU.getValue(); static final String DEFAULT_CACHE = LRU_CACHE; - static final int DEFAULT_MAX_EVENT_STREAM_RETRIES = 7; static final int BASE_EVENT_STREAM_RETRY_BACKOFF_MS = 1000; static String fallBackToEnvOrDefault(String key, String defaultValue) { diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdOptions.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdOptions.java index 6c1036bd13..ad0848d7c0 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdOptions.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdOptions.java @@ -69,13 +69,6 @@ public class FlagdOptions { private int maxCacheSize = fallBackToEnvOrDefault(Config.MAX_CACHE_SIZE_ENV_VAR_NAME, Config.DEFAULT_MAX_CACHE_SIZE); - /** - * Max event stream connection retries. - */ - @Builder.Default - private int maxEventStreamRetries = fallBackToEnvOrDefault(Config.MAX_EVENT_STREAM_RETRIES_ENV_VAR_NAME, - Config.DEFAULT_MAX_EVENT_STREAM_RETRIES); - /** * Backoff interval in milliseconds. */ @@ -102,11 +95,12 @@ public class FlagdOptions { Config.DEFAULT_STREAM_DEADLINE_MS); /** - * Amount of stream retry attempts before provider moves from STALE to ERROR - * Defaults to 5 + * Grace time period in milliseconds before provider moves from STALE to ERROR. + * Defaults to 50_000 */ @Builder.Default - private int streamRetryGracePeriod = fallBackToEnvOrDefault(Config.STREAM_RETRY_GRACE_PERIOD, Config.DEFAULT_STREAM_RETRY_GRACE_PERIOD); + private int streamRetryGracePeriod = fallBackToEnvOrDefault(Config.STREAM_RETRY_GRACE_PERIOD, + Config.DEFAULT_STREAM_RETRY_GRACE_PERIOD); /** * Selector to be used with flag sync gRPC contract. **/ @@ -116,7 +110,6 @@ public class FlagdOptions { /** * gRPC client KeepAlive in milliseconds. Disabled with 0. * Defaults to 0 (disabled). - * **/ @Builder.Default private long keepAlive = fallBackToEnvOrDefault(Config.KEEP_ALIVE_MS_ENV_VAR_NAME, diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProvider.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProvider.java index 7b451ec91a..771b3624b5 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProvider.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProvider.java @@ -1,10 +1,5 @@ package dev.openfeature.contrib.providers.flagd; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.function.Function; - import dev.openfeature.contrib.providers.flagd.resolver.Resolver; import dev.openfeature.contrib.providers.flagd.resolver.common.ConnectionEvent; import dev.openfeature.contrib.providers.flagd.resolver.grpc.GrpcResolver; @@ -22,11 +17,16 @@ import dev.openfeature.sdk.Value; import lombok.extern.slf4j.Slf4j; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Function; + /** * OpenFeature provider for flagd. */ @Slf4j -@SuppressWarnings({ "PMD.TooManyStaticImports", "checkstyle:NoFinalizer" }) +@SuppressWarnings({"PMD.TooManyStaticImports", "checkstyle:NoFinalizer"}) public class FlagdProvider extends EventProvider { private Function contextEnricher; private static final String FLAGD_PROVIDER = "flagd"; @@ -62,7 +62,6 @@ public FlagdProvider(final FlagdOptions options) { case Config.RESOLVER_RPC: this.flagResolver = new GrpcResolver(options, new Cache(options.getCacheType(), options.getMaxCacheSize()), - this::isConnected, this::onConnectionEvent); break; default: @@ -85,7 +84,7 @@ public synchronized void initialize(EvaluationContext evaluationContext) throws } this.flagResolver.init(); - this.initialized = true; + this.initialized = this.connected = true; } @Override @@ -139,7 +138,7 @@ public ProviderEvaluation getObjectEvaluation(String key, Value defaultVa * Set on initial connection and updated with every reconnection. * see: * https://buf.build/open-feature/flagd/docs/main:flagd.sync.v1#flagd.sync.v1.FlagSyncService.GetMetadata - * + * * @return Object map representing sync metadata */ protected Structure getSyncMetadata() { @@ -148,6 +147,7 @@ protected Structure getSyncMetadata() { /** * The updated context mixed into all evaluations based on the sync-metadata. + * * @return context */ EvaluationContext getEnrichedContext() { @@ -159,33 +159,42 @@ private boolean isConnected() { } private void onConnectionEvent(ConnectionEvent connectionEvent) { - boolean previous = connected; - boolean current = connected = connectionEvent.isConnected(); + final boolean wasConnected = connected;// WHY the F*** is this false? wasconnected is false ,hence no change + // event will be sent. why is was connected false? not updated via event? + final boolean isConnected = connected = connectionEvent.isConnected(); + syncMetadata = connectionEvent.getSyncMetadata(); enrichedContext = contextEnricher.apply(connectionEvent.getSyncMetadata()); - // configuration changed - if (initialized && previous && current) { - log.debug("Configuration changed"); + if (!initialized) { + return; + } + + if (!wasConnected && isConnected) { ProviderEventDetails details = ProviderEventDetails.builder() .flagsChanged(connectionEvent.getFlagsChanged()) - .message("configuration changed").build(); - this.emitProviderConfigurationChanged(details); + .message("connected to flagd") + .build(); + this.emitProviderReady(details); return; } - // there was an error - if (initialized && previous && !current) { - log.debug("There has been an error"); - ProviderEventDetails details = ProviderEventDetails.builder().message("there has been an error").build(); - this.emitProviderError(details); + + if (wasConnected && isConnected) { + ProviderEventDetails details = ProviderEventDetails.builder() + .flagsChanged(connectionEvent.getFlagsChanged()) + .message("configuration changed") + .build(); + this.emitProviderConfigurationChanged(details); return; } - // we recovered from an error - if (initialized && !previous && current) { - log.debug("Recovered from error"); - ProviderEventDetails details = ProviderEventDetails.builder().message("recovered from error").build(); - this.emitProviderReady(details); - this.emitProviderConfigurationChanged(details); + + if (connectionEvent.isStale()) { + this.emitProviderStale(ProviderEventDetails.builder().message("there has been an error").build()); + } else { + this.emitProviderError(ProviderEventDetails.builder().message("there has been an error").build()); } } } + + + diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelMonitor.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelMonitor.java new file mode 100644 index 0000000000..2e169b615a --- /dev/null +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelMonitor.java @@ -0,0 +1,169 @@ +package dev.openfeature.contrib.providers.flagd.resolver.common; + +import dev.openfeature.sdk.exceptions.GeneralError; +import io.grpc.ConnectivityState; +import io.grpc.ManagedChannel; +import lombok.extern.slf4j.Slf4j; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + + +/** + * A utility class to monitor and manage the connectivity state of a gRPC ManagedChannel. + */ +@Slf4j +public class ChannelMonitor { + + + private ChannelMonitor() { + + } + + /** + * Monitors the state of a gRPC channel and triggers the specified callbacks based on state changes. + * + * @param expectedState the initial state to monitor. + * @param channel the ManagedChannel to monitor. + * @param onConnectionReady callback invoked when the channel transitions to a READY state. + * @param onConnectionLost callback invoked when the channel transitions to a FAILURE or SHUTDOWN state. + */ + public static void monitorChannelState(ConnectivityState expectedState, ManagedChannel channel, + Runnable onConnectionReady, Runnable onConnectionLost) { + channel.notifyWhenStateChanged(expectedState, () -> { + ConnectivityState currentState = channel.getState(true); + log.info("Channel state changed to: {}", currentState); + if (currentState == ConnectivityState.READY) { + onConnectionReady.run(); + } else if (currentState == ConnectivityState.TRANSIENT_FAILURE + || currentState == ConnectivityState.SHUTDOWN) { + onConnectionLost.run(); + } + // Re-register the state monitor to watch for the next state transition. + monitorChannelState(currentState, channel, onConnectionReady, onConnectionLost); + }); + } + + + /** + * Waits for the channel to reach a desired state within a specified timeout period. + * + * @param channel the ManagedChannel to monitor. + * @param desiredState the ConnectivityState to wait for. + * @param connectCallback callback invoked when the desired state is reached. + * @param timeout the maximum amount of time to wait. + * @param unit the time unit of the timeout. + * @throws InterruptedException if the current thread is interrupted while waiting. + */ + public static void waitForDesiredState(ManagedChannel channel, + ConnectivityState desiredState, + Runnable connectCallback, + long timeout, + TimeUnit unit) throws InterruptedException { + waitForDesiredState(channel, desiredState, connectCallback, new CountDownLatch(1), timeout, unit); + } + + + private static void waitForDesiredState(ManagedChannel channel, + ConnectivityState desiredState, + Runnable connectCallback, + CountDownLatch latch, + long timeout, + TimeUnit unit) throws InterruptedException { + channel.notifyWhenStateChanged(ConnectivityState.SHUTDOWN, () -> { + try { + ConnectivityState state = channel.getState(true); + log.debug("Channel state changed to: {}", state); + + if (state == desiredState) { + connectCallback.run(); + latch.countDown(); + return; + } + waitForDesiredState(channel, desiredState, connectCallback, latch, timeout, unit); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + log.error("Thread interrupted while waiting for desired state", e); + } catch (Exception e) { + log.error("Error occurred while waiting for desired state", e); + } + }); + + // Await the latch or timeout for the state change + if (!latch.await(timeout, unit)) { + throw new GeneralError(String.format("Deadline exceeded. Condition did not complete within the %d " + + "deadline", timeout)); + } + } + + + /** + * Polls the state of a gRPC channel at regular intervals and triggers callbacks upon state changes. + * + * @param executor the ScheduledExecutorService used for polling. + * @param channel the ManagedChannel to monitor. + * @param onConnectionReady callback invoked when the channel transitions to a READY state. + * @param onConnectionLost callback invoked when the channel transitions to a FAILURE or SHUTDOWN state. + * @param pollIntervalMs the polling interval in milliseconds. + */ + public static void pollChannelState(ScheduledExecutorService executor, ManagedChannel channel, + Runnable onConnectionReady, + Runnable onConnectionLost, long pollIntervalMs) { + + AtomicReference lastState = new AtomicReference<>(ConnectivityState.READY); + + Runnable pollTask = () -> { + ConnectivityState currentState = channel.getState(true); + if (currentState != lastState.get()) { + if (currentState == ConnectivityState.READY) { + log.debug("gRPC connection became READY"); + onConnectionReady.run(); + } else if (currentState == ConnectivityState.TRANSIENT_FAILURE + || currentState == ConnectivityState.SHUTDOWN) { + log.debug("gRPC connection became TRANSIENT_FAILURE"); + onConnectionLost.run(); + } + lastState.set(currentState); + } + }; + executor.scheduleAtFixedRate(pollTask, 0, pollIntervalMs, TimeUnit.MILLISECONDS); + } + + + /** + * Polls the channel state at fixed intervals and waits for the channel to reach a desired state within a timeout + * period. + * + * @param executor the ScheduledExecutorService used for polling. + * @param channel the ManagedChannel to monitor. + * @param desiredState the ConnectivityState to wait for. + * @param connectCallback callback invoked when the desired state is reached. + * @param timeout the maximum amount of time to wait. + * @param unit the time unit of the timeout. + * @return {@code true} if the desired state was reached within the timeout period, {@code false} otherwise. + * @throws InterruptedException if the current thread is interrupted while waiting. + */ + public static boolean pollForDesiredState(ScheduledExecutorService executor, ManagedChannel channel, + ConnectivityState desiredState, Runnable connectCallback, long timeout, + TimeUnit unit) throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + + Runnable waitForStateTask = () -> { + ConnectivityState currentState = channel.getState(true); + if (currentState == desiredState) { + connectCallback.run(); + latch.countDown(); + } + }; + + ScheduledFuture scheduledFuture = executor.scheduleWithFixedDelay(waitForStateTask, 0, 100, + TimeUnit.MILLISECONDS); + + boolean success = latch.await(timeout, unit); + scheduledFuture.cancel(true); + return success; + } +} diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ConnectionEvent.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ConnectionEvent.java index d48b9e49e7..8eea3a1049 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ConnectionEvent.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ConnectionEvent.java @@ -1,69 +1,122 @@ package dev.openfeature.contrib.providers.flagd.resolver.common; -import java.util.Collections; -import java.util.List; - import dev.openfeature.sdk.ImmutableStructure; import dev.openfeature.sdk.Structure; -import lombok.AllArgsConstructor; -import lombok.Getter; + +import java.util.Collections; +import java.util.List; /** - * Event payload for a - * {@link dev.openfeature.contrib.providers.flagd.resolver.Resolver} connection - * state change event. + * Represents an event payload for a connection state change in a + * {@link dev.openfeature.contrib.providers.flagd.resolver.Resolver}. + * The event includes information about the connection status, any flags that have changed, + * and metadata associated with the synchronization process. */ -@AllArgsConstructor public class ConnectionEvent { - @Getter - private final boolean connected; + + /** + * The current state of the connection. + */ + private final ConnectionState connected; + + /** + * A list of flags that have changed due to this connection event. + */ private final List flagsChanged; + + /** + * Metadata associated with synchronization in this connection event. + */ private final Structure syncMetadata; /** - * Construct a new ConnectionEvent. - * - * @param connected status of the connection + * Constructs a new {@code ConnectionEvent} with the connection status only. + * + * @param connected {@code true} if the connection is established, otherwise {@code false}. */ public ConnectionEvent(boolean connected) { + this(connected ? ConnectionState.CONNECTED : ConnectionState.DISCONNECTED, Collections.emptyList(), + new ImmutableStructure()); + } + + /** + * Constructs a new {@code ConnectionEvent} with the specified connection state. + * + * @param connected the connection state indicating if the connection is established or not. + */ + public ConnectionEvent(ConnectionState connected) { this(connected, Collections.emptyList(), new ImmutableStructure()); } /** - * Construct a new ConnectionEvent. - * - * @param connected status of the connection - * @param flagsChanged list of flags changed + * Constructs a new {@code ConnectionEvent} with the specified connection state and changed flags. + * + * @param connected the connection state indicating if the connection is established or not. + * @param flagsChanged a list of flags that have changed due to this connection event. */ - public ConnectionEvent(boolean connected, List flagsChanged) { + public ConnectionEvent(ConnectionState connected, List flagsChanged) { this(connected, flagsChanged, new ImmutableStructure()); } /** - * Construct a new ConnectionEvent. - * - * @param connected status of the connection - * @param syncMetadata sync.getMetadata + * Constructs a new {@code ConnectionEvent} with the specified connection state and synchronization metadata. + * + * @param connected the connection state indicating if the connection is established or not. + * @param syncMetadata metadata related to the synchronization process of this event. */ - public ConnectionEvent(boolean connected, Structure syncMetadata) { + public ConnectionEvent(ConnectionState connected, Structure syncMetadata) { this(connected, Collections.emptyList(), new ImmutableStructure(syncMetadata.asMap())); } /** - * Get changed flags. - * - * @return an unmodifiable view of the changed flags + * Constructs a new {@code ConnectionEvent} with the specified connection state, changed flags, and + * synchronization metadata. + * + * @param connectionState the state of the connection. + * @param flagsChanged a list of flags that have changed due to this connection event. + * @param syncMetadata metadata related to the synchronization process of this event. + */ + public ConnectionEvent(ConnectionState connectionState, List flagsChanged, Structure syncMetadata) { + this.connected = connectionState; + this.flagsChanged = flagsChanged != null ? flagsChanged : Collections.emptyList(); // Ensure non-null list + this.syncMetadata = syncMetadata != null ? new ImmutableStructure(syncMetadata.asMap()) : + new ImmutableStructure(); // Ensure valid syncMetadata + } + + /** + * Retrieves an unmodifiable view of the list of changed flags. + * + * @return an unmodifiable list of changed flags. */ public List getFlagsChanged() { return Collections.unmodifiableList(flagsChanged); } /** - * Get changed sync metadata represented as SDK structure type. - * - * @return an unmodifiable view of the sync metadata + * Retrieves the synchronization metadata represented as an immutable SDK structure type. + * + * @return an immutable structure containing the synchronization metadata. */ public Structure getSyncMetadata() { return new ImmutableStructure(syncMetadata.asMap()); } + + /** + * Indicates whether the current connection state is connected. + * + * @return {@code true} if connected, otherwise {@code false}. + */ + public boolean isConnected() { + return this.connected == ConnectionState.CONNECTED; + } + + /** + * Indicates + * whether the current connection state is stale. + * + * @return {@code true} if stale, otherwise {@code false}. + */ + public boolean isStale() { + return this.connected == ConnectionState.STALE; + } } diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ConnectionState.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ConnectionState.java new file mode 100644 index 0000000000..93ece9d118 --- /dev/null +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ConnectionState.java @@ -0,0 +1,27 @@ +package dev.openfeature.contrib.providers.flagd.resolver.common; + +/** + * Represents the possible states of a connection. + */ +public enum ConnectionState { + + /** + * The connection is active and functioning as expected. + */ + CONNECTED, + + /** + * The connection is not active and has been fully disconnected. + */ + DISCONNECTED, + + /** + * The connection is inactive or degraded but may still recover. + */ + STALE, + + /** + * The connection has encountered an error and cannot function correctly. + */ + ERROR, +} \ No newline at end of file diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/Util.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/Util.java index 3f9d8981f2..b8c609520f 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/Util.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/Util.java @@ -1,33 +1,38 @@ package dev.openfeature.contrib.providers.flagd.resolver.common; -import java.util.function.Supplier; - import dev.openfeature.sdk.exceptions.GeneralError; +import lombok.extern.slf4j.Slf4j; + +import java.util.function.Supplier; /** - * Utils for flagd resolvers. + * Utility class for managing gRPC connection states and handling synchronization operations. */ +@Slf4j public class Util { + /** + * Private constructor to prevent instantiation of utility class. + */ private Util() { } /** - * A helper to block the caller for given conditions. - * - * @param deadline number of milliseconds to block - * @param connectedSupplier func to check for status true - * @throws InterruptedException if interrupted + * A helper method to block the caller until a condition is met or a timeout occurs. + * + * @param deadline the maximum number of milliseconds to block + * @param connectedSupplier a function that evaluates to {@code true} when the desired condition is met + * @throws InterruptedException if the thread is interrupted during the waiting process + * @throws GeneralError if the deadline is exceeded before the condition is met */ - public static void busyWaitAndCheck(final Long deadline, final Supplier connectedSupplier) - throws InterruptedException { + public static void busyWaitAndCheck(final Long deadline, + final Supplier connectedSupplier) throws InterruptedException { long start = System.currentTimeMillis(); do { if (deadline <= System.currentTimeMillis() - start) { - throw new GeneralError( - String.format("Deadline exceeded. Condition did not complete within the %d deadline", - deadline)); + throw new GeneralError(String.format("Deadline exceeded. Condition did not complete within the %d " + + "deadline", deadline)); } Thread.sleep(50L); diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/EventStreamObserver.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/EventStreamObserver.java index 6b4efe58ef..0f02bdbb1a 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/EventStreamObserver.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/EventStreamObserver.java @@ -1,47 +1,52 @@ package dev.openfeature.contrib.providers.flagd.resolver.grpc; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.function.BiConsumer; -import java.util.function.Supplier; - import com.google.protobuf.Value; - import dev.openfeature.contrib.providers.flagd.resolver.grpc.cache.Cache; import dev.openfeature.flagd.grpc.evaluation.Evaluation.EventStreamResponse; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import io.grpc.stub.StreamObserver; import lombok.extern.slf4j.Slf4j; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; + /** - * EventStreamObserver handles events emitted by flagd. + * Observer for a gRPC event stream that handles notifications about flag changes and provider readiness events. + * This class updates a cache and notifies listeners via a lambda callback when events occur. */ @Slf4j @SuppressFBWarnings(justification = "cache needs to be read and write by multiple objects") class EventStreamObserver implements StreamObserver { + + /** + * A consumer to handle connection events with a flag indicating success and a list of changed flags. + */ private final BiConsumer> onConnectionEvent; - private final Supplier shouldRetrySilently; - private final Object sync; + + /** + * The cache to update based on received events. + */ private final Cache cache; /** - * Create a gRPC stream that get notified about flag changes. + * Constructs a new {@code EventStreamObserver} instance. * - * @param sync synchronization object from caller - * @param cache cache to update - * @param onConnectionEvent lambda to call to handle the response - * @param shouldRetrySilently Boolean supplier indicating if the GRPC connector will try to recover silently + * @param cache the cache to update based on received events + * @param onConnectionEvent a consumer to handle connection events with a boolean and a list of changed flags */ - EventStreamObserver(Object sync, Cache cache, BiConsumer> onConnectionEvent, - Supplier shouldRetrySilently) { - this.sync = sync; + EventStreamObserver(Cache cache, BiConsumer> onConnectionEvent) { this.cache = cache; this.onConnectionEvent = onConnectionEvent; - this.shouldRetrySilently = shouldRetrySilently; } + /** + * Called when a new event is received from the stream. + * + * @param value the event stream response containing event data + */ @Override public void onNext(EventStreamResponse value) { switch (value.getType()) { @@ -52,37 +57,38 @@ public void onNext(EventStreamResponse value) { this.handleProviderReadyEvent(); break; default: - log.debug("unhandled event type {}", value.getType()); + log.debug("Unhandled event type {}", value.getType()); } } + /** + * Called when an error occurs in the stream. + * + * @param throwable the error that occurred + */ @Override public void onError(Throwable throwable) { - if (Boolean.TRUE.equals(shouldRetrySilently.get())) { - log.debug("Event stream error, trying to recover", throwable); - } else { - log.error("Event stream error", throwable); - if (this.cache.getEnabled()) { - this.cache.clear(); - } - this.onConnectionEvent.accept(false, Collections.emptyList()); + if (this.cache.getEnabled().equals(Boolean.TRUE)) { + this.cache.clear(); } - - // handle last call of this stream - handleEndOfStream(); } + /** + * Called when the stream is completed. + */ @Override public void onCompleted() { - if (this.cache.getEnabled()) { + if (this.cache.getEnabled().equals(Boolean.TRUE)) { this.cache.clear(); } this.onConnectionEvent.accept(false, Collections.emptyList()); - - // handle last call of this stream - handleEndOfStream(); } + /** + * Handles configuration change events by updating the cache and notifying listeners about changed flags. + * + * @param value the event stream response containing configuration change data + */ private void handleConfigurationChangeEvent(EventStreamResponse value) { List changedFlags = new ArrayList<>(); boolean cachingEnabled = this.cache.getEnabled(); @@ -95,7 +101,6 @@ private void handleConfigurationChangeEvent(EventStreamResponse value) { } } else { Map flags = flagsValue.getStructValue().getFieldsMap(); - this.cache.getEnabled(); for (String flagKey : flags.keySet()) { changedFlags.add(flagKey); if (cachingEnabled) { @@ -107,16 +112,13 @@ private void handleConfigurationChangeEvent(EventStreamResponse value) { this.onConnectionEvent.accept(true, changedFlags); } + /** + * Handles provider readiness events by clearing the cache (if enabled) and notifying listeners of readiness. + */ private void handleProviderReadyEvent() { - this.onConnectionEvent.accept(true, Collections.emptyList()); - if (this.cache.getEnabled()) { + this.onConnectionEvent.accept(true, Collections.emptyList()); // TODO: check if this is needed + if (this.cache.getEnabled().equals(Boolean.TRUE)) { this.cache.clear(); } } - - private void handleEndOfStream() { - synchronized (this.sync) { - this.sync.notifyAll(); - } - } } diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcConnector.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcConnector.java index 5cf10a94a5..de579f8cd2 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcConnector.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcConnector.java @@ -1,174 +1,246 @@ package dev.openfeature.contrib.providers.flagd.resolver.grpc; +import com.google.common.annotations.VisibleForTesting; import dev.openfeature.contrib.providers.flagd.FlagdOptions; import dev.openfeature.contrib.providers.flagd.resolver.common.ChannelBuilder; +import dev.openfeature.contrib.providers.flagd.resolver.common.ChannelMonitor; import dev.openfeature.contrib.providers.flagd.resolver.common.ConnectionEvent; -import dev.openfeature.contrib.providers.flagd.resolver.common.Util; -import dev.openfeature.contrib.providers.flagd.resolver.common.backoff.GrpcStreamConnectorBackoffService; -import dev.openfeature.contrib.providers.flagd.resolver.grpc.cache.Cache; -import dev.openfeature.flagd.grpc.evaluation.Evaluation.EventStreamRequest; -import dev.openfeature.flagd.grpc.evaluation.Evaluation.EventStreamResponse; -import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc; -import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import dev.openfeature.contrib.providers.flagd.resolver.common.ConnectionState; +import dev.openfeature.sdk.ImmutableStructure; +import io.grpc.ConnectivityState; import io.grpc.ManagedChannel; -import io.grpc.stub.StreamObserver; +import io.grpc.stub.AbstractBlockingStub; +import io.grpc.stub.AbstractStub; +import lombok.Getter; import lombok.extern.slf4j.Slf4j; import java.util.Collections; -import java.util.List; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; -import java.util.function.Supplier; - -import static dev.openfeature.contrib.providers.flagd.resolver.common.backoff.BackoffStrategies.maxRetriesWithExponentialTimeBackoffStrategy; +import java.util.function.Function; /** - * Class that abstracts the gRPC communication with flagd. + * A generic GRPC connector that manages connection states, reconnection logic, and event streaming for + * GRPC services. + * + * @param the type of the asynchronous stub for the GRPC service + * @param the type of the blocking stub for the GRPC service */ @Slf4j -@SuppressFBWarnings(justification = "cache needs to be read and write by multiple objects") -public class GrpcConnector { - private final Object sync = new Object(); +public class GrpcConnector, K extends AbstractBlockingStub> { - private final ServiceGrpc.ServiceBlockingStub serviceBlockingStub; - private final ServiceGrpc.ServiceStub serviceStub; + /** + * The asynchronous service stub for making non-blocking GRPC calls. + */ + private final T serviceStub; + + /** + * The blocking service stub for making blocking GRPC calls. + */ + private final K blockingStub; + + /** + * The GRPC managed channel for managing the underlying GRPC connection. + */ private final ManagedChannel channel; + /** + * The deadline in milliseconds for GRPC operations. + */ private final long deadline; + + /** + * The deadline in milliseconds for event streaming operations. + */ private final long streamDeadlineMs; - private final Cache cache; + /** + * A consumer that handles connection events such as connection loss or reconnection. + */ private final Consumer onConnectionEvent; - private final Supplier connectedSupplier; - private final GrpcStreamConnectorBackoffService backoff; - // Thread responsible for event observation - private Thread eventObserverThread; + /** + * A consumer that handles GRPC service stubs for event stream handling. + */ + private final Consumer streamObserver; + + /** + * An executor service responsible for scheduling reconnection attempts. + */ + private final ScheduledExecutorService reconnectExecutor; + + /** + * The grace period in milliseconds to wait for reconnection before emitting an error event. + */ + private final long gracePeriod; + + /** + * Indicates whether the connector is currently connected to the GRPC service. + */ + @Getter + private boolean connected = false; + + /** + * A scheduled task for managing reconnection attempts. + */ + private ScheduledFuture reconnectTask; /** - * GrpcConnector creates an abstraction over gRPC communication. + * Constructs a new {@code GrpcConnector} instance with the specified options and parameters. * - * @param options flagd options - * @param cache cache to use - * @param connectedSupplier lambda providing current connection status from caller - * @param onConnectionEvent lambda which handles changes in the connection/stream - */ - public GrpcConnector(final FlagdOptions options, final Cache cache, final Supplier connectedSupplier, - Consumer onConnectionEvent) { - this.channel = ChannelBuilder.nettyChannel(options); - this.serviceStub = ServiceGrpc.newStub(channel); - this.serviceBlockingStub = ServiceGrpc.newBlockingStub(channel); + * @param options the configuration options for the GRPC connection + * @param stub a function to create the asynchronous service stub from a {@link ManagedChannel} + * @param blockingStub a function to create the blocking service stub from a {@link ManagedChannel} + * @param onConnectionEvent a consumer to handle connection events + * @param eventStreamObserver a consumer to handle the event stream + * @param channel the managed channel for the GRPC connection + */ + public GrpcConnector(final FlagdOptions options, + final Function stub, + final Function blockingStub, + final Consumer onConnectionEvent, + final Consumer eventStreamObserver, ManagedChannel channel) { + + this.channel = channel; + this.serviceStub = stub.apply(channel); + this.blockingStub = blockingStub.apply(channel); this.deadline = options.getDeadline(); this.streamDeadlineMs = options.getStreamDeadlineMs(); - this.cache = cache; this.onConnectionEvent = onConnectionEvent; - this.connectedSupplier = connectedSupplier; - this.backoff = new GrpcStreamConnectorBackoffService(maxRetriesWithExponentialTimeBackoffStrategy( - options.getMaxEventStreamRetries(), - options.getRetryBackoffMs()) - ); + this.streamObserver = eventStreamObserver; + this.gracePeriod = options.getStreamRetryGracePeriod(); + this.reconnectExecutor = Executors.newSingleThreadScheduledExecutor(); } /** - * Initialize the gRPC stream. + * Constructs a {@code GrpcConnector} instance for testing purposes. + * + * @param options the configuration options for the GRPC connection + * @param stub a function to create the asynchronous service stub from a {@link ManagedChannel} + * @param blockingStub a function to create the blocking service stub from a {@link ManagedChannel} + * @param onConnectionEvent a consumer to handle connection events + * @param eventStreamObserver a consumer to handle the event stream + */ + @VisibleForTesting + GrpcConnector(final FlagdOptions options, + final Function stub, + final Function blockingStub, + final Consumer onConnectionEvent, + final Consumer eventStreamObserver) { + this(options, stub, blockingStub, onConnectionEvent, eventStreamObserver, ChannelBuilder.nettyChannel(options)); + } + + /** + * Initializes the GRPC connection by waiting for the channel to be ready and monitoring its state. + * + * @throws Exception if the channel does not reach the desired state within the deadline */ public void initialize() throws Exception { - eventObserverThread = new Thread(this::observeEventStream); - eventObserverThread.setDaemon(true); - eventObserverThread.start(); + log.info("Initializing GRPC connection..."); + ChannelMonitor.waitForDesiredState(channel, ConnectivityState.READY, this::onInitialConnect, deadline, + TimeUnit.MILLISECONDS); + ChannelMonitor.monitorChannelState(ConnectivityState.READY, channel, this::onReady, this::onConnectionLost); + } - // block till ready - Util.busyWaitAndCheck(this.deadline, this.connectedSupplier); + /** + * Returns the blocking service stub for making blocking GRPC calls. + * + * @return the blocking service stub + */ + public K getResolver() { + return blockingStub; } /** - * Shuts down all gRPC resources. + * Shuts down the GRPC connection and cleans up associated resources. * - * @throws Exception is something goes wrong while terminating the - * communication. + * @throws InterruptedException if interrupted while waiting for termination */ - public void shutdown() throws Exception { - // first shutdown the event listener - if (this.eventObserverThread != null) { - this.eventObserverThread.interrupt(); + public void shutdown() throws InterruptedException { + log.info("Shutting down GRPC connection..."); + if (reconnectExecutor != null) { + reconnectExecutor.shutdownNow(); + reconnectExecutor.awaitTermination(deadline, TimeUnit.MILLISECONDS); } - try { - if (this.channel != null && !this.channel.isShutdown()) { - this.channel.shutdown(); - this.channel.awaitTermination(this.deadline, TimeUnit.MILLISECONDS); - } - } finally { - this.cache.clear(); - if (this.channel != null && !this.channel.isShutdown()) { - this.channel.shutdownNow(); - this.channel.awaitTermination(this.deadline, TimeUnit.MILLISECONDS); - log.warn(String.format("Unable to shut down channel by %d deadline", this.deadline)); - } + if (!channel.isShutdown()) { + channel.shutdownNow(); + channel.awaitTermination(deadline, TimeUnit.MILLISECONDS); + } + + if (connected) { this.onConnectionEvent.accept(new ConnectionEvent(false)); + connected = false; } } - /** - * Provide the object that can be used to resolve Feature Flag values. - * - * @return a {@link ServiceGrpc.ServiceBlockingStub} for running FF resolution. - */ - public ServiceGrpc.ServiceBlockingStub getResolver() { - return serviceBlockingStub.withDeadlineAfter(this.deadline, TimeUnit.MILLISECONDS); + + private synchronized void onInitialConnect() { + connected = true; + restartStream(); } /** - * Event stream observer logic. This contains blocking mechanisms, hence must be - * run in a dedicated thread. + * Handles the event when the GRPC channel becomes ready, marking the connection as established. + * Cancels any pending reconnection task and restarts the event stream. */ - private void observeEventStream() { - while (backoff.shouldRetry()) { - final StreamObserver responseObserver = new EventStreamObserver(sync, this.cache, - this::onConnectionEvent, backoff::shouldRetrySilently); - - ServiceGrpc.ServiceStub localServiceStub = this.serviceStub; + private synchronized void onReady() { + connected = true; - if (this.streamDeadlineMs > 0) { - localServiceStub = localServiceStub.withDeadlineAfter(this.streamDeadlineMs, TimeUnit.MILLISECONDS); - } + if (reconnectTask != null && !reconnectTask.isCancelled()) { + reconnectTask.cancel(false); + log.debug("Reconnection task cancelled as connection became READY."); + } + restartStream(); + this.onConnectionEvent.accept(new ConnectionEvent(true)); + } - localServiceStub.eventStream(EventStreamRequest.getDefaultInstance(), responseObserver); - - try { - synchronized (sync) { - sync.wait(); - } - } catch (InterruptedException e) { - // Interruptions are considered end calls for this observer, hence log and - // return - // Note - this is the most common interruption when shutdown, hence the log - // level debug - log.debug("interruption while waiting for condition", e); - Thread.currentThread().interrupt(); - } + /** + * Handles the event when the GRPC channel loses its connection, marking the connection as lost. + * Schedules a reconnection task after a grace period and emits a stale connection event. + */ + private synchronized void onConnectionLost() { + log.debug("Connection lost. Emit STALE event..."); + log.debug("Waiting {}ms for connection to become available...", gracePeriod); + connected = false; + + this.onConnectionEvent.accept( + new ConnectionEvent( + ConnectionState.STALE, + Collections.emptyList(), + new ImmutableStructure())); + + if (reconnectTask != null && !reconnectTask.isCancelled()) { + reconnectTask.cancel(false); + } - try { - backoff.waitUntilNextAttempt(); - } catch (InterruptedException e) { - // Interruptions are considered end calls for this observer, hence log and - // return - log.warn("interrupted while restarting gRPC Event Stream"); - Thread.currentThread().interrupt(); - } + if (!reconnectExecutor.isShutdown()) { + reconnectTask = reconnectExecutor.schedule(() -> { + log.debug("Provider did not reconnect successfully within {}ms. Emit ERROR event...", gracePeriod); + this.onConnectionEvent.accept( + new ConnectionEvent(false)); + }, gracePeriod, TimeUnit.MILLISECONDS); } - log.error("failed to connect to event stream, exhausted retries"); - this.onConnectionEvent(false, Collections.emptyList()); } - private void onConnectionEvent(final boolean connected, final List changedFlags) { - // reset reconnection states + /** + * Restarts the event stream using the asynchronous service stub, applying a deadline if configured. + * Emits a connection event if the restart is successful. + */ + private synchronized void restartStream() { if (connected) { - backoff.reset(); + log.debug("(Re)initializing event stream."); + T localServiceStub = this.serviceStub; + if (streamDeadlineMs > 0) { + localServiceStub = localServiceStub.withDeadlineAfter(this.streamDeadlineMs, TimeUnit.MILLISECONDS); + } + streamObserver.accept(localServiceStub); + return; } - - // chain to initiator - this.onConnectionEvent.accept(new ConnectionEvent(connected, changedFlags)); + log.debug("Stream restart skipped. Not connected."); } } diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcResolver.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcResolver.java index 9fcede67e8..6e04c24177 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcResolver.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcResolver.java @@ -1,30 +1,22 @@ package dev.openfeature.contrib.providers.flagd.resolver.grpc; -import static dev.openfeature.contrib.providers.flagd.resolver.common.Convert.convertContext; -import static dev.openfeature.contrib.providers.flagd.resolver.common.Convert.convertObjectResponse; -import static dev.openfeature.contrib.providers.flagd.resolver.common.Convert.getField; -import static dev.openfeature.contrib.providers.flagd.resolver.common.Convert.getFieldDescriptor; - -import java.util.Map; -import java.util.function.Consumer; -import java.util.function.Function; -import java.util.function.Supplier; - import com.google.protobuf.Message; import com.google.protobuf.Struct; - import dev.openfeature.contrib.providers.flagd.Config; import dev.openfeature.contrib.providers.flagd.FlagdOptions; import dev.openfeature.contrib.providers.flagd.resolver.Resolver; import dev.openfeature.contrib.providers.flagd.resolver.common.ConnectionEvent; +import dev.openfeature.contrib.providers.flagd.resolver.common.ConnectionState; import dev.openfeature.contrib.providers.flagd.resolver.grpc.cache.Cache; import dev.openfeature.contrib.providers.flagd.resolver.grpc.strategy.ResolveFactory; import dev.openfeature.contrib.providers.flagd.resolver.grpc.strategy.ResolveStrategy; +import dev.openfeature.flagd.grpc.evaluation.Evaluation; import dev.openfeature.flagd.grpc.evaluation.Evaluation.ResolveBooleanRequest; import dev.openfeature.flagd.grpc.evaluation.Evaluation.ResolveFloatRequest; import dev.openfeature.flagd.grpc.evaluation.Evaluation.ResolveIntRequest; import dev.openfeature.flagd.grpc.evaluation.Evaluation.ResolveObjectRequest; import dev.openfeature.flagd.grpc.evaluation.Evaluation.ResolveStringRequest; +import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc; import dev.openfeature.sdk.EvaluationContext; import dev.openfeature.sdk.ImmutableMetadata; import dev.openfeature.sdk.ProviderEvaluation; @@ -38,6 +30,15 @@ import io.grpc.Status.Code; import io.grpc.StatusRuntimeException; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; + +import static dev.openfeature.contrib.providers.flagd.resolver.common.Convert.convertContext; +import static dev.openfeature.contrib.providers.flagd.resolver.common.Convert.convertObjectResponse; +import static dev.openfeature.contrib.providers.flagd.resolver.common.Convert.getField; +import static dev.openfeature.contrib.providers.flagd.resolver.common.Convert.getFieldDescriptor; + /** * Resolves flag values using https://buf.build/open-feature/flagd/docs/main:flagd.evaluation.v1. * Flags are evaluated remotely. @@ -46,28 +47,35 @@ @SuppressFBWarnings(justification = "cache needs to be read and write by multiple objects") public final class GrpcResolver implements Resolver { - private final GrpcConnector connector; + private final GrpcConnector connector; private final Cache cache; private final ResolveStrategy strategy; - private final Supplier connectedSupplier; /** * Resolves flag values using https://buf.build/open-feature/flagd/docs/main:flagd.evaluation.v1. * Flags are evaluated remotely. - * - * @param options flagd options - * @param cache cache to use - * @param connectedSupplier lambda providing current connection status from caller + * + * @param options flagd options + * @param cache cache to use * @param onConnectionEvent lambda which handles changes in the connection/stream */ - public GrpcResolver(final FlagdOptions options, final Cache cache, final Supplier connectedSupplier, - final Consumer onConnectionEvent) { + public GrpcResolver(final FlagdOptions options, final Cache cache, + final Consumer onConnectionEvent) { this.cache = cache; - this.connectedSupplier = connectedSupplier; this.strategy = ResolveFactory.getStrategy(options); - this.connector = new GrpcConnector(options, cache, connectedSupplier, onConnectionEvent); + this.connector = new GrpcConnector<>(options, + ServiceGrpc::newStub, + ServiceGrpc::newBlockingStub, + onConnectionEvent, + stub -> stub.eventStream(Evaluation.EventStreamRequest.getDefaultInstance(), + new EventStreamObserver(cache, + (k, e) -> onConnectionEvent.accept(new ConnectionEvent(ConnectionState.CONNECTED, + e))))); + + } + /** * Initialize Grpc resolver. */ @@ -86,41 +94,44 @@ public void shutdown() throws Exception { * Boolean evaluation from grpc resolver. */ public ProviderEvaluation booleanEvaluation(String key, Boolean defaultValue, - EvaluationContext ctx) { + EvaluationContext ctx) { ResolveBooleanRequest request = ResolveBooleanRequest.newBuilder().buildPartial(); - return resolve(key, ctx, request, this.connector.getResolver()::resolveBoolean, null); + + return resolve(key, ctx, request, connector.getResolver()::resolveBoolean, + null); } /** * String evaluation from grpc resolver. */ public ProviderEvaluation stringEvaluation(String key, String defaultValue, - EvaluationContext ctx) { + EvaluationContext ctx) { ResolveStringRequest request = ResolveStringRequest.newBuilder().buildPartial(); - - return resolve(key, ctx, request, this.connector.getResolver()::resolveString, null); + return resolve(key, ctx, request, connector.getResolver()::resolveString, + null); } /** * Double evaluation from grpc resolver. */ public ProviderEvaluation doubleEvaluation(String key, Double defaultValue, - EvaluationContext ctx) { + EvaluationContext ctx) { ResolveFloatRequest request = ResolveFloatRequest.newBuilder().buildPartial(); - return resolve(key, ctx, request, this.connector.getResolver()::resolveFloat, null); + return resolve(key, ctx, request, connector.getResolver()::resolveFloat, + null); } /** * Integer evaluation from grpc resolver. */ public ProviderEvaluation integerEvaluation(String key, Integer defaultValue, - EvaluationContext ctx) { + EvaluationContext ctx) { ResolveIntRequest request = ResolveIntRequest.newBuilder().buildPartial(); - return resolve(key, ctx, request, this.connector.getResolver()::resolveInt, + return resolve(key, ctx, request, connector.getResolver()::resolveInt, (Object value) -> ((Long) value).intValue()); } @@ -128,11 +139,11 @@ public ProviderEvaluation integerEvaluation(String key, Integer default * Object evaluation from grpc resolver. */ public ProviderEvaluation objectEvaluation(String key, Value defaultValue, - EvaluationContext ctx) { + EvaluationContext ctx) { ResolveObjectRequest request = ResolveObjectRequest.newBuilder().buildPartial(); - return resolve(key, ctx, request, this.connector.getResolver()::resolveObject, + return resolve(key, ctx, request, connector.getResolver()::resolveObject, (Object value) -> convertObjectResponse((Struct) value)); } @@ -197,7 +208,7 @@ private Boolean isEvaluationCacheable(ProviderEvaluation evaluation) { } private Boolean cacheAvailable() { - return this.cache.getEnabled() && this.connectedSupplier.get(); + return this.cache.getEnabled() && this.connector.isConnected(); } private static ImmutableMetadata metadataFromResponse(Message response) { diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolver.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolver.java index 39c77f01b2..50b6098896 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolver.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolver.java @@ -1,13 +1,9 @@ package dev.openfeature.contrib.providers.flagd.resolver.process; -import static dev.openfeature.contrib.providers.flagd.resolver.process.model.FeatureFlag.EMPTY_TARGETING_STRING; - -import java.util.function.Consumer; -import java.util.function.Supplier; - import dev.openfeature.contrib.providers.flagd.FlagdOptions; import dev.openfeature.contrib.providers.flagd.resolver.Resolver; import dev.openfeature.contrib.providers.flagd.resolver.common.ConnectionEvent; +import dev.openfeature.contrib.providers.flagd.resolver.common.ConnectionState; import dev.openfeature.contrib.providers.flagd.resolver.common.Util; import dev.openfeature.contrib.providers.flagd.resolver.process.model.FeatureFlag; import dev.openfeature.contrib.providers.flagd.resolver.process.storage.FlagStore; @@ -28,6 +24,11 @@ import dev.openfeature.sdk.exceptions.TypeMismatchError; import lombok.extern.slf4j.Slf4j; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import static dev.openfeature.contrib.providers.flagd.resolver.process.model.FeatureFlag.EMPTY_TARGETING_STRING; + /** * Resolves flag values using * https://buf.build/open-feature/flagd/docs/main:flagd.sync.v1. @@ -46,7 +47,7 @@ public class InProcessResolver implements Resolver { * Resolves flag values using * https://buf.build/open-feature/flagd/docs/main:flagd.sync.v1. * Flags are evaluated locally. - * + * * @param options flagd options * @param connectedSupplier lambda providing current connection status from * caller @@ -54,7 +55,7 @@ public class InProcessResolver implements Resolver { * connection/stream */ public InProcessResolver(FlagdOptions options, final Supplier connectedSupplier, - Consumer onConnectionEvent) { + Consumer onConnectionEvent) { this.flagStore = new FlagStore(getConnector(options)); this.deadline = options.getDeadline(); this.onConnectionEvent = onConnectionEvent; @@ -62,8 +63,8 @@ public InProcessResolver(FlagdOptions options, final Supplier connected this.connectedSupplier = connectedSupplier; this.metadata = options.getSelector() == null ? null : ImmutableMetadata.builder() - .addString("scope", options.getSelector()) - .build(); + .addString("scope", options.getSelector()) + .build(); } /** @@ -77,7 +78,8 @@ public void init() throws Exception { final StorageStateChange storageStateChange = flagStore.getStateQueue().take(); switch (storageStateChange.getStorageState()) { case OK: - onConnectionEvent.accept(new ConnectionEvent(true, storageStateChange.getChangedFlagsKeys(), + onConnectionEvent.accept(new ConnectionEvent(ConnectionState.CONNECTED, + storageStateChange.getChangedFlagsKeys(), storageStateChange.getSyncMetadata())); break; case ERROR: @@ -114,7 +116,7 @@ public void shutdown() throws InterruptedException { * Resolve a boolean flag. */ public ProviderEvaluation booleanEvaluation(String key, Boolean defaultValue, - EvaluationContext ctx) { + EvaluationContext ctx) { return resolve(Boolean.class, key, ctx); } diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdOptionsTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdOptionsTest.java index 71adc687f5..e0eb453a99 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdOptionsTest.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdOptionsTest.java @@ -27,7 +27,6 @@ void TestDefaults() { assertNull(builder.getSocketPath()); assertEquals(DEFAULT_CACHE, builder.getCacheType()); assertEquals(DEFAULT_MAX_CACHE_SIZE, builder.getMaxCacheSize()); - assertEquals(DEFAULT_MAX_EVENT_STREAM_RETRIES, builder.getMaxEventStreamRetries()); assertNull(builder.getSelector()); assertNull(builder.getOpenTelemetry()); assertNull(builder.getCustomConnector()); @@ -48,7 +47,6 @@ void TestBuilderOptions() { .certPath("etc/cert/ca.crt") .cacheType("lru") .maxCacheSize(100) - .maxEventStreamRetries(1) .selector("app=weatherApp") .offlineFlagSourcePath("some-path") .openTelemetry(openTelemetry) @@ -64,7 +62,6 @@ void TestBuilderOptions() { assertEquals("etc/cert/ca.crt", flagdOptions.getCertPath()); assertEquals("lru", flagdOptions.getCacheType()); assertEquals(100, flagdOptions.getMaxCacheSize()); - assertEquals(1, flagdOptions.getMaxEventStreamRetries()); assertEquals("app=weatherApp", flagdOptions.getSelector()); assertEquals("some-path", flagdOptions.getOfflineFlagSourcePath()); assertEquals(openTelemetry, flagdOptions.getOpenTelemetry()); @@ -137,7 +134,6 @@ void usesSetOldAndNewName() { } - @Test void testInProcessProvider_noPortConfigured_defaultsToCorrectPort() { FlagdOptions flagdOptions = FlagdOptions.builder() diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderTest.java index 2a5850172f..240d2cd958 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderTest.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderTest.java @@ -1,47 +1,9 @@ package dev.openfeature.contrib.providers.flagd; -import static dev.openfeature.contrib.providers.flagd.Config.CACHED_REASON; -import static dev.openfeature.contrib.providers.flagd.Config.STATIC_REASON; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.ArgumentMatchers.argThat; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.mockConstruction; -import static org.mockito.Mockito.mockStatic; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.lang.reflect.Field; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; -import java.util.function.Function; -import java.util.function.Supplier; -import java.util.concurrent.atomic.AtomicReference; -import java.util.Collections; -import java.util.Optional; - - -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import org.mockito.MockedConstruction; -import org.mockito.MockedStatic; - import com.google.protobuf.Struct; - import dev.openfeature.contrib.providers.flagd.resolver.Resolver; import dev.openfeature.contrib.providers.flagd.resolver.common.ConnectionEvent; +import dev.openfeature.contrib.providers.flagd.resolver.common.ConnectionState; import dev.openfeature.contrib.providers.flagd.resolver.grpc.GrpcConnector; import dev.openfeature.contrib.providers.flagd.resolver.grpc.GrpcResolver; import dev.openfeature.contrib.providers.flagd.resolver.grpc.cache.Cache; @@ -56,9 +18,7 @@ import dev.openfeature.flagd.grpc.evaluation.Evaluation.ResolveIntResponse; import dev.openfeature.flagd.grpc.evaluation.Evaluation.ResolveObjectResponse; import dev.openfeature.flagd.grpc.evaluation.Evaluation.ResolveStringResponse; -import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc; import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc.ServiceBlockingStub; -import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc.ServiceStub; import dev.openfeature.sdk.EvaluationContext; import dev.openfeature.sdk.FlagEvaluationDetails; import dev.openfeature.sdk.FlagValueType; @@ -72,8 +32,37 @@ import dev.openfeature.sdk.Structure; import dev.openfeature.sdk.Value; import io.cucumber.java.AfterAll; -import io.grpc.Channel; -import io.grpc.Deadline; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.mockito.MockedConstruction; + +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.function.Function; + +import static dev.openfeature.contrib.providers.flagd.Config.CACHED_REASON; +import static dev.openfeature.contrib.providers.flagd.Config.STATIC_REASON; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockConstruction; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; class FlagdProviderTest { private static final String FLAG_KEY = "some-key"; @@ -92,7 +81,7 @@ class FlagdProviderTest { private static final Double DOUBLE_VALUE = .5d; private static final String INNER_STRUCT_KEY = "inner_key"; private static final String INNER_STRUCT_VALUE = "inner_value"; - private static final com.google.protobuf.Struct PROTOBUF_STRUCTURE_VALUE = Struct.newBuilder() + private static final Struct PROTOBUF_STRUCTURE_VALUE = Struct.newBuilder() .putFields(INNER_STRUCT_KEY, com.google.protobuf.Value.newBuilder().setStringValue(INNER_STRUCT_VALUE) .build()) @@ -382,18 +371,18 @@ void context_is_parsed_and_passed_to_grpc_service() { return STRING_ATTR_VALUE.equals(valueMap.get(STRING_ATTR_KEY).getStringValue()) && INT_ATTR_VALUE == valueMap.get(INT_ATTR_KEY).getNumberValue() && DOUBLE_ATTR_VALUE == valueMap.get(DOUBLE_ATTR_KEY) - .getNumberValue() + .getNumberValue() && valueMap.get(BOOLEAN_ATTR_KEY).getBoolValue() && "MY_TARGETING_KEY".equals( - valueMap.get("targetingKey").getStringValue()) + valueMap.get("targetingKey").getStringValue()) && LIST_ATTR_VALUE.get(0).asInteger() == valueMap - .get(LIST_ATTR_KEY).getListValue() - .getValuesList().get(0).getNumberValue() + .get(LIST_ATTR_KEY).getListValue() + .getValuesList().get(0).getNumberValue() && STRUCT_ATTR_INNER_VALUE.equals( - valueMap.get(STRUCT_ATTR_KEY).getStructValue() - .getFieldsMap() - .get(STRUCT_ATTR_INNER_KEY) - .getStringValue()); + valueMap.get(STRUCT_ATTR_KEY).getStructValue() + .getFieldsMap() + .get(STRUCT_ATTR_INNER_KEY) + .getStringValue()); }))).thenReturn(booleanResponse); GrpcConnector grpc = mock(GrpcConnector.class); @@ -470,143 +459,9 @@ void reason_mapped_correctly_if_unknown() { FlagEvaluationDetails booleanDetails = api.getClient() .getBooleanDetails(FLAG_KEY, false, new MutableContext()); assertEquals(Reason.UNKNOWN.toString(), booleanDetails.getReason()); // reason should be converted to - // UNKNOWN + // UNKNOWN } - @Test - void invalidate_cache() { - ResolveBooleanResponse booleanResponse = ResolveBooleanResponse.newBuilder() - .setValue(true) - .setVariant(BOOL_VARIANT) - .setReason(STATIC_REASON) - .build(); - - ResolveStringResponse stringResponse = ResolveStringResponse.newBuilder() - .setValue(STRING_VALUE) - .setVariant(STRING_VARIANT) - .setReason(STATIC_REASON) - .build(); - - ResolveIntResponse intResponse = ResolveIntResponse.newBuilder() - .setValue(INT_VALUE) - .setVariant(INT_VARIANT) - .setReason(STATIC_REASON) - .build(); - - ResolveFloatResponse floatResponse = ResolveFloatResponse.newBuilder() - .setValue(DOUBLE_VALUE) - .setVariant(DOUBLE_VARIANT) - .setReason(STATIC_REASON) - .build(); - - ResolveObjectResponse objectResponse = ResolveObjectResponse.newBuilder() - .setValue(PROTOBUF_STRUCTURE_VALUE) - .setVariant(OBJECT_VARIANT) - .setReason(STATIC_REASON) - .build(); - - ServiceBlockingStub serviceBlockingStubMock = mock(ServiceBlockingStub.class); - ServiceStub serviceStubMock = mock(ServiceStub.class); - when(serviceStubMock.withWaitForReady()).thenReturn(serviceStubMock); - doNothing().when(serviceStubMock).eventStream(any(), any()); - when(serviceStubMock.withDeadline(any(Deadline.class))) - .thenReturn(serviceStubMock); - when(serviceBlockingStubMock.withWaitForReady()).thenReturn(serviceBlockingStubMock); - when(serviceBlockingStubMock - .withDeadline(any(Deadline.class))) - .thenReturn(serviceBlockingStubMock); - when(serviceBlockingStubMock.withDeadlineAfter(anyLong(), any(TimeUnit.class))) - .thenReturn(serviceBlockingStubMock); - when(serviceBlockingStubMock - .resolveBoolean(argThat(x -> FLAG_KEY_BOOLEAN.equals(x.getFlagKey())))) - .thenReturn(booleanResponse); - when(serviceBlockingStubMock - .resolveFloat(argThat(x -> FLAG_KEY_DOUBLE.equals(x.getFlagKey())))) - .thenReturn(floatResponse); - when(serviceBlockingStubMock - .resolveInt(argThat(x -> FLAG_KEY_INTEGER.equals(x.getFlagKey())))) - .thenReturn(intResponse); - when(serviceBlockingStubMock - .resolveString(argThat(x -> FLAG_KEY_STRING.equals(x.getFlagKey())))) - .thenReturn(stringResponse); - when(serviceBlockingStubMock - .resolveObject(argThat(x -> FLAG_KEY_OBJECT.equals(x.getFlagKey())))) - .thenReturn(objectResponse); - - GrpcConnector grpc; - try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { - mockStaticService.when(() -> ServiceGrpc.newBlockingStub(any(Channel.class))) - .thenReturn(serviceBlockingStubMock); - mockStaticService.when(() -> ServiceGrpc.newStub(any())) - .thenReturn(serviceStubMock); - - final Cache cache = new Cache("lru", 5); - - class NoopInitGrpcConnector extends GrpcConnector { - public NoopInitGrpcConnector(FlagdOptions options, Cache cache, - Supplier connectedSupplier, - Consumer onConnectionEvent) { - super(options, cache, connectedSupplier, onConnectionEvent); - } - - public void initialize() throws Exception { - }; - } - - grpc = new NoopInitGrpcConnector(FlagdOptions.builder().build(), cache, () -> true, - (connectionEvent) -> { - }); - } - - FlagdProvider provider = createProvider(grpc); - OpenFeatureAPI.getInstance().setProviderAndWait(provider); - - HashMap flagsMap = new HashMap(); - HashMap structMap = new HashMap(); - - flagsMap.put(FLAG_KEY_BOOLEAN, com.google.protobuf.Value.newBuilder().setStringValue("foo").build()); - flagsMap.put(FLAG_KEY_STRING, com.google.protobuf.Value.newBuilder().setStringValue("foo").build()); - flagsMap.put(FLAG_KEY_INTEGER, com.google.protobuf.Value.newBuilder().setStringValue("foo").build()); - flagsMap.put(FLAG_KEY_DOUBLE, com.google.protobuf.Value.newBuilder().setStringValue("foo").build()); - flagsMap.put(FLAG_KEY_OBJECT, com.google.protobuf.Value.newBuilder().setStringValue("foo").build()); - - structMap.put("flags", com.google.protobuf.Value.newBuilder() - .setStructValue(Struct.newBuilder().putAllFields(flagsMap)).build()); - - // should cache results - FlagEvaluationDetails booleanDetails; - FlagEvaluationDetails stringDetails; - FlagEvaluationDetails intDetails; - FlagEvaluationDetails floatDetails; - FlagEvaluationDetails objectDetails; - - // assert cache has been invalidated - booleanDetails = api.getClient().getBooleanDetails(FLAG_KEY_BOOLEAN, false); - assertTrue(booleanDetails.getValue()); - assertEquals(BOOL_VARIANT, booleanDetails.getVariant()); - assertEquals(STATIC_REASON, booleanDetails.getReason()); - - stringDetails = api.getClient().getStringDetails(FLAG_KEY_STRING, "wrong"); - assertEquals(STRING_VALUE, stringDetails.getValue()); - assertEquals(STRING_VARIANT, stringDetails.getVariant()); - assertEquals(STATIC_REASON, stringDetails.getReason()); - - intDetails = api.getClient().getIntegerDetails(FLAG_KEY_INTEGER, 0); - assertEquals(INT_VALUE, intDetails.getValue()); - assertEquals(INT_VARIANT, intDetails.getVariant()); - assertEquals(STATIC_REASON, intDetails.getReason()); - - floatDetails = api.getClient().getDoubleDetails(FLAG_KEY_DOUBLE, 0.1); - assertEquals(DOUBLE_VALUE, floatDetails.getValue()); - assertEquals(DOUBLE_VARIANT, floatDetails.getVariant()); - assertEquals(STATIC_REASON, floatDetails.getReason()); - - objectDetails = api.getClient().getObjectDetails(FLAG_KEY_OBJECT, new Value()); - assertEquals(INNER_STRUCT_VALUE, objectDetails.getValue().asStructure() - .asMap().get(INNER_STRUCT_KEY).asString()); - assertEquals(OBJECT_VARIANT, objectDetails.getVariant()); - assertEquals(STATIC_REASON, objectDetails.getReason()); - } private void do_resolvers_cache_responses(String reason, Boolean eventStreamAlive, Boolean shouldCache) { String expectedReason = CACHED_REASON; @@ -665,7 +520,9 @@ private void do_resolvers_cache_responses(String reason, Boolean eventStreamAliv GrpcConnector grpc = mock(GrpcConnector.class); when(grpc.getResolver()).thenReturn(serviceBlockingStubMock); - FlagdProvider provider = createProvider(grpc, () -> eventStreamAlive); + when(grpc.isConnected()).thenReturn(eventStreamAlive); + FlagdProvider provider = createProvider(grpc); + // provider.setState(eventStreamAlive); // caching only available when event // stream is alive OpenFeatureAPI.getInstance().setProviderAndWait(provider); @@ -674,7 +531,7 @@ private void do_resolvers_cache_responses(String reason, Boolean eventStreamAliv false); booleanDetails = api.getClient() .getBooleanDetails(FLAG_KEY_BOOLEAN, false); // should retrieve from cache on second - // invocation + // invocation assertTrue(booleanDetails.getValue()); assertEquals(BOOL_VARIANT, booleanDetails.getVariant()); assertEquals(expectedReason, booleanDetails.getReason()); @@ -707,147 +564,6 @@ private void do_resolvers_cache_responses(String reason, Boolean eventStreamAliv assertEquals(expectedReason, objectDetails.getReason()); } - @Test - void disabled_cache() { - ResolveBooleanResponse booleanResponse = ResolveBooleanResponse.newBuilder() - .setValue(true) - .setVariant(BOOL_VARIANT) - .setReason(STATIC_REASON) - .build(); - - ResolveStringResponse stringResponse = ResolveStringResponse.newBuilder() - .setValue(STRING_VALUE) - .setVariant(STRING_VARIANT) - .setReason(STATIC_REASON) - .build(); - - ResolveIntResponse intResponse = ResolveIntResponse.newBuilder() - .setValue(INT_VALUE) - .setVariant(INT_VARIANT) - .setReason(STATIC_REASON) - .build(); - - ResolveFloatResponse floatResponse = ResolveFloatResponse.newBuilder() - .setValue(DOUBLE_VALUE) - .setVariant(DOUBLE_VARIANT) - .setReason(STATIC_REASON) - .build(); - - ResolveObjectResponse objectResponse = ResolveObjectResponse.newBuilder() - .setValue(PROTOBUF_STRUCTURE_VALUE) - .setVariant(OBJECT_VARIANT) - .setReason(STATIC_REASON) - .build(); - - ServiceBlockingStub serviceBlockingStubMock = mock(ServiceBlockingStub.class); - ServiceStub serviceStubMock = mock(ServiceStub.class); - when(serviceStubMock.withWaitForReady()).thenReturn(serviceStubMock); - when(serviceStubMock.withDeadline(any(Deadline.class))) - .thenReturn(serviceStubMock); - when(serviceBlockingStubMock.withWaitForReady()).thenReturn(serviceBlockingStubMock); - when(serviceBlockingStubMock.withDeadline(any(Deadline.class))) - .thenReturn(serviceBlockingStubMock); - when(serviceBlockingStubMock.withDeadlineAfter(anyLong(), any(TimeUnit.class))) - .thenReturn(serviceBlockingStubMock); - when(serviceBlockingStubMock - .resolveBoolean(argThat(x -> FLAG_KEY_BOOLEAN.equals(x.getFlagKey())))) - .thenReturn(booleanResponse); - when(serviceBlockingStubMock - .resolveFloat(argThat(x -> FLAG_KEY_DOUBLE.equals(x.getFlagKey())))) - .thenReturn(floatResponse); - when(serviceBlockingStubMock - .resolveInt(argThat(x -> FLAG_KEY_INTEGER.equals(x.getFlagKey())))) - .thenReturn(intResponse); - when(serviceBlockingStubMock - .resolveString(argThat(x -> FLAG_KEY_STRING.equals(x.getFlagKey())))) - .thenReturn(stringResponse); - when(serviceBlockingStubMock - .resolveObject(argThat(x -> FLAG_KEY_OBJECT.equals(x.getFlagKey())))) - .thenReturn(objectResponse); - - // disabled cache - final Cache cache = new Cache("disabled", 0); - - GrpcConnector grpc; - try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { - mockStaticService.when(() -> ServiceGrpc.newBlockingStub(any(Channel.class))) - .thenReturn(serviceBlockingStubMock); - mockStaticService.when(() -> ServiceGrpc.newStub(any())) - .thenReturn(serviceStubMock); - - class NoopInitGrpcConnector extends GrpcConnector { - public NoopInitGrpcConnector(FlagdOptions options, Cache cache, - Supplier connectedSupplier, - Consumer onConnectionEvent) { - super(options, cache, connectedSupplier, onConnectionEvent); - } - - public void initialize() throws Exception { - }; - } - - grpc = new NoopInitGrpcConnector(FlagdOptions.builder().build(), cache, () -> true, - (connectionEvent) -> { - }); - } - - FlagdProvider provider = createProvider(grpc, cache, () -> true); - - try { - provider.initialize(null); - } catch (Exception e) { - // ignore exception if any - } - - OpenFeatureAPI.getInstance().setProviderAndWait(provider); - - HashMap flagsMap = new HashMap<>(); - HashMap structMap = new HashMap<>(); - - flagsMap.put("foo", com.google.protobuf.Value.newBuilder().setStringValue("foo") - .build()); // assert that a configuration_change event works - - structMap.put("flags", com.google.protobuf.Value.newBuilder() - .setStructValue(Struct.newBuilder().putAllFields(flagsMap)).build()); - - // should not cache results - FlagEvaluationDetails booleanDetails = api.getClient().getBooleanDetails(FLAG_KEY_BOOLEAN, - false); - FlagEvaluationDetails stringDetails = api.getClient().getStringDetails(FLAG_KEY_STRING, - "wrong"); - FlagEvaluationDetails intDetails = api.getClient().getIntegerDetails(FLAG_KEY_INTEGER, 0); - FlagEvaluationDetails floatDetails = api.getClient().getDoubleDetails(FLAG_KEY_DOUBLE, 0.1); - FlagEvaluationDetails objectDetails = api.getClient().getObjectDetails(FLAG_KEY_OBJECT, - new Value()); - - // assert values are not cached - booleanDetails = api.getClient().getBooleanDetails(FLAG_KEY_BOOLEAN, false); - assertTrue(booleanDetails.getValue()); - assertEquals(BOOL_VARIANT, booleanDetails.getVariant()); - assertEquals(STATIC_REASON, booleanDetails.getReason()); - - stringDetails = api.getClient().getStringDetails(FLAG_KEY_STRING, "wrong"); - assertEquals(STRING_VALUE, stringDetails.getValue()); - assertEquals(STRING_VARIANT, stringDetails.getVariant()); - assertEquals(STATIC_REASON, stringDetails.getReason()); - - intDetails = api.getClient().getIntegerDetails(FLAG_KEY_INTEGER, 0); - assertEquals(INT_VALUE, intDetails.getValue()); - assertEquals(INT_VARIANT, intDetails.getVariant()); - assertEquals(STATIC_REASON, intDetails.getReason()); - - floatDetails = api.getClient().getDoubleDetails(FLAG_KEY_DOUBLE, 0.1); - assertEquals(DOUBLE_VALUE, floatDetails.getValue()); - assertEquals(DOUBLE_VARIANT, floatDetails.getVariant()); - assertEquals(STATIC_REASON, floatDetails.getReason()); - - objectDetails = api.getClient().getObjectDetails(FLAG_KEY_OBJECT, new Value()); - assertEquals(INNER_STRUCT_VALUE, objectDetails.getValue().asStructure() - .asMap().get(INNER_STRUCT_KEY).asString()); - assertEquals(OBJECT_VARIANT, objectDetails.getVariant()); - assertEquals(STATIC_REASON, objectDetails.getReason()); - } - @Test void initializationAndShutdown() throws Exception { // given @@ -899,7 +615,7 @@ void contextEnrichment() throws Exception { // callback doAnswer(invocation -> { onConnectionEvent.accept( - new ConnectionEvent(true, metadata)); + new ConnectionEvent(ConnectionState.CONNECTED, metadata)); return null; }).when(mock).init(); })) { @@ -935,7 +651,7 @@ void updatesSyncMetadataWithCallback() throws Exception { // callback doAnswer(invocation -> { onConnectionEvent.accept( - new ConnectionEvent(true, metadata)); + new ConnectionEvent(ConnectionState.CONNECTED, metadata)); return null; }).when(mock).init(); })) { @@ -957,23 +673,17 @@ void updatesSyncMetadataWithCallback() throws Exception { } // test helper - - // create provider with given grpc connector - private FlagdProvider createProvider(GrpcConnector grpc) { - return createProvider(grpc, () -> true); - } - // create provider with given grpc provider and state supplier - private FlagdProvider createProvider(GrpcConnector grpc, Supplier getConnected) { + private FlagdProvider createProvider(GrpcConnector grpc) { final Cache cache = new Cache("lru", 5); - return createProvider(grpc, cache, getConnected); + return createProvider(grpc, cache); } // create provider with given grpc provider, cache and state supplier - private FlagdProvider createProvider(GrpcConnector grpc, Cache cache, Supplier getConnected) { + private FlagdProvider createProvider(GrpcConnector grpc, Cache cache) { final FlagdOptions flagdOptions = FlagdOptions.builder().build(); - final GrpcResolver grpcResolver = new GrpcResolver(flagdOptions, cache, getConnected, + final GrpcResolver grpcResolver = new GrpcResolver(flagdOptions, cache, (connectionEvent) -> { }); diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/RunFlagdRpcReconnectCucumberTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/RunFlagdRpcReconnectCucumberTest.java index fa226c1a69..966d701e50 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/RunFlagdRpcReconnectCucumberTest.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/RunFlagdRpcReconnectCucumberTest.java @@ -7,8 +7,8 @@ import org.junit.platform.suite.api.Suite; import org.testcontainers.junit.jupiter.Testcontainers; -import static io.cucumber.junit.platform.engine.Constants.PLUGIN_PROPERTY_NAME; import static io.cucumber.junit.platform.engine.Constants.GLUE_PROPERTY_NAME; +import static io.cucumber.junit.platform.engine.Constants.PLUGIN_PROPERTY_NAME; /** * Class for running the reconnection tests for the RPC provider @@ -17,6 +17,7 @@ @Suite @IncludeEngines("cucumber") @SelectClasspathResource("features/flagd-reconnect.feature") +@SelectClasspathResource("features/events.feature") @ConfigurationParameter(key = PLUGIN_PROPERTY_NAME, value = "pretty") @ConfigurationParameter(key = GLUE_PROPERTY_NAME, value = "dev.openfeature.contrib.providers.flagd.e2e.reconnect.rpc,dev.openfeature.contrib.providers.flagd.e2e.reconnect.steps") @Testcontainers diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/reconnect/rpc/FlagdRpcSetup.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/reconnect/rpc/FlagdRpcSetup.java index 6601a5dd3c..ea0a197401 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/reconnect/rpc/FlagdRpcSetup.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/reconnect/rpc/FlagdRpcSetup.java @@ -31,6 +31,7 @@ public static void setupTest() throws InterruptedException { .resolverType(Config.Resolver.RPC) .port(flagdContainer.getFirstMappedPort()) .deadline(1000) + .streamRetryGracePeriod(1) .streamDeadlineMs(0) // this makes reconnect tests more predictable .cacheType(CacheType.DISABLED.getValue()) .build()); diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelBuilderTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelBuilderTest.java new file mode 100644 index 0000000000..7740ca6b66 --- /dev/null +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelBuilderTest.java @@ -0,0 +1,159 @@ +package dev.openfeature.contrib.providers.flagd.resolver.common; + +import dev.openfeature.contrib.providers.flagd.FlagdOptions; +import io.grpc.ManagedChannel; +import io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.NettyChannelBuilder; +import io.netty.channel.epoll.Epoll; +import io.netty.channel.epoll.EpollDomainSocketChannel; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.unix.DomainSocketAddress; +import io.netty.handler.ssl.SslContextBuilder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledOnOs; +import org.junit.jupiter.api.condition.OS; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.MockedStatic; + +import javax.net.ssl.SSLKeyException; +import java.io.File; +import java.util.concurrent.TimeUnit; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class ChannelBuilderTest { + + @Test + @EnabledOnOs(OS.LINUX) + void testNettyChannel_withSocketPath() { + try (MockedStatic epollMock = mockStatic(Epoll.class); + MockedStatic nettyMock = mockStatic(NettyChannelBuilder.class)) { + + // Mocks + epollMock.when(Epoll::isAvailable).thenReturn(true); + NettyChannelBuilder mockBuilder = mock(NettyChannelBuilder.class); + ManagedChannel mockChannel = mock(ManagedChannel.class); + + nettyMock.when(() -> NettyChannelBuilder.forAddress(any(DomainSocketAddress.class))) + .thenReturn(mockBuilder); + + when(mockBuilder.keepAliveTime(anyLong(), any(TimeUnit.class))).thenReturn(mockBuilder); + when(mockBuilder.eventLoopGroup(any(EpollEventLoopGroup.class))).thenReturn(mockBuilder); + when(mockBuilder.channelType(EpollDomainSocketChannel.class)).thenReturn(mockBuilder); + when(mockBuilder.usePlaintext()).thenReturn(mockBuilder); + when(mockBuilder.build()).thenReturn(mockChannel); + + // Input options + FlagdOptions options = FlagdOptions.builder() + .socketPath("/path/to/socket") + .keepAlive(1000) + .build(); + + // Call method under test + ManagedChannel channel = ChannelBuilder.nettyChannel(options); + + // Assertions + assertThat(channel).isEqualTo(mockChannel); + nettyMock.verify(() -> NettyChannelBuilder.forAddress(new DomainSocketAddress("/path/to/socket"))); + verify(mockBuilder).keepAliveTime(1000, TimeUnit.MILLISECONDS); + verify(mockBuilder).eventLoopGroup(any(EpollEventLoopGroup.class)); + verify(mockBuilder).channelType(EpollDomainSocketChannel.class); + verify(mockBuilder).usePlaintext(); + verify(mockBuilder).build(); + } + } + + @Test + void testNettyChannel_withTlsAndCert() { + try (MockedStatic nettyMock = mockStatic(NettyChannelBuilder.class)) { + // Mocks + NettyChannelBuilder mockBuilder = mock(NettyChannelBuilder.class); + ManagedChannel mockChannel = mock(ManagedChannel.class); + nettyMock.when(() -> NettyChannelBuilder.forTarget("localhost:8080")) + .thenReturn(mockBuilder); + + when(mockBuilder.keepAliveTime(anyLong(), any(TimeUnit.class))).thenReturn(mockBuilder); + when(mockBuilder.sslContext(any())).thenReturn(mockBuilder); + when(mockBuilder.build()).thenReturn(mockChannel); + + File mockCert = mock(File.class); + when(mockCert.exists()).thenReturn(true); + String path = "test-harness/ssl/custom-root-cert.crt"; + + File file = new File(path); + String absolutePath = file.getAbsolutePath(); + // Input options + FlagdOptions options = FlagdOptions.builder() + .host("localhost") + .port(8080) + .keepAlive(5000) + .tls(true) + .certPath(absolutePath) + .build(); + + // Call method under test + ManagedChannel channel = ChannelBuilder.nettyChannel(options); + + // Assertions + assertThat(channel).isEqualTo(mockChannel); + nettyMock.verify(() -> NettyChannelBuilder.forTarget("localhost:8080")); + verify(mockBuilder).keepAliveTime(5000, TimeUnit.MILLISECONDS); + verify(mockBuilder).sslContext(any()); + verify(mockBuilder).build(); + } + } + + @ParameterizedTest + @ValueSource(strings = {"/incorrect/{uri}/;)"}) + void testNettyChannel_withInvalidTargetUri(String uri) { + FlagdOptions options = FlagdOptions.builder() + .targetUri(uri) + .build(); + + assertThatThrownBy(() -> ChannelBuilder.nettyChannel(options)) + .isInstanceOf(GenericConfigException.class) + .hasMessageContaining("Error with gRPC target string configuration"); + } + + @Test + void testNettyChannel_epollNotAvailable() { + try (MockedStatic epollMock = mockStatic(Epoll.class)) { + epollMock.when(Epoll::isAvailable).thenReturn(false); + + FlagdOptions options = FlagdOptions.builder().socketPath("/path/to/socket").build(); + + assertThatThrownBy(() -> ChannelBuilder.nettyChannel(options)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("unix socket cannot be used"); + } + } + + @Test + void testNettyChannel_sslException() throws Exception { + try (MockedStatic nettyMock = mockStatic(NettyChannelBuilder.class)) { + NettyChannelBuilder mockBuilder = mock(NettyChannelBuilder.class); + nettyMock.when(() -> NettyChannelBuilder.forTarget(anyString())).thenReturn(mockBuilder); + try (MockedStatic sslmock = mockStatic(GrpcSslContexts.class)) { + SslContextBuilder sslMockBuilder = mock(SslContextBuilder.class); + sslmock.when(GrpcSslContexts::forClient).thenReturn(sslMockBuilder); + when(sslMockBuilder.build()).thenThrow(new SSLKeyException("Test SSL error")); + when(mockBuilder.keepAliveTime(anyLong(), any(TimeUnit.class))).thenReturn(mockBuilder); + + FlagdOptions options = FlagdOptions.builder().tls(true).build(); + + assertThatThrownBy(() -> ChannelBuilder.nettyChannel(options)) + .isInstanceOf(SslConfigException.class) + .hasMessageContaining("Error with SSL configuration"); + } + } + } +} diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/EventStreamObserverTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/EventStreamObserverTest.java index 2f42d4fd38..02c480ccc9 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/EventStreamObserverTest.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/EventStreamObserverTest.java @@ -1,34 +1,27 @@ package dev.openfeature.contrib.providers.flagd.resolver.grpc; -import static org.junit.Assert.assertFalse; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import dev.openfeature.contrib.providers.flagd.resolver.grpc.cache.Cache; +import dev.openfeature.flagd.grpc.evaluation.Evaluation.EventStreamResponse; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; + import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.atMost; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.function.Supplier; - -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Nested; -import org.junit.jupiter.api.Test; - -import com.google.protobuf.Struct; -import com.google.protobuf.Value; - -import dev.openfeature.contrib.providers.flagd.resolver.grpc.cache.Cache; -import dev.openfeature.flagd.grpc.evaluation.Evaluation.EventStreamResponse; -import io.grpc.Status; -import io.grpc.StatusRuntimeException; - class EventStreamObserverTest { @Nested @@ -39,18 +32,14 @@ class StateChange { EventStreamObserver stream; Runnable reconnect; Object sync; - Supplier shouldRetrySilently; @BeforeEach void setUp() { states = new ArrayList<>(); - sync = new Object(); cache = mock(Cache.class); reconnect = mock(Runnable.class); when(cache.getEnabled()).thenReturn(true); - shouldRetrySilently = mock(Supplier.class); - when(shouldRetrySilently.get()).thenReturn(true, false); // 1st time we should retry silently, subsequent calls should not - stream = new EventStreamObserver(sync, cache, (state, changed) -> states.add(state), shouldRetrySilently); + stream = new EventStreamObserver(cache, (state, changed) -> states.add(state)); } @Test @@ -80,35 +69,6 @@ public void ready() { verify(cache, atLeast(1)).clear(); } - @Test - public void noReconnectionOnFirstError() { - stream.onError(new Throwable("error")); - // we flush the cache - verify(cache, never()).clear(); - // we notify the error - assertEquals(0, states.size()); - } - - @Test - public void reconnections() { - stream.onError(new Throwable("error 1")); - stream.onError(new Throwable("error 2")); - // we flush the cache - verify(cache, atLeast(1)).clear(); - // we notify the error - assertEquals(1, states.size()); - assertFalse(states.get(0)); - } - - @Test - public void deadlineExceeded() { - stream.onError(new StatusRuntimeException(Status.DEADLINE_EXCEEDED)); - // we flush the cache - verify(cache, never()).clear(); - // we notify the error - assertEquals(0, states.size()); - } - @Test public void cacheBustingForKnownKeys() { final String key1 = "myKey1"; diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcConnectorTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcConnectorTest.java index 7e552d05dd..894557bccf 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcConnectorTest.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcConnectorTest.java @@ -1,496 +1,142 @@ package dev.openfeature.contrib.providers.flagd.resolver.grpc; -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; -import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.argThat; -import static org.mockito.Mockito.*; - -import java.lang.reflect.Field; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Consumer; - -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledOnOs; -import org.junit.jupiter.api.condition.OS; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.mockito.MockedConstruction; -import org.mockito.MockedStatic; -import org.mockito.invocation.InvocationOnMock; +import com.google.common.collect.Lists; import dev.openfeature.contrib.providers.flagd.FlagdOptions; import dev.openfeature.contrib.providers.flagd.resolver.common.ConnectionEvent; -import dev.openfeature.contrib.providers.flagd.resolver.grpc.cache.Cache; -import dev.openfeature.flagd.grpc.evaluation.Evaluation.EventStreamResponse; +import dev.openfeature.flagd.grpc.evaluation.Evaluation; import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc; -import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc.ServiceBlockingStub; -import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc.ServiceStub; -import io.grpc.Channel; -import io.grpc.Status; -import io.grpc.StatusRuntimeException; -import io.grpc.netty.NettyChannelBuilder; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.epoll.EpollEventLoopGroup; -import io.netty.channel.unix.DomainSocketAddress; -import uk.org.webcompere.systemstubs.environment.EnvironmentVariables; - -class GrpcConnectorTest { - - @ParameterizedTest - @ValueSource(ints = { 1, 2, 3 }) - void validate_retry_calls(int retries) throws Exception { - final int backoffMs = 100; - - final FlagdOptions options = FlagdOptions.builder() - // shorter backoff for testing - .retryBackoffMs(backoffMs) - .maxEventStreamRetries(retries) - .build(); - - final Cache cache = new Cache("disabled", 0); - - final ServiceGrpc.ServiceStub mockStub = createServiceStubMock(); - doAnswer(invocation -> null).when(mockStub).eventStream(any(), any()); - - final GrpcConnector connector = new GrpcConnector(options, cache, () -> true, - (connectionEvent) -> { - }); - - Field serviceStubField = GrpcConnector.class.getDeclaredField("serviceStub"); - serviceStubField.setAccessible(true); - serviceStubField.set(connector, mockStub); - - final Object syncObject = new Object(); - - Field syncField = GrpcConnector.class.getDeclaredField("sync"); - syncField.setAccessible(true); - syncField.set(connector, syncObject); - - assertDoesNotThrow(connector::initialize); - - for (int i = 1; i < retries; i++) { - // verify invocation with enough timeout value - verify(mockStub, timeout(2L * i * backoffMs).times(i)).eventStream(any(), any()); - - synchronized (syncObject) { - syncObject.notify(); - } - } - } - - @Test - void initialization_succeed_with_connected_status() { - final Cache cache = new Cache("disabled", 0); - final ServiceGrpc.ServiceStub mockStub = createServiceStubMock(); - Consumer onConnectionEvent = mock(Consumer.class); - doAnswer((InvocationOnMock invocation) -> { - EventStreamObserver eventStreamObserver = (EventStreamObserver) invocation.getArgument(1); - eventStreamObserver - .onNext(EventStreamResponse.newBuilder().setType(Constants.PROVIDER_READY).build()); - return null; - }).when(mockStub).eventStream(any(), any()); - - try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { - mockStaticService.when(() -> ServiceGrpc.newStub(any())) - .thenReturn(mockStub); - - // pass true in connected lambda - final GrpcConnector connector = new GrpcConnector(FlagdOptions.builder().build(), cache, () -> { - try { - Thread.sleep(100); - return true; - } catch (Exception e) { - } - return false; - - }, - onConnectionEvent); - - assertDoesNotThrow(connector::initialize); - // assert that onConnectionEvent is connected - verify(onConnectionEvent).accept(argThat(arg -> arg.isConnected())); - } - } - - @Test - void stream_does_not_fail_on_first_error() { - final Cache cache = new Cache("disabled", 0); - final ServiceStub mockStub = createServiceStubMock(); - Consumer onConnectionEvent = mock(Consumer.class); - doAnswer((InvocationOnMock invocation) -> { - EventStreamObserver eventStreamObserver = (EventStreamObserver) invocation.getArgument(1); - eventStreamObserver - .onError(new Exception("fake")); - return null; - }).when(mockStub).eventStream(any(), any()); - - try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { - mockStaticService.when(() -> ServiceGrpc.newStub(any())) - .thenReturn(mockStub); - - // pass true in connected lambda - final GrpcConnector connector = new GrpcConnector(FlagdOptions.builder().build(), cache, - () -> { - try { - Thread.sleep(100); - return true; - } catch (Exception e) { - } - return false; - - }, - onConnectionEvent); - - assertDoesNotThrow(connector::initialize); - // assert that onConnectionEvent is connected gets not called - verify(onConnectionEvent, timeout(300).times(0)).accept(any()); - } - } - - @Test - void stream_fails_on_second_error_in_a_row() throws Exception { - final FlagdOptions options = FlagdOptions.builder() - // shorter backoff for testing - .retryBackoffMs(0) - .build(); - - final Cache cache = new Cache("disabled", 0); - Consumer onConnectionEvent = mock(Consumer.class); - - final ServiceGrpc.ServiceStub mockStub = createServiceStubMock(); - doAnswer((InvocationOnMock invocation) -> { - EventStreamObserver eventStreamObserver = (EventStreamObserver) invocation.getArgument(1); - eventStreamObserver - .onError(new Exception("fake")); - return null; - }).when(mockStub).eventStream(any(), any()); - - final GrpcConnector connector = new GrpcConnector(options, cache, () -> true, onConnectionEvent); +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Server; +import io.grpc.netty.NettyServerBuilder; +import io.grpc.stub.StreamObserver; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; - Field serviceStubField = GrpcConnector.class.getDeclaredField("serviceStub"); - serviceStubField.setAccessible(true); - serviceStubField.set(connector, mockStub); +import java.io.IOException; +import java.util.ArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; - final Object syncObject = new Object(); +class GrpcConnectorTest { - Field syncField = GrpcConnector.class.getDeclaredField("sync"); - syncField.setAccessible(true); - syncField.set(connector, syncObject); + private ManagedChannel testChannel; + private Server testServer; + private static final boolean CONNECTED = true; + private static final boolean DISCONNECTED = false; - assertDoesNotThrow(connector::initialize); + @Mock + private EventStreamObserver mockEventStreamObserver; - // 1st try - verify(mockStub, timeout(300).times(1)).eventStream(any(), any()); - verify(onConnectionEvent, timeout(300).times(0)).accept(any()); - synchronized (syncObject) { - syncObject.notify(); + private final ServiceGrpc.ServiceImplBase testServiceImpl = new ServiceGrpc.ServiceImplBase() { + @Override + public void eventStream(Evaluation.EventStreamRequest request, + StreamObserver responseObserver) { + // noop } + }; - // 2nd try - verify(mockStub, timeout(300).times(2)).eventStream(any(), any()); - verify(onConnectionEvent, timeout(300).times(1)).accept(argThat(arg -> !arg.isConnected())); + @BeforeEach + void setUp() throws Exception { + MockitoAnnotations.openMocks(this); + setupTestGrpcServer(); } - @Test - void stream_does_not_fail_when_message_between_errors() throws Exception { - final FlagdOptions options = FlagdOptions.builder() - // shorter backoff for testing - .retryBackoffMs(0) + private void setupTestGrpcServer() throws IOException { + testServer = NettyServerBuilder.forPort(8080) + .addService(testServiceImpl) .build(); + testServer.start(); - final Cache cache = new Cache("disabled", 0); - Consumer onConnectionEvent = mock(Consumer.class); - - final AtomicBoolean successMessage = new AtomicBoolean(false); - final ServiceGrpc.ServiceStub mockStub = createServiceStubMock(); - doAnswer((InvocationOnMock invocation) -> { - EventStreamObserver eventStreamObserver = (EventStreamObserver) invocation.getArgument(1); - - if (successMessage.get()) { - eventStreamObserver - .onNext(EventStreamResponse.newBuilder().setType(Constants.PROVIDER_READY).build()); - } else { - eventStreamObserver - .onError(new Exception("fake")); - } - return null; - }).when(mockStub).eventStream(any(), any()); - - final GrpcConnector connector = new GrpcConnector(options, cache, () -> true, onConnectionEvent); - - Field serviceStubField = GrpcConnector.class.getDeclaredField("serviceStub"); - serviceStubField.setAccessible(true); - serviceStubField.set(connector, mockStub); - - final Object syncObject = new Object(); - - Field syncField = GrpcConnector.class.getDeclaredField("sync"); - syncField.setAccessible(true); - syncField.set(connector, syncObject); - - assertDoesNotThrow(connector::initialize); - - // 1st message with error - verify(mockStub, timeout(300).times(1)).eventStream(any(), any()); - verify(onConnectionEvent, timeout(300).times(0)).accept(any()); - - synchronized (syncObject) { - successMessage.set(true); - syncObject.notify(); - } - - // 2nd message with provider ready - verify(mockStub, timeout(300).times(2)).eventStream(any(), any()); - verify(onConnectionEvent, timeout(300).times(1)).accept(argThat(arg -> arg.isConnected())); - synchronized (syncObject) { - successMessage.set(false); - syncObject.notify(); - } - - - // 3nd message with error - verify(mockStub, timeout(300).times(2)).eventStream(any(), any()); - verify(onConnectionEvent, timeout(300).times(0)).accept(argThat(arg -> !arg.isConnected())); - } - - @Test - void stream_does_not_fail_with_deadline_error() throws Exception { - final Cache cache = new Cache("disabled", 0); - final ServiceStub mockStub = createServiceStubMock(); - Consumer onConnectionEvent = mock(Consumer.class); - doAnswer((InvocationOnMock invocation) -> { - EventStreamObserver eventStreamObserver = (EventStreamObserver) invocation.getArgument(1); - eventStreamObserver - .onError(new StatusRuntimeException(Status.DEADLINE_EXCEEDED)); - return null; - }).when(mockStub).eventStream(any(), any()); - - try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { - mockStaticService.when(() -> ServiceGrpc.newStub(any())) - .thenReturn(mockStub); - // pass true in connected lambda - final GrpcConnector connector = new GrpcConnector(FlagdOptions.builder().build(), cache, () -> { - try { - Thread.sleep(100); - return true; - } catch (Exception e) { - } - return false; - - }, - onConnectionEvent); - - assertDoesNotThrow(connector::initialize); - // this should not call the connection event - verify(onConnectionEvent, never()).accept(any()); + if (testChannel == null) { + testChannel = ManagedChannelBuilder.forAddress("localhost", 8080) + .usePlaintext() + .build(); } - } - - @Test - void host_and_port_arg_should_build_tcp_socket() { - final String host = "host.com"; - final int port = 1234; - final String targetUri = String.format("%s:%s", host, port); - - ServiceGrpc.ServiceBlockingStub mockBlockingStub = mock(ServiceGrpc.ServiceBlockingStub.class); - ServiceGrpc.ServiceStub mockStub = createServiceStubMock(); - NettyChannelBuilder mockChannelBuilder = getMockChannelBuilderSocket(); - - try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { - mockStaticService.when(() -> ServiceGrpc.newBlockingStub(any(Channel.class))) - .thenReturn(mockBlockingStub); - mockStaticService.when(() -> ServiceGrpc.newStub(any())) - .thenReturn(mockStub); - try (MockedStatic mockStaticChannelBuilder = mockStatic(NettyChannelBuilder.class)) { - mockStaticChannelBuilder.when(() -> NettyChannelBuilder - .forTarget(anyString())).thenReturn(mockChannelBuilder); - - final FlagdOptions flagdOptions = FlagdOptions.builder().host(host).port(port).tls(false).build(); - new GrpcConnector(flagdOptions, null, null, null); - - // verify host/port matches - mockStaticChannelBuilder.verify(() -> NettyChannelBuilder - .forTarget(String.format(targetUri)), times(1)); - } - } } - @Test - void no_args_host_and_port_env_set_should_build_tcp_socket() throws Exception { - final String host = "server.com"; - final int port = 4321; - final String targetUri = String.format("%s:%s", host, port); - - new EnvironmentVariables("FLAGD_HOST", host, "FLAGD_PORT", String.valueOf(port)).execute(() -> { - ServiceGrpc.ServiceBlockingStub mockBlockingStub = mock(ServiceGrpc.ServiceBlockingStub.class); - ServiceGrpc.ServiceStub mockStub = createServiceStubMock(); - NettyChannelBuilder mockChannelBuilder = getMockChannelBuilderSocket(); - - try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { - mockStaticService.when(() -> ServiceGrpc.newBlockingStub(any(Channel.class))) - .thenReturn(mockBlockingStub); - mockStaticService.when(() -> ServiceGrpc.newStub(any())) - .thenReturn(mockStub); - - try (MockedStatic mockStaticChannelBuilder = mockStatic( - NettyChannelBuilder.class)) { - - mockStaticChannelBuilder.when(() -> NettyChannelBuilder - .forTarget(anyString())).thenReturn(mockChannelBuilder); - - new GrpcConnector(FlagdOptions.builder().build(), null, null, null); - - // verify host/port matches & called times(= 1 as we rely on reusable channel) - mockStaticChannelBuilder.verify(() -> NettyChannelBuilder.forTarget(targetUri), times(1)); - } - } - }); + @AfterEach + void tearDown() throws Exception { + tearDownGrpcServer(); } - /** - * OS Specific test - This test is valid only on Linux system as it rely on - * epoll availability - */ - @Test - @EnabledOnOs(OS.LINUX) - void path_arg_should_build_domain_socket_with_correct_path() { - final String path = "/some/path"; - - ServiceGrpc.ServiceBlockingStub mockBlockingStub = mock(ServiceGrpc.ServiceBlockingStub.class); - ServiceGrpc.ServiceStub mockStub = createServiceStubMock(); - NettyChannelBuilder mockChannelBuilder = getMockChannelBuilderSocket(); - - try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { - mockStaticService.when(() -> ServiceGrpc.newBlockingStub(any(Channel.class))) - .thenReturn(mockBlockingStub); - mockStaticService.when(() -> ServiceGrpc.newStub(any())) - .thenReturn(mockStub); - - try (MockedStatic mockStaticChannelBuilder = mockStatic(NettyChannelBuilder.class)) { - - try (MockedConstruction mockEpollEventLoopGroup = mockConstruction( - EpollEventLoopGroup.class, - (mock, context) -> { - })) { - when(NettyChannelBuilder.forAddress(any(DomainSocketAddress.class))).thenReturn(mockChannelBuilder); - - new GrpcConnector(FlagdOptions.builder().socketPath(path).build(), null, null, null); - - // verify path matches - mockStaticChannelBuilder.verify(() -> NettyChannelBuilder - .forAddress(argThat((DomainSocketAddress d) -> { - assertEquals(d.path(), path); // path should match - return true; - })), times(1)); - } - } + private void tearDownGrpcServer() throws InterruptedException { + if (testServer != null) { + testServer.shutdownNow(); + testServer.awaitTermination(); } } - /** - * OS Specific test - This test is valid only on Linux system as it rely on - * epoll availability - */ - @Test - @EnabledOnOs(OS.LINUX) - void no_args_socket_env_should_build_domain_socket_with_correct_path() throws Exception { - final String path = "/some/other/path"; - - new EnvironmentVariables("FLAGD_SOCKET_PATH", path).execute(() -> { - - ServiceBlockingStub mockBlockingStub = mock(ServiceBlockingStub.class); - ServiceStub mockStub = mock(ServiceStub.class); - NettyChannelBuilder mockChannelBuilder = getMockChannelBuilderSocket(); - - try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { - mockStaticService.when(() -> ServiceGrpc.newBlockingStub(any(Channel.class))) - .thenReturn(mockBlockingStub); - mockStaticService.when(() -> ServiceGrpc.newStub(any())) - .thenReturn(mockStub); - - try (MockedStatic mockStaticChannelBuilder = mockStatic( - NettyChannelBuilder.class)) { - - try (MockedConstruction mockEpollEventLoopGroup = mockConstruction( - EpollEventLoopGroup.class, - (mock, context) -> { - })) { - mockStaticChannelBuilder.when(() -> NettyChannelBuilder - .forAddress(any(DomainSocketAddress.class))).thenReturn(mockChannelBuilder); - - new GrpcConnector(FlagdOptions.builder().build(), null, null, null); - - // verify path matches & called times(= 1 as we rely on reusable channel) - mockStaticChannelBuilder.verify(() -> NettyChannelBuilder - .forAddress(argThat((DomainSocketAddress d) -> { - return d.path() == path; - })), times(1)); - } - } - } - }); - } @Test - void initialization_with_stream_deadline() throws NoSuchFieldException, IllegalAccessException { - final FlagdOptions options = FlagdOptions.builder() - .streamDeadlineMs(16983) - .build(); - - final Cache cache = new Cache("disabled", 0); - final ServiceGrpc.ServiceStub mockStub = createServiceStubMock(); - - try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { - mockStaticService.when(() -> ServiceGrpc.newStub(any())).thenReturn(mockStub); - - final GrpcConnector connector = new GrpcConnector(options, cache, () -> true, null); - - assertDoesNotThrow(connector::initialize); - verify(mockStub).withDeadlineAfter(16983, TimeUnit.MILLISECONDS); - } + void whenShuttingDownAndRestartingGrpcServer_ConsumerReceivesDisconnectedAndConnectedEvent() throws Exception { + CountDownLatch sync = new CountDownLatch(2); + ArrayList connectionStateChanges = Lists.newArrayList(); + Consumer testConsumer = event -> { + connectionStateChanges.add(event.isConnected()); + sync.countDown(); + }; + + GrpcConnector instance = + new GrpcConnector<>(FlagdOptions.builder().build(), + ServiceGrpc::newStub, + ServiceGrpc::newBlockingStub, + testConsumer, + stub -> stub.eventStream(Evaluation.EventStreamRequest.getDefaultInstance(), + mockEventStreamObserver) + , testChannel); + + instance.initialize(); + + // when shutting down server + testServer.shutdown(); + testServer.awaitTermination(5, TimeUnit.SECONDS); + + // when restarting server + setupTestGrpcServer(); + + // then consumer received DISCONNECTED and CONNECTED event + boolean finished = sync.await(10, TimeUnit.SECONDS); + Assertions.assertTrue(finished); + Assertions.assertEquals(Lists.newArrayList(DISCONNECTED, CONNECTED), connectionStateChanges); } @Test - void initialization_without_stream_deadline() throws NoSuchFieldException, IllegalAccessException { - final FlagdOptions options = FlagdOptions.builder() - .streamDeadlineMs(0) - .build(); - - final Cache cache = new Cache("disabled", 0); - final ServiceGrpc.ServiceStub mockStub = createServiceStubMock(); - - try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { - mockStaticService.when(() -> ServiceGrpc.newStub(any())).thenReturn(mockStub); - - final GrpcConnector connector = new GrpcConnector(options, cache, () -> true, null); - - assertDoesNotThrow(connector::initialize); - verify(mockStub, never()).withDeadlineAfter(16983, TimeUnit.MILLISECONDS); - } - } - - private static ServiceStub createServiceStubMock() { - final ServiceStub mockStub = mock(ServiceStub.class); - when(mockStub.withDeadlineAfter(anyLong(), any())).thenReturn(mockStub); - return mockStub; - } - - private NettyChannelBuilder getMockChannelBuilderSocket() { - NettyChannelBuilder mockChannelBuilder = mock(NettyChannelBuilder.class); - when(mockChannelBuilder.eventLoopGroup(any(EventLoopGroup.class))).thenReturn(mockChannelBuilder); - when(mockChannelBuilder.channelType(any(Class.class))).thenReturn(mockChannelBuilder); - when(mockChannelBuilder.usePlaintext()).thenReturn(mockChannelBuilder); - when(mockChannelBuilder.keepAliveTime(anyLong(), any())).thenReturn(mockChannelBuilder); - when(mockChannelBuilder.build()).thenReturn(null); - return mockChannelBuilder; + void whenShuttingDownGrpcConnector_ConsumerReceivesDisconnectedEvent() throws Exception { + CountDownLatch sync = new CountDownLatch(1); + ArrayList connectionStateChanges = Lists.newArrayList(); + Consumer testConsumer = event -> { + connectionStateChanges.add(event.isConnected()); + sync.countDown(); + }; + + GrpcConnector instance = + new GrpcConnector<>(FlagdOptions.builder().build(), + ServiceGrpc::newStub, + ServiceGrpc::newBlockingStub, + testConsumer, + stub -> stub.eventStream(Evaluation.EventStreamRequest.getDefaultInstance(), + mockEventStreamObserver) + , testChannel); + + instance.initialize(); + // when shutting grpc connector + instance.shutdown(); + + // then consumer received DISCONNECTED and CONNECTED event + boolean finished = sync.await(10, TimeUnit.SECONDS); + Assertions.assertTrue(finished); + Assertions.assertEquals(Lists.newArrayList(DISCONNECTED), connectionStateChanges); } } +