Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updating to support dedicated serving mode with fine-tuned models #34

Merged
merged 1 commit into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions app/src/components/content/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ type Chat = {
answer?: string;
loading?: string;
};
type Model = {
id: string;
name: string;
vendor: string;
version: string;
capabilities: Array<string>;
timeCreated: string;
};

const defaultServiceType: string = localStorage.getItem("service") || "text";
const defaultBackendType: string = localStorage.getItem("backend") || "java";
Expand All @@ -46,6 +54,7 @@ const Content = () => {
const question = useRef<string>();
const chatData = useRef<Array<object>>([]);
const socket = useRef<WebSocket>();
const finetune = useRef<boolean>(false);
const [client, setClient] = useState<Client | null>(null);

const messagesDP = useRef(
Expand Down Expand Up @@ -167,7 +176,13 @@ const Content = () => {
JSON.stringify({ msgType: "question", data: question.current })
);
} else {
sendPrompt(client, question.current!, modelId!, conversationId!);
sendPrompt(
client,
question.current!,
modelId!,
conversationId!,
finetune.current
);
}
}
};
Expand Down Expand Up @@ -199,9 +214,9 @@ const Content = () => {
localStorage.setItem("backend", backend);
location.reload();
};
const modelIdChangeHandler = (event: CustomEvent) => {
console.log("model Id: ", event.detail.value);
if (event.detail.value != null) setModelId(event.detail.value);
const modelIdChangeHandler = (value: string, modelType: boolean) => {
if (value != null) setModelId(value);
finetune.current = modelType;
};
const clearSummary = () => {
setSummaryResults("");
Expand Down
72 changes: 67 additions & 5 deletions app/src/components/content/settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import "oj-c/select-single";
import "ojs/ojlistitemlayout";
import "ojs/ojhighlighttext";
import MutableArrayDataProvider = require("ojs/ojmutablearraydataprovider");
import { ojSelectSingle } from "@oracle/oraclejet/ojselectsingle";

type ServiceTypeVal = "text" | "summary" | "sim";
type BackendTypeVal = "java" | "python";
Expand All @@ -17,7 +18,7 @@ type Props = {
backendType: BackendTypeVal;
aiServiceChange: (service: ServiceTypeVal) => void;
backendChange: (backend: BackendTypeVal) => void;
modelIdChange: (modelName: any) => void;
modelIdChange: (modelId: any, modelData: any) => void;
};

const serviceTypes = [
Expand All @@ -30,6 +31,21 @@ const backendTypes = [
{ value: "java", label: "Java" },
{ value: "python", label: "Python" },
];
type Model = {
id: string;
name: string;
vendor: string;
version: string;
capabilities: Array<string>;
timeCreated: string;
};
type Endpoint = {
id: string;
name: string;
state: string;
model: string;
timeCreated: string;
};
const serviceOptionsDP = new MutableArrayDataProvider<
Services["value"],
Services
Expand All @@ -50,8 +66,11 @@ export const Settings = (props: Props) => {
};

const modelDP = useRef(
new MutableArrayDataProvider<string, {}>([], { keyAttributes: "id" })
new MutableArrayDataProvider<string, {}>([], {
keyAttributes: "id",
})
);
const endpoints = useRef<Array<Endpoint>>();

const fetchModels = async () => {
try {
Expand All @@ -60,9 +79,8 @@ export const Settings = (props: Props) => {
throw new Error(`Response status: ${response.status}`);
}
const json = await response.json();
const result = json.filter((model: any) => {
const result = json.filter((model: Model) => {
if (
// model.capabilities.includes("FINE_TUNE") &&
model.capabilities.includes("TEXT_GENERATION") &&
(model.vendor == "cohere" || model.vendor == "") &&
model.version != "14.2"
Expand All @@ -77,11 +95,55 @@ export const Settings = (props: Props) => {
);
}
};
const fetchEndpoints = async () => {
try {
const response = await fetch("/api/genai/endpoints");
if (!response.ok) {
throw new Error(`Response status: ${response.status}`);
}
const json = await response.json();
const result = json.filter((endpoint: Endpoint) => {
// add filtering code here
return endpoint;
});
endpoints.current = result;
} catch (error: any) {
console.log(
"Java service not available for fetching list of Endpoints: ",
error.message
);
}
};

useEffect(() => {
fetchEndpoints();
fetchModels();
}, []);

const modelChangeHandler = async (
event: ojSelectSingle.valueChanged<string, {}>
) => {
let selected = event.detail.value;
let finetune = false;
const asyncIterator = modelDP.current.fetchFirst()[Symbol.asyncIterator]();
let result = await asyncIterator.next();
let value = result.value;
let data = value.data as Array<Model>;
let idx = data.find((e: Model) => {
if (e.id === selected) return e;
});
if (idx?.capabilities.includes("FINE_TUNE")) {
finetune = true;
let endpointId = endpoints.current?.find((e: Endpoint) => {
if (e.model === event.detail.value) {
return e.id;
}
});
selected = endpointId ? endpointId.id : event.detail.value;
}
props.modelIdChange(selected, finetune);
};

const modelTemplate = (item: any) => {
return (
<oj-list-item-layout class="oj-listitemlayout-padding-off">
Expand Down Expand Up @@ -134,7 +196,7 @@ export const Settings = (props: Props) => {
data={modelDP.current}
labelHint={"Model"}
itemText={"name"}
onvalueChanged={props.modelIdChange}
onvalueChanged={modelChangeHandler}
>
<template slot="itemTemplate" render={modelTemplate}></template>
</oj-c-select-single>
Expand Down
4 changes: 3 additions & 1 deletion app/src/components/content/stomp-interface.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ export const sendPrompt = (
client: Client | null,
prompt: string,
modelId: string,
convoId: string
convoId: string,
finetune: boolean
) => {
if (client?.connected) {
console.log("Sending prompt: ", prompt);
Expand All @@ -134,6 +135,7 @@ export const sendPrompt = (
conversationId: convoId,
content: prompt,
modelId: modelId,
finetune: finetune,
}),
});
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
import com.oracle.bmc.generativeai.GenerativeAiClient;
import com.oracle.bmc.generativeai.model.ModelCapability;
import com.oracle.bmc.generativeai.requests.ListModelsRequest;
import com.oracle.bmc.generativeai.requests.ListEndpointsRequest;
import com.oracle.bmc.generativeai.responses.ListModelsResponse;
import com.oracle.bmc.generativeai.responses.ListEndpointsResponse;
import com.oracle.bmc.generativeai.model.EndpointSummary;
import dev.victormartin.oci.genai.backend.backend.dao.GenAiModel;
import dev.victormartin.oci.genai.backend.backend.dao.GenAiEndpoint;
import dev.victormartin.oci.genai.backend.backend.service.GenerativeAiClientService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -33,11 +37,25 @@ public List<GenAiModel> getModels() {
GenerativeAiClient client = generativeAiClientService.getClient();
ListModelsResponse response = client.listModels(listModelsRequest);
return response.getModelCollection().getItems().stream().map(m -> {
List<String> capabilities = m.getCapabilities().stream().map(ModelCapability::getValue).collect(Collectors.toList());
GenAiModel model = new GenAiModel(m.getId(),m.getDisplayName(), m.getVendor(), m.getVersion(),
capabilities,
m.getTimeCreated());
List<String> capabilities = m.getCapabilities().stream().map(ModelCapability::getValue)
.collect(Collectors.toList());
GenAiModel model = new GenAiModel(m.getId(), m.getDisplayName(), m.getVendor(), m.getVersion(),
capabilities, m.getTimeCreated());
return model;
}).collect(Collectors.toList());
}

@GetMapping("/api/genai/endpoints")
public List<GenAiEndpoint> getEndpoints() {
logger.info("getEndpoints()");
ListEndpointsRequest listEndpointsRequest = ListEndpointsRequest.builder().compartmentId(COMPARTMENT_ID)
.build();
GenerativeAiClient client = generativeAiClientService.getClient();
ListEndpointsResponse response = client.listEndpoints(listEndpointsRequest);
return response.getEndpointCollection().getItems().stream().map(e -> {
GenAiEndpoint endpoint = new GenAiEndpoint(e.getId(), e.getDisplayName(), e.getLifecycleState(),
e.getModelId(), e.getTimeCreated());
return endpoint;
}).collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public PromptController(InteractionRepository interactionRepository, OCIGenAISer
@SendToUser("/queue/answer")
public Answer handlePrompt(Prompt prompt) {
String promptEscaped = HtmlUtils.htmlEscape(prompt.content());
boolean finetune = prompt.finetune();
String activeModel = (prompt.modelId() == null) ? hardcodedChatModelId : prompt.modelId();
logger.info("Prompt " + promptEscaped + " received, on model " + activeModel);

Expand All @@ -59,11 +60,8 @@ public Answer handlePrompt(Prompt prompt) {
if (prompt.content().isEmpty()) {
throw new InvalidPromptRequest();
}
// if (prompt.modelId() == null ||
// !prompt.modelId().startsWith("ocid1.generativeaimodel.")) { throw new
// InvalidPromptRequest(); }
saved.setDatetimeResponse(new Date());
String responseFromGenAI = genAI.resolvePrompt(promptEscaped, activeModel);
String responseFromGenAI = genAI.resolvePrompt(promptEscaped, activeModel, finetune);
saved.setResponse(responseFromGenAI);
interactionRepository.save(saved);
return new Answer(responseFromGenAI, "");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package dev.victormartin.oci.genai.backend.backend.dao;

import java.util.Date;
import com.oracle.bmc.generativeai.model.Endpoint;

public record GenAiEndpoint(String id, String name, Endpoint.LifecycleState state, String model, Date timeCreated) {
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
package dev.victormartin.oci.genai.backend.backend.dao;

public record Prompt(String content, String conversationId, String modelId) {};
public record Prompt(String content, String conversationId, String modelId, boolean finetune) {
};
Original file line number Diff line number Diff line change
Expand Up @@ -14,56 +14,56 @@

@Service
public class OCIGenAIService {
@Value("${genai.compartment_id}")
private String COMPARTMENT_ID;
@Value("${genai.compartment_id}")
private String COMPARTMENT_ID;

@Autowired
private GenerativeAiInferenceClientService generativeAiInferenceClientService;
@Autowired
private GenerativeAiInferenceClientService generativeAiInferenceClientService;

public String resolvePrompt(String input, String modelId) {
// Build generate text request, send, and get response
CohereLlmInferenceRequest llmInferenceRequest =
CohereLlmInferenceRequest.builder()
.prompt(input)
.maxTokens(600)
.temperature((double)1)
.frequencyPenalty((double)0)
.topP((double)0.75)
.isStream(false)
.isEcho(false)
.build();
public String resolvePrompt(String input, String modelId, boolean finetune) {
// Build generate text request, send, and get response
CohereLlmInferenceRequest llmInferenceRequest = CohereLlmInferenceRequest.builder()
.prompt(input)
.maxTokens(600)
.temperature((double) 1)
.frequencyPenalty((double) 0)
.topP((double) 0.75)
.isStream(false)
.isEcho(false)
.build();

GenerateTextDetails generateTextDetails = GenerateTextDetails.builder()
.servingMode(OnDemandServingMode.builder().modelId(modelId).build())
.compartmentId(COMPARTMENT_ID)
.inferenceRequest(llmInferenceRequest)
.build();
GenerateTextRequest generateTextRequest = GenerateTextRequest.builder()
.generateTextDetails(generateTextDetails)
.build();
GenerativeAiInferenceClient client = generativeAiInferenceClientService.getClient();
GenerateTextResponse generateTextResponse = client.generateText(generateTextRequest);
CohereLlmInferenceResponse response =
(CohereLlmInferenceResponse) generateTextResponse.getGenerateTextResult().getInferenceResponse();
String responseTexts = response.getGeneratedTexts()
.stream()
.map(t -> t.getText())
.collect(Collectors.joining(","));
return responseTexts;
}
GenerateTextDetails generateTextDetails = GenerateTextDetails.builder()
.servingMode(finetune ? DedicatedServingMode.builder().endpointId(modelId).build()
: OnDemandServingMode.builder().modelId(modelId).build())
.compartmentId(COMPARTMENT_ID)
.inferenceRequest(llmInferenceRequest)
.build();
GenerateTextRequest generateTextRequest = GenerateTextRequest.builder()
.generateTextDetails(generateTextDetails)
.build();
GenerativeAiInferenceClient client = generativeAiInferenceClientService.getClient();
GenerateTextResponse generateTextResponse = client.generateText(generateTextRequest);
CohereLlmInferenceResponse response = (CohereLlmInferenceResponse) generateTextResponse
.getGenerateTextResult().getInferenceResponse();
String responseTexts = response.getGeneratedTexts()
.stream()
.map(t -> t.getText())
.collect(Collectors.joining(","));
return responseTexts;
}

public String summaryText(String input, String modelId) {
SummarizeTextDetails summarizeTextDetails = SummarizeTextDetails.builder()
.servingMode(OnDemandServingMode.builder().modelId(modelId).build())
.compartmentId(COMPARTMENT_ID)
.input(input)
.build();
SummarizeTextRequest request = SummarizeTextRequest.builder()
.summarizeTextDetails(summarizeTextDetails)
.build();
GenerativeAiInferenceClient client = generativeAiInferenceClientService.getClient();
SummarizeTextResponse summarizeTextResponse = client.summarizeText(request);
String summaryText = summarizeTextResponse.getSummarizeTextResult().getSummary();
return summaryText;
}
public String summaryText(String input, String modelId) {
SummarizeTextDetails summarizeTextDetails = SummarizeTextDetails.builder()
.servingMode(OnDemandServingMode.builder().modelId(modelId).build())
.compartmentId(COMPARTMENT_ID)
.input(input)
.build();
SummarizeTextRequest request = SummarizeTextRequest.builder()
.summarizeTextDetails(summarizeTextDetails)
.build();
GenerativeAiInferenceClient client = generativeAiInferenceClientService.getClient();
SummarizeTextResponse summarizeTextResponse = client.summarizeText(request);
String summaryText = summarizeTextResponse.getSummarizeTextResult().getSummary();
return summaryText;
}
}