Skip to content

Commit

Permalink
fix: context has no system or prompt state
Browse files Browse the repository at this point in the history
  • Loading branch information
gsuuon committed Feb 14, 2024
1 parent d16b30d commit 32d631e
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 65 deletions.
39 changes: 19 additions & 20 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import type {
ModelSpec,
GenerateOptions,
CommonOptions,
LoadedModel,
LoadReport,
TemplateExpression,
Expand All @@ -18,8 +16,8 @@ let cachedModelAndSpec: { spec: ModelSpec, model: LoadedModel } | undefined;

// NOTE this currently only works for Llama 2 variations due to different wasm naming conventions
const guessModelSpecFromPrebuiltId = (id: string) => ({ // TODO generally works for currently known prebuilts
modelWeightsConfigUrl: `https://huggingface.co/mlc-ai/mlc-chat-${id}/resolve/main/`,
modelLibWasmUrl: `https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/${id}-webgpu.wasm`
modelWeightsConfigUrl: `https://huggingface.co/mlc-ai/mlc-chat-${id}/resolve/main/`,
modelLibWasmUrl: `https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/${id}-webgpu.wasm`
})
/**
* Load a model spec or try to guess it from a prebuilt-id.
Expand Down Expand Up @@ -64,7 +62,7 @@ export const loadModel = async (
})

if (cachedModelAndSpec?.spec.modelLibWasmUrl == spec.modelLibWasmUrl
&& cachedModelAndSpec?.spec.modelWeightsConfigUrl == cachedModelAndSpec?.spec.modelWeightsConfigUrl) {
&& cachedModelAndSpec?.spec.modelWeightsConfigUrl == cachedModelAndSpec?.spec.modelWeightsConfigUrl) {
await cachedModelAndSpec.model.cancel()
return cachedModelAndSpec.model
}
Expand Down Expand Up @@ -148,7 +146,7 @@ const asOp = (

const contextPreword = configPreword ?? (
expr.preword === 'a' ? 'Generate' :
expr.preword === 'the' ? 'What is' :
expr.preword === 'the' ? 'What is' :
''
)

Expand All @@ -173,7 +171,7 @@ const expandIfRefOp = (op: Exclude<Op, string>, ref: (id: string) => string | un
prompt: expr,
stop: op.stop
}
}
}

return {
...expr,
Expand All @@ -191,7 +189,7 @@ export type Template = {
/** Collect the template as a string - optionally with a streaming handler */
collect: (stream?: GenerationStreamHandler) => Promise<string>
/** Like collect but returns the completion and refs */
collect_refs: (stream?: GenerationStreamHandler) => Promise<{completion: string, refs: Record<string, string>}>
collect_refs: (stream?: GenerationStreamHandler) => Promise<{ completion: string, refs: Record<string, string> }>
model: LoadedModel // TODO refactor to just cancel() -- this is only used to cancel the underlying model
}

