diff --git a/app/components/model-config.tsx b/app/components/model-config.tsx index 6ce25f6642c..948c9fb2921 100644 --- a/app/components/model-config.tsx +++ b/app/components/model-config.tsx @@ -12,6 +12,7 @@ export function ModelConfigList(props: { }) { const allModels = useAllModels(); const value = `${props.modelConfig.model}@${props.modelConfig?.providerName}`; + const compressModelValue = `${props.modelConfig.compressModel}@${props.modelConfig?.compressProviderName}`; return ( <> @@ -228,6 +229,30 @@ export function ModelConfigList(props: { } > + + + ); } diff --git a/app/store/chat.ts b/app/store/chat.ts index 58c105e7ef7..4332c2246ad 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -1,33 +1,29 @@ -import { trimTopic, getMessageTextContent } from "../utils"; +import { getMessageTextContent, trimTopic } from "../utils"; -import Locale, { getLang } from "../locales"; +import { indexedDBStorage } from "@/app/utils/indexedDB-storage"; +import { nanoid } from "nanoid"; +import type { + ClientApi, + MultimodalContent, + RequestMessage, +} from "../client/api"; +import { getClientApi } from "../client/api"; +import { ChatControllerPool } from "../client/controller"; import { showToast } from "../components/ui-lib"; -import { ModelConfig, ModelType, useAppConfig } from "./config"; -import { createEmptyMask, Mask } from "./mask"; import { DEFAULT_INPUT_TEMPLATE, DEFAULT_MODELS, DEFAULT_SYSTEM_TEMPLATE, KnowledgeCutOffDate, StoreKey, - SUMMARIZE_MODEL, - GEMINI_SUMMARIZE_MODEL, } from "../constant"; -import { getClientApi } from "../client/api"; -import type { - ClientApi, - RequestMessage, - MultimodalContent, -} from "../client/api"; -import { ChatControllerPool } from "../client/controller"; +import Locale, { getLang } from "../locales"; +import { isDalle3, safeLocalStorage } from "../utils"; import { prettyObject } from "../utils/format"; -import { estimateTokenLength } from "../utils/token"; -import { nanoid } from "nanoid"; import { createPersistStore } from "../utils/store"; -import { collectModelsWithDefaultModel } from "../utils/model"; -import { useAccessStore } from "./access"; -import { isDalle3, safeLocalStorage } from "../utils"; -import { indexedDBStorage } from "@/app/utils/indexedDB-storage"; +import { estimateTokenLength } from "../utils/token"; +import { ModelConfig, ModelType, useAppConfig } from "./config"; +import { createEmptyMask, Mask } from "./mask"; const localStorage = safeLocalStorage(); @@ -106,27 +102,6 @@ function createEmptySession(): ChatSession { }; } -function getSummarizeModel(currentModel: string) { - // if it is using gpt-* models, force to use 4o-mini to summarize - if (currentModel.startsWith("gpt") || currentModel.startsWith("chatgpt")) { - const configStore = useAppConfig.getState(); - const accessStore = useAccessStore.getState(); - const allModel = collectModelsWithDefaultModel( - configStore.models, - [configStore.customModels, accessStore.customModels].join(","), - accessStore.defaultModel, - ); - const summarizeModel = allModel.find( - (m) => m.name === SUMMARIZE_MODEL && m.available, - ); - return summarizeModel?.name ?? currentModel; - } - if (currentModel.startsWith("gemini")) { - return GEMINI_SUMMARIZE_MODEL; - } - return currentModel; -} - function countMessages(msgs: ChatMessage[]) { return msgs.reduce( (pre, cur) => pre + estimateTokenLength(getMessageTextContent(cur)), @@ -581,7 +556,7 @@ export const useChatStore = createPersistStore( return; } - const providerName = modelConfig.providerName; + const providerName = modelConfig.compressProviderName; const api: ClientApi = getClientApi(providerName); // remove error messages if any @@ -603,7 +578,7 @@ export const useChatStore = createPersistStore( api.llm.chat({ messages: topicMessages, config: { - model: getSummarizeModel(session.mask.modelConfig.model), + model: modelConfig.compressModel, stream: false, providerName, }, @@ -666,7 +641,7 @@ export const useChatStore = createPersistStore( config: { ...modelcfg, stream: true, - model: getSummarizeModel(session.mask.modelConfig.model), + model: modelConfig.compressModel, }, onUpdate(message) { session.memoryPrompt = message; @@ -715,7 +690,7 @@ export const useChatStore = createPersistStore( }, { name: StoreKey.Chat, - version: 3.1, + version: 3.2, migrate(persistedState, version) { const state = persistedState as any; const newState = JSON.parse( @@ -762,6 +737,16 @@ export const useChatStore = createPersistStore( }); } + // add default summarize model for every session + if (version < 3.2) { + newState.sessions.forEach((s) => { + const config = useAppConfig.getState(); + s.mask.modelConfig.compressModel = config.modelConfig.compressModel; + s.mask.modelConfig.compressProviderName = + config.modelConfig.compressProviderName; + }); + } + return newState as any; }, }, diff --git a/app/store/config.ts b/app/store/config.ts index e8e3c9863ef..9985b9e768c 100644 --- a/app/store/config.ts +++ b/app/store/config.ts @@ -50,7 +50,7 @@ export const DEFAULT_CONFIG = { models: DEFAULT_MODELS as any as LLMModel[], modelConfig: { - model: "gpt-3.5-turbo" as ModelType, + model: "gpt-4o-mini" as ModelType, providerName: "OpenAI" as ServiceProvider, temperature: 0.5, top_p: 1, @@ -60,6 +60,8 @@ export const DEFAULT_CONFIG = { sendMemory: true, historyMessageCount: 4, compressMessageLengthThreshold: 1000, + compressModel: "gpt-4o-mini" as ModelType, + compressProviderName: "OpenAI" as ServiceProvider, enableInjectSystemPrompts: true, template: config?.template ?? DEFAULT_INPUT_TEMPLATE, size: "1024x1024" as DalleSize, @@ -140,7 +142,7 @@ export const useAppConfig = createPersistStore( }), { name: StoreKey.Config, - version: 3.9, + version: 4, migrate(persistedState, version) { const state = persistedState as ChatConfig; @@ -178,6 +180,13 @@ export const useAppConfig = createPersistStore( : config?.template ?? DEFAULT_INPUT_TEMPLATE; } + if (version < 4) { + state.modelConfig.compressModel = + DEFAULT_CONFIG.modelConfig.compressModel; + state.modelConfig.compressProviderName = + DEFAULT_CONFIG.modelConfig.compressProviderName; + } + return state as any; }, },