From 15eccb1197366c925a5beff310fb5d7d369bde0c Mon Sep 17 00:00:00 2001 From: Ian Tenney Date: Thu, 30 May 2024 09:43:37 -0700 Subject: [PATCH] Simplify file_cache logic and re-enable for sequence salience demo. PiperOrigin-RevId: 638684231 --- lit_nlp/examples/lm_salience_demo.py | 19 +++++++++---------- lit_nlp/lib/file_cache.py | 1 + 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/lit_nlp/examples/lm_salience_demo.py b/lit_nlp/examples/lm_salience_demo.py index 37a68d1f..fb555bec 100644 --- a/lit_nlp/examples/lm_salience_demo.py +++ b/lit_nlp/examples/lm_salience_demo.py @@ -54,10 +54,7 @@ from lit_nlp import server_flags from lit_nlp.api import layout from lit_nlp.examples.datasets import lm as lm_data - -# TODO(b/333698148): file_cache doesn't work well with certain HF and KerasNLP -# preset names. Disabling until resolved. -# from lit_nlp.lib import file_cache +from lit_nlp.lib import file_cache # NOTE: additional flags defined in server_flags.py @@ -282,12 +279,14 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: model_name, path = model_string.split(":", 1) logging.info("Loading model '%s' from '%s'", model_name, path) - # TODO(b/333698148): file_cache doesn't work well with certain HF and - # KerasNLP preset names. Disabling until resolved. - # path = file_cache.cached_path( - # path, - # extract_compressed_file=path.endswith(".tar.gz"), - # ) + # Limit scope of caching to archive files and remote paths, as some preset + # names like "google/gemma-1.1-7b-it" look like file paths but should not + # be handled as such. + if path.endswith(".tar.gz") or file_cache.is_remote(path): + path = file_cache.cached_path( + path, + extract_compressed_file=path.endswith(".tar.gz"), + ) if _DL_FRAMEWORK.value == "kerasnlp": # pylint: disable=g-import-not-at-top diff --git a/lit_nlp/lib/file_cache.py b/lit_nlp/lib/file_cache.py index a8417ce0..88ab3934 100644 --- a/lit_nlp/lib/file_cache.py +++ b/lit_nlp/lib/file_cache.py @@ -217,6 +217,7 @@ def filename_fom_url(url: str, etag: Optional[str] = None) -> str: def is_remote(url_of_filepath: str) -> bool: + """Check if a path represents a remote URL or non-local file.""" parsed = urllib_parse.urlparse(url_of_filepath) return parsed.scheme in ('http', 'https')