Skip to content

Commit c660a2d

Browse files
♻️ refactor(ai): 重构模型验证和选择逻辑
- 【重构】将模型验证和选择逻辑从命令类中抽离到独立模块 - 【新增】创建 modelValidation 工具类统一处理模型验证逻辑 - 【优化】简化命令类中的模型配置处理流程 - 【移动】迁移 CodeReviewReportGenerator 到 services 目录 - 【文档】新增 AI 模块文档说明其功能和使用方式
1 parent b8550cd commit c660a2d

10 files changed

+286
-179
lines changed

src/ai/providers/BaseOpenAIProvider.ts

+7-9
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import {
1515
} from "../utils/generateHelper";
1616

1717
import { getWeeklyReportPrompt } from "../../prompt/weeklyReport";
18-
import { CodeReviewReportGenerator } from "../../utils/review/CodeReviewReportGenerator";
18+
import { CodeReviewReportGenerator } from "../../services/CodeReviewReportGenerator";
1919
import { formatMessage } from "../../utils/i18n/LocalizationManager";
2020

2121
/**
@@ -193,10 +193,9 @@ export abstract class BaseOpenAIProvider implements AIProvider {
193193
},
194194
};
195195
} catch (error) {
196-
const message = formatMessage(
197-
"codeReview.generation.failed",
198-
[error instanceof Error ? error.message : String(error)]
199-
);
196+
const message = formatMessage("codeReview.generation.failed", [
197+
error instanceof Error ? error.message : String(error),
198+
]);
200199
throw new Error(message);
201200
}
202201
},
@@ -244,10 +243,9 @@ export abstract class BaseOpenAIProvider implements AIProvider {
244243
};
245244
} catch (error) {
246245
throw new Error(
247-
formatMessage(
248-
"weeklyReport.generation.failed",
249-
[error instanceof Error ? error.message : String(error)]
250-
)
246+
formatMessage("weeklyReport.generation.failed", [
247+
error instanceof Error ? error.message : String(error),
248+
])
251249
);
252250
}
253251
}

src/ai/providers/VscodeProvider.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import { generateCommitMessageSystemPrompt } from "../../prompt/prompt";
1111
import { getCodeReviewPrompt, getSystemPrompt } from "../utils/generateHelper";
1212
import { getWeeklyReportPrompt } from "../../prompt/weeklyReport";
1313
import { getMessage, formatMessage } from "../../utils/i18n";
14-
import { CodeReviewReportGenerator } from "../../utils/review/CodeReviewReportGenerator";
14+
import { CodeReviewReportGenerator } from "../../services/CodeReviewReportGenerator";
1515

1616
interface DiffBlock {
1717
header: string;

src/commands/BaseCommand.ts

+26-60
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { SCMFactory } from "../scm/SCMProvider";
55
import { ModelPickerService } from "../services/ModelPickerService";
66
import { notify } from "../utils/notification/NotificationManager";
77
import { getMessage, formatMessage } from "../utils/i18n";
8+
import { validateAndGetModel } from "../utils/ai/modelValidation";
89

910
/**
1011
* 基础命令类,提供通用的命令执行功能
@@ -65,81 +66,46 @@ export abstract class BaseCommand {
6566
let model = configuration.base.model;
6667

6768
if (!provider || !model) {
68-
return this.selectAndUpdateModelConfiguration(provider, model);
69+
return this.selectAndUpdateModelConfiguration(provider, model, true);
6970
}
7071

7172
return { provider, model };
7273
}
7374

74-
/**
75-
* 获取模型并更新配置
76-
* @param provider - AI提供商名称
77-
* @param model - 模型名称
78-
* @returns 更新后的提供商、模型和AI实例信息
79-
* @throws Error 当无法获取模型列表或找不到指定模型时
80-
*/
81-
protected async getModelAndUpdateConfiguration(
82-
provider = "Ollama",
83-
model = "Ollama"
84-
) {
85-
let aiProvider = AIProviderFactory.getProvider(provider);
86-
let models = await aiProvider.getModels();
87-
88-
if (!models || models.length === 0) {
89-
const { provider: newProvider, model: newModel } =
90-
await this.selectAndUpdateModelConfiguration(provider, model);
91-
provider = newProvider;
92-
model = newModel;
93-
94-
aiProvider = AIProviderFactory.getProvider(provider);
95-
models = await aiProvider.getModels();
96-
97-
if (!models || models.length === 0) {
98-
throw new Error(getMessage("model.list.empty"));
99-
}
100-
}
101-
102-
let selectedModel = models.find((m) => m.name === model);
103-
104-
if (!selectedModel) {
105-
const { provider: newProvider, model: newModel } =
106-
await this.selectAndUpdateModelConfiguration(provider, model);
107-
provider = newProvider;
108-
model = newModel;
109-
110-
aiProvider = AIProviderFactory.getProvider(provider);
111-
models = await aiProvider.getModels();
112-
selectedModel = models.find((m) => m.name === model);
113-
114-
if (!selectedModel) {
115-
throw new Error(getMessage("model.notFound"));
116-
}
117-
}
118-
119-
return { provider, model, selectedModel, aiProvider };
120-
}
121-
12275
/**
12376
* 选择模型并更新配置
12477
* @param provider - 当前AI提供商
12578
* @param model - 当前模型名称
79+
* @param throwError - 是否抛出错误,默认为false
12680
* @returns 更新后的提供商和模型信息
12781
*/
12882
protected async selectAndUpdateModelConfiguration(
12983
provider = "Ollama",
130-
model = "Ollama"
84+
model = "Ollama",
85+
throwError = false
13186
) {
132-
const modelSelection = await this.showModelPicker(provider, model);
133-
if (!modelSelection) {
134-
return { provider, model };
87+
try {
88+
const result = await validateAndGetModel(provider, model);
89+
return {
90+
provider: result.provider,
91+
model: result.model,
92+
selectedModel: result.selectedModel,
93+
aiProvider: result.aiProvider,
94+
};
95+
} catch (error: any) {
96+
if (throwError) {
97+
await notify.error(error.message);
98+
throw error;
99+
}
100+
// 如果不抛出错误,返回原始值
101+
const aiProvider = AIProviderFactory.getProvider(provider);
102+
return {
103+
provider,
104+
model,
105+
selectedModel: undefined,
106+
aiProvider,
107+
};
135108
}
136-
137-
const config = ConfigurationManager.getInstance();
138-
await config.updateAIConfiguration(
139-
modelSelection.provider,
140-
modelSelection.model
141-
);
142-
return { provider: modelSelection.provider, model: modelSelection.model };
143109
}
144110

145111
/**

src/commands/GenerateCommitCommand.ts

+4-100
Original file line numberDiff line numberDiff line change
@@ -8,98 +8,12 @@ import { ModelPickerService } from "../services/ModelPickerService";
88
import { notify } from "../utils/notification";
99
import { getMessage, formatMessage } from "../utils/i18n";
1010
import { ProgressHandler } from "../utils/notification/ProgressHandler";
11+
import { validateAndGetModel } from "../utils/ai/modelValidation";
1112

1213
/**
1314
* 提交信息生成命令类
1415
*/
1516
export class GenerateCommitCommand extends BaseCommand {
16-
/**
17-
* 获取模型并更新配置
18-
* @param provider - 当前AI提供商
19-
* @param model - 当前模型名称
20-
* @returns 更新后的提供商、模型和AI实例信息
21-
* @throws Error 当无法获取模型列表或找不到指定模型时
22-
*/
23-
protected async getModelAndUpdateConfiguration(
24-
provider = "Ollama",
25-
model = "Ollama"
26-
) {
27-
let aiProvider = AIProviderFactory.getProvider(provider);
28-
// 获取模型列表
29-
let models = await aiProvider.getModels();
30-
31-
// 如果模型为空或无法获取,直接让用户选择模型
32-
if (!models || models.length === 0) {
33-
const { provider: newProvider, model: newModel } =
34-
await this.selectAndUpdateModelConfiguration(provider, model);
35-
provider = newProvider;
36-
model = newModel;
37-
38-
// 获取更新后的模型列表
39-
aiProvider = AIProviderFactory.getProvider(provider);
40-
models = await aiProvider.getModels();
41-
42-
// 如果新的模型列表仍然为空,则抛出错误
43-
if (!models || models.length === 0) {
44-
throw new Error(getMessage("model.list.empty"));
45-
}
46-
}
47-
48-
// 查找已选择的模型
49-
let selectedModel = models.find((m) => m.name === model);
50-
51-
// 如果没有找到对应的模型,弹窗让用户重新选择
52-
if (!selectedModel) {
53-
const { provider: newProvider, model: newModel } =
54-
await this.selectAndUpdateModelConfiguration(provider, model);
55-
provider = newProvider;
56-
model = newModel;
57-
58-
// 获取更新后的模型列表
59-
aiProvider = AIProviderFactory.getProvider(provider);
60-
models = await aiProvider.getModels();
61-
62-
// 选择有效的模型
63-
selectedModel = models.find((m) => m.name === model);
64-
65-
// 如果依然没有找到对应的模型,抛出错误
66-
if (!selectedModel) {
67-
throw new Error(getMessage("model.notFound"));
68-
}
69-
}
70-
71-
return { provider, model, selectedModel, aiProvider };
72-
}
73-
74-
/**
75-
* 选择模型并更新配置
76-
* @param provider - 当前AI提供商
77-
* @param model - 当前模型名称
78-
* @returns 更新后的提供商和模型信息
79-
*/
80-
protected async selectAndUpdateModelConfiguration(
81-
provider = "Ollama",
82-
model = "Ollama"
83-
) {
84-
// 获取模型选择
85-
const modelSelection = await this.showModelPicker(provider, model);
86-
87-
// 如果没有选择模型,则直接返回当前的 provider 和 model
88-
if (!modelSelection) {
89-
return { provider, model };
90-
}
91-
92-
const config = ConfigurationManager.getInstance();
93-
// 使用新的封装方法更新配置
94-
await config.updateAIConfiguration(
95-
modelSelection.provider,
96-
modelSelection.model
97-
);
98-
99-
// 返回更新后的 provider 和 model
100-
return { provider: modelSelection.provider, model: modelSelection.model };
101-
}
102-
10317
/**
10418
* 处理AI配置
10519
* @returns AI提供商和模型信息
@@ -142,6 +56,7 @@ export class GenerateCommitCommand extends BaseCommand {
14256
if (!configResult) {
14357
return;
14458
}
59+
const { provider, model } = configResult;
14560

14661
try {
14762
// 检测SCM提供程序
@@ -154,21 +69,10 @@ export class GenerateCommitCommand extends BaseCommand {
15469
// 获取当前提交输入框内容
15570
const currentInput = await scmProvider.getCommitInput();
15671

157-
// 获取配置信息
72+
// 获取配置信息以用于后续操作
15873
const config = ConfigurationManager.getInstance();
15974
const configuration = config.getConfiguration();
16075

161-
// 获取或更新AI提供商和模型配置
162-
let provider = configuration.base.provider;
163-
let model = configuration.base.model;
164-
165-
if (!provider || !model) {
166-
const { provider: newProvider, model: newModel } =
167-
await this.selectAndUpdateModelConfiguration(provider, model);
168-
provider = newProvider;
169-
model = newModel;
170-
}
171-
17276
// 使用进度提示生成提交信息
17377
const response = await ProgressHandler.withProgress(
17478
formatMessage("progress.generating.commit", [
@@ -191,7 +95,7 @@ export class GenerateCommitCommand extends BaseCommand {
19195
model: newModel,
19296
aiProvider,
19397
selectedModel,
194-
} = await this.getModelAndUpdateConfiguration(provider, model);
98+
} = await this.selectAndUpdateModelConfiguration(provider, model);
19599

196100
// 生成提交信息
197101
const result = await aiProvider.generateResponse({

src/commands/ReviewCodeCommand.ts

+5-6
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import {
66
withProgress,
77
} from "../utils/notification/NotificationManager";
88
import * as path from "path";
9+
import { validateAndGetModel } from "../utils/ai/modelValidation";
910

1011
/**
1112
* 代码审查命令类
@@ -73,12 +74,10 @@ export class ReviewCodeCommand extends BaseCommand {
7374
const { config, configuration } = this.getExtConfig();
7475
let { provider, model } = configResult;
7576

76-
const {
77-
provider: newProvider,
78-
model: newModel,
79-
aiProvider,
80-
selectedModel,
81-
} = await this.getModelAndUpdateConfiguration(provider, model);
77+
const { aiProvider, selectedModel } = await validateAndGetModel(
78+
provider,
79+
model
80+
);
8281

8382
await withProgress(getMessage("reviewing.code"), async (progress) => {
8483
// 获取所有选中文件的差异

src/utils/review/CodeReviewReportGenerator.ts src/services/CodeReviewReportGenerator.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import { CodeReviewResult, CodeReviewIssue } from "../../ai/types";
1+
import { CodeReviewResult, CodeReviewIssue } from "../ai/types";
22
import * as vscode from "vscode";
3-
import { getMessage } from "../i18n";
3+
import { getMessage } from "../utils/i18n";
44

55
/**
66
* 代码审查报告生成器,将代码审查结果转换为格式化的 Markdown 文档

0 commit comments

Comments
 (0)