Expand All @@ -211,7 +209,7 @@ export type Template = {
type CreateTemplateContext =
(
literals: TemplateStringsArray,
...expressions: (TemplateExpression|WithRef<TemplateExpression>)[]
...expressions: (TemplateExpression | WithRef<TemplateExpression>)[]
) => Template

/**
Expand Down Expand Up @@ -284,7 +282,7 @@ export const ad = (model: LoadedModel): CreateTemplate => {
options,
preword: null,
}),
context: (system: string, preprompt?: string, config?: TemplateContextOptions): CreateTemplateContext =>(literals, ...expressions) => {
context: (system: string, preprompt?: string, config?: TemplateContextOptions): CreateTemplateContext => (literals, ...expressions) => {
let refs: Record<string, string> = {}
const ref = (id: string): string | undefined => refs[id]

Expand All @@ -300,17 +298,15 @@ export const ad = (model: LoadedModel): CreateTemplate => {
}

const collect = async (stream?: GenerationStreamHandler) => {
await model.setContext(system, preprompt)

if (stream) {
stream({
content: ops.reduce<string>( (completion, op) => {
if (typeof(op) === 'string') {
content: ops.reduce<string>((completion, op) => {
if (typeof (op) === 'string') {
return completion + op
} else {
if ('refExpr' in op) {
const expr = op.refExpr(x => `(ref: ${x})`)
console.log('template refExpr', {expr})
console.log('template refExpr', { expr })
return completion + `\${'${typeof expr === 'string' ? expr : expr.prompt}'}`
} else {
return completion + `\${'${op.prompt}'}`
Expand All @@ -331,7 +327,7 @@ export const ad = (model: LoadedModel): CreateTemplate => {
return ops.reduce<Promise<string>>(async (completion_, op) => {
const completion = await completion_

if (typeof(op) === 'string') {
if (typeof (op) === 'string') {
stream?.({
content: op,
type: 'lit'
Expand All @@ -340,10 +336,13 @@ export const ad = (model: LoadedModel): CreateTemplate => {
} else {
const { options, prompt, stop } = expandIfRefOp(op, ref)

const generated = await model.generate(
const generated = await model.generate({
prompt,
completion,
[stop, ...(options?.stops ?? [])],
preprompt,
system,
priorCompletion: completion,
stops: [stop, ...(options?.stops ?? [])],
},
{
stream,
...config,
Expand All @@ -367,7 +366,7 @@ export const ad = (model: LoadedModel): CreateTemplate => {

return {
completion,
refs
refs // FIXME refs should be scoped to collect, subsequent contexts will have stale refs
}
},
model
Expand Down
59 changes: 22 additions & 37 deletions src/loadModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ const perf = (() => {
get entries() { return entries },
summarize: () => {
const sums = Object.fromEntries(Object.entries(entries).map(
([label, results]) => [label, results.reduce((a,x) => a + x, 0) ]
([label, results]) => [label, results.reduce((a, x) => a + x, 0)]
))

const averages = Object.fromEntries(Object.entries(sums).map(
([label, sum]) => [label, sum / entries[label].length ]
([label, sum]) => [label, sum / entries[label].length]
))

console.debug('#perf', { sums, averages, entries })
Expand Down Expand Up @@ -98,17 +98,17 @@ export default async (
const device = targetDevice === TargetDevice.GPU ? tvm.webgpu() : tvm.cpu()

if (targetDevice === TargetDevice.GPU) {
updateReport({detectGPU: 'waiting'})
updateReport({ detectGPU: 'waiting' })

const gpu = await detectGPUDevice()
if (gpu == undefined) {
updateReport({detectGPU: 'failed'})
updateReport({ detectGPU: 'failed' })
throw Error('Cannot find GPU in environment')
}

updateReport({ detectGPU: gpu.adapterInfo.vendor })

tvm.initWebGPU(gpu.device)
tvm.initWebGPU(gpu.device)
}

let isLoadingGpuShaders = false
Expand Down Expand Up @@ -146,7 +146,7 @@ export default async (
'tokenizer.model': Tokenizer.fromSentencePiece,
'tokenizer.json': Tokenizer.fromJSON
}).find(([file, _create]) => config.tokenizer_files.includes(file))
// preference comes from the order of tokenizer_files -- seems like .json is preferred over .model
// preference comes from the order of tokenizer_files -- seems like .json is preferred over .model

if (configTokenizerFiles == undefined) {
const err = `Cant handle tokenizer files ${config.tokenizer_files}`;
Expand All @@ -158,7 +158,7 @@ export default async (

const tokenizerResult =
await cacheScope('model')
.fetchWithCache(new URL(path, spec.modelWeightsConfigUrl).href)
.fetchWithCache(new URL(path, spec.modelWeightsConfigUrl).href)

const tokenizer = await create(await tokenizerResult.arrayBuffer())
const w = window as any
Expand Down Expand Up @@ -192,7 +192,6 @@ export default async (
const getMetadata = vm.getFunction('get_metadata')
const metadata = JSON.parse(tvm.detachFromCurrentScope(getMetadata()).toString())

console.info({metadata})
const stopTokens: number[] = metadata.stop_tokens

const createKvCache = vm.getFunction('create_kv_cache')
Expand Down Expand Up @@ -308,8 +307,6 @@ export default async (
}

let modelState: ModelState = ModelState.Waiting
let system_ = '<<sys>>You are a helpful assistant<</sys>>\n\n'
let preprompt_ = '[INST]'

let totalTokenCount = 0

Expand All @@ -333,10 +330,6 @@ export default async (
totalTokenCount += 1

if (stopTokens.includes(token)) {
console.info('stop token', token, {
tokens,
completion
})
return false
}

Expand Down Expand Up @@ -374,10 +367,13 @@ export default async (
}


const generate = async (
prompt: string,
priorCompletion: string,
stops: string[],
const generate: LoadedModel['generate'] = async ({
prompt,
priorCompletion,
stops,
system,
preprompt
},
options?: GenerateOptions
): Promise<string> => {
modelState = ModelState.Running as ModelState
Expand All @@ -394,11 +390,11 @@ export default async (
const buildSampler = options_?.sampler
const sample =
buildSampler
? buildSampler(priorCompletion, stops, options_.temperature, options_.top_p)
: (logits: CpuNDArray) => sampleTokenFromLogits(logits, options_.temperature, options_.top_p)
? buildSampler(priorCompletion, stops, options_.temperature, options_.top_p)
: (logits: CpuNDArray) => sampleTokenFromLogits(logits, options_.temperature, options_.top_p)

const prefillText = `${system_}${preprompt_} ${prompt} [/INST] ${priorCompletion}`
console.info('[generate:start]', prompt, {...options_, prefillText})
const prefillText = `<<sys>>${system ?? 'You are a helpful assistant'}<</sys>>\n\n[INST]${preprompt ? ` ${preprompt}` : ''} ${prompt} [/INST] ${priorCompletion}`
console.info('[generate:start]', prompt, { ...options_, prefillText })

if (filledKvCacheLength > 0) {
unfill()
Expand Down Expand Up @@ -457,9 +453,9 @@ export default async (
content: accepted.completion
})

console.log({failedValidation: accepted.completion})
console.info('[validation-failed]', accepted.completion)

return await generate(prompt, priorCompletion, stops, {
return await generate({ prompt, priorCompletion, stops }, {
...options_,
validate: {
...options_.validate,
Expand Down Expand Up @@ -509,27 +505,16 @@ export default async (
updateReport({ ready: true })

return {
generate: async (prompt, priorCompletion, stops, options?) => {
generate: async (params, options?) => {
try {
return await generate(prompt, priorCompletion, stops, options)
return await generate(params, options)
} catch (e) {
unfill()
modelState = ModelState.Waiting
throw e
}
},
bias,
setContext: async (system: string, preprompt?: string) => {
system_ = `<<sys>>${system}<</sys>>\n\n`
preprompt_ = preprompt ? `[INST] ${preprompt}` : preprompt_

console.log('[context]', system, preprompt)

// TODO prefill here, save kvCache, reset kvCache on each generate as necessary
// Is that possible? can I prefill with existing kvCache?
// This only saves prefilling the system + preprompt anyway - it won't do anything for generates since the generate prompt
// goes before the completion body
},
cancel: async () => {
if (modelState === ModelState.Running) {
modelState = ModelState.Cancelling
Expand Down
16 changes: 8 additions & 8 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ export type TemplateExpressionOptions = {

export type TemplateContextOptions =
Omit<TemplateExpressionOptions, 'id'> & {
preword?: string
}
preword?: string
}

export type TemplateExpression = {
prompt: string
Expand All @@ -80,17 +80,17 @@ export type TemplateExpression = {
* ```
*/
export type LoadedModel = {
setContext: (system: string, preprompt?: string) => Promise<void>
generate: (
generate: (params: {
prompt: string,
completion: string,
priorCompletion: string,
stops: string[],
config?: GenerateOptions
) => Promise<string>
system?: string,
preprompt?: string
}, config?: GenerateOptions) => Promise<string>
cancel: () => Promise<void>
bias: Bias
totalTokenCount: number
}
}

/**
* Specifies where to retrieve model weights and configuration
Expand Down

0 comments on commit 32d631e

Please sign in to comment.