@@ -5,7 +5,7 @@ import { NotificationHandler } from "../utils/NotificationHandler";
5
5
import { ProgressHandler } from "../utils/ProgressHandler" ;
6
6
import { AIProviderFactory } from "../ai/AIProviderFactory" ;
7
7
import { SCMFactory } from "../scm/SCMProvider" ;
8
- import { getProviderModelConfig } from "../config/types" ;
8
+ import { getProviderModelConfig , type ConfigKey } from "../config/types" ;
9
9
import { DISPLAY_NAME } from "../constants" ;
10
10
import { getMaxCharacters } from "../ai/types" ;
11
11
import { LocalizationManager } from "../utils/LocalizationManager" ;
@@ -35,16 +35,17 @@ export class GenerateCommitCommand extends BaseCommand {
35
35
}
36
36
37
37
const config = ConfigurationManager . getInstance ( ) ;
38
- await config . updateConfig ( "OPENAI_BASE_URL" , baseURL ) ;
39
- await config . updateConfig ( "OPENAI_API_KEY" , apiKey ) ;
38
+ await config . updateConfig ( "PROVIDERS_OPENAI_BASEURL" as ConfigKey , baseURL ) ;
39
+ await config . updateConfig ( "PROVIDERS_OPENAI_APIKEY" as ConfigKey , apiKey ) ;
40
40
return true ;
41
41
}
42
42
43
43
private async ensureConfiguration ( ) : Promise < boolean > {
44
44
const locManager = LocalizationManager . getInstance ( ) ;
45
45
const config = ConfigurationManager . getInstance ( ) ;
46
- const baseURL = config . getConfig < string > ( "OPENAI_BASE_URL" , false ) ;
47
- const apiKey = config . getConfig < string > ( "OPENAI_API_KEY" , false ) ;
46
+ const configuration = config . getConfiguration ( ) ;
47
+ const baseURL = configuration . providers . openai . baseUrl ;
48
+ const apiKey = configuration . providers . openai . apiKey ;
48
49
49
50
if ( ! baseURL || ! apiKey ) {
50
51
const result = await vscode . window . showInformationMessage (
@@ -145,7 +146,8 @@ export class GenerateCommitCommand extends BaseCommand {
145
146
146
147
const config = ConfigurationManager . getInstance ( ) ;
147
148
const configuration = config . getConfiguration ( ) ;
148
- let { provider, model } = configuration ;
149
+ let provider = configuration . base . provider ;
150
+ let model = configuration . base . model ;
149
151
150
152
if ( ! provider || ! model ) {
151
153
const result = await this . selectAndUpdateModelConfiguration (
@@ -169,7 +171,6 @@ export class GenerateCommitCommand extends BaseCommand {
169
171
170
172
const locManager = LocalizationManager . getInstance ( ) ;
171
173
try {
172
- // 检测当前 SCM 类型
173
174
const scmProvider = await SCMFactory . detectSCM ( ) ;
174
175
if ( ! scmProvider ) {
175
176
await NotificationHandler . error (
@@ -178,12 +179,15 @@ export class GenerateCommitCommand extends BaseCommand {
178
179
return ;
179
180
}
180
181
182
+ // 获取当前提交消息输入框的内容
183
+ const currentInput = await scmProvider . getCommitInput ( ) ;
184
+
181
185
const config = ConfigurationManager . getInstance ( ) ;
182
186
const configuration = config . getConfiguration ( ) ;
183
187
184
188
// 检查是否已配置 AI 提供商和模型
185
- let provider = configuration . provider ;
186
- let model = configuration . model ;
189
+ let provider = configuration . base . provider ;
190
+ let model = configuration . base . model ;
187
191
188
192
// 如果没有配置提供商或模型,提示用户选择
189
193
if ( ! provider || ! model ) {
@@ -201,9 +205,7 @@ export class GenerateCommitCommand extends BaseCommand {
201
205
) ,
202
206
async ( progress ) => {
203
207
const selectedFiles = this . getSelectedFiles ( resources ) ;
204
- // progress.report({
205
- // message: locManager.getMessage("progress.analyzing.changes"),
206
- // });
208
+
207
209
const diffContent = await scmProvider . getDiff ( selectedFiles ) ;
208
210
if ( ! diffContent ) {
209
211
await NotificationHandler . info ( locManager . getMessage ( "no.changes" ) ) ;
@@ -219,25 +221,17 @@ export class GenerateCommitCommand extends BaseCommand {
219
221
selectedModel,
220
222
} = await this . getModelAndUpdateConfiguration ( provider , model ) ;
221
223
222
- // if (selectedModel) {
223
- // const maxChars = getMaxCharacters(selectedModel, 2600) - 1000;
224
- // if (diffContent.length > maxChars) {
225
- // throw new Error(
226
- // locManager.format("diff.too.long", diffContent.length, maxChars)
227
- // );
228
- // }
229
- // }
230
-
231
- // progress.report({ message: "正在生成提交信息..." });
232
-
233
224
const result = await aiProvider . generateResponse ( {
234
225
diff : diffContent ,
235
- systemPrompt : configuration . systemPrompt ,
226
+ systemPrompt : configuration . base . systemPrompt ,
236
227
model : selectedModel ,
237
- language : configuration . language ,
228
+ language : configuration . base . language ,
238
229
scm : scmProvider . type ?? "git" ,
239
- allowMergeCommits : configuration . allowMergeCommits ,
230
+ allowMergeCommits :
231
+ configuration . features . commitOptions . allowMergeCommits ,
240
232
splitChangesInSingleFile : false ,
233
+ additionalContext : currentInput , // 添加额外上下文
234
+ useEmoji : configuration . features . commitOptions . useEmoji , // 添加这一行
241
235
} ) ;
242
236
243
237
// progress.report({
0 commit comments