Skip to content

Commit

Permalink
Added nice ✨Enhancing loader for embedding process
Browse files Browse the repository at this point in the history
Added documents caching (PDF)
Added context length limit for RAG
Improved chunking and fixed PDF reading
First attempt of tests for rag.ts
  • Loading branch information
pfrankov committed Oct 10, 2024
1 parent 5e71727 commit 3efd693
Show file tree
Hide file tree
Showing 16 changed files with 826 additions and 256 deletions.
6 changes: 5 additions & 1 deletion jest.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ module.exports = {
moduleNameMapper: {
'^obsidian$': '<rootDir>/tests/__mocks__/obsidian.ts',
'^electron$': '<rootDir>/tests/__mocks__/electron.ts',
'^../logger.js$': '<rootDir>/tests/__mocks__/logger.ts'
'^../logger.js$': '<rootDir>/tests/__mocks__/logger.ts',
'^./pdf.worker.js$': '<rootDir>/tests/__mocks__/pdf.worker.js'
},
transformIgnorePatterns: [
'/node_modules/(?!pdfjs-dist).+\\.js$'
],
};
1 change: 1 addition & 0 deletions src/LocalGPTSettingTab.ts
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ export class LocalGPTSettingTab extends PluginSettingTab {
) || "",
)
.onChange(async (value) => {
clearEmbeddingsCache();
selectedProviderConfig.embeddingModel =
value;
await this.plugin.saveSettings();
Expand Down
6 changes: 5 additions & 1 deletion src/embeddings/CustomEmbeddings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export class CustomEmbeddings extends Embeddings {
constructor(
private config: {
aiProvider: AIProvider;
updateCompletedSteps: (steps: number) => void;
},
) {
super({});
Expand All @@ -16,7 +17,10 @@ export class CustomEmbeddings extends Embeddings {

async embedDocuments(texts: string[]): Promise<number[][]> {
logger.debug("Embedding documents", texts);
return await this.config.aiProvider.getEmbeddings(texts);
return await this.config.aiProvider.getEmbeddings(
texts,
this.config.updateCompletedSteps,
);
}

async embedQuery(text: string): Promise<number[]> {
Expand Down
42 changes: 36 additions & 6 deletions src/indexedDB.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import { openDB, IDBPDatabase } from "idb";

interface CacheItem {
interface EmbeddingsCacheItem {
mtime: number;
chunks: {
content: string;
embedding: number[];
}[];
}

class EmbeddingsCache {
interface ContentCacheItem {
mtime: number;
content: string;
}

class FileCache {
private db: IDBPDatabase | null = null;
private vaultId: string = "";

Expand All @@ -18,24 +23,49 @@ class EmbeddingsCache {
this.db = await openDB(dbName, 1, {
upgrade(db) {
db.createObjectStore("embeddings");
db.createObjectStore("content");
},
});
}

async get(key: string): Promise<CacheItem | undefined> {
async getEmbeddings(key: string): Promise<EmbeddingsCacheItem | undefined> {
if (!this.db) throw new Error("Database not initialized");
return this.db.get("embeddings", key);
}

async set(key: string, value: CacheItem): Promise<void> {
async setEmbeddings(
key: string,
value: EmbeddingsCacheItem,
): Promise<void> {
if (!this.db) throw new Error("Database not initialized");
await this.db.put("embeddings", value, key);
}

async clear(): Promise<void> {
async getContent(key: string): Promise<ContentCacheItem | undefined> {
if (!this.db) throw new Error("Database not initialized");
return this.db.get("content", key);
}

async setContent(key: string, value: ContentCacheItem): Promise<void> {
if (!this.db) throw new Error("Database not initialized");
await this.db.put("content", value, key);
}

async clearEmbeddings(): Promise<void> {
if (!this.db) throw new Error("Database not initialized");
await this.db.clear("embeddings");
}

async clearContent(): Promise<void> {
if (!this.db) throw new Error("Database not initialized");
await this.db.clear("content");
}

async clearAll(): Promise<void> {
if (!this.db) throw new Error("Database not initialized");
await this.db.clear("embeddings");
await this.db.clear("content");
}
}

export const embeddingsCache = new EmbeddingsCache();
export const fileCache = new FileCache();
6 changes: 5 additions & 1 deletion src/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,12 @@ export type AIProviderProcessingOptions = {
temperature: number;
};
};

export interface AIProvider {
abortController?: AbortController;
getEmbeddings(texts: string[]): Promise<number[][]>;
getEmbeddings(
texts: string[],
updateProgress: (progress: number) => void,
): Promise<number[][]>;
process(arg: AIProviderProcessingOptions): Promise<string>;
}
112 changes: 110 additions & 2 deletions src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,23 @@ import {
getLinkedFiles,
} from "./rag";
import { logger } from "./logger";
import { embeddingsCache } from "./indexedDB";
import { fileCache } from "./indexedDB";

export default class LocalGPT extends Plugin {
settings: LocalGPTSettings;
abortControllers: AbortController[] = [];
updatingInterval: number;
private statusBarItem: HTMLElement;
private currentPercentage: number = 0;
private targetPercentage: number = 0;
private animationFrameId: number | null = null;
private totalProgressSteps: number = 0;
private completedProgressSteps: number = 0;

async onload() {
await this.loadSettings();
// @ts-ignore
await embeddingsCache.init(this.app.appId);
await fileCache.init(this.app.appId);
this.reload();
this.app.workspace.onLayoutReady(async () => {
window.setTimeout(() => {
Expand All @@ -39,6 +45,13 @@ export default class LocalGPT extends Plugin {

this.registerEditorExtension(spinnerPlugin);
this.addSettingTab(new LocalGPTSettingTab(this.app, this));
this.initializeStatusBar();
}

private initializeStatusBar() {
this.statusBarItem = this.addStatusBarItem();
this.statusBarItem.addClass("local-gpt-status");
this.statusBarItem.hide();
}

processText(text: string, selectedText: string) {
Expand Down Expand Up @@ -281,17 +294,23 @@ export default class LocalGPT extends Plugin {
if (aiProvider.abortController?.signal.aborted) {
return "";
}

this.initializeProgress();

const processedDocs = await startProcessing(
linkedFiles,
this.app.vault,
this.app.metadataCache,
activeFile,
);

if (processedDocs.size === 0) {
this.hideStatusBar();
return "";
}

if (aiProvider.abortController?.signal.aborted) {
this.hideStatusBar();
return "";
}

Expand All @@ -300,9 +319,12 @@ export default class LocalGPT extends Plugin {
this,
activeFile.path,
aiProvider,
this.addTotalProgressSteps.bind(this),
this.updateCompletedSteps.bind(this),
);

if (aiProvider.abortController?.signal.aborted) {
this.hideStatusBar();
return "";
}

Expand All @@ -311,10 +333,13 @@ export default class LocalGPT extends Plugin {
vectorStore,
);

this.hideStatusBar();

if (relevantContext.trim()) {
return relevantContext;
}
} catch (error) {
this.hideStatusBar();
if (aiProvider.abortController?.signal.aborted) {
return "";
}
Expand All @@ -331,6 +356,9 @@ export default class LocalGPT extends Plugin {
onunload() {
document.removeEventListener("keydown", this.escapeHandler);
window.clearInterval(this.updatingInterval);
if (this.animationFrameId !== null) {
cancelAnimationFrame(this.animationFrameId);
}
}

async loadSettings() {
Expand Down Expand Up @@ -488,4 +516,84 @@ export default class LocalGPT extends Plugin {
await this.saveData(this.settings);
this.reload();
}

private initializeProgress() {
this.totalProgressSteps = 0;
this.completedProgressSteps = 0;
this.currentPercentage = 0;
this.targetPercentage = 0;
this.statusBarItem.show();
this.updateStatusBar();
}

private addTotalProgressSteps(steps: number) {
this.totalProgressSteps += steps;
this.updateProgressBar();
}

private updateCompletedSteps(steps: number) {
this.completedProgressSteps += steps;
this.updateProgressBar();
}

private updateProgressBar() {
const newTargetPercentage =
this.totalProgressSteps > 0
? Math.round(
(this.completedProgressSteps /
this.totalProgressSteps) *
100,
)
: 0;

if (this.targetPercentage !== newTargetPercentage) {
this.targetPercentage = newTargetPercentage;
if (this.animationFrameId === null) {
this.animatePercentage();
}
}
}

private updateStatusBar() {
this.statusBarItem.setAttr(
"data-text",
this.currentPercentage
? `✨ Enhancing ${this.currentPercentage}%`
: "✨ Enhancing",
);
this.statusBarItem.setText(` `);
}

private animatePercentage() {
const startTime = performance.now();
const duration = 300;

const animate = (currentTime: number) => {
const elapsedTime = currentTime - startTime;
const progress = Math.min(elapsedTime / duration, 1);

this.currentPercentage = Math.round(
this.currentPercentage +
(this.targetPercentage - this.currentPercentage) * progress,
);

this.updateStatusBar();

if (progress < 1) {
this.animationFrameId = requestAnimationFrame(animate);
} else {
this.animationFrameId = null;
}
};

this.animationFrameId = requestAnimationFrame(animate);
}

private hideStatusBar() {
this.statusBarItem.hide();
this.totalProgressSteps = 0;
this.completedProgressSteps = 0;
this.currentPercentage = 0;
this.targetPercentage = 0;
}
}
20 changes: 14 additions & 6 deletions src/processors/pdf.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { logger } from "../logger.js";
import * as pdfjs from "pdfjs-dist";
import { TextItem } from "pdfjs-dist/types/src/display/api.js";

// @ts-ignore
import WorkerMessageHandler from "./pdf.worker.js";
Expand Down Expand Up @@ -47,9 +46,18 @@ async function getPageText(
pageNum: number,
): Promise<string> {
const page = await pdf.getPage(pageNum);
const textContent = await page.getTextContent();
return textContent.items
.filter((item) => "str" in item)
.map((item: TextItem) => item.str)
.join(" ");
const content = await page.getTextContent();
let lastY;
const textItems = [];
for (const item of content.items) {
if ("str" in item) {
if (lastY === item.transform[5] || !lastY) {
textItems.push(item.str);
} else {
textItems.push(`\n${item.str}`);
}
lastY = item.transform[5];
}
}
return textItems.join("") + "\n\n";
}
6 changes: 5 additions & 1 deletion src/providers/ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,10 @@ export class OllamaAIProvider implements AIProvider {
return modelInfo;
}

async getEmbeddings(texts: string[]): Promise<number[][]> {
async getEmbeddings(
texts: string[],
updateProgress: (progress: number) => void,
): Promise<number[][]> {
logger.info("Getting embeddings for texts");
const groupedTexts: string[][] = [];
let currentGroup: string[] = [];
Expand Down Expand Up @@ -257,6 +260,7 @@ export class OllamaAIProvider implements AIProvider {
embeddings: json.embeddings,
});
allEmbeddings.push(...json.embeddings);
updateProgress(json.embeddings.length);
}

return allEmbeddings;
Expand Down
6 changes: 5 additions & 1 deletion src/providers/openai-compatible.ts
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ export class OpenAICompatibleAIProvider implements AIProvider {
});
}

async getEmbeddings(texts: string[]): Promise<number[][]> {
async getEmbeddings(
texts: string[],
updateProgress: (progress: number) => void,
): Promise<number[][]> {
logger.info("Getting embeddings for texts");
const results: number[][] = [];

Expand All @@ -205,6 +208,7 @@ export class OpenAICompatibleAIProvider implements AIProvider {
embedding: json.data[0].embedding,
});
results.push(json.data[0].embedding);
updateProgress(1);
} catch (error) {
console.error("Error getting embedding:", { error });
throw error;
Expand Down
Loading

0 comments on commit 3efd693

Please sign in to comment.