Skip to content

Commit

Permalink
DST-227: demo code for call summaries (#31)
Browse files Browse the repository at this point in the history
Resolves DST-227
adds chainlit interface for demo
  • Loading branch information
ccheng26 authored Jun 3, 2024
1 parent e611107 commit c650128
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 123 deletions.
150 changes: 150 additions & 0 deletions 04-call-summaries/chainlit-call-summaries-bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#!/usr/bin/env chainlit run

import pprint

import chainlit as cl
from chainlit.input_widget import Select, Switch, Slider

from chunking import chunking_ingest
from llm import LLM
from refinement import refinement_ingest
from run import get_transcript, stuffing_summary

OLLAMA_LLMS = ["openhermes", "dolphin-mistral"]
GOOGLE_LLMS = ["gemini-pro", "gemini-flash"]
OPENAI_LLMS = ["gpt-4", "gpt-4o"]
ANTHROPIC_LLMS = ["claude"]
SUMMARIZATION_TECHNIQUE = ["stuffing", "chunking", "refinement"]


@cl.on_chat_start
async def init_chat():
await cl.Message(
content="Hi, this is a transcript summarizer, upload a transcript file to summarize it.",
actions=[
cl.Action(name="settingsAct", value="chat_settings", label="Show settings"),
cl.Action(
name="uploadTranscriptAct",
value="upload_transcript_file",
label="Upload transcript file for summarization",
),
],
).send()

settings = await cl.ChatSettings(
[
Select(
id="model",
label="LLM Model",
values=OLLAMA_LLMS + GOOGLE_LLMS + OPENAI_LLMS + ANTHROPIC_LLMS,
initial_index=0,
),
Select(
id="summarization_technique",
label="Choose summarization technique",
values=SUMMARIZATION_TECHNIQUE,
initial_index=0,
),
Slider(
id="temperature",
label="LLM Temperature",
initial=0.1,
min=0,
max=1,
step=0.1,
),
Switch(id="streaming", label="Stream response tokens", initial=False),
]
).send()
cl.user_session.set("settings", settings)
await init_llm_client_if_needed()


@cl.action_callback("settingsAct")
async def on_click_settings(action: cl.Action):
settings = cl.user_session.get("settings")
settings_str = pprint.pformat(settings, indent=4)
await cl.Message(content=f"{action.value}:\n`{settings_str}`").send()


@cl.on_settings_update
async def update_settings(settings):
print("Settings updated:", pprint.pformat(settings, indent=4))
cl.user_session.set("settings", settings)
await set_llm_model()


async def set_llm_model():
settings = cl.user_session.get("settings")
model_name = settings["model"]
llm_settings = dict((k, settings[k]) for k in ["temperature"] if k in settings)
msg = cl.Message(
author="backend",
content=f"Setting up LLM: {model_name} with `{llm_settings}`...\n",
)
client = None
if model_name in OLLAMA_LLMS:
client = LLM(client_name="ollama", model_name=model_name, settings=llm_settings)
elif model_name in GOOGLE_LLMS:
client = LLM(client_name="gemini", model_name=model_name, settings=llm_settings)
elif model_name in OPENAI_LLMS:
client = LLM(client_name="gpt", model_name=model_name, settings=llm_settings)
elif model_name in ANTHROPIC_LLMS:
client = LLM(client_name="claude", settings=llm_settings)
else:
await cl.Message(content=f"Could not initialize model: {model_name}").send()
return

client.init_client()
cl.user_session.set("client", client)
await msg.stream_token(f"Done setting up {model_name} LLM")
await msg.send()


async def init_llm_client_if_needed():
client = cl.user_session.get("client")
if not client:
await set_llm_model()


@cl.action_callback("uploadTranscriptAct")
async def on_click_upload_file_query(action: cl.Action):
files = None
# Wait for the user to upload a file
while files is None:
files = await cl.AskFileMessage(
content="Please upload a transcript as a text file.",
accept=["text/plain", "application/pdf", "application/json"],
max_size_mb=20,
timeout=180,
).send()
file = files[0]
transcript = get_transcript(file_path=file.path)
settings = cl.user_session.get("settings")
msg = cl.Message(content=f"Processing `{file.name}`...", disable_feedback=True)
await msg.send()
msg.content = f"Processing `{file.name}` completed."
await msg.update()
msg.content = "Generating transcript..."
await msg.update()
response = await run_summarization_technique(
technique=settings["summarization_technique"], transcript=transcript
)
print(response)
answer = f"Result:\n{response}"
await cl.Message(content=answer).send()


async def run_summarization_technique(transcript, technique):
await init_llm_client_if_needed()
client = cl.user_session.get("client")
print(f"client: {client}")
if technique == "stuffing":
print("running stuffing")
return stuffing_summary(transcript=transcript, client=client)
elif technique == "chunking":
print("running chunking")
return chunking_ingest(transcript=transcript, client=client)
elif technique == "refinement":
print("running refinement")
return refinement_ingest(transcript=transcript, client=client)
15 changes: 6 additions & 9 deletions 04-call-summaries/chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,10 @@ def get_text_chunks(text, chunk_size, chunk_overlap, text_splitter_choice):
"""


def chunking_ingest(transcript):
def chunking_ingest(transcript, client):
text_chunks = get_text_chunks(
transcript, chunk_size=750, chunk_overlap=300, text_splitter_choice="2"
)
client = select_client()
client.init_client()
ct = datetime.datetime.now()
print(ct)
summaries = []
Expand All @@ -87,14 +85,13 @@ def chunking_ingest(transcript):
summaries=text_summary
)
filled_out_template = client.generate_text(prompt_with_all_summaries)
print(filled_out_template)
return filled_out_template


if __name__ == "__main__":
print(
chunking_ingest(
transcript=get_transcript("./multi_benefit_transcript.txt"),
)
)
client = select_client()
client.init_client()
transcript = get_transcript("./multi_benefit_transcript.txt")
print(chunking_ingest(transcript=transcript, client=client))
ct = datetime.datetime.now()
print(ct)
17 changes: 7 additions & 10 deletions 04-call-summaries/refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,13 @@
"""


def refinement_ingest(transcript, prompt):
def refinement_ingest(transcript, client):
text_chunks = get_text_chunks(
transcript, chunk_size=750, chunk_overlap=300, text_splitter_choice="2"
)
prompt_template = PromptTemplate.from_template(prompt)
prompt_template = PromptTemplate.from_template(CHUNKING_PROMPT)
template = initial_template

client = select_client()
client.init_client()
ct = datetime.datetime.now()
print("current time:-", ct)
for text in text_chunks:
Expand All @@ -55,11 +53,10 @@ def refinement_ingest(transcript, prompt):


if __name__ == "__main__":
print(
refinement_ingest(
transcript=get_transcript("./multi_benefit_transcript.txt"),
prompt=CHUNKING_PROMPT,
)
)
client = select_client()
client.init_client()
transcript = (get_transcript("./multi_benefit_transcript.txt"),)

print(refinement_ingest(transcript=transcript, client=client))
ct = datetime.datetime.now()
print(ct)
3 changes: 2 additions & 1 deletion 04-call-summaries/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ sentence-transformers
nltk
spacy==3.7.4
langchain-text-splitters
langchain_openai
langchain_openai
chainlit
113 changes: 10 additions & 103 deletions 04-call-summaries/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from langchain_core.prompts import PromptTemplate

from llm import LLM
from llm import select_client


# download transcripts from https://drive.google.com/drive/folders/19r6x3Zep4N9Rl_x4n4H6RpWkXviwbxyw?usp=sharing
Expand Down Expand Up @@ -33,108 +33,15 @@ def get_transcript(file_path="./transcript.txt"):
"""


transcript = get_transcript()
prompt_template = PromptTemplate.from_template(PROMPT)
formatted_prompt = prompt_template.format(transcript=transcript)


def stuffing_summary(prompt=None):
print("""
Select an llm
1. openhermes (default)
2. dolphin
3. gemini
4. gpt 4
5. gpt 4o
6. claude 3
7. all
""")

llm = input() or "1"

if llm == "2":
client = LLM(client_name="ollama", model_name="dolphin-mistral")
print("""----------
Dolphin
""")
elif llm == "3":
client = LLM(client_name="gemini")
print("""----------
Gemini Flash 1.5
""")
elif llm == "4":
print("""----------
GPT 4
""")
client = LLM(client_name="gpt", model_name="gpt4")
elif llm == "5":
print("""----------
GPT 4o
""")
client = LLM(client_name="gpt", model_name="gpt4o")
elif llm == "6":
print("""----------
Claude 3
""")
client = LLM(client_name="claude")
elif llm == "7":
print("""
Openhermes
""")
ollama_openhermes = LLM(client_name="ollama", model_name="openhermes")
ollama_openhermes.init_client()
ollama_openhermes_response = ollama_openhermes.generate_text(prompt=prompt)
print(ollama_openhermes_response)

print("""----------
Dolphin
""")
ollama_dolphin = LLM(client_name="ollama", model_name="dolphin-mistral")
ollama_dolphin.init_client()
dolphin_response = ollama_dolphin.generate_text(prompt=prompt)
print(dolphin_response)

print("""----------
Gemini Flash 1.5
""")
gemini = LLM(client_name="gemini")
gemini.init_client()
gemini_response = gemini.generate_text(prompt=prompt)
print(gemini_response)

print("""----------
GPT 4
""")
gpt_4 = LLM(client_name="gpt", model_name="gpt4")
gpt_4.init_client()
gpt_4_response = gpt_4.generate_text(prompt=prompt)
print(gpt_4_response)

print("""----------
GPT 4o
""")
gpt_4o = LLM(client_name="_4o", model_name="gpt4o")
gpt_4o.init_client()
gpt_4o_response = gpt_4o.generate_text(prompt=prompt)
print(gpt_4o_response)

print("""----------
Claude 3
""")
claude = LLM(client_name="claude")
claude.init_client()
claude_response = claude.generate_text(prompt=prompt)
print(claude_response)
else:
client = LLM(client_name="ollama", model_name="openhermes")
print("""
Openhermes
""")
client.init_client()
response = client.generate_text(prompt=prompt)
if response:
print(response)
def stuffing_summary(transcript, client):
prompt_template = PromptTemplate.from_template(PROMPT)
formatted_prompt = prompt_template.format(transcript=transcript)
response = client.generate_text(prompt=formatted_prompt)
return response


if __name__ == "__main__":
stuffing_summary(prompt=formatted_prompt)
transcript = get_transcript()
client = select_client()
client.init_client()
stuffing_summary(transcript=transcript, client=client)

0 comments on commit c650128

Please sign in to comment.