Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add getMemoryByIds to database adapters #2293

Merged
merged 2 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions packages/adapter-pglite/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,33 @@ export class PGLiteDatabaseAdapter
}, "getMemoryById");
}

async getMemoriesByIds(
memoryIds: UUID[],
tableName?: string
): Promise<Memory[]> {
return this.withDatabase(async () => {
if (memoryIds.length === 0) return [];
const placeholders = memoryIds.map((_, i) => `$${i + 1}`).join(",");
let sql = `SELECT * FROM memories WHERE id IN (${placeholders})`;
const queryParams: any[] = [...memoryIds];

if (tableName) {
sql += ` AND type = $${memoryIds.length + 1}`;
queryParams.push(tableName);
}

const { rows } = await this.query<Memory>(sql, queryParams);

return rows.map((row) => ({
...row,
content:
typeof row.content === "string"
? JSON.parse(row.content)
: row.content,
}));
}, "getMemoriesByIds");
}

async createMemory(memory: Memory, tableName: string): Promise<void> {
return this.withDatabase(async () => {
elizaLogger.debug("PostgresAdapter createMemory:", {
Expand Down
27 changes: 27 additions & 0 deletions packages/adapter-postgres/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,33 @@ export class PostgresDatabaseAdapter
}, "getMemoryById");
}

async getMemoriesByIds(
memoryIds: UUID[],
tableName?: string
): Promise<Memory[]> {
return this.withDatabase(async () => {
if (memoryIds.length === 0) return [];
const placeholders = memoryIds.map((_, i) => `$${i + 1}`).join(",");
let sql = `SELECT * FROM memories WHERE id IN (${placeholders})`;
const queryParams: any[] = [...memoryIds];

if (tableName) {
sql += ` AND type = $${memoryIds.length + 1}`;
queryParams.push(tableName);
}

const { rows } = await this.pool.query(sql, queryParams);

return rows.map((row) => ({
...row,
content:
typeof row.content === "string"
? JSON.parse(row.content)
: row.content,
}));
}, "getMemoriesByIds");
}

async createMemory(memory: Memory, tableName: string): Promise<void> {
return this.withDatabase(async () => {
elizaLogger.debug("PostgresAdapter createMemory:", {
Expand Down
27 changes: 27 additions & 0 deletions packages/adapter-sqlite/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,33 @@ export class SqliteDatabaseAdapter
return null;
}

async getMemoriesByIds(
memoryIds: UUID[],
tableName?: string
): Promise<Memory[]> {
if (memoryIds.length === 0) return [];
const queryParams: any[] = [];
const placeholders = memoryIds.map(() => "?").join(",");
let sql = `SELECT * FROM memories WHERE id IN (${placeholders})`;
queryParams.push(...memoryIds);

if (tableName) {
sql += ` AND type = ?`;
queryParams.push(tableName);
}

const memories = this.db.prepare(sql).all(...queryParams) as Memory[];

return memories.map((memory) => ({
...memory,
createdAt:
typeof memory.createdAt === "string"
? Date.parse(memory.createdAt as string)
: memory.createdAt,
content: JSON.parse(memory.content as unknown as string),
}));
}

async createMemory(memory: Memory, tableName: string): Promise<void> {
// Delete any existing memory with the same ID first
// const deleteSql = `DELETE FROM memories WHERE id = ? AND type = ?`;
Expand Down
29 changes: 29 additions & 0 deletions packages/adapter-sqljs/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,35 @@ export class SqlJsDatabaseAdapter
return memory || null;
}

async getMemoriesByIds(
memoryIds: UUID[],
tableName?: string
): Promise<Memory[]> {
if (memoryIds.length === 0) return [];
const placeholders = memoryIds.map(() => "?").join(",");
let sql = `SELECT * FROM memories WHERE id IN (${placeholders})`;
const queryParams: any[] = [...memoryIds];

if (tableName) {
sql += ` AND type = ?`;
queryParams.push(tableName);
}

const stmt = this.db.prepare(sql);
stmt.bind(queryParams);

const memories: Memory[] = [];
while (stmt.step()) {
const memory = stmt.getAsObject() as unknown as Memory;
memories.push({
...memory,
content: JSON.parse(memory.content as unknown as string),
});
}
stmt.free();
return memories;
}

async createMemory(memory: Memory, tableName: string): Promise<void> {
let isUnique = true;
if (memory.embedding) {
Expand Down
25 changes: 25 additions & 0 deletions packages/adapter-supabase/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,31 @@ export class SupabaseDatabaseAdapter extends DatabaseAdapter {
return data as Memory;
}

async getMemoriesByIds(
memoryIds: UUID[],
tableName?: string
): Promise<Memory[]> {
if (memoryIds.length === 0) return [];

let query = this.supabase
.from("memories")
.select("*")
.in("id", memoryIds);

if (tableName) {
query = query.eq("type", tableName);
}

const { data, error } = await query;

if (error) {
console.error("Error retrieving memories by IDs:", error);
return [];
}

return data as Memory[];
}
wtfsayo marked this conversation as resolved.
Show resolved Hide resolved

async createMemory(
memory: Memory,
tableName: string,
Expand Down
12 changes: 12 additions & 0 deletions packages/core/__tests__/database.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@ class MockDatabaseAdapter extends DatabaseAdapter {
getMemoryById(_id: UUID): Promise<Memory | null> {
throw new Error("Method not implemented.");
}
async getMemoriesByIds(
memoryIds: UUID[],
_tableName?: string
): Promise<Memory[]> {
return memoryIds.map((id) => ({
id: id,
content: { text: "Test Memory" },
roomId: "room-id" as UUID,
userId: "user-id" as UUID,
agentId: "agent-id" as UUID,
})) as Memory[];
}
log(_params: {
body: { [key: string]: unknown };
userId: UUID;
Expand Down
15 changes: 13 additions & 2 deletions packages/core/src/database.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,17 @@ export abstract class DatabaseAdapter<DB = any> implements IDatabaseAdapter {

abstract getMemoryById(id: UUID): Promise<Memory | null>;

/**
* Retrieves multiple memories by their IDs
* @param memoryIds Array of UUIDs of the memories to retrieve
* @param tableName Optional table name to filter memories by type
* @returns Promise resolving to array of Memory objects
*/
abstract getMemoriesByIds(
memoryIds: UUID[],
tableName?: string
): Promise<Memory[]>;

/**
* Retrieves cached embeddings based on the specified query parameters.
* @param params An object containing parameters for the embedding retrieval.
Expand Down Expand Up @@ -382,12 +393,12 @@ export abstract class DatabaseAdapter<DB = any> implements IDatabaseAdapter {
userId: UUID;
}): Promise<Relationship[]>;

/**
/**
* Retrieves knowledge items based on specified parameters.
* @param params Object containing search parameters
* @returns Promise resolving to array of knowledge items
*/
abstract getKnowledge(params: {
abstract getKnowledge(params: {
id?: UUID;
agentId: UUID;
limit?: number;
Expand Down
36 changes: 30 additions & 6 deletions packages/core/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ export enum ModelProviderName {
AKASH_CHAT_API = "akash_chat_api",
LIVEPEER = "livepeer",
LETZAI = "letzai",
DEEPSEEK="deepseek",
INFERA="infera"
DEEPSEEK = "deepseek",
INFERA = "infera",
}

/**
Expand Down Expand Up @@ -909,6 +909,8 @@ export interface IDatabaseAdapter {

getMemoryById(id: UUID): Promise<Memory | null>;

getMemoriesByIds(ids: UUID[], tableName?: string): Promise<Memory[]>;

getMemoriesByRoomIds(params: {
tableName: string;
agentId: UUID;
Expand Down Expand Up @@ -1087,7 +1089,10 @@ export interface IMemoryManager {
): Promise<{ embedding: number[]; levenshtein_score: number }[]>;

getMemoryById(id: UUID): Promise<Memory | null>;
getMemoriesByRoomIds(params: { roomIds: UUID[], limit?: number }): Promise<Memory[]>;
getMemoriesByRoomIds(params: {
roomIds: UUID[];
limit?: number;
}): Promise<Memory[]>;
searchMemoriesByEmbedding(
embedding: number[],
opts: {
Expand Down Expand Up @@ -1378,9 +1383,28 @@ export interface IrysTimestamp {
}

export interface IIrysService extends Service {
getDataFromAnAgent(agentsWalletPublicKeys: string[], tags: GraphQLTag[], timestamp: IrysTimestamp): Promise<DataIrysFetchedFromGQL>;
workerUploadDataOnIrys(data: any, dataType: IrysDataType, messageType: IrysMessageType, serviceCategory: string[], protocol: string[], validationThreshold: number[], minimumProviders: number[], testProvider: boolean[], reputation: number[]): Promise<UploadIrysResult>;
providerUploadDataOnIrys(data: any, dataType: IrysDataType, serviceCategory: string[], protocol: string[]): Promise<UploadIrysResult>;
getDataFromAnAgent(
agentsWalletPublicKeys: string[],
tags: GraphQLTag[],
timestamp: IrysTimestamp
): Promise<DataIrysFetchedFromGQL>;
workerUploadDataOnIrys(
data: any,
dataType: IrysDataType,
messageType: IrysMessageType,
serviceCategory: string[],
protocol: string[],
validationThreshold: number[],
minimumProviders: number[],
testProvider: boolean[],
reputation: number[]
): Promise<UploadIrysResult>;
providerUploadDataOnIrys(
data: any,
dataType: IrysDataType,
serviceCategory: string[],
protocol: string[]
): Promise<UploadIrysResult>;
}

export interface ITeeLogService extends Service {
Expand Down
Loading