Skip to content

Commit

Permalink
allow model selection from available providers
Browse files Browse the repository at this point in the history
  • Loading branch information
rjmacarthy committed Sep 4, 2024
1 parent 96560b2 commit 36c6211
Show file tree
Hide file tree
Showing 10 changed files with 188 additions and 66 deletions.
13 changes: 12 additions & 1 deletion package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 5 additions & 9 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -371,14 +371,8 @@
"default": true,
"description": "Enable twinny debug mode"
},
"twinny.symmetryModelName": {
"order": 19,
"type": "string",
"default": "llama3.1:latest",
"description": "The symmetry model name for chat."
},
"twinny.symmetryProvider": {
"order": 20,
"order": 19,
"type": "string",
"description": "The symmetry provider type.",
"enum": [
Expand All @@ -392,7 +386,7 @@
"default": "ollama"
},
"twinny.symmetryServerKey": {
"order": 21,
"order": 20,
"type": "string",
"description": "The symmetry master server key.",
"default": "4b4a9cc325d134dee6679e9407420023531fd7e96c563f6c5d00fd5549b77435"
Expand Down Expand Up @@ -425,6 +419,7 @@
"@types/string_score": "^0.1.31",
"@types/uuid": "^9.0.8",
"@types/vscode": "^1.70.0",
"@types/ws": "^8.5.12",
"@typescript-eslint/eslint-plugin": "^5.31.0",
"@typescript-eslint/parser": "^5.31.0",
"@vscode/test-cli": "^0.0.6",
Expand Down Expand Up @@ -482,7 +477,8 @@
"tippy.js": "^6.3.7",
"tiptap-markdown": "^0.8.10",
"toxe": "^1.1.0",
"uuid": "^9.0.1"
"uuid": "^9.0.1",
"ws": "^8.18.0"
},
"os": [
"darwin",
Expand Down
2 changes: 2 additions & 0 deletions src/common/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ export const SKIP_IMPORT_KEYWORDS_AFTER = ['from', 'as', 'import']
export const MIN_COMPLETION_CHUNKS = 2
export const MAX_EMPTY_COMPLETION_CHARS = 250
export const DEFAULT_RERANK_THRESHOLD = 0.5
export const URL_SYMMETRY_WS = 'wss://twinny.dev/ws'

export const defaultChunkOptions = {
maxSize: 500,
Expand All @@ -39,6 +40,7 @@ export const EVENT_NAME = {
twinnyChatMessage: 'twinny-chat-message',
twinnyClickSuggestion: 'twinny-click-suggestion',
twinnyConnectedToSymmetry: 'twinny-connected-to-symmetry',
twinnySymmetryModeles: 'twinny-symmetry-models',
twinnyConnectSymmetry: 'twinny-connect-symmetry',
twinnyDisconnectedFromSymmetry: 'twinny-disconnected-from-symmetry',
twinnyDisconnectSymmetry: 'twinny-disconnect-symmetry',
Expand Down
14 changes: 14 additions & 0 deletions src/common/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,20 @@ export interface SymmetryConnection {
provider: string;
id: string;
}

export interface SymmetryModelProvider {
connections: number | null;
data_collection_enabled: number;
id: number;
last_seen: string;
max_connections: number;
model_name: string;
name: string;
online: number;
provider: string;
public: number;
}

export interface InferenceRequest {
key: string;
messages: Message[];
Expand Down
7 changes: 4 additions & 3 deletions src/extension/providers/sidebar.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ import {
Message,
ApiModel,
ServerMessage,
InferenceRequest
InferenceRequest,
SymmetryModelProvider
} from '../../common/types'
import { TemplateProvider } from '../template-provider'
import { OllamaService } from '../ollama-service'
Expand Down Expand Up @@ -424,9 +425,9 @@ export class SidebarProvider implements vscode.WebviewViewProvider {
})
}

private connectToSymmetry = () => {
private connectToSymmetry = (data: ClientMessage<SymmetryModelProvider>) => {
if (this._config.symmetryServerKey) {
this.symmetryService?.connect(this._config.symmetryServerKey)
this.symmetryService?.connect(this._config.symmetryServerKey, data.data?.model_name)
}
}

Expand Down
20 changes: 10 additions & 10 deletions src/extension/symmetry-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ import {
SYMMETRY_DATA_MESSAGE,
WEBUI_TABS,
ACTIVE_CHAT_PROVIDER_STORAGE_KEY,
GLOBAL_STORAGE_KEY
GLOBAL_STORAGE_KEY,
} from '../common/constants'
import { SessionManager } from './session-manager'
import { EventEmitter } from 'stream'
import { TwinnyProvider } from './provider-manager'
import { SymmetryWs } from './symmetry-ws'

export class SymmetryService extends EventEmitter {
private _config = workspace.getConfiguration('twinny')
Expand All @@ -50,10 +51,9 @@ export class SymmetryService extends EventEmitter {
private _providerTopic: Buffer | undefined
private _emitterKey = ''
private _provider: SymmetryProvider | undefined
private _modelName = this._config.symmetryModelName
private _symmetryProvider = this._config.symmetryProvider
private _symmetryServerKey = this._config.symmetryServerKey

private ws: SymmetryWs | undefined

constructor(
view: WebviewView | undefined,
Expand All @@ -76,9 +76,13 @@ export class SymmetryService extends EventEmitter {
if (!event.affectsConfiguration('twinny')) return
this.updateConfig()
})

this.ws = new SymmetryWs(view)
this.ws.connectSymmetryWs()
}

public connect = async (key: string) => {
public connect = async (key: string, model: string | undefined) => {
if (!model || !key) return
this._serverSwarm = new Hyperswarm()
const serverKey = Buffer.from(key)
const discoveryKey = crypto.discoveryKey(serverKey)
Expand All @@ -89,7 +93,7 @@ export class SymmetryService extends EventEmitter {
this._serverPeer = peer
peer.write(
createSymmetryMessage(SYMMETRY_DATA_MESSAGE.requestProvider, {
modelName: this._modelName,
modelName: model
})
)
peer.on('data', (message: Buffer) => {
Expand Down Expand Up @@ -182,10 +186,7 @@ export class SymmetryService extends EventEmitter {
this.handleInferenceEnd()

this.handleIncomingData(chunk, (response: StreamResponse) => {
const data = getChatDataFromProvider(
this._symmetryProvider,
response
)
const data = getChatDataFromProvider(this._symmetryProvider, response)
this._completion = this._completion + data
if (!data) return
this.emit(this._emitterKey, this._completion)
Expand Down Expand Up @@ -320,6 +321,5 @@ export class SymmetryService extends EventEmitter {
this._config = workspace.getConfiguration('twinny')
this._symmetryProvider = this._config.symmetryProvider
this._symmetryServerKey = this._config.symmetryServerKey
this._modelName = this._config.symmetryModelName
}
}
40 changes: 40 additions & 0 deletions src/extension/symmetry-ws.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import { WebSocket } from 'ws'
import * as vscode from 'vscode'
import { EVENT_NAME, URL_SYMMETRY_WS } from '../common/constants'

export class SymmetryWs {
private ws: WebSocket | null = null
private view: vscode.WebviewView | undefined

constructor(view: vscode.WebviewView | undefined) {
this.view = view
}

public connectSymmetryWs = () => {
this.ws = new WebSocket(URL_SYMMETRY_WS)

this.ws.on('message', (data) => {
try {
const parsedData = JSON.parse(data.toString())
this.view?.webview.postMessage({
type: EVENT_NAME.twinnySymmetryModeles,
value: {
data: parsedData?.allPeers
}
})
} catch (error) {
console.error('Error parsing WebSocket message:', error)
}
})

this.ws.on('error', (error) => {
console.error('WebSocket error:', error)
})
}

public dispose() {
if (this.ws) {
this.ws.close()
}
}
}
22 changes: 18 additions & 4 deletions src/webview/hooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
LanguageType,
ServerMessage,
SymmetryConnection,
SymmetryModelProvider,
ThemeType
} from '../common/types'
import { TwinnyProvider } from '../extension/provider-manager'
Expand Down Expand Up @@ -477,6 +478,9 @@ const useAutosizeTextArea = (

export const useSymmetryConnection = () => {
const [connecting, setConnecting] = useState(false)
const [models, setModels] = useState<SymmetryModelProvider[]>([])
const [selectedModel, setSelectedModel] =
useState<SymmetryModelProvider | null>(null)
const {
context: symmetryConnectionSession,
setContext: setSymmetryConnectionSession
Expand All @@ -501,8 +505,9 @@ export const useSymmetryConnection = () => {
const connectToSymmetry = () => {
setConnecting(true)
global.vscode.postMessage({
type: EVENT_NAME.twinnyConnectSymmetry
} as ClientMessage)
type: EVENT_NAME.twinnyConnectSymmetry,
data: selectedModel
} as ClientMessage<SymmetryModelProvider>)
}

const disconnectSymmetry = () => {
Expand All @@ -525,7 +530,9 @@ export const useSymmetryConnection = () => {
}

const handler = (event: MessageEvent) => {
const message: ServerMessage<SymmetryConnection | string> = event.data
const message: ServerMessage<
SymmetryConnection | string | SymmetryModelProvider[]
> = event.data
if (message?.type === EVENT_NAME.twinnyConnectedToSymmetry) {
setConnecting(false)
setSymmetryConnectionSession(message.value.data as SymmetryConnection)
Expand All @@ -537,6 +544,11 @@ export const useSymmetryConnection = () => {
if (message?.type === EVENT_NAME.twinnySendSymmetryMessage) {
setSymmetryProviderStatus(message?.value.data as string)
}

if (message?.type === EVENT_NAME.twinnySymmetryModeles) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
setModels(message?.value.data as unknown as SymmetryModelProvider[])
}
return () => window.removeEventListener('message', handler)
}

Expand All @@ -557,10 +569,12 @@ export const useSymmetryConnection = () => {
}
}, [autoConnectProviderContext, symmetryProviderStatus, connectAsProvider])


return {
autoConnectProviderContext,
connectAsProvider,
models,
selectedModel,
setSelectedModel,
connecting,
connectToSymmetry,
disconnectAsProvider,
Expand Down
9 changes: 9 additions & 0 deletions src/webview/symmetry.module.css
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,12 @@
display: flex;
justify-content: flex-start;
}

.dropdownContainer {
margin-bottom: 1rem;
}

.dropdownContainer label {
display: block;
margin-bottom: 0.5rem;
}
Loading

0 comments on commit 36c6211

Please sign in to comment.