Skip to content

Commit

Permalink
Gradio theme cache
Browse files Browse the repository at this point in the history
  • Loading branch information
w-e-w committed Aug 6, 2023
1 parent c6278c1 commit bf38252
Showing 1 changed file with 33 additions and 1 deletion.
34 changes: 33 additions & 1 deletion modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ def list_samplers():
options_templates.update(options_section(('ui', "User interface"), {
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
"re_download_theme": OptionInfo(False, "Re-download the selected Gradio theme"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
Expand Down Expand Up @@ -846,6 +847,38 @@ def sd_model(self, value):

progress_print_out = sys.stdout


def from_hub_with_cache_wrapper(func):
def wrapper(*args, **kwargs):
import pickle
repo_name = ''
if 'key_name' in kwargs:
repo_name = kwargs['repo_name']
elif args and len(args) >= 1:
repo_name = args[0]
if repo_name:
theme_cache_path = os.path.join(script_path, 'tmp', 'gradio_themes', repo_name.replace('/', '_'))
# if theme is cached use cache and same gradio version
if not opts.re_download_theme and os.path.exists(theme_cache_path):
with open(theme_cache_path, 'rb') as cached_theme:
theme_cache = pickle.load(cached_theme)
if gr.__version__ == theme_cache.get('gradio_version'):
return theme_cache.get('theme')
# get theme from hub
result = func(*args, **kwargs)
# save theme to cache
os.makedirs(os.path.dirname(theme_cache_path), exist_ok=True)
with open(theme_cache_path, 'wb') as cached_theme:
theme_cache = {'theme': result, 'gradio_version': gr.__version__}
pickle.dump(theme_cache, cached_theme)

return result
return wrapper


gr.themes.ThemeClass.from_hub = from_hub_with_cache_wrapper(gr.themes.ThemeClass.from_hub) # decorates gr.themes.ThemeClass.from_hub with from_hub_with_cache_wrapper


gradio_theme = gr.themes.Base()


Expand All @@ -869,7 +902,6 @@ def reload_gradio_theme(theme_name=None):
gradio_theme = gr.themes.Default(**default_theme_args)



class TotalTQDM:
def __init__(self):
self._tqdm = None
Expand Down

0 comments on commit bf38252

Please sign in to comment.