From 32d631eab6040e3d30270cb28f44578f83218708 Mon Sep 17 00:00:00 2001 From: Steven Sun Date: Wed, 14 Feb 2024 14:52:41 -0500 Subject: [PATCH] fix: context has no system or prompt state --- src/index.ts | 39 ++++++++++++++++---------------- src/loadModel.ts | 59 ++++++++++++++++++------------------------------ src/types.ts | 16 ++++++------- 3 files changed, 49 insertions(+), 65 deletions(-) diff --git a/src/index.ts b/src/index.ts index ec241c0..15ea3d7 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,7 +1,5 @@ import type { ModelSpec, - GenerateOptions, - CommonOptions, LoadedModel, LoadReport, TemplateExpression, @@ -18,8 +16,8 @@ let cachedModelAndSpec: { spec: ModelSpec, model: LoadedModel } | undefined; // NOTE this currently only works for Llama 2 variations due to different wasm naming conventions const guessModelSpecFromPrebuiltId = (id: string) => ({ // TODO generally works for currently known prebuilts - modelWeightsConfigUrl: `https://huggingface.co/mlc-ai/mlc-chat-${id}/resolve/main/`, - modelLibWasmUrl: `https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/${id}-webgpu.wasm` + modelWeightsConfigUrl: `https://huggingface.co/mlc-ai/mlc-chat-${id}/resolve/main/`, + modelLibWasmUrl: `https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/${id}-webgpu.wasm` }) /** * Load a model spec or try to guess it from a prebuilt-id. @@ -64,7 +62,7 @@ export const loadModel = async ( }) if (cachedModelAndSpec?.spec.modelLibWasmUrl == spec.modelLibWasmUrl - && cachedModelAndSpec?.spec.modelWeightsConfigUrl == cachedModelAndSpec?.spec.modelWeightsConfigUrl) { + && cachedModelAndSpec?.spec.modelWeightsConfigUrl == cachedModelAndSpec?.spec.modelWeightsConfigUrl) { await cachedModelAndSpec.model.cancel() return cachedModelAndSpec.model } @@ -148,7 +146,7 @@ const asOp = ( const contextPreword = configPreword ?? ( expr.preword === 'a' ? 'Generate' : - expr.preword === 'the' ? 'What is' : + expr.preword === 'the' ? 'What is' : '' ) @@ -173,7 +171,7 @@ const expandIfRefOp = (op: Exclude, ref: (id: string) => string | un prompt: expr, stop: op.stop } - } + } return { ...expr, @@ -191,7 +189,7 @@ export type Template = { /** Collect the template as a string - optionally with a streaming handler */ collect: (stream?: GenerationStreamHandler) => Promise /** Like collect but returns the completion and refs */ - collect_refs: (stream?: GenerationStreamHandler) => Promise<{completion: string, refs: Record}> + collect_refs: (stream?: GenerationStreamHandler) => Promise<{ completion: string, refs: Record }> model: LoadedModel // TODO refactor to just cancel() -- this is only used to cancel the underlying model } @@ -211,7 +209,7 @@ export type Template = { type CreateTemplateContext = ( literals: TemplateStringsArray, - ...expressions: (TemplateExpression|WithRef)[] + ...expressions: (TemplateExpression | WithRef)[] ) => Template /** @@ -284,7 +282,7 @@ export const ad = (model: LoadedModel): CreateTemplate => { options, preword: null, }), - context: (system: string, preprompt?: string, config?: TemplateContextOptions): CreateTemplateContext =>(literals, ...expressions) => { + context: (system: string, preprompt?: string, config?: TemplateContextOptions): CreateTemplateContext => (literals, ...expressions) => { let refs: Record = {} const ref = (id: string): string | undefined => refs[id] @@ -300,17 +298,15 @@ export const ad = (model: LoadedModel): CreateTemplate => { } const collect = async (stream?: GenerationStreamHandler) => { - await model.setContext(system, preprompt) - if (stream) { stream({ - content: ops.reduce( (completion, op) => { - if (typeof(op) === 'string') { + content: ops.reduce((completion, op) => { + if (typeof (op) === 'string') { return completion + op } else { if ('refExpr' in op) { const expr = op.refExpr(x => `(ref: ${x})`) - console.log('template refExpr', {expr}) + console.log('template refExpr', { expr }) return completion + `\${'${typeof expr === 'string' ? expr : expr.prompt}'}` } else { return completion + `\${'${op.prompt}'}` @@ -331,7 +327,7 @@ export const ad = (model: LoadedModel): CreateTemplate => { return ops.reduce>(async (completion_, op) => { const completion = await completion_ - if (typeof(op) === 'string') { + if (typeof (op) === 'string') { stream?.({ content: op, type: 'lit' @@ -340,10 +336,13 @@ export const ad = (model: LoadedModel): CreateTemplate => { } else { const { options, prompt, stop } = expandIfRefOp(op, ref) - const generated = await model.generate( + const generated = await model.generate({ prompt, - completion, - [stop, ...(options?.stops ?? [])], + preprompt, + system, + priorCompletion: completion, + stops: [stop, ...(options?.stops ?? [])], + }, { stream, ...config, @@ -367,7 +366,7 @@ export const ad = (model: LoadedModel): CreateTemplate => { return { completion, - refs + refs // FIXME refs should be scoped to collect, subsequent contexts will have stale refs } }, model diff --git a/src/loadModel.ts b/src/loadModel.ts index 736e23c..9e72b2a 100644 --- a/src/loadModel.ts +++ b/src/loadModel.ts @@ -58,11 +58,11 @@ const perf = (() => { get entries() { return entries }, summarize: () => { const sums = Object.fromEntries(Object.entries(entries).map( - ([label, results]) => [label, results.reduce((a,x) => a + x, 0) ] + ([label, results]) => [label, results.reduce((a, x) => a + x, 0)] )) const averages = Object.fromEntries(Object.entries(sums).map( - ([label, sum]) => [label, sum / entries[label].length ] + ([label, sum]) => [label, sum / entries[label].length] )) console.debug('#perf', { sums, averages, entries }) @@ -98,17 +98,17 @@ export default async ( const device = targetDevice === TargetDevice.GPU ? tvm.webgpu() : tvm.cpu() if (targetDevice === TargetDevice.GPU) { - updateReport({detectGPU: 'waiting'}) + updateReport({ detectGPU: 'waiting' }) const gpu = await detectGPUDevice() if (gpu == undefined) { - updateReport({detectGPU: 'failed'}) + updateReport({ detectGPU: 'failed' }) throw Error('Cannot find GPU in environment') } updateReport({ detectGPU: gpu.adapterInfo.vendor }) - tvm.initWebGPU(gpu.device) + tvm.initWebGPU(gpu.device) } let isLoadingGpuShaders = false @@ -146,7 +146,7 @@ export default async ( 'tokenizer.model': Tokenizer.fromSentencePiece, 'tokenizer.json': Tokenizer.fromJSON }).find(([file, _create]) => config.tokenizer_files.includes(file)) - // preference comes from the order of tokenizer_files -- seems like .json is preferred over .model + // preference comes from the order of tokenizer_files -- seems like .json is preferred over .model if (configTokenizerFiles == undefined) { const err = `Cant handle tokenizer files ${config.tokenizer_files}`; @@ -158,7 +158,7 @@ export default async ( const tokenizerResult = await cacheScope('model') - .fetchWithCache(new URL(path, spec.modelWeightsConfigUrl).href) + .fetchWithCache(new URL(path, spec.modelWeightsConfigUrl).href) const tokenizer = await create(await tokenizerResult.arrayBuffer()) const w = window as any @@ -192,7 +192,6 @@ export default async ( const getMetadata = vm.getFunction('get_metadata') const metadata = JSON.parse(tvm.detachFromCurrentScope(getMetadata()).toString()) - console.info({metadata}) const stopTokens: number[] = metadata.stop_tokens const createKvCache = vm.getFunction('create_kv_cache') @@ -308,8 +307,6 @@ export default async ( } let modelState: ModelState = ModelState.Waiting - let system_ = '<>You are a helpful assistant<>\n\n' - let preprompt_ = '[INST]' let totalTokenCount = 0 @@ -333,10 +330,6 @@ export default async ( totalTokenCount += 1 if (stopTokens.includes(token)) { - console.info('stop token', token, { - tokens, - completion - }) return false } @@ -374,10 +367,13 @@ export default async ( } - const generate = async ( - prompt: string, - priorCompletion: string, - stops: string[], + const generate: LoadedModel['generate'] = async ({ + prompt, + priorCompletion, + stops, + system, + preprompt + }, options?: GenerateOptions ): Promise => { modelState = ModelState.Running as ModelState @@ -394,11 +390,11 @@ export default async ( const buildSampler = options_?.sampler const sample = buildSampler - ? buildSampler(priorCompletion, stops, options_.temperature, options_.top_p) - : (logits: CpuNDArray) => sampleTokenFromLogits(logits, options_.temperature, options_.top_p) + ? buildSampler(priorCompletion, stops, options_.temperature, options_.top_p) + : (logits: CpuNDArray) => sampleTokenFromLogits(logits, options_.temperature, options_.top_p) - const prefillText = `${system_}${preprompt_} ${prompt} [/INST] ${priorCompletion}` - console.info('[generate:start]', prompt, {...options_, prefillText}) + const prefillText = `<>${system ?? 'You are a helpful assistant'}<>\n\n[INST]${preprompt ? ` ${preprompt}` : ''} ${prompt} [/INST] ${priorCompletion}` + console.info('[generate:start]', prompt, { ...options_, prefillText }) if (filledKvCacheLength > 0) { unfill() @@ -457,9 +453,9 @@ export default async ( content: accepted.completion }) - console.log({failedValidation: accepted.completion}) + console.info('[validation-failed]', accepted.completion) - return await generate(prompt, priorCompletion, stops, { + return await generate({ prompt, priorCompletion, stops }, { ...options_, validate: { ...options_.validate, @@ -509,9 +505,9 @@ export default async ( updateReport({ ready: true }) return { - generate: async (prompt, priorCompletion, stops, options?) => { + generate: async (params, options?) => { try { - return await generate(prompt, priorCompletion, stops, options) + return await generate(params, options) } catch (e) { unfill() modelState = ModelState.Waiting @@ -519,17 +515,6 @@ export default async ( } }, bias, - setContext: async (system: string, preprompt?: string) => { - system_ = `<>${system}<>\n\n` - preprompt_ = preprompt ? `[INST] ${preprompt}` : preprompt_ - - console.log('[context]', system, preprompt) - - // TODO prefill here, save kvCache, reset kvCache on each generate as necessary - // Is that possible? can I prefill with existing kvCache? - // This only saves prefilling the system + preprompt anyway - it won't do anything for generates since the generate prompt - // goes before the completion body - }, cancel: async () => { if (modelState === ModelState.Running) { modelState = ModelState.Cancelling diff --git a/src/types.ts b/src/types.ts index ce6d7e7..5afb9e7 100644 --- a/src/types.ts +++ b/src/types.ts @@ -62,8 +62,8 @@ export type TemplateExpressionOptions = { export type TemplateContextOptions = Omit & { - preword?: string -} + preword?: string + } export type TemplateExpression = { prompt: string @@ -80,17 +80,17 @@ export type TemplateExpression = { * ``` */ export type LoadedModel = { - setContext: (system: string, preprompt?: string) => Promise - generate: ( + generate: (params: { prompt: string, - completion: string, + priorCompletion: string, stops: string[], - config?: GenerateOptions - ) => Promise + system?: string, + preprompt?: string + }, config?: GenerateOptions) => Promise cancel: () => Promise bias: Bias totalTokenCount: number -} +} /** * Specifies where to retrieve model weights and configuration