Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making the LLM providers more generics #10

Merged
merged 5 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@
"@langchain/core": "^0.3.13",
"@langchain/mistralai": "^0.1.1",
"@lumino/coreutils": "^2.1.2",
"@lumino/polling": "^2.1.2"
"@lumino/polling": "^2.1.2",
"@lumino/signaling": "^2.1.2"
},
"devDependencies": {
"@jupyterlab/builder": "^4.0.0",
Expand Down
21 changes: 21 additions & 0 deletions schema/ai-provider.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"title": "AI provider",
"description": "Provider settings",
"type": "object",
"properties": {
"provider": {
"type": "string",
"title": "The AI provider",
"description": "The AI provider to use for chat and completion",
"default": "None",
"enum": ["None", "MistralAI"]
},
"apiKey": {
"type": "string",
"title": "The Codestral API key",
"description": "The API key to use for Codestral",
"default": ""
}
},
"additionalProperties": false
}
14 changes: 0 additions & 14 deletions schema/inline-provider.json

This file was deleted.

41 changes: 31 additions & 10 deletions src/handler.ts → src/chat-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,30 @@ import {
IChatMessage,
INewMessage
} from '@jupyter/chat';
import { UUID } from '@lumino/coreutils';
import type { ChatMistralAI } from '@langchain/mistralai';
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
import {
AIMessage,
HumanMessage,
mergeMessageRuns
} from '@langchain/core/messages';
import { UUID } from '@lumino/coreutils';

export type ConnectionMessage = {
type: 'connection';
client_id: string;
};

