Skip to content

Commit 5603b08

Browse files
✨ feat(ai): 添加AI工具类和模型选择服务
- 【工具类】新增 generateHelper.ts,实现AI生成重试机制 - 【错误处理】添加错误类型枚举和重试选项配置 - 【模型服务】新增 ModelPickerService,支持动态选择AI模型 - 【UI交互】实现模型选择的QuickPick界面
1 parent d708190 commit 5603b08

File tree

2 files changed

+172
-0
lines changed

2 files changed

+172
-0
lines changed

src/ai/utils/generateHelper.ts

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import { NotificationHandler } from "../../utils/NotificationHandler";
2+
import { LocalizationManager } from "../../utils/LocalizationManager";
3+
import { generateCommitMessageSystemPrompt } from "../../prompt/prompt";
4+
import { DEFAULT_CONFIG } from "../../config/default";
5+
import { AIRequestParams } from "../types";
6+
7+
// 添加错误类型枚举
8+
export enum AIGenerationErrorType {
9+
CONTEXT_LENGTH = "CONTEXT_LENGTH",
10+
TOKEN_LIMIT = "TOKEN_LIMIT",
11+
UNKNOWN = "UNKNOWN",
12+
}
13+
14+
export interface GenerateWithRetryOptions {
15+
maxRetries?: number;
16+
initialMaxLength: number;
17+
reductionFactor?: number;
18+
provider: string;
19+
retryableErrors?: AIGenerationErrorType[]; // 添加可重试的错误类型
20+
retryDelay?: number; // 添加重试延迟时间
21+
}
22+
23+
export async function generateWithRetry<T>(
24+
params: AIRequestParams,
25+
generateFn: (truncatedDiff: string) => Promise<T>,
26+
options: GenerateWithRetryOptions
27+
): Promise<T> {
28+
const {
29+
maxRetries = 2,
30+
initialMaxLength,
31+
reductionFactor = 0.8,
32+
provider,
33+
retryableErrors = [
34+
AIGenerationErrorType.CONTEXT_LENGTH,
35+
AIGenerationErrorType.TOKEN_LIMIT,
36+
],
37+
retryDelay = 1000,
38+
} = options;
39+
40+
let retries = 0;
41+
let maxInputLength = initialMaxLength;
42+
43+
while (true) {
44+
try {
45+
const truncatedPrompt = params.diff.substring(0, maxInputLength);
46+
47+
if (params.diff.length > maxInputLength) {
48+
NotificationHandler.warn(
49+
LocalizationManager.getInstance().getMessage(
50+
`${provider}.input.truncated`
51+
)
52+
);
53+
}
54+
55+
return await generateFn(truncatedPrompt);
56+
} catch (error: any) {
57+
console.log("error", error);
58+
if (
59+
retries < maxRetries &&
60+
(error.message?.includes("maximum context length") ||
61+
error.message?.includes("context length exceeded") ||
62+
error.message?.includes("exceeds token limit"))
63+
) {
64+
retries++;
65+
maxInputLength = Math.floor(maxInputLength * reductionFactor);
66+
continue;
67+
}
68+
69+
const errorMessage = LocalizationManager.getInstance().format(
70+
`${provider}.generation.failed`,
71+
error.message || String(error)
72+
);
73+
NotificationHandler.error(errorMessage);
74+
throw new Error(errorMessage);
75+
}
76+
}
77+
}
78+
79+
export function getSystemPrompt(params: AIRequestParams): string {
80+
return (
81+
params.systemPrompt ||
82+
generateCommitMessageSystemPrompt(
83+
params.language || DEFAULT_CONFIG.language,
84+
params.allowMergeCommits || false,
85+
params.splitChangesInSingleFile || false,
86+
params.scm || "git"
87+
)
88+
);
89+
}

src/services/ModelPickerService.ts

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import * as vscode from "vscode";
2+
import { AIProviderFactory } from "../ai/AIProviderFactory";
3+
import { NotificationHandler } from "../utils/NotificationHandler";
4+
import { LocalizationManager } from "../utils/LocalizationManager";
5+
6+
export class ModelPickerService {
7+
static async showModelPicker(
8+
currentProvider: string,
9+
currentModel: string
10+
): Promise<{ provider: string; model: string } | undefined> {
11+
const locManager = LocalizationManager.getInstance();
12+
try {
13+
const providers = AIProviderFactory.getAllProviders();
14+
const modelsMap = new Map<string, string[]>();
15+
16+
console.log("providers", providers);
17+
18+
const progressMsg = locManager.getMessage("ai.model.loading");
19+
await vscode.window.withProgress(
20+
{
21+
location: vscode.ProgressLocation.Notification,
22+
title: progressMsg,
23+
cancellable: false,
24+
},
25+
async () => {
26+
await Promise.all(
27+
providers.map(async (provider) => {
28+
if (await provider.isAvailable()) {
29+
const models = await provider.getModels();
30+
modelsMap.set(
31+
provider.getName(),
32+
models.map((model) => model.name)
33+
);
34+
}
35+
})
36+
);
37+
}
38+
);
39+
40+
const items: vscode.QuickPickItem[] = [];
41+
for (const [provider, models] of modelsMap) {
42+
items.push({
43+
label: provider,
44+
kind: vscode.QuickPickItemKind.Separator,
45+
});
46+
models.forEach((model) => {
47+
items.push({
48+
label: model,
49+
description: provider,
50+
picked: provider === currentProvider && model === currentModel,
51+
});
52+
});
53+
}
54+
55+
const quickPick = vscode.window.createQuickPick();
56+
quickPick.items = items;
57+
quickPick.title = locManager.getMessage("ai.model.picker.title");
58+
quickPick.placeholder = locManager.getMessage(
59+
"ai.model.picker.placeholder"
60+
);
61+
quickPick.ignoreFocusOut = true;
62+
63+
const result = await new Promise<vscode.QuickPickItem | undefined>(
64+
(resolve) => {
65+
quickPick.onDidAccept(() => resolve(quickPick.selectedItems[0]));
66+
quickPick.onDidHide(() => resolve(undefined));
67+
quickPick.show();
68+
}
69+
);
70+
71+
quickPick.dispose();
72+
73+
if (result && result.description) {
74+
return { provider: result.description, model: result.label };
75+
}
76+
return undefined;
77+
} catch (error) {
78+
console.error("获取模型列表失败:", error);
79+
await NotificationHandler.error("model.list.failed");
80+
return undefined;
81+
}
82+
}
83+
}

0 commit comments

Comments
 (0)