Skip to content

Commit

Permalink
Simplify file_cache logic and re-enable for sequence salience demo.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 638684231
  • Loading branch information
iftenney authored and LIT team committed May 30, 2024
1 parent 5188c8c commit 15eccb1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
19 changes: 9 additions & 10 deletions lit_nlp/examples/lm_salience_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,7 @@
from lit_nlp import server_flags
from lit_nlp.api import layout
from lit_nlp.examples.datasets import lm as lm_data

# TODO(b/333698148): file_cache doesn't work well with certain HF and KerasNLP
# preset names. Disabling until resolved.
# from lit_nlp.lib import file_cache
from lit_nlp.lib import file_cache

# NOTE: additional flags defined in server_flags.py

Expand Down Expand Up @@ -282,12 +279,14 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
model_name, path = model_string.split(":", 1)
logging.info("Loading model '%s' from '%s'", model_name, path)

# TODO(b/333698148): file_cache doesn't work well with certain HF and
# KerasNLP preset names. Disabling until resolved.
# path = file_cache.cached_path(
# path,
# extract_compressed_file=path.endswith(".tar.gz"),
# )
# Limit scope of caching to archive files and remote paths, as some preset
# names like "google/gemma-1.1-7b-it" look like file paths but should not
# be handled as such.
if path.endswith(".tar.gz") or file_cache.is_remote(path):
path = file_cache.cached_path(
path,
extract_compressed_file=path.endswith(".tar.gz"),
)

if _DL_FRAMEWORK.value == "kerasnlp":
# pylint: disable=g-import-not-at-top
Expand Down
1 change: 1 addition & 0 deletions lit_nlp/lib/file_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def filename_fom_url(url: str, etag: Optional[str] = None) -> str:


def is_remote(url_of_filepath: str) -> bool:
"""Check if a path represents a remote URL or non-local file."""
parsed = urllib_parse.urlparse(url_of_filepath)
return parsed.scheme in ('http', 'https')

Expand Down

0 comments on commit 15eccb1

Please sign in to comment.