Skip to content

Commit

Permalink
participant update target state
Browse files Browse the repository at this point in the history
  • Loading branch information
xyuanlu committed Sep 10, 2024
1 parent a257614 commit d3b6c68
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@
import io.grpc.Status;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.helix.gateway.api.service.HelixGatewayServiceChannel;
Expand All @@ -52,10 +51,10 @@ public class HelixGatewayServiceGrpcService extends HelixGatewayServiceGrpc.Heli
private static final Logger logger = LoggerFactory.getLogger(HelixGatewayServiceGrpcService.class);

// Map to store the observer for each instance
private final Map<String, StreamObserver<ShardChangeRequests>> _observerMap = new ConcurrentHashMap<>();
private final Map<String, StreamObserver<ShardChangeRequests>> _observerMap = new HashMap<>();
// A reverse map to store the instance name for each observer. It is used to find the instance when connection is closed.
// map<observer, pair<instance, cluster>>
private final Map<StreamObserver<ShardChangeRequests>, Pair<String, String>> _reversedObserverMap = new ConcurrentHashMap<>();
private final Map<StreamObserver<ShardChangeRequests>, Pair<String, String>> _reversedObserverMap = new HashMap<>();

private final GatewayServiceManager _manager;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ public class HelixGatewayParticipant implements HelixManagerStateListener {
private final HelixGatewayServiceChannel _gatewayServiceChannel;
private final HelixManager _helixManager;
private final Runnable _onDisconnectedCallback;
private final Map<String, Map<String, String>> _shardStateMap;

private final Map<String, CompletableFuture<String>> _stateTransitionResultMap;

private final GatewayServiceManager _gatewayServiceManager;
Expand All @@ -59,7 +57,6 @@ private HelixGatewayParticipant(HelixGatewayServiceChannel gatewayServiceChannel
_gatewayServiceChannel = gatewayServiceChannel;
_helixManager = helixManager;
_onDisconnectedCallback = onDisconnectedCallback;
_shardStateMap = initialShardStateMap;
_stateTransitionResultMap = new ConcurrentHashMap<>();
_gatewayServiceManager = gatewayServiceManager;
}
Expand All @@ -71,16 +68,14 @@ public void processStateTransitionMessage(Message message) throws Exception {
String concatenatedShardName = resourceId + shardId;

try {
if (isCurrentStateAlreadyTarget(resourceId, shardId, toState)) {
return;
}

CompletableFuture<String> future = new CompletableFuture<>();

// update the target state in cache
_gatewayServiceManager.updateTargetState(_helixManager.getClusterName(), _helixManager.getInstanceName(),
resourceId, shardId, toState);

if (isCurrentStateAlreadyTarget(resourceId, shardId, toState)) {
return;
}
CompletableFuture<String> future = new CompletableFuture<>();
_stateTransitionResultMap.put(concatenatedShardName, future);
_gatewayServiceChannel.sendStateChangeRequests(_helixManager.getInstanceName(),
StateTransitionMessageTranslateUtil.translateSTMsgToShardChangeRequests(message));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,10 @@ public void onGatewayServiceEvent(GatewayServiceEvent event) {
}
}

private GatewayCurrentStateCache getCache(String clusterName) {
return _currentStateCacheMap.computeIfAbsent(clusterName, k -> new GatewayCurrentStateCache(clusterName));
}

public void resetTargetStateCache(String clusterName, String instanceName) {
getCache(clusterName).resetTargetStateCache(instanceName);
}

public void addInstanceToCache(String clusterName, String instanceName) {
getCache(clusterName).addInstanceToCache(instanceName);
}

/**
* Overwrite the current state cache with the new current state map, and return the diff of the change.
* @param clusterName
Expand All @@ -127,14 +119,19 @@ public String serializeTargetState() {
return targetStateNode.toString();
}

public void updateTargetState(String clusterName, String instanceName, String resourceId, String shardId, String toState) {
getCache(clusterName).updateTargetStateWithDiff(instanceName, Map.of(resourceId, Map.of(shardId, toState)));
public void updateTargetState(String clusterName, String instanceName, String resourceId, String shardId,
String toState) {
getCache(clusterName).updateTargetStateOfExistingInstance(instanceName, resourceId, shardId, toState);
}

public String getCurrentState(String clusterName, String instanceName, String resourceId, String shardId) {
return getCache(clusterName).getCurrentState(instanceName, resourceId, shardId);
}

public String getTargetState(String clusterName, String instanceName, String resourceId, String shardId) {
return getCache(clusterName).getTargetState(instanceName, resourceId, shardId);
}

/**
* Update in memory shard state
*/
Expand Down Expand Up @@ -214,12 +211,16 @@ private void removeHelixGatewayParticipant(String clusterName, String instanceNa
participant.disconnect();
_helixGatewayParticipantMap.get(clusterName).remove(instanceName);
}
_currentStateCacheMap.get(clusterName).removeInstanceFromCache(instanceName);
_currentStateCacheMap.get(clusterName).removeInstanceTargetDataFromCache(instanceName);
}

private HelixGatewayParticipant getHelixGatewayParticipant(String clusterName,
String instanceName) {
return _helixGatewayParticipantMap.getOrDefault(clusterName, Collections.emptyMap())
.get(instanceName);
}

private GatewayCurrentStateCache getCache(String clusterName) {
return _currentStateCacheMap.computeIfAbsent(clusterName, k -> new GatewayCurrentStateCache(clusterName));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.helix.gateway.channel.HelixGatewayServiceGrpcService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -39,16 +38,16 @@ public class GatewayCurrentStateCache {

// A cache of current state. It should be updated by the HelixGatewayServiceChannel
// instance -> resource state (resource -> shard -> target state)
final Map<String, ShardStateMap> _currentStateMap;
Map<String, ShardStateMap> _currentStateMap;

// A cache of target state.
// instance -> resource state (resource -> shard -> target state)
final Map<String, ShardStateMap> _targetStateMap;

public GatewayCurrentStateCache(String clusterName) {
_clusterName = clusterName;
_currentStateMap = new ConcurrentHashMap<>();
_targetStateMap = new ConcurrentHashMap<>();
_currentStateMap = new HashMap<>();
_targetStateMap = new HashMap<>();
}

public String getCurrentState(String instance, String resource, String shard) {
Expand All @@ -63,111 +62,104 @@ public String getTargetState(String instance, String resource, String shard) {

/**
* Update the cached current state of instances in a cluster, and return the diff of the change.
* @param newCurrentStateMap The new current state map of instances in the cluster
* @param userCurrentStateMap The new current state map of instances in the cluster
* @return
*/
public Map<String, Map<String, Map<String, String>>> updateCacheWithNewCurrentStateAndGetDiff(
Map<String, Map<String, Map<String, String>>> newCurrentStateMap) {
Map<String, Map<String, Map<String, String>>> userCurrentStateMap) {
Map<String, ShardStateMap> newCurrentStateMap = new HashMap<>(_currentStateMap);
Map<String, Map<String, Map<String, String>>> diff = new HashMap<>();
for (String instance : newCurrentStateMap.keySet()) {
if (!_currentStateMap.containsKey(instance)) {
logger.warn("Instance {} is not in the state map, skip updating", instance);
continue;
}
Map<String, Map<String, String>> newCurrentState = newCurrentStateMap.get(instance);
Map<String, Map<String, String>> resourceStateDiff =
_currentStateMap.computeIfAbsent(instance, k -> new ShardStateMap(new HashMap<>()))
.updateAndGetDiff(newCurrentState);
if (resourceStateDiff != null && !resourceStateDiff.isEmpty()) {
diff.put(instance, resourceStateDiff);
for (String instance : userCurrentStateMap.keySet()) {
ShardStateMap oldStateMap = _currentStateMap.get(instance);
Map<String, Map<String, String>> instanceDiff = oldStateMap == null ? userCurrentStateMap.get(instance)
: oldStateMap.getDiff(userCurrentStateMap.get(instance));
if (!instanceDiff.isEmpty()) {
diff.put(instance, instanceDiff);
}
newCurrentStateMap.put(instance, new ShardStateMap(userCurrentStateMap.get(instance)));
}
logger.info("Update current state cache for instances: {}", diff.keySet());
_currentStateMap = newCurrentStateMap;
return diff;
}

/**
* Update the current state with the changed current state maps.
*/
public void updateCurrentStateOfExistingInstance(String instance, String resource, String shard, String shardState) {
updateShardStateMapWithDiff(_currentStateMap, instance, Map.of(resource, Map.of(shard, shardState)));
logger.info("Update current state of instance: {}, resource: {}, shard: {}, state: {}", instance, resource, shard,
shardState);
updateShardStateMapWithDiff(_currentStateMap, instance, resource, shard, shardState);
}

/**
* Update the target state with the changed target state maps.
* All existing target states remains the same
* @param diff
*/
public void updateTargetStateWithDiff(String instance, Map<String, Map<String, String>> diff) {
updateShardStateMapWithDiff(_targetStateMap, instance, diff);
public void updateTargetStateOfExistingInstance(String instance, String resource, String shard, String shardState) {
logger.info("Update target state of instance: {}, resource: {}, shard: {}, state: {}", instance, resource, shard,
shardState);
updateShardStateMapWithDiff(_targetStateMap, instance, resource, shard, shardState);
}

private void updateShardStateMapWithDiff(Map<String, ShardStateMap> stateMap, String instance,
String resource, String shard, String shardState) {
ShardStateMap curStateMap = stateMap.get(instance);
if (curStateMap == null) {
logger.warn("Instance {} is not in the state map, skip updating", instance);
return;
}
curStateMap.updateWithShardState(resource, shard, shardState);
}

/**
* Serialize the target state assignments to a JSON Node.
* example : {"instance1":{"resource1":{"shard1":"ONLINE","shard2":"OFFLINE"}}}}
*/
public ObjectNode serializeTargetAssignmentsToJSONNode() {
public synchronized ObjectNode serializeTargetAssignmentsToJSONNode() {
ObjectNode root = mapper.createObjectNode();
for (Map.Entry<String, ShardStateMap> entry : _targetStateMap.entrySet()) {
root.set(entry.getKey(), entry.getValue().toJSONNode());
}
return root;
}

public void removeInstanceFromCache(String instance) {
_currentStateMap.remove(instance);
/**
* Remove the target state data of an instance from the cache.
*/
public synchronized void removeInstanceTargetDataFromCache(String instance) {
logger.info("Remove instance target data from cache for instance: {}", instance);
_targetStateMap.remove(instance);
}

public void addInstanceToCache(String instance) {
_currentStateMap.put(instance, new ShardStateMap(new HashMap<>()));
_targetStateMap.put(instance, new ShardStateMap(new HashMap<>()));
}

private void updateShardStateMapWithDiff(Map<String, ShardStateMap> stateMap, String instance,
Map<String, Map<String, String>> diffMap) {
if (diffMap == null || diffMap.isEmpty()) {
return;
}
if (!stateMap.containsKey(instance)) {
logger.warn("Instance {} is not in the state map, skip updating", instance);
}
stateMap.get(instance).updateWithDiff(diffMap);
}

public void resetTargetStateCache(String instance) {
/**
* Remove the current state data of an instance from the cache to an empty map.
*/
public synchronized void resetTargetStateCache(String instance) {
logger.info("Reset target state cache for instance: {}", instance);
_targetStateMap.put(instance, new ShardStateMap(new HashMap<>()));
}

public static class ShardStateMap {
Map<String, Map<String, String>> _stateMap;
final Object _lock = new Object();

public ShardStateMap(Map<String, Map<String, String>> stateMap) {
_stateMap = new ConcurrentHashMap<>(stateMap);
}

public String getState(String instance, String shard) {
Map<String, String> shardStateMap = _stateMap.get(instance);
public String getState(String resource, String shard) {
Map<String, String> shardStateMap = _stateMap.get(resource);
return shardStateMap == null ? null : shardStateMap.get(shard);
}

private void updateWithDiff(Map<String, Map<String, String>> diffMap) {
for (Map.Entry<String, Map<String, String>> diffEntry : diffMap.entrySet()) {
String resource = diffEntry.getKey();
Map<String, String> diffCurrentState = diffEntry.getValue();
if (_stateMap.get(resource) != null) {
_stateMap.get(resource).entrySet().forEach(currentMapEntry -> {
String shard = currentMapEntry.getKey();
if (diffCurrentState.get(shard) != null) {
currentMapEntry.setValue(diffCurrentState.get(shard));
}
});
} else {
_stateMap.put(resource, diffCurrentState);
}
public void updateWithShardState(String resource, String shard, String shardState) {
logger.info("Update ShardStateMap of resource: {}, shard: {}, state: {}", resource, shard, shardState);
synchronized (_lock) {
_stateMap.computeIfAbsent(resource, k -> new HashMap<>()).put(shard, shardState);
}
}

private Map<String, Map<String, String>> updateAndGetDiff(Map<String, Map<String, String>> newCurrentStateMap) {
private Map<String, Map<String, String>> getDiff(Map<String, Map<String, String>> newCurrentStateMap) {
Map<String, Map<String, String>> diff = new HashMap<>();
for (Map.Entry<String, Map<String, String>> entry : newCurrentStateMap.entrySet()) {
String resource = entry.getKey();
Expand All @@ -185,7 +177,6 @@ private Map<String, Map<String, String>> updateAndGetDiff(Map<String, Map<String
}
}
}
_stateMap = new ConcurrentHashMap<>(newCurrentStateMap);
return diff;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,11 @@ public static GatewayServiceEvent translateShardStateMessageToEventAndUpdateCach
.put(state.getShardName(), state.getCurrentState());
}
}
manager.addInstanceToCache(shardState.getClusterName(), shardState.getInstanceName());
// update current state cache. We always overwrite the current state map for initial connection
manager.updateCacheWithNewCurrentStateAndGetDiff(shardState.getClusterName(), Map.of(shardState.getInstanceName(), shardStateMap));
Map<String, Map<String, Map<String, String>>> newShardStateMap = new HashMap<>();
newShardStateMap.put(shardState.getInstanceName(), shardStateMap);
manager.updateCacheWithNewCurrentStateAndGetDiff(shardState.getClusterName(), newShardStateMap);

builder = new GatewayServiceEvent.GateWayServiceEventBuilder(GatewayServiceEventType.CONNECT).setClusterName(
shardState.getClusterName()).setParticipantName(shardState.getInstanceName())
.setShardStateMap(shardStateMap);
Expand Down
Loading

0 comments on commit d3b6c68

Please sign in to comment.