Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve workflow #31

Merged
merged 4 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# https://EditorConfig.org

root = true

[*]
charset = utf-8
end_of_line = lf
indent_size = 2
indent_style = space
insert_final_newline = true
trim_trailing_whitespace = true
5 changes: 5 additions & 0 deletions .prettierc
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"printWidth": 120,
"tabWidth": 2,
"useTabs": false,
}
3 changes: 2 additions & 1 deletion server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ app.post("/ask", async (c) => {
return c.json({ error: "question is required." }, 400);
}

const response = await sentai.agent.run(content);
const response = await sentai.agent.execute(content);
return c.json({ data: response });
} catch (e) {
console.error(e);
return c.json({ error: "Internal server error." }, 400);
}
});
Expand Down
13 changes: 4 additions & 9 deletions src/SentientAI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,13 @@ import {

export class SentientAI {
weatherAgent = new Agent({
name: "Weather Agent",
description:
"Get current weather with CurrentWeatherAPITool and forecast weather with ForecastWeatherAPITool.",
tools: [
new CurrentWeatherAPITool(process.env.NUBILA_API_KEY!),
new ForecastWeatherAPITool(process.env.OPENWEATHER_API_KEY!),
new CurrentWeatherAPITool(),
new ForecastWeatherAPITool(),
],
});

newsTool = new NewsAPITool(process.env.NEWSAPI_API_KEY!);
newsTool = new NewsAPITool();

agent = new Agent({
tools: [this.weatherAgent, this.newsTool],
});
agent = new Agent({ tools: [this.weatherAgent, this.newsTool] });
}
61 changes: 37 additions & 24 deletions src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,48 @@ import { Tool } from "./tools/tool";
import { Memory } from "./memory";
import { Workflow } from "./workflow";

interface PromptContext {
tool: Tool;
toolOutput: string;
toolInput: string;
input: string;
}

export interface Agent {
name: string;
description: string;
tools: Tool[];
prompt: (ctx: PromptContext) => string;
}

export class Agent {
name: string = "";
description: string = "";
tools: (Tool | Agent)[] = [];
tools: Tool[] = [];
workflow: Workflow;

// support tempalte format
prompt = (ctx: PromptContext) => `
User Input: ${ctx.input}
Tool Used: ${ctx.tool.name}
Tool Input: ${ctx.toolInput}
Tool Output: ${ctx.toolOutput}

private workflow: Workflow;
Generate a human-readable response based on the tool output${ctx.tool.twitterAccount ? ` and mention x handle ${ctx.tool.twitterAccount} in the end.` : ""}`;

constructor({
name,
description,
fastllm,
llm,
tools,
memory,
}: {
name?: string;
description?: string;
fastllm?: LLM;
llm?: LLM;
tools: (Tool | Agent)[];
memory?: Memory;
}) {
this.name = name || "";
this.description = description || "";
this.tools = tools;
this.workflow = new Workflow({ fastllm, llm, tools, memory });
constructor(args: Partial<Agent> = {}) {
Object.assign(this, args);
this.tools = this.tools.flatMap((i) => {
if (i instanceof Agent) {
return i.tools;
}
return i;
});
if (!this.workflow) {
this.workflow = new Workflow({});
}
this.workflow.agent = this;
}

async run(input: string): Promise<string> {
async execute(input: string): Promise<string> {
return this.workflow.execute(input);
}
}
14 changes: 4 additions & 10 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@ async function runExample() {
return;
}

const openWeatherApiKey = process.env.OPEN_WEATHER_API_KEY;
if (!openWeatherApiKey) {
console.error("Please set the OPEN_WEATHER_API_KEY environment variable.");
return;
}

const newsApiKey = process.env.NEWSAPI_API_KEY;
if (!newsApiKey) {
Expand All @@ -47,16 +42,15 @@ async function runExample() {

const weatherAgent = new Agent({
tools: [
new CurrentWeatherAPITool(nubilaApiKey),
new ForecastWeatherAPITool(openWeatherApiKey),
new CurrentWeatherAPITool(),
new ForecastWeatherAPITool(),
],
});

const newsTool = new NewsAPITool(newsApiKey);

const tools: (Tool | Agent)[] = [weatherAgent, newsTool];
const memory = new SimpleMemory();
const agent = new Agent({ llm, tools, memory });
const agent = new Agent({ tools });

const inputs = [
"Hello World",
Expand All @@ -69,7 +63,7 @@ async function runExample() {
for (const input of inputs) {
console.log(`User Input: ${input}`);
try {
const response = await agent.run(input);
const response = await agent.execute(input);
console.log(`Agent Response: ${response}`);
} catch (error) {
console.error("Error running agent:", error);
Expand Down
83 changes: 64 additions & 19 deletions src/llm.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import OpenAI from "openai";
import { RAGApplication, RAGApplicationBuilder } from '@llm-tools/embedjs';
import { LibSqlDb } from '@llm-tools/embedjs-libsql';

import { OpenAiEmbeddings } from '@llm-tools/embedjs-openai';
import { RAGApplication, RAGApplicationBuilder } from "@llm-tools/embedjs";
import { LibSqlDb, LibSqlStore } from "@llm-tools/embedjs-libsql";
import { OpenAi } from "@llm-tools/embedjs-openai";
import { OpenAiEmbeddings } from "@llm-tools/embedjs-openai";
export interface LLM {
generate(prompt: string): Promise<string>;
}
Expand All @@ -21,7 +21,11 @@ export class OpenAILLM implements LLM {
private openai: OpenAI;
private model: string;

constructor(apiKey: string, model: string = "gpt-4") { // Default to gpt-4
constructor(
apiKey: string = process.env.OPENAI_API_KEY!,
model: string = "gpt-4",
) {
// Default to gpt-4
if (!apiKey) {
throw new Error("OpenAI API key is required.");
}
Expand All @@ -40,16 +44,20 @@ export class OpenAILLM implements LLM {

// Correctly access the message content
const message = completion.choices?.[0]?.message;
if (message) { // Check if message exists
if (message) {
// Check if message exists
return message.content?.trim() || "No content in message"; // Check if message.content exists
} else {
console.error("Unexpected OpenAI response format:", completion); // Log the full response
return "No message in response";
}

} catch (error: any) {
if (error.response) {
console.error("OpenAI API Error:", error.response.status, error.response.data);
console.error(
"OpenAI API Error:",
error.response.status,
error.response.data,
);
} else {
console.error("OpenAI Error:", error.message);
}
Expand All @@ -58,31 +66,68 @@ export class OpenAILLM implements LLM {
}
}

export class OpenAIRAG implements LLM {
rag: RAGApplication | null = null;

export class EmbedLLM implements LLM {
model: RAGApplication | null = null
constructor(args: Partial<FastLLM> = {}) {
Object.assign(this, args);
if (!this.rag) {
new RAGApplicationBuilder()
.setModel(
new OpenAi({
model: "gpt-3.5-turbo",
}),
)
.setEmbeddingModel(
new OpenAiEmbeddings({
model: "text-embedding-3-small",
}),
)
.setVectorDatabase(new LibSqlDb({ path: "./data.db" }))
.build()
.then((rag) => (this.rag = rag));
}
}

constructor(args: Partial<EmbedLLM> = {}) {
Object.assign(this, args)
async generate(prompt: string): Promise<string> {
try {
const result = await this.rag?.query(prompt);
console.log(result);
return result?.content.trim() || "No content in response";
} catch (error: any) {
console.error(" API Error:", error.message);
return ` API Error: ${error.message}`;
}
}
}

export class FastLLM implements LLM {
model: RAGApplication | null = null;

constructor(args: Partial<FastLLM> = {}) {
Object.assign(this, args);
if (!this.model) {
new RAGApplicationBuilder()
.setEmbeddingModel(new OpenAiEmbeddings({
model: 'text-embedding-3-small'
}))
.setVectorDatabase(new LibSqlDb({ path: './data.db' }))
.build().then(model => this.model = model)
.setModel(new OpenAi({ model: "gpt-3.5-turbo" }))
.setEmbeddingModel(
new OpenAiEmbeddings({
model: "text-embedding-3-small",
}),
)
.setVectorDatabase(new LibSqlDb({ path: "./data.db" }))
.build()
.then((model) => (this.model = model));
}
}

async generate(prompt: string): Promise<string> {
try {
const result = await this.model?.query(prompt);
console.log(result)
console.log(result);
return result?.content.trim() || "No content in response";
} catch (error: any) {
console.error("Together API Error:", error.message);
return `Together API Error: ${error.message}`;
}
}
}

8 changes: 6 additions & 2 deletions src/tools/newsapi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ interface NewsAPIResponse {
}

export class NewsAPITool extends APITool {
constructor(apiKey: string) {
super("NewsAPI", "Fetches today's headlines from News API", apiKey);
constructor() {
if (!process.env.NEWSAPI_API_KEY) {
console.error("Please set the NUBILA_API_KEY environment variable.");
return;
}
super("NewsAPI", "Fetches today's headlines from News API", process.env.NEWSAPI_API_KEY!);
}

async execute(input: string): Promise<string> {
Expand Down
Loading
Loading