export class CodestralHandler extends ChatModel {
constructor(options: CodestralHandler.IOptions) {
export class ChatHandler extends ChatModel {
constructor(options: ChatHandler.IOptions) {
super(options);
this._mistralClient = options.mistralClient;
this._provider = options.provider;
}

get provider(): BaseChatModel | null {
return this._provider;
}
set provider(provider: BaseChatModel | null) {
this._provider = provider;
}

async sendMessage(message: INewMessage): Promise<boolean> {
Expand All @@ -38,6 +45,19 @@ export class CodestralHandler extends ChatModel {
type: 'msg'
};
this.messageAdded(msg);

if (this._provider === null) {
const botMsg: IChatMessage = {
id: UUID.uuid4(),
body: '**AI provider not configured for the chat**',
sender: { username: 'ERROR' },
time: Date.now(),
type: 'msg'
};
this.messageAdded(botMsg);
return false;
}

this._history.messages.push(msg);

const messages = mergeMessageRuns(
Expand All @@ -48,13 +68,14 @@ export class CodestralHandler extends ChatModel {
return new AIMessage(msg.body);
})
);
const response = await this._mistralClient.invoke(messages);

const response = await this._provider.invoke(messages);
// TODO: fix deprecated response.text
const content = response.text;
const botMsg: IChatMessage = {
id: UUID.uuid4(),
body: content,
sender: { username: 'Codestral' },
sender: { username: 'Bot' },
time: Date.now(),
type: 'msg'
};
Expand All @@ -75,12 +96,12 @@ export class CodestralHandler extends ChatModel {
super.messageAdded(message);
}

private _mistralClient: ChatMistralAI;
private _provider: BaseChatModel | null;
private _history: IChatHistory = { messages: [] };
}

export namespace CodestralHandler {
export namespace ChatHandler {
export interface IOptions extends ChatModel.IOptions {
mistralClient: ChatMistralAI;
provider: BaseChatModel | null;
}
}
61 changes: 61 additions & 0 deletions src/completion-provider.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import {
CompletionHandler,
IInlineCompletionContext,
IInlineCompletionProvider
} from '@jupyterlab/completer';
import { LLM } from '@langchain/core/language_models/llms';

import { getCompleter, IBaseCompleter } from './llm-models';

/**
* The generic completion provider to register to the completion provider manager.
*/
export class CompletionProvider implements IInlineCompletionProvider {
readonly identifier = '@jupyterlite/ai';

constructor(options: CompletionProvider.IOptions) {
this.name = options.name;
}

/**
* Getter and setter of the name.
* The setter will create the appropriate completer, accordingly to the name.
*/
get name(): string {
return this._name;
}
set name(name: string) {
this._name = name;
this._completer = getCompleter(name);
}

/**
* get the current completer.
*/
get completer(): IBaseCompleter | null {
return this._completer;
}

/**
* Get the LLM completer.
*/
get llmCompleter(): LLM | null {
return this._completer?.provider || null;
}

async fetch(
request: CompletionHandler.IRequest,
context: IInlineCompletionContext
) {
return this._completer?.fetch(request, context);
}

private _name: string = 'None';
private _completer: IBaseCompleter | null = null;
}

export namespace CompletionProvider {
export interface IOptions {
name: string;
}
}
114 changes: 47 additions & 67 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,55 +13,20 @@ import { ICompletionProviderManager } from '@jupyterlab/completer';
import { INotebookTracker } from '@jupyterlab/notebook';
import { IRenderMimeRegistry } from '@jupyterlab/rendermime';
import { ISettingRegistry } from '@jupyterlab/settingregistry';
import { ChatMistralAI, MistralAI } from '@langchain/mistralai';

import { CodestralHandler } from './handler';
import { CodestralProvider } from './provider';

const inlineProviderPlugin: JupyterFrontEndPlugin<void> = {
id: 'jupyterlab-codestral:inline-provider',
autoStart: true,
requires: [ICompletionProviderManager, ISettingRegistry],
activate: (
app: JupyterFrontEnd,
manager: ICompletionProviderManager,
settingRegistry: ISettingRegistry
): void => {
const mistralClient = new MistralAI({
model: 'codestral-latest',
apiKey: 'TMP'
});
const provider = new CodestralProvider({ mistralClient });
manager.registerInlineProvider(provider);

settingRegistry
.load(inlineProviderPlugin.id)
.then(settings => {
const updateKey = () => {
const apiKey = settings.get('apiKey').composite as string;
mistralClient.apiKey = apiKey;
};

settings.changed.connect(() => updateKey());
updateKey();
})
.catch(reason => {
console.error(
`Failed to load settings for ${inlineProviderPlugin.id}`,
reason
);
});
}
};
import { ChatHandler } from './chat-handler';
import { AIProvider } from './provider';
import { IAIProvider } from './token';

const chatPlugin: JupyterFrontEndPlugin<void> = {
id: 'jupyterlab-codestral:chat',
description: 'Codestral chat extension',
description: 'LLM chat extension',
autoStart: true,
optional: [INotebookTracker, ISettingRegistry, IThemeManager],
requires: [IRenderMimeRegistry],
requires: [IAIProvider, IRenderMimeRegistry],
activate: async (
app: JupyterFrontEnd,
aiProvider: IAIProvider,
rmRegistry: IRenderMimeRegistry,
notebookTracker: INotebookTracker | null,
settingsRegistry: ISettingRegistry | null,
Expand All @@ -75,15 +40,15 @@ const chatPlugin: JupyterFrontEndPlugin<void> = {
});
}

const mistralClient = new ChatMistralAI({
model: 'codestral-latest',
apiKey: 'TMP'
});
const chatHandler = new CodestralHandler({
mistralClient,
const chatHandler = new ChatHandler({
provider: aiProvider.chatModel,
activeCellManager: activeCellManager
});

aiProvider.modelChange.connect(() => {
chatHandler.provider = aiProvider.chatModel;
});

let sendWithShiftEnter = false;
let enableCodeToolbar = true;

Expand All @@ -94,25 +59,6 @@ const chatPlugin: JupyterFrontEndPlugin<void> = {
chatHandler.config = { sendWithShiftEnter, enableCodeToolbar };
}

// TODO: handle the apiKey better
settingsRegistry
?.load(inlineProviderPlugin.id)
.then(settings => {
const updateKey = () => {
const apiKey = settings.get('apiKey').composite as string;
mistralClient.apiKey = apiKey;
};

settings.changed.connect(() => updateKey());
updateKey();
})
.catch(reason => {
console.error(
`Failed to load settings for ${inlineProviderPlugin.id}`,
reason
);
});

Promise.all([app.restored, settingsRegistry?.load(chatPlugin.id)])
.then(([, settings]) => {
if (!settings) {
Expand Down Expand Up @@ -148,4 +94,38 @@ const chatPlugin: JupyterFrontEndPlugin<void> = {
}
};

export default [inlineProviderPlugin, chatPlugin];
const aiProviderPlugin: JupyterFrontEndPlugin<IAIProvider> = {
id: 'jupyterlab-codestral:ai-provider',
autoStart: true,
requires: [ICompletionProviderManager, ISettingRegistry],
provides: IAIProvider,
activate: (
app: JupyterFrontEnd,
manager: ICompletionProviderManager,
settingRegistry: ISettingRegistry
): IAIProvider => {
const aiProvider = new AIProvider({ completionProviderManager: manager });

settingRegistry
.load(aiProviderPlugin.id)
.then(settings => {
const updateProvider = () => {
const provider = settings.get('provider').composite as string;
aiProvider.setModels(provider, settings.composite);
};

settings.changed.connect(() => updateProvider());
updateProvider();
})
.catch(reason => {
console.error(
`Failed to load settings for ${aiProviderPlugin.id}`,
reason
);
});

return aiProvider;
}
};

export default [chatPlugin, aiProviderPlugin];
20 changes: 20 additions & 0 deletions src/llm-models/base-completer.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import {
CompletionHandler,
IInlineCompletionContext
} from '@jupyterlab/completer';
import { LLM } from '@langchain/core/language_models/llms';

export interface IBaseCompleter {
/**
* The LLM completer.
*/
provider: LLM;

/**
* The fetch request for the LLM completer.
*/
fetch(
request: CompletionHandler.IRequest,
context: IInlineCompletionContext
): Promise<any>;
}
Loading
Loading