Skip to content

Commit

Permalink
🐛 fix: fix inbox agent can not save config (lobehub#6186)
Browse files Browse the repository at this point in the history
  • Loading branch information
arvinxx authored Feb 16, 2025
1 parent daa60e9 commit 588cba7
Show file tree
Hide file tree
Showing 11 changed files with 174 additions and 49 deletions.
6 changes: 3 additions & 3 deletions src/database/server/models/__tests__/session.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ describe('SessionModel', () => {

describe('createInbox', () => {
it('should create inbox session if not exists', async () => {
const inbox = await sessionModel.createInbox();
const inbox = await sessionModel.createInbox({});

expect(inbox).toBeDefined();
expect(inbox?.slug).toBe('inbox');
Expand All @@ -641,10 +641,10 @@ describe('SessionModel', () => {

it('should not create duplicate inbox session', async () => {
// Create first inbox
await sessionModel.createInbox();
await sessionModel.createInbox({});

// Try to create another inbox
const duplicateInbox = await sessionModel.createInbox();
const duplicateInbox = await sessionModel.createInbox({});

// Should return undefined as inbox already exists
expect(duplicateInbox).toBeUndefined();
Expand Down
11 changes: 5 additions & 6 deletions src/database/server/models/session.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Column, count, sql } from 'drizzle-orm';
import { and, asc, desc, eq, gt, inArray, isNull, like, not, or } from 'drizzle-orm/expressions';
import { DeepPartial } from 'utility-types';

import { appEnv } from '@/config/app';
import { DEFAULT_INBOX_AVATAR } from '@/const/meta';
import { INBOX_SESSION_ID } from '@/const/session';
import { DEFAULT_AGENT_CONFIG } from '@/const/settings';
Expand All @@ -13,7 +13,7 @@ import {
genWhere,
} from '@/database/utils/genWhere';
import { idGenerator } from '@/database/utils/idGenerator';
import { parseAgentConfig } from '@/server/globalConfig/parseDefaultAgent';
import { LobeAgentConfig } from '@/types/agent';
import { ChatSessionList, LobeAgentSession, SessionRankItem } from '@/types/session';
import { merge } from '@/utils/merge';

Expand Down Expand Up @@ -226,16 +226,15 @@ export class SessionModel {
});
};

createInbox = async () => {
createInbox = async (defaultAgentConfig: DeepPartial<LobeAgentConfig>) => {
const item = await this.db.query.sessions.findFirst({
where: and(eq(sessions.userId, this.userId), eq(sessions.slug, INBOX_SESSION_ID)),
});
if (item) return;

const serverAgentConfig = parseAgentConfig(appEnv.DEFAULT_AGENT_CONFIG) || {};
if (item) return;

return await this.create({
config: merge(DEFAULT_AGENT_CONFIG, serverAgentConfig),
config: merge(DEFAULT_AGENT_CONFIG, defaultAgentConfig),
slug: INBOX_SESSION_ID,
type: 'agent',
});
Expand Down
6 changes: 1 addition & 5 deletions src/database/server/models/user.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import { merge } from '@/utils/merge';
import { today } from '@/utils/time';

import { NewUser, UserItem, UserSettingsItem, userSettings, users } from '../../schemas';
import { SessionModel } from './session';

type DecryptUserKeyVaults = (
encryptKeyVaultsStr: string | null,
Expand Down Expand Up @@ -160,10 +159,7 @@ export class UserModel {
.values({ ...params })
.returning();

// Create an inbox session for the user
const model = new SessionModel(db, user.id);

await model.createInbox();
return user;
};

static deleteUser = async (db: LobeChatDatabase, id: string) => {
Expand Down
67 changes: 34 additions & 33 deletions src/features/DevPanel/features/Table/TableCell.tsx
Original file line number Diff line number Diff line change
@@ -1,29 +1,27 @@
import { Typography } from 'antd';
import { createStyles } from 'antd-style';
import dayjs from 'dayjs';
import { get, isDate } from 'lodash-es';
import React, { useMemo } from 'react';

import TooltipContent from './TooltipContent';
// import TooltipContent from './TooltipContent';

const { Text } = Typography;
// const { Text } = Typography;

const useStyles = createStyles(({ token, css }) => ({
cell: css`
font-family: ${token.fontFamilyCode};
font-size: ${token.fontSizeSM}px;
`,
tooltip: css`
border: 1px solid ${token.colorBorder};
font-family: ${token.fontFamilyCode};
font-size: ${token.fontSizeSM}px;
color: ${token.colorText} !important;
word-break: break-all;
background: ${token.colorBgElevated} !important;
`,
}));
// const useStyles = createStyles(({ token, css }) => ({
// cell: css`
// font-family: ${token.fontFamilyCode};
// font-size: ${token.fontSizeSM}px;
// `,
// tooltip: css`
// border: 1px solid ${token.colorBorder};
//
// font-family: ${token.fontFamilyCode};
// font-size: ${token.fontSizeSM}px;
// color: ${token.colorText} !important;
// word-break: break-all;
//
// background: ${token.colorBgElevated} !important;
// `,
// }));

interface TableCellProps {
column: string;
Expand All @@ -32,7 +30,7 @@ interface TableCellProps {
}

const TableCell = ({ dataItem, column, rowIndex }: TableCellProps) => {
const { styles } = useStyles();
// const { styles } = useStyles();
const data = get(dataItem, column);
const content = useMemo(() => {
if (isDate(data)) return dayjs(data).format('YYYY-MM-DD HH:mm:ss');
Expand All @@ -54,18 +52,21 @@ const TableCell = ({ dataItem, column, rowIndex }: TableCellProps) => {

return (
<td key={column} onDoubleClick={() => console.log('Edit cell:', rowIndex, column)}>
<Text
className={styles.cell}
ellipsis={{
tooltip: {
arrow: false,
classNames: { body: styles.tooltip },
title: <TooltipContent>{content}</TooltipContent>,
},
}}
>
{content}
</Text>
{content}

{/* 不能使用 antd 的 Text, 会有大量的重渲染导致滚动极其卡顿 */}
{/*<Text*/}
{/* className={styles.cell}*/}
{/* ellipsis={{*/}
{/* tooltip: {*/}
{/* arrow: false,*/}
{/* classNames: { body: styles.tooltip },*/}
{/* title: <TooltipContent>{content}</TooltipContent>,*/}
{/* },*/}
{/* }}*/}
{/*>*/}
{/* {content}*/}
{/*</Text>*/}
</td>
);
};
Expand Down
7 changes: 7 additions & 0 deletions src/libs/next-auth/adapter/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { Adapter, AdapterAccount } from 'next-auth/adapters';

import * as schema from '@/database/schemas';
import { UserModel } from '@/database/server/models/user';
import { AgentService } from '@/server/services/agent';
import { merge } from '@/utils/merge';

import {
Expand Down Expand Up @@ -65,6 +66,7 @@ export function LobeNextAuthDbAdapter(serverDB: NeonDatabase<typeof schema>): Ad
const adapterUser = mapLobeUserToAdapterUser(existingUser);
return adapterUser;
}

// create a new user if it does not exist
await UserModel.createUser(
serverDB,
Expand All @@ -77,6 +79,11 @@ export function LobeNextAuthDbAdapter(serverDB: NeonDatabase<typeof schema>): Ad
name,
}),
);

// 3. Create an inbox session for the user
const agentService = new AgentService(serverDB, id);
await agentService.createInbox();

return { ...user, id: providerAccountId ?? id };
},
async createVerificationToken(data): Promise<VerificationToken | null | undefined> {
Expand Down
4 changes: 3 additions & 1 deletion src/server/routers/lambda/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { SessionModel } from '@/database/server/models/session';
import { UserModel } from '@/database/server/models/user';
import { pino } from '@/libs/logger';
import { authedProcedure, router } from '@/libs/trpc';
import { AgentService } from '@/server/services/agent';
import { KnowledgeItem, KnowledgeType } from '@/types/knowledgeBase';

const agentProcedure = authedProcedure.use(async (opts) => {
Expand All @@ -18,6 +19,7 @@ const agentProcedure = authedProcedure.use(async (opts) => {
return opts.next({
ctx: {
agentModel: new AgentModel(serverDB, ctx.userId),
agentService: new AgentService(serverDB, ctx.userId),
fileModel: new FileModel(serverDB, ctx.userId),
knowledgeBaseModel: new KnowledgeBaseModel(serverDB, ctx.userId),
sessionModel: new SessionModel(serverDB, ctx.userId),
Expand Down Expand Up @@ -91,7 +93,7 @@ export const agentRouter = router({
const user = await UserModel.findById(serverDB, ctx.userId);
if (!user) return DEFAULT_AGENT_CONFIG;

const res = await ctx.sessionModel.createInbox();
const res = await ctx.agentService.createInbox();
pino.info('create inbox session', res);
}
}
Expand Down
65 changes: 65 additions & 0 deletions src/server/services/agent/index.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// @vitest-environment node
import { beforeEach, describe, expect, it, vi } from 'vitest';

import { SessionModel } from '@/database/server/models/session';
import { parseAgentConfig } from '@/server/globalConfig/parseDefaultAgent';

import { AgentService } from './index';

vi.mock('@/config/app', () => ({
appEnv: {
DEFAULT_AGENT_CONFIG: 'model=gpt-4;temperature=0.7',
},
}));

vi.mock('@/server/globalConfig/parseDefaultAgent', () => ({
parseAgentConfig: vi.fn(),
}));

vi.mock('@/database/server/models/session', () => ({
SessionModel: vi.fn(),
}));

describe('AgentService', () => {
let service: AgentService;
const mockDb = {} as any;
const mockUserId = 'test-user-id';

beforeEach(() => {
vi.clearAllMocks();
service = new AgentService(mockDb, mockUserId);
});

describe('createInbox', () => {
it('should create inbox with default agent config', async () => {
const mockConfig = { model: 'gpt-4', temperature: 0.7 };
const mockSessionModel = {
createInbox: vi.fn(),
};

(SessionModel as any).mockImplementation(() => mockSessionModel);
(parseAgentConfig as any).mockReturnValue(mockConfig);

await service.createInbox();

expect(SessionModel).toHaveBeenCalledWith(mockDb, mockUserId);
expect(parseAgentConfig).toHaveBeenCalledWith('model=gpt-4;temperature=0.7');
expect(mockSessionModel.createInbox).toHaveBeenCalledWith(mockConfig);
});

it('should create inbox with empty config if parseAgentConfig returns undefined', async () => {
const mockSessionModel = {
createInbox: vi.fn(),
};

(SessionModel as any).mockImplementation(() => mockSessionModel);
(parseAgentConfig as any).mockReturnValue(undefined);

await service.createInbox();

expect(SessionModel).toHaveBeenCalledWith(mockDb, mockUserId);
expect(parseAgentConfig).toHaveBeenCalledWith('model=gpt-4;temperature=0.7');
expect(mockSessionModel.createInbox).toHaveBeenCalledWith({});
});
});
});
22 changes: 22 additions & 0 deletions src/server/services/agent/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import { appEnv } from '@/config/app';
import { SessionModel } from '@/database/server/models/session';
import { LobeChatDatabase } from '@/database/type';
import { parseAgentConfig } from '@/server/globalConfig/parseDefaultAgent';

export class AgentService {
private readonly userId: string;
private readonly db: LobeChatDatabase;

constructor(db: LobeChatDatabase, userId: string) {
this.userId = userId;
this.db = db;
}

async createInbox() {
const sessionModel = new SessionModel(this.db, this.userId);

const defaultAgentConfig = parseAgentConfig(appEnv.DEFAULT_AGENT_CONFIG) || {};

await sessionModel.createInbox(defaultAgentConfig);
}
}
18 changes: 17 additions & 1 deletion src/server/services/user/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest';
import { UserItem } from '@/database/schemas';
import { UserModel } from '@/database/server/models/user';
import { pino } from '@/libs/logger';
import { AgentService } from '@/server/services/agent';

import { UserService } from './index';

Expand All @@ -29,6 +30,12 @@ vi.mock('@/libs/logger', () => ({
},
}));

vi.mock('@/server/services/agent', () => ({
AgentService: vi.fn().mockImplementation(() => ({
createInbox: vi.fn().mockResolvedValue(undefined),
})),
}));

let service: UserService;
const mockUserId = 'test-user-id';

Expand Down Expand Up @@ -57,7 +64,7 @@ describe('UserService', () => {
// Mock user not found
vi.mocked(UserModel.findById).mockResolvedValue(null as any);

await service.createUser(mockUserId, mockUserJSON);
const result = await service.createUser(mockUserId, mockUserJSON);

expect(UserModel.findById).toHaveBeenCalledWith(expect.anything(), mockUserId);
expect(UserModel.createUser).toHaveBeenCalledWith(
Expand All @@ -73,6 +80,12 @@ describe('UserService', () => {
clerkCreatedAt: new Date('2023-01-01T00:00:00Z'),
}),
);
expect(AgentService).toHaveBeenCalledWith(expect.anything(), mockUserId);
expect(vi.mocked(AgentService).mock.results[0].value.createInbox).toHaveBeenCalled();
expect(result).toEqual({
message: 'user created',
success: true,
});
});

it('should not create user if already exists', async () => {
Expand All @@ -83,6 +96,7 @@ describe('UserService', () => {

expect(UserModel.findById).toHaveBeenCalledWith(expect.anything(), mockUserId);
expect(UserModel.createUser).not.toHaveBeenCalled();
expect(AgentService).not.toHaveBeenCalled();
expect(result).toEqual({
message: 'user not created due to user already existing in the database',
success: false,
Expand All @@ -106,6 +120,8 @@ describe('UserService', () => {
phone: '+1234567890', // Should use first phone number
}),
);
expect(AgentService).toHaveBeenCalledWith(expect.anything(), mockUserId);
expect(vi.mocked(AgentService).mock.results[0].value.createInbox).toHaveBeenCalled();
});
});

Expand Down
5 changes: 5 additions & 0 deletions src/server/services/user/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { serverDB } from '@/database/server';
import { UserModel } from '@/database/server/models/user';
import { pino } from '@/libs/logger';
import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt';
import { AgentService } from '@/server/services/agent';

export class UserService {
createUser = async (id: string, params: UserJSON) => {
Expand Down Expand Up @@ -41,6 +42,10 @@ export class UserService {
username: params.username,
});

// 3. Create an inbox session for the user
const agentService = new AgentService(serverDB, id);
await agentService.createInbox();

/* ↓ cloud slot ↓ */

/* ↑ cloud slot ↑ */
Expand Down
Loading

0 comments on commit 588cba7

Please sign in to comment.