Skip to content

Commit

Permalink
Support the display of text + picture mixed mode for WebModule (#417)
Browse files Browse the repository at this point in the history
Co-authored-by: wangjian <[email protected]>
  • Loading branch information
wangjian052163 and wangjian authored Jan 23, 2025
1 parent 8970487 commit ff9c5e2
Showing 1 changed file with 34 additions and 6 deletions.
40 changes: 34 additions & 6 deletions lazyllm/tools/webpages/webmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from lazyllm.thirdparty import gradio as gr, PIL
import time
import re
from pathlib import Path
from typing import List, Union

import lazyllm
from lazyllm import LOG, globals, FileSystemQueue, OnlineChatModule, TrainableModule
Expand All @@ -34,8 +36,17 @@ class Mode:

def __init__(self, m, *, components=dict(), title='对话演示终端', port=None,
history=[], text_mode=None, trace_mode=None, audio=False, stream=False,
files_target=None) -> None:
files_target=None, static_paths: Union[str, Path, List[str | Path]] = None) -> None:
super().__init__()
# Set the static directory of gradio so that gradio can access local resources in the directory
if isinstance(static_paths, (str, Path)):
self._static_paths = [static_paths]
elif isinstance(static_paths, list) and all(isinstance(p, (str, Path)) for p in static_paths):
self._static_paths = static_paths
elif static_paths is None:
self._static_paths = []
else:
raise ValueError(f"static_paths only supported str, path or list types. Not supported {static_paths}")
self.m = lazyllm.ActionModule(m) if isinstance(m, lazyllm.FlowBase) else m
self.pool = lazyllm.ThreadPoolExecutor(max_workers=50)
self.title = title
Expand Down Expand Up @@ -82,6 +93,7 @@ def _set_up_caching(self):
return cach_path

def init_web(self, component_descs):
gr.set_static_paths(self._static_paths)
with gr.Blocks(css=css, title=self.title, analytics_enabled=False) as demo:
sess_data = gr.State(value={
'sess_titles': [''],
Expand Down Expand Up @@ -314,6 +326,15 @@ def get_log_and_message(s):
LOG.error(f"Uncaptured error `{e}` when parsing `{s}`, please contact us if you see this.")
return s, "".join(log_history), None

def contains_markdown_image(text: str):
pattern = r"!\[.*?\]\((.*?)\)"
return bool(re.search(pattern, text))

def extract_img_path(text: str):
pattern = r"!\[.*?\]\((.*?)\)"
urls = re.findall(pattern, text)
return urls

file_paths = None
if isinstance(result, (str, dict)):
result, log, file_paths = get_log_and_message(result)
Expand All @@ -334,11 +355,18 @@ def get_log_and_message(s):
if result:
chat_history.append([None, result])
else:
assert isinstance(result, (str, dict)), f'Result should only be str, but got {type(result)}'
if isinstance(result, dict): result = result.get('message', '')
count = (len(match.group(1)) if (match := re.search(r'(\n+)$', result)) else 0) + len(result) + 1
if result and not (result in chat_history[-1][1][-count:]):
chat_history[-1][1] += "\n\n" + result
assert isinstance(result, str), f'Result should only be str, but got {type(result)}'
if not contains_markdown_image(result):
count = (len(match.group(1)) if (match := re.search(r'(\n+)$', result)) else 0) + len(result) + 1
if result and not (result in chat_history[-1][1][-count:]):
chat_history[-1][1] += "\n\n" + result
else:
urls = extract_img_path(result)
for url in urls:
suffix = os.path.splitext(url)[-1].lower()
if suffix in PIL.Image.registered_extensions().keys() and os.path.exists(url):
result = result.replace(url, "file=" + url)
chat_history[-1][1] += result
except requests.RequestException as e:
chat_history = None
log = str(e)
Expand Down

0 comments on commit ff9c5e2

Please sign in to comment.