diff --git a/.github/workflows/jan-electron-linter-and-test.yml b/.github/workflows/jan-electron-linter-and-test.yml index 96258e7097..55c3308da6 100644 --- a/.github/workflows/jan-electron-linter-and-test.yml +++ b/.github/workflows/jan-electron-linter-and-test.yml @@ -22,6 +22,7 @@ on: branches: - main - dev + - release/** paths: - "electron/**" - .github/workflows/jan-electron-linter-and-test.yml diff --git a/core/package.json b/core/package.json index 2f4f6b576f..c4d0d475df 100644 --- a/core/package.json +++ b/core/package.json @@ -46,7 +46,7 @@ }, "devDependencies": { "@types/jest": "^29.5.12", - "@types/node": "^12.0.2", + "@types/node": "^20.11.4", "eslint": "8.57.0", "eslint-plugin-jest": "^27.9.0", "jest": "^29.7.0", diff --git a/core/src/api/index.ts b/core/src/api/index.ts index f3b4fe10f2..8e41da0d17 100644 --- a/core/src/api/index.ts +++ b/core/src/api/index.ts @@ -96,6 +96,7 @@ export enum FileManagerRoute { fileStat = 'fileStat', writeBlob = 'writeBlob', mkdir = 'mkdir', + rm = 'rm', } export type ApiFunction = (...args: any[]) => any diff --git a/core/src/extension.ts b/core/src/extension.ts index 22accb4b47..973d4778a7 100644 --- a/core/src/extension.ts +++ b/core/src/extension.ts @@ -19,6 +19,7 @@ export interface Compatibility { const ALL_INSTALLATION_STATE = [ 'NotRequired', // not required. 'Installed', // require and installed. Good to go. + 'Updatable', // require and installed but need to be updated. 'NotInstalled', // require to be installed. 'Corrupted', // require but corrupted. Need to redownload. ] as const @@ -59,6 +60,13 @@ export abstract class BaseExtension implements ExtensionType { return undefined } + /** + * Determine if the extension is updatable. + */ + updatable(): boolean { + return false + } + /** * Determine if the prerequisites for the extension are installed. * diff --git a/core/src/fs.ts b/core/src/fs.ts index 1c6d96ef01..dacdbb6d6f 100644 --- a/core/src/fs.ts +++ b/core/src/fs.ts @@ -45,6 +45,9 @@ const mkdir = (...args: any[]) => global.core.api?.mkdir(...args) */ const rmdirSync = (...args: any[]) => global.core.api?.rmdirSync(...args, { recursive: true, force: true }) + +const rm = (path: string) => global.core.api?.rm(path) + /** * Deletes a file from the local file system. * @param {string} path - The path of the file to delete. @@ -96,6 +99,7 @@ export const fs = { mkdirSync, mkdir, rmdirSync, + rm, unlinkSync, appendFileSync, copyFileSync, diff --git a/core/src/node/api/processors/fsExt.ts b/core/src/node/api/processors/fsExt.ts index 0f7dde6d9c..9b88cfef9f 100644 --- a/core/src/node/api/processors/fsExt.ts +++ b/core/src/node/api/processors/fsExt.ts @@ -100,4 +100,16 @@ export class FSExt implements Processor { }) }) } + + rmdir(path: string): Promise { + return new Promise((resolve, reject) => { + fs.rm(path, { recursive: true }, (err) => { + if (err) { + reject(err) + } else { + resolve() + } + }) + }) + } } diff --git a/core/src/node/helper/config.ts b/core/src/node/helper/config.ts index 06f2b03cd7..0a4c8cc367 100644 --- a/core/src/node/helper/config.ts +++ b/core/src/node/helper/config.ts @@ -82,26 +82,34 @@ export const getJanExtensionsPath = (): string => { */ export const physicalCpuCount = async (): Promise => { const platform = os.platform() - if (platform === 'linux') { - const output = await exec('lscpu -p | egrep -v "^#" | sort -u -t, -k 2,4 | wc -l') - return parseInt(output.trim(), 10) - } else if (platform === 'darwin') { - const output = await exec('sysctl -n hw.physicalcpu_max') - return parseInt(output.trim(), 10) - } else if (platform === 'win32') { - const output = await exec('WMIC CPU Get NumberOfCores') - return output - .split(os.EOL) - .map((line: string) => parseInt(line)) - .filter((value: number) => !isNaN(value)) - .reduce((sum: number, number: number) => sum + number, 1) - } else { - const cores = os.cpus().filter((cpu: any, index: number) => { - const hasHyperthreading = cpu.model.includes('Intel') - const isOdd = index % 2 === 1 - return !hasHyperthreading || isOdd - }) - return cores.length + try { + if (platform === 'linux') { + const output = await exec('lscpu -p | egrep -v "^#" | sort -u -t, -k 2,4 | wc -l') + return parseInt(output.trim(), 10) + } else if (platform === 'darwin') { + const output = await exec('sysctl -n hw.physicalcpu_max') + return parseInt(output.trim(), 10) + } else if (platform === 'win32') { + const output = await exec('WMIC CPU Get NumberOfCores') + return output + .split(os.EOL) + .map((line: string) => parseInt(line)) + .filter((value: number) => !isNaN(value)) + .reduce((sum: number, number: number) => sum + number, 1) + } else { + const cores = os.cpus().filter((cpu: any, index: number) => { + const hasHyperthreading = cpu.model.includes('Intel') + const isOdd = index % 2 === 1 + return !hasHyperthreading || isOdd + }) + return cores.length + } + } catch (err) { + console.warn('Failed to get physical CPU count', err) + // Divide by 2 to get rid of hyper threading + const coreCount = Math.ceil(os.cpus().length / 2) + console.debug('Using node API to get physical CPU count:', coreCount) + return coreCount } } diff --git a/core/src/node/helper/resource.ts b/core/src/node/helper/resource.ts index c79a63688b..faaaace05e 100644 --- a/core/src/node/helper/resource.ts +++ b/core/src/node/helper/resource.ts @@ -1,6 +1,6 @@ import { SystemResourceInfo } from '../../types' import { physicalCpuCount } from './config' -import { log, logServer } from './log' +import { log } from './log' export const getSystemResourceInfo = async (): Promise => { const cpu = await physicalCpuCount() diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts index e2970b8f9d..6072758842 100644 --- a/extensions/model-extension/src/index.ts +++ b/extensions/model-extension/src/index.ts @@ -38,7 +38,7 @@ export default class JanModelExtension extends ModelExtension { private static readonly _tensorRtEngineFormat = '.engine' private static readonly _configDirName = 'config' private static readonly _defaultModelFileName = 'default-model.json' - private static readonly _supportedGpuArch = ['turing', 'ampere', 'ada'] + private static readonly _supportedGpuArch = ['ampere', 'ada'] /** * Called when the extension is loaded. diff --git a/extensions/monitoring-extension/src/node/index.ts b/extensions/monitoring-extension/src/node/index.ts index 00fa7d0f64..ca767d348f 100644 --- a/extensions/monitoring-extension/src/node/index.ts +++ b/extensions/monitoring-extension/src/node/index.ts @@ -181,8 +181,7 @@ const updateNvidiaDriverInfo = async () => const getGpuArch = (gpuName: string): string => { if (!gpuName.toLowerCase().includes('nvidia')) return 'unknown' - if (gpuName.includes('20')) return 'turing' - else if (gpuName.includes('30')) return 'ampere' + if (gpuName.includes('30')) return 'ampere' else if (gpuName.includes('40')) return 'ada' else return 'unknown' } diff --git a/extensions/tensorrt-llm-extension/models.json b/extensions/tensorrt-llm-extension/models.json index 7f95940b71..a27cf059d2 100644 --- a/extensions/tensorrt-llm-extension/models.json +++ b/extensions/tensorrt-llm-extension/models.json @@ -3,27 +3,31 @@ "sources": [ { "filename": "config.json", - "url": "https://delta.jan.ai/dist/models///LlamaCorn-1.1B-Chat-fp16/config.json" + "url": "https://delta.jan.ai/dist/models///tensorrt-llm-v0.7.1/LlamaCorn-1.1B-Chat-fp16/config.json" }, { - "filename": "rank0.engine", - "url": "https://delta.jan.ai/dist/models///LlamaCorn-1.1B-Chat-fp16/rank0.engine" + "filename": "mistral_float16_tp1_rank0.engine", + "url": "https://delta.jan.ai/dist/models///tensorrt-llm-v0.7.1/LlamaCorn-1.1B-Chat-fp16/mistral_float16_tp1_rank0.engine" }, { "filename": "tokenizer.model", - "url": "https://delta.jan.ai/dist/models///LlamaCorn-1.1B-Chat-fp16/tokenizer.model" + "url": "https://delta.jan.ai/dist/models///tensorrt-llm-v0.7.1/LlamaCorn-1.1B-Chat-fp16/tokenizer.model" }, { "filename": "special_tokens_map.json", - "url": "https://delta.jan.ai/dist/models///LlamaCorn-1.1B-Chat-fp16/special_tokens_map.json" + "url": "https://delta.jan.ai/dist/models///tensorrt-llm-v0.7.1/LlamaCorn-1.1B-Chat-fp16/special_tokens_map.json" }, { "filename": "tokenizer.json", - "url": "https://delta.jan.ai/dist/models///LlamaCorn-1.1B-Chat-fp16/tokenizer.json" + "url": "https://delta.jan.ai/dist/models///tensorrt-llm-v0.7.1/LlamaCorn-1.1B-Chat-fp16/tokenizer.json" }, { "filename": "tokenizer_config.json", - "url": "https://delta.jan.ai/dist/models///LlamaCorn-1.1B-Chat-fp16/tokenizer_config.json" + "url": "https://delta.jan.ai/dist/models///tensorrt-llm-v0.7.1/LlamaCorn-1.1B-Chat-fp16/tokenizer_config.json" + }, + { + "filename": "model.cache", + "url": "https://delta.jan.ai/dist/models///tensorrt-llm-v0.7.1/LlamaCorn-1.1B-Chat-fp16/model.cache" } ], "id": "llamacorn-1.1b-chat-fp16", @@ -50,27 +54,31 @@ "sources": [ { "filename": "config.json", - "url": "https://delta.jan.ai/dist/models///TinyJensen-1.1B-Chat-fp16/config.json" + "url": "https://delta.jan.ai/dist/models///tensorrt-llm-v0.7.1/TinyJensen-1.1B-Chat-fp16/config.json" }, { - "filename": "rank0.engine", - "url": "https://delta.jan.ai/dist/models///TinyJensen-1.1B-Chat-fp16/rank0.engine" + "filename": "mistral_float16_tp1_rank0.engine", + "url": "https://delta.jan.ai/dist/models///tensorrt-llm-v0.7.1/TinyJensen-1.1B-Chat-fp16/mistral_float16_tp1_rank0.engine" }, { "filename": "tokenizer.model", - "url": "https://delta.jan.ai/dist/models///TinyJensen-1.1B-Chat-fp16/tokenizer.model" + "url": "https://delta.jan.ai/dist/models///tensorrt-llm-v0.7.1/TinyJensen-1.1B-Chat-fp16/tokenizer.model" }, { "filename": "special_tokens_map.json", - "url": "https://delta.jan.ai/dist/models///TinyJensen-1.1B-Chat-fp16/special_tokens_map.json" + "url": "https://delta.jan.ai/dist/models///tensorrt-llm-v0.7.1/TinyJensen-1.1B-Chat-fp16/special_tokens_map.json" }, { "filename": "tokenizer.json", - "url": "https://delta.jan.ai/dist/models///TinyJensen-1.1B-Chat-fp16/tokenizer.json" + "url": "https://delta.jan.ai/dist/models///tensorrt-llm-v0.7.1/TinyJensen-1.1B-Chat-fp16/tokenizer.json" }, { "filename": "tokenizer_config.json", - "url": "https://delta.jan.ai/dist/models///TinyJensen-1.1B-Chat-fp16/tokenizer_config.json" + "url": "https://delta.jan.ai/dist/models///tensorrt-llm-v0.7.1/TinyJensen-1.1B-Chat-fp16/tokenizer_config.json" + }, + { + "filename": "model.cache", + "url": "https://delta.jan.ai/dist/models///tensorrt-llm-v0.7.1/TinyJensen-1.1B-Chat-fp16/model.cache" } ], "id": "tinyjensen-1.1b-chat-fp16", @@ -92,5 +100,57 @@ "size": 2151000000 }, "engine": "nitro-tensorrt-llm" + }, + { + "sources": [ + { + "filename": "config.json", + "url": "https://delta.jan.ai/dist/models///tensorrt-llm-v0.7.1/Mistral-7B-Instruct-v0.1-int4/config.json" + }, + { + "filename": "mistral_float16_tp1_rank0.engine", + "url": "https://delta.jan.ai/dist/models///tensorrt-llm-v0.7.1/Mistral-7B-Instruct-v0.1-int4/mistral_float16_tp1_rank0.engine" + }, + { + "filename": "tokenizer.model", + "url": "https://delta.jan.ai/dist/models///tensorrt-llm-v0.7.1/Mistral-7B-Instruct-v0.1-int4/tokenizer.model" + }, + { + "filename": "special_tokens_map.json", + "url": "https://delta.jan.ai/dist/models///tensorrt-llm-v0.7.1/Mistral-7B-Instruct-v0.1-int4/special_tokens_map.json" + }, + { + "filename": "tokenizer.json", + "url": "https://delta.jan.ai/dist/models///tensorrt-llm-v0.7.1/Mistral-7B-Instruct-v0.1-int4/tokenizer.json" + }, + { + "filename": "tokenizer_config.json", + "url": "https://delta.jan.ai/dist/models///tensorrt-llm-v0.7.1/Mistral-7B-Instruct-v0.1-int4/tokenizer_config.json" + }, + { + "filename": "model.cache", + "url": "https://delta.jan.ai/dist/models///tensorrt-llm-v0.7.1/Mistral-7B-Instruct-v0.1-int4/model.cache" + } + ], + "id": "mistral-7b-instruct-int4", + "object": "model", + "name": "Mistral 7B Instruct v0.1 INT4", + "version": "1.0", + "description": "Mistral 7B Instruct v0.1 INT4", + "format": "TensorRT-LLM", + "settings": { + "ctx_len": 2048, + "text_model": false, + "prompt_template": "[INST] {prompt} [/INST]" + }, + "parameters": { + "max_tokens": 4096 + }, + "metadata": { + "author": "MistralAI", + "tags": ["TensorRT-LLM", "7B", "Finetuned"], + "size": 3840000000 + }, + "engine": "nitro-tensorrt-llm" } ] diff --git a/extensions/tensorrt-llm-extension/package.json b/extensions/tensorrt-llm-extension/package.json index ec54a82c15..d3521669e2 100644 --- a/extensions/tensorrt-llm-extension/package.json +++ b/extensions/tensorrt-llm-extension/package.json @@ -18,7 +18,7 @@ "0.1.0" ] }, - "tensorrtVersion": "0.1.6", + "tensorrtVersion": "0.1.8", "provider": "nitro-tensorrt-llm", "scripts": { "build": "tsc --module commonjs && rollup -c rollup.config.ts", diff --git a/extensions/tensorrt-llm-extension/rollup.config.ts b/extensions/tensorrt-llm-extension/rollup.config.ts index ee8d050d3f..e602bc7205 100644 --- a/extensions/tensorrt-llm-extension/rollup.config.ts +++ b/extensions/tensorrt-llm-extension/rollup.config.ts @@ -21,7 +21,7 @@ export default [ DOWNLOAD_RUNNER_URL: process.platform === 'win32' ? JSON.stringify( - 'https://github.com/janhq/nitro-tensorrt-llm/releases/download/windows-v/nitro-windows-v-amd64-tensorrt-llm-.tar.gz' + 'https://github.com/janhq/nitro-tensorrt-llm/releases/download/windows-v-tensorrt-llm-v0.7.1/nitro-windows-v-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz' ) : JSON.stringify( 'https://github.com/janhq/nitro-tensorrt-llm/releases/download/linux-v/nitro-linux-v-amd64-tensorrt-llm-.tar.gz' diff --git a/extensions/tensorrt-llm-extension/src/index.ts b/extensions/tensorrt-llm-extension/src/index.ts index f8e2f775ed..d2d08e8a71 100644 --- a/extensions/tensorrt-llm-extension/src/index.ts +++ b/extensions/tensorrt-llm-extension/src/index.ts @@ -39,8 +39,9 @@ export default class TensorRTLLMExtension extends LocalOAIEngine { override inferenceUrl = INFERENCE_URL override nodeModule = NODE - private supportedGpuArch = ['turing', 'ampere', 'ada'] + private supportedGpuArch = ['ampere', 'ada'] private supportedPlatform = ['win32', 'linux'] + private isUpdateAvailable = false compatibility() { return COMPATIBILITY as unknown as Compatibility @@ -56,6 +57,8 @@ export default class TensorRTLLMExtension extends LocalOAIEngine { } override async install(): Promise { + await this.removePopulatedModels() + const info = await systemInformation() console.debug( `TensorRTLLMExtension installing pre-requisites... ${JSON.stringify(info)}` @@ -141,6 +144,22 @@ export default class TensorRTLLMExtension extends LocalOAIEngine { events.on(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess) } + async removePopulatedModels(): Promise { + console.debug(`removePopulatedModels`, JSON.stringify(models)) + const janDataFolderPath = await getJanDataFolderPath() + const modelFolderPath = await joinPath([janDataFolderPath, 'models']) + + for (const model of models) { + const modelPath = await joinPath([modelFolderPath, model.id]) + console.debug(`modelPath: ${modelPath}`) + if (await fs.existsSync(modelPath)) { + console.debug(`Removing model ${modelPath}`) + await fs.rmdirSync(modelPath) + } + } + events.emit(ModelEvent.OnModelsUpdate, {}) + } + async onModelInit(model: Model): Promise { if (model.engine !== this.provider) return @@ -156,6 +175,10 @@ export default class TensorRTLLMExtension extends LocalOAIEngine { } } + updatable() { + return this.isUpdateAvailable + } + override async installationState(): Promise { const info = await systemInformation() diff --git a/extensions/tensorrt-llm-extension/src/node/index.ts b/extensions/tensorrt-llm-extension/src/node/index.ts index 3766b5524c..1afebb950f 100644 --- a/extensions/tensorrt-llm-extension/src/node/index.ts +++ b/extensions/tensorrt-llm-extension/src/node/index.ts @@ -5,12 +5,13 @@ import fetchRT from 'fetch-retry' import { log, getJanDataFolderPath } from '@janhq/core/node' import decompress from 'decompress' import { SystemInformation } from '@janhq/core' +import { PromptTemplate } from '@janhq/core' // Polyfill fetch with retry const fetchRetry = fetchRT(fetch) const supportedPlatform = (): string[] => ['win32', 'linux'] -const supportedGpuArch = (): string[] => ['turing', 'ampere', 'ada'] +const supportedGpuArch = (): string[] => ['ampere', 'ada'] /** * The response object for model init operation. @@ -35,9 +36,21 @@ async function loadModel( // e.g. ~/jan/models/llama-2 let modelFolder = params.modelFolder + if (params.model.settings.prompt_template) { + const promptTemplate = params.model.settings.prompt_template + const prompt = promptTemplateConverter(promptTemplate) + if (prompt?.error) { + return Promise.reject(prompt.error) + } + params.model.settings.system_prompt = prompt.system_prompt + params.model.settings.user_prompt = prompt.user_prompt + params.model.settings.ai_prompt = prompt.ai_prompt + } + const settings: ModelLoadParams = { engine_path: modelFolder, ctx_len: params.model.settings.ctx_len ?? 2048, + ...params.model.settings, } if (!systemInfo) { throw new Error('Cannot get system info. Unable to start nitro x tensorrt.') @@ -220,6 +233,52 @@ const decompressRunner = async (zipPath: string, output: string) => { } } +/** + * Parse prompt template into agrs settings + * @param promptTemplate Template as string + * @returns + */ +function promptTemplateConverter(promptTemplate: string): PromptTemplate { + // Split the string using the markers + const systemMarker = '{system_message}' + const promptMarker = '{prompt}' + + if ( + promptTemplate.includes(systemMarker) && + promptTemplate.includes(promptMarker) + ) { + // Find the indices of the markers + const systemIndex = promptTemplate.indexOf(systemMarker) + const promptIndex = promptTemplate.indexOf(promptMarker) + + // Extract the parts of the string + const system_prompt = promptTemplate.substring(0, systemIndex) + const user_prompt = promptTemplate.substring( + systemIndex + systemMarker.length, + promptIndex + ) + const ai_prompt = promptTemplate.substring( + promptIndex + promptMarker.length + ) + + // Return the split parts + return { system_prompt, user_prompt, ai_prompt } + } else if (promptTemplate.includes(promptMarker)) { + // Extract the parts of the string for the case where only promptMarker is present + const promptIndex = promptTemplate.indexOf(promptMarker) + const user_prompt = promptTemplate.substring(0, promptIndex) + const ai_prompt = promptTemplate.substring( + promptIndex + promptMarker.length + ) + + // Return the split parts + return { user_prompt, ai_prompt } + } + + // Return an error if none of the conditions are met + return { error: 'Cannot split prompt template' } +} + export default { supportedPlatform, supportedGpuArch, diff --git a/web/containers/Providers/index.tsx b/web/containers/Providers/index.tsx index e70a56ca87..10c6c7547a 100644 --- a/web/containers/Providers/index.tsx +++ b/web/containers/Providers/index.tsx @@ -1,6 +1,6 @@ 'use client' -import { PropsWithChildren, useEffect, useState } from 'react' +import { PropsWithChildren, useCallback, useEffect, useState } from 'react' import { Toaster } from 'react-hot-toast' @@ -37,7 +37,7 @@ const Providers = (props: PropsWithChildren) => { const [activated, setActivated] = useState(false) const [settingUp, setSettingUp] = useState(false) - async function setupExtensions() { + const setupExtensions = useCallback(async () => { // Register all active extensions await extensionManager.registerActive() @@ -57,7 +57,7 @@ const Providers = (props: PropsWithChildren) => { setSettingUp(false) setActivated(true) }, 500) - } + }, [pathname]) // Services Setup useEffect(() => { @@ -78,7 +78,7 @@ const Providers = (props: PropsWithChildren) => { setActivated(true) } } - }, [setupCore]) + }, [setupCore, setupExtensions]) return ( diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index 11a57a5988..0bbc779a6f 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -102,7 +102,6 @@ export default function useSendChatMessage() { console.error('No active thread') return } - setIsGeneratingResponse(true) updateThreadWaiting(activeThreadRef.current.id, true) const messages: ChatCompletionMessage[] = [ activeThreadRef.current.assistants[0]?.instructions, @@ -148,7 +147,7 @@ export default function useSendChatMessage() { await waitForModelStarting(modelId) setQueuedMessage(false) } - + setIsGeneratingResponse(true) if (currentMessage.role !== ChatCompletionRole.User) { // Delete last response before regenerating deleteMessage(currentMessage.id ?? '') @@ -171,7 +170,6 @@ export default function useSendChatMessage() { console.error('No active thread') return } - setIsGeneratingResponse(true) if (engineParamsUpdate) setReloadModel(true) @@ -361,7 +359,7 @@ export default function useSendChatMessage() { await waitForModelStarting(modelId) setQueuedMessage(false) } - + setIsGeneratingResponse(true) events.emit(MessageEvent.OnMessageSent, messageRequest) setReloadModel(false) diff --git a/web/hooks/useSettings.ts b/web/hooks/useSettings.ts index 9ff89827e8..378ca33faf 100644 --- a/web/hooks/useSettings.ts +++ b/web/hooks/useSettings.ts @@ -70,11 +70,6 @@ export const useSettings = () => { } } await fs.writeFileSync(settingsFile, JSON.stringify(settings)) - - // Relaunch to apply settings - if (vulkan != null) { - window.location.reload() - } } return { diff --git a/web/screens/Settings/Advanced/index.tsx b/web/screens/Settings/Advanced/index.tsx index 3cc43e744e..67ebf81d52 100644 --- a/web/screens/Settings/Advanced/index.tsx +++ b/web/screens/Settings/Advanced/index.tsx @@ -90,12 +90,38 @@ const Advanced = () => { [setPartialProxy, setProxy] ) - const updateQuickAskEnabled = async (e: boolean) => { + const updateQuickAskEnabled = async ( + e: boolean, + relaunch: boolean = true + ) => { const appConfiguration: AppConfiguration = await window.core?.api?.getAppConfigurations() appConfiguration.quick_ask = e await window.core?.api?.updateAppConfiguration(appConfiguration) - window.core?.api?.relaunch() + if (relaunch) window.core?.api?.relaunch() + } + + const updateVulkanEnabled = async (e: boolean, relaunch: boolean = true) => { + toaster({ + title: 'Reload', + description: 'Vulkan settings updated. Reload now to apply the changes.', + }) + stopModel() + setVulkanEnabled(e) + await saveSettings({ vulkan: e, gpusInUse: [] }) + // Relaunch to apply settings + if (relaunch) window.location.reload() + } + + const updateExperimentalEnabled = async (e: boolean) => { + setExperimentalEnabled(e) + if (e) return + + // It affects other settings, so we need to reset them + const isRelaunch = quickAskEnabled || vulkanEnabled + if (quickAskEnabled) await updateQuickAskEnabled(false, false) + if (vulkanEnabled) await updateVulkanEnabled(false, false) + if (isRelaunch) window.core?.api?.relaunch() } useEffect(() => { @@ -179,7 +205,7 @@ const Advanced = () => { @@ -381,16 +407,7 @@ const Advanced = () => { { - toaster({ - title: 'Reload', - description: - 'Vulkan settings updated. Reload now to apply the changes.', - }) - stopModel() - saveSettings({ vulkan: e, gpusInUse: [] }) - setVulkanEnabled(e) - }} + onCheckedChange={(e) => updateVulkanEnabled(e)} /> )} diff --git a/web/screens/Settings/CoreExtensions/TensorRtExtensionItem.tsx b/web/screens/Settings/CoreExtensions/TensorRtExtensionItem.tsx index 60677b1850..fb0214536a 100644 --- a/web/screens/Settings/CoreExtensions/TensorRtExtensionItem.tsx +++ b/web/screens/Settings/CoreExtensions/TensorRtExtensionItem.tsx @@ -23,6 +23,8 @@ import { useAtomValue } from 'jotai' import { Marked, Renderer } from 'marked' +import UpdateExtensionModal from './UpdateExtensionModal' + import { extensionManager } from '@/extension' import Extension from '@/extension/Extension' import { installingExtensionAtom } from '@/helpers/atoms/Extension.atom' @@ -39,7 +41,7 @@ const TensorRtExtensionItem: React.FC = ({ item }) => { useState('NotRequired') const installingExtensions = useAtomValue(installingExtensionAtom) const [isGpuSupported, setIsGpuSupported] = useState(false) - + const [promptUpdateModal, setPromptUpdateModal] = useState(false) const isInstalling = installingExtensions.some( (e) => e.extensionId === item.name ) @@ -69,7 +71,7 @@ const TensorRtExtensionItem: React.FC = ({ item }) => { return } - const supportedGpuArch = ['turing', 'ampere', 'ada'] + const supportedGpuArch = ['ampere', 'ada'] setIsGpuSupported(supportedGpuArch.includes(arch)) } getSystemInfos() @@ -138,6 +140,7 @@ const TensorRtExtensionItem: React.FC = ({ item }) => { installProgress={progress} installState={installState} onInstallClick={onInstallClick} + onUpdateClick={() => setPromptUpdateModal(true)} onCancelClick={onCancelInstallingClick} /> @@ -177,6 +180,9 @@ const TensorRtExtensionItem: React.FC = ({ item }) => { )} + {promptUpdateModal && ( + + )} ) } @@ -185,6 +191,7 @@ type InstallStateProps = { installProgress: number installState: InstallationState onInstallClick: () => void + onUpdateClick: () => void onCancelClick: () => void } @@ -192,6 +199,7 @@ const InstallStateIndicator: React.FC = ({ installProgress, installState, onInstallClick, + onUpdateClick, onCancelClick, }) => { if (installProgress !== -1) { @@ -218,6 +226,12 @@ const InstallStateIndicator: React.FC = ({ Installed ) + case 'Updatable': + return ( + + ) case 'NotInstalled': return ( + + + + + + + + + ) +} + +export default React.memo(UpdateExtensionModal)