|
| 1 | +import { NotificationHandler } from "../../utils/NotificationHandler"; |
| 2 | +import { LocalizationManager } from "../../utils/LocalizationManager"; |
| 3 | +import { generateCommitMessageSystemPrompt } from "../../prompt/prompt"; |
| 4 | +import { DEFAULT_CONFIG } from "../../config/default"; |
| 5 | +import { AIRequestParams } from "../types"; |
| 6 | + |
| 7 | +// 添加错误类型枚举 |
| 8 | +export enum AIGenerationErrorType { |
| 9 | + CONTEXT_LENGTH = "CONTEXT_LENGTH", |
| 10 | + TOKEN_LIMIT = "TOKEN_LIMIT", |
| 11 | + UNKNOWN = "UNKNOWN", |
| 12 | +} |
| 13 | + |
| 14 | +export interface GenerateWithRetryOptions { |
| 15 | + maxRetries?: number; |
| 16 | + initialMaxLength: number; |
| 17 | + reductionFactor?: number; |
| 18 | + provider: string; |
| 19 | + retryableErrors?: AIGenerationErrorType[]; // 添加可重试的错误类型 |
| 20 | + retryDelay?: number; // 添加重试延迟时间 |
| 21 | +} |
| 22 | + |
| 23 | +export async function generateWithRetry<T>( |
| 24 | + params: AIRequestParams, |
| 25 | + generateFn: (truncatedDiff: string) => Promise<T>, |
| 26 | + options: GenerateWithRetryOptions |
| 27 | +): Promise<T> { |
| 28 | + const { |
| 29 | + maxRetries = 2, |
| 30 | + initialMaxLength, |
| 31 | + reductionFactor = 0.8, |
| 32 | + provider, |
| 33 | + retryableErrors = [ |
| 34 | + AIGenerationErrorType.CONTEXT_LENGTH, |
| 35 | + AIGenerationErrorType.TOKEN_LIMIT, |
| 36 | + ], |
| 37 | + retryDelay = 1000, |
| 38 | + } = options; |
| 39 | + |
| 40 | + let retries = 0; |
| 41 | + let maxInputLength = initialMaxLength; |
| 42 | + |
| 43 | + while (true) { |
| 44 | + try { |
| 45 | + const truncatedPrompt = params.diff.substring(0, maxInputLength); |
| 46 | + |
| 47 | + if (params.diff.length > maxInputLength) { |
| 48 | + NotificationHandler.warn( |
| 49 | + LocalizationManager.getInstance().getMessage( |
| 50 | + `${provider}.input.truncated` |
| 51 | + ) |
| 52 | + ); |
| 53 | + } |
| 54 | + |
| 55 | + return await generateFn(truncatedPrompt); |
| 56 | + } catch (error: any) { |
| 57 | + console.log("error", error); |
| 58 | + if ( |
| 59 | + retries < maxRetries && |
| 60 | + (error.message?.includes("maximum context length") || |
| 61 | + error.message?.includes("context length exceeded") || |
| 62 | + error.message?.includes("exceeds token limit")) |
| 63 | + ) { |
| 64 | + retries++; |
| 65 | + maxInputLength = Math.floor(maxInputLength * reductionFactor); |
| 66 | + continue; |
| 67 | + } |
| 68 | + |
| 69 | + const errorMessage = LocalizationManager.getInstance().format( |
| 70 | + `${provider}.generation.failed`, |
| 71 | + error.message || String(error) |
| 72 | + ); |
| 73 | + NotificationHandler.error(errorMessage); |
| 74 | + throw new Error(errorMessage); |
| 75 | + } |
| 76 | + } |
| 77 | +} |
| 78 | + |
| 79 | +export function getSystemPrompt(params: AIRequestParams): string { |
| 80 | + return ( |
| 81 | + params.systemPrompt || |
| 82 | + generateCommitMessageSystemPrompt( |
| 83 | + params.language || DEFAULT_CONFIG.language, |
| 84 | + params.allowMergeCommits || false, |
| 85 | + params.splitChangesInSingleFile || false, |
| 86 | + params.scm || "git" |
| 87 | + ) |
| 88 | + ); |
| 89 | +} |
0 commit comments