Skip to content

Commit

Permalink
Update SDK to automatically pull runtime URL when a session starts
Browse files Browse the repository at this point in the history
  • Loading branch information
bharatsuri97 committed Apr 11, 2023
1 parent 210a52a commit 5091e11
Show file tree
Hide file tree
Showing 6 changed files with 318 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,20 @@
*/
public class InMemoryCache implements Cache {

private final com.google.common.cache.Cache<String, String> cache;
private final com.google.common.cache.Cache<String, Object> cache;

public InMemoryCache(long ttlSeconds) {
cache = CacheBuilder.newBuilder().expireAfterAccess(ttlSeconds, TimeUnit.SECONDS).build();
}

@Override
public Optional<String> get(String key) {
String val = cache.getIfPresent(key);
String val = (String) cache.getIfPresent(key);
return Optional.ofNullable(val);
}

public Optional<Object> getObject(String key) {
Object val = cache.getIfPresent(key);
return Optional.ofNullable(val);
}

Expand All @@ -34,12 +39,20 @@ public void set(String key, String val) {
cache.put(key, val);
}

public void setObject(String key, Object val) {
cache.put(key, val);
}

/**
* This method does not respect the ttlSeconds parameter.
*/
@Override
public void set(String key, String val, long ttlSeconds) {
cache.put(key, val);
set(key, val);
}

public void setObject(String key, Object val, long ttlSeconds) {
setObject(key, val);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class RedisCache implements Cache {
private static final Long DEFAULT_TTL_SECONDS = 259140L; // 2 days, 23 hours, 59 minutes

private JedisPool jedisPool;
private long ttlSeconds;
private final long ttlSeconds;

/**
* This constructor will use the default ttl of 259,140 seconds and will assume standard Redis
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,15 @@
import static com.salesforce.einsteinbot.sdk.client.model.BotResponseBuilder.fromChatMessageResponseEnvelopeResponseEntity;
import static com.salesforce.einsteinbot.sdk.client.util.RequestFactory.buildChatMessageEnvelope;
import static com.salesforce.einsteinbot.sdk.client.util.RequestFactory.buildInitMessageEnvelope;
import static com.salesforce.einsteinbot.sdk.util.WebClientUtil.createErrorResponseProcessor;
import static com.salesforce.einsteinbot.sdk.util.WebClientUtil.createFilter;
import static com.salesforce.einsteinbot.sdk.util.WebClientUtil.createLoggingRequestProcessor;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.VisibleForTesting;
import com.salesforce.einsteinbot.sdk.api.BotApi;
import com.salesforce.einsteinbot.sdk.api.HealthApi;
import com.salesforce.einsteinbot.sdk.api.VersionsApi;
import com.salesforce.einsteinbot.sdk.auth.AuthMechanism;
import com.salesforce.einsteinbot.sdk.cache.InMemoryCache;
import com.salesforce.einsteinbot.sdk.client.model.BotEndSessionRequest;
import com.salesforce.einsteinbot.sdk.client.model.BotRequest;
import com.salesforce.einsteinbot.sdk.client.model.BotResponse;
Expand All @@ -28,35 +27,33 @@
import com.salesforce.einsteinbot.sdk.client.model.ExternalSessionId;
import com.salesforce.einsteinbot.sdk.client.model.RequestConfig;
import com.salesforce.einsteinbot.sdk.client.model.RuntimeSessionId;
import com.salesforce.einsteinbot.sdk.exception.ChatbotResponseException;
import com.salesforce.einsteinbot.sdk.client.util.ClientFactory;
import com.salesforce.einsteinbot.sdk.client.util.ClientFactory.ClientWrapper;
import com.salesforce.einsteinbot.sdk.exception.UnsupportedSDKException;
import com.salesforce.einsteinbot.sdk.handler.ApiClient;
import com.salesforce.einsteinbot.sdk.model.ChatMessageEnvelope;
import com.salesforce.einsteinbot.sdk.model.EndSessionReason;
import com.salesforce.einsteinbot.sdk.model.InitMessageEnvelope;
import com.salesforce.einsteinbot.sdk.model.Status;
import com.salesforce.einsteinbot.sdk.model.SupportedVersions;
import com.salesforce.einsteinbot.sdk.model.SupportedVersionsVersions;
import com.salesforce.einsteinbot.sdk.model.SupportedVersionsVersions.StatusEnum;
import com.salesforce.einsteinbot.sdk.util.LoggingJsonEncoder;
import com.salesforce.einsteinbot.sdk.util.ReleaseInfo;
import com.salesforce.einsteinbot.sdk.util.UtilFunctions;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.util.Objects;
import java.util.Optional;
import java.util.Properties;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.function.Consumer;

import com.salesforce.einsteinbot.sdk.util.WebClientUtil;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ClientCodecConfigurer;
import org.springframework.http.codec.json.Jackson2JsonDecoder;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.apache.http.HttpHeaders;
import org.apache.http.HttpHost;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.utils.URIUtils;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;

/**
* This is a basic implementation of {@link BasicChatbotClient}. It does not perform session
Expand All @@ -67,41 +64,44 @@
*/
public class BasicChatbotClientImpl implements BasicChatbotClient {

protected BotApi botApi;
protected HealthApi healthApi;
protected VersionsApi versionsApi;
protected ApiClient apiClient;
private static final Long DEFAULT_TTL_SECONDS = 259140L;
private static final String API_INFO_URI = "/services/data/v58.0/connect/bots/api-info";

protected InMemoryCache cache;
protected String basePath;
protected WebClient.Builder webClientBuilder;
protected AuthMechanism authMechanism;
protected ReleaseInfo releaseInfo = ReleaseInfo.getInstance();
protected ClientWrapper clientWrapper;

protected BasicChatbotClientImpl(String basePath,
AuthMechanism authMechanism,
WebClient.Builder webClientBuilder) {

this.authMechanism = authMechanism;
this.apiClient = new ApiClient(createWebClient(webClientBuilder), UtilFunctions.getMapper(),
UtilFunctions
.createDefaultDateFormat());
apiClient.setBasePath(basePath);
apiClient.setUserAgent(releaseInfo.getAsUserAgent());
botApi = new BotApi(apiClient);
healthApi = new HealthApi(apiClient);
versionsApi = new VersionsApi(apiClient);
this.basePath = basePath;
this.webClientBuilder = webClientBuilder;
this.clientWrapper = ClientFactory.createClient(basePath, webClientBuilder);
this.cache = new InMemoryCache(DEFAULT_TTL_SECONDS);
}

@VisibleForTesting
void setBotApi(BotApi botApi) {
this.botApi = botApi;
this.clientWrapper.setBotApi(botApi);
}

@VisibleForTesting
void setHealthApi(HealthApi healthApi) {
this.healthApi = healthApi;
this.clientWrapper.setHealthApi(healthApi);
}

@VisibleForTesting
void setVersionsApi(VersionsApi versionsApi) {
this.versionsApi = versionsApi;
this.clientWrapper.setVersionsApi(versionsApi);
}

@VisibleForTesting
void setCache(InMemoryCache cache) {
this.cache = cache;
}

@Override
Expand All @@ -112,15 +112,28 @@ public BotResponse startChatSession(RequestConfig config,
if (!isApiVersionSupported()) {
throw new UnsupportedSDKException(getCurrentApiVersion(), getLatestApiVersion());
}

String basePath = getRuntimeUrl(config.getForceConfigEndpoint());
Optional<Object> clientOptional = this.cache.getObject(basePath);
ClientWrapper clientWrapper = ClientFactory.createClient(basePath, webClientBuilder);
if (clientOptional.isPresent()) {
clientWrapper = (ClientWrapper) clientOptional.get();
}
this.clientWrapper = clientWrapper;

InitMessageEnvelope initMessageEnvelope = createInitMessageEnvelope(config, sessionId,
botSendMessageRequest);

notifyRequestEnvelopeInterceptor(botSendMessageRequest, initMessageEnvelope);
CompletableFuture<BotResponse> futureResponse = invokeEstablishChatSession(config,
initMessageEnvelope,
botSendMessageRequest);
botSendMessageRequest,
clientWrapper);
try {
return futureResponse.get();
BotResponse botResponse = futureResponse.get();
this.cache.set(botResponse.getResponseEnvelope().getSessionId(), basePath);
this.cache.setObject(basePath, clientWrapper);
return botResponse;
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
Expand All @@ -141,11 +154,13 @@ public BotResponse sendMessage(RequestConfig config,

ChatMessageEnvelope chatMessageEnvelope = createChatMessageEnvelope(botSendMessageRequest);

ClientWrapper clientWrapper = getCachedClientWrapper(sessionId);
notifyRequestEnvelopeInterceptor(botSendMessageRequest, chatMessageEnvelope);
CompletableFuture<BotResponse> futureResponse = invokeContinueChatSession(config.getOrgId(),
sessionId.getValue(),
chatMessageEnvelope,
botSendMessageRequest);
botSendMessageRequest,
clientWrapper);

try {
return futureResponse.get();
Expand All @@ -166,11 +181,14 @@ public BotResponse endChatSession(RequestConfig config,
BotEndSessionRequest botEndSessionRequest) {

EndSessionReason endSessionReason = botEndSessionRequest.getEndSessionReason();

ClientWrapper clientWrapper = getCachedClientWrapper(sessionId);
notifyRequestEnvelopeInterceptor(botEndSessionRequest, "EndSessionReason: " + endSessionReason);
CompletableFuture<BotResponse> futureResponse = invokeEndChatSession(config.getOrgId(),
sessionId.getValue(),
endSessionReason,
botEndSessionRequest);
botEndSessionRequest,
clientWrapper);
try {
return futureResponse.get();
} catch (InterruptedException | ExecutionException e) {
Expand All @@ -183,11 +201,23 @@ protected void notifyRequestEnvelopeInterceptor(BotRequest botRequest, Object re
.accept(requestEnvelope);
}

private ClientWrapper getCachedClientWrapper(RuntimeSessionId sessionId) {
Optional<String> basePath = this.cache.get(sessionId.getValue());
if (!basePath.isPresent()) {
throw new RuntimeException("No base path found in cache for session ID: " + sessionId.getValue());
}
Optional<Object> clientOptional = this.cache.getObject(basePath.get());
if (!clientOptional.isPresent()) {
throw new RuntimeException("No client implementation found in cache for base path: " + basePath.get());
}
return (ClientWrapper) clientOptional.get();
}

protected CompletableFuture<BotResponse> invokeEndChatSession(String orgId, String sessionId,
EndSessionReason endSessionReason, BotEndSessionRequest botRequest) {
EndSessionReason endSessionReason, BotEndSessionRequest botRequest, ClientWrapper clientWrapper) {

apiClient.setBearerToken(authMechanism.getToken());
CompletableFuture<BotResponse> futureResponse = botApi
clientWrapper.getApiClient().setBearerToken(authMechanism.getToken());
CompletableFuture<BotResponse> futureResponse = clientWrapper.getBotApi()
.endSessionWithHttpInfo(sessionId,
orgId,
endSessionReason,
Expand All @@ -202,10 +232,11 @@ protected CompletableFuture<BotResponse> invokeEndChatSession(String orgId, Stri

protected CompletableFuture<BotResponse> invokeEstablishChatSession(RequestConfig config,
InitMessageEnvelope initMessageEnvelope,
BotSendMessageRequest botRequest) {
BotSendMessageRequest botRequest,
ClientWrapper clientWrapper) {

apiClient.setBearerToken(authMechanism.getToken());
CompletableFuture<BotResponse> futureResponse = botApi
clientWrapper.getApiClient().setBearerToken(authMechanism.getToken());
CompletableFuture<BotResponse> futureResponse = clientWrapper.getBotApi()
.startSessionWithHttpInfo(config.getBotId(), config.getOrgId(),
initMessageEnvelope, botRequest.getRequestId().orElse(null))
.toFuture()
Expand All @@ -216,10 +247,11 @@ protected CompletableFuture<BotResponse> invokeEstablishChatSession(RequestConfi

protected CompletableFuture<BotResponse> invokeContinueChatSession(String orgId, String sessionId,
ChatMessageEnvelope messageEnvelope,
BotSendMessageRequest botRequest) {
BotSendMessageRequest botRequest,
ClientWrapper clientWrapper) {

apiClient.setBearerToken(authMechanism.getToken());
CompletableFuture<BotResponse> futureResponse = botApi
clientWrapper.getApiClient().setBearerToken(authMechanism.getToken());
CompletableFuture<BotResponse> futureResponse = clientWrapper.getBotApi()
.continueSessionWithHttpInfo(sessionId,
orgId,
messageEnvelope,
Expand All @@ -233,7 +265,7 @@ protected CompletableFuture<BotResponse> invokeContinueChatSession(String orgId,
}

public Status getHealthStatus() {
CompletableFuture<Status> statusFuture = healthApi.checkHealthStatus().toFuture();
CompletableFuture<Status> statusFuture = this.clientWrapper.getHealthApi().checkHealthStatus().toFuture();

try {
return statusFuture.get();
Expand All @@ -243,7 +275,7 @@ public Status getHealthStatus() {
}

public SupportedVersions getSupportedVersions() {
CompletableFuture<SupportedVersions> versionsFuture = versionsApi.getAPIVersions().toFuture();
CompletableFuture<SupportedVersions> versionsFuture = this.clientWrapper.getVersionsApi().getAPIVersions().toFuture();

try {
SupportedVersions versions = versionsFuture.get();
Expand All @@ -256,30 +288,23 @@ public SupportedVersions getSupportedVersions() {
}
}

private WebClient createWebClient(WebClient.Builder webClientBuilder) {

return webClientBuilder
.codecs(createCodecsConfiguration(UtilFunctions.getMapper()))
.filter(createFilter(clientRequest -> createLoggingRequestProcessor(clientRequest),
clientResponse -> createErrorResponseProcessor(clientResponse, this::mapErrorResponse)))
.build();
}

private Consumer<ClientCodecConfigurer> createCodecsConfiguration(ObjectMapper mapper) {
return clientDefaultCodecsConfigurer -> {
clientDefaultCodecsConfigurer.defaultCodecs()
.jackson2JsonEncoder(new LoggingJsonEncoder(mapper, MediaType.APPLICATION_JSON, false));
clientDefaultCodecsConfigurer.defaultCodecs()
.jackson2JsonDecoder(new Jackson2JsonDecoder(mapper, MediaType.APPLICATION_JSON));
};
}

private Mono<ClientResponse> mapErrorResponse(ClientResponse clientResponse) {
return clientResponse
.body(WebClientUtil.errorBodyExtractor())
.flatMap(errorDetails -> Mono
.error(new ChatbotResponseException(clientResponse.statusCode(), errorDetails,
clientResponse.headers())));
private String getRuntimeUrl(String forceEndpoint) {
try {
URI uri = URI.create(forceEndpoint);
HttpHost forceHost = URIUtils.extractHost(uri);
String infoPath = uri.getRawPath().replace("/$", "") + API_INFO_URI;
HttpGet httpGet = new HttpGet(forceHost.toString() + infoPath);
httpGet.setHeader(HttpHeaders.AUTHORIZATION, "Bearer " + this.authMechanism.getToken());
try (CloseableHttpClient httpClient = HttpClients.createDefault()) {
return httpClient.execute(httpGet, httpResponse -> {
String response = EntityUtils.toString(httpResponse.getEntity());
JsonNode node = new ObjectMapper().readValue(response, JsonNode.class);
return node.get("runtimeBaseUrl").asText();
});
}
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}

private String getCurrentApiVersion() {
Expand Down
Loading

0 comments on commit 5091e11

Please sign in to comment.