diff --git a/lit_nlp/examples/datasets/question_answering.py b/lit_nlp/examples/tydi/data.py similarity index 100% rename from lit_nlp/examples/datasets/question_answering.py rename to lit_nlp/examples/tydi/data.py diff --git a/lit_nlp/examples/tydi/demo.py b/lit_nlp/examples/tydi/demo.py index 02739c5b..307ad8a3 100644 --- a/lit_nlp/examples/tydi/demo.py +++ b/lit_nlp/examples/tydi/demo.py @@ -12,12 +12,11 @@ from absl import app from absl import flags - from lit_nlp import dev_server from lit_nlp import server_flags from lit_nlp.components import word_replacer -from lit_nlp.examples.datasets import question_answering -from lit_nlp.examples.tydi import model +from lit_nlp.examples.tydi import data as tydi_data +from lit_nlp.examples.tydi import model as tydi_model # NOTE: additional flags defined in server_flags.py _FLAGS = flags.FLAGS @@ -55,10 +54,10 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: # Ignore path prefix, if using /path/to/ to load from a # specific directory rather than the default shortcut. model_name = os.path.basename(model_name_or_path) - models[model_name] = model.TyDiModel(model_name=model_name_or_path) + models[model_name] = tydi_model.TyDiModel(model_name=model_name_or_path) max_examples: int = _MAX_EXAMPLES.value - dataset_defs: tuple[tuple[str, str]] = ( + dataset_defs: tuple[tuple[str, str], ...] = ( ("TyDiQA-Multilingual", "validation"), ("TyDiQA-English", "validation-en"), ("TyDiQA-Finnish", "validation-fi"), @@ -71,7 +70,7 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: ("TyDiQA-Telugu", "validation-te"), ) datasets = { - name: question_answering.TyDiQA(split=split, max_examples=max_examples) + name: tydi_data.TyDiQA(split=split, max_examples=max_examples) for name, split in dataset_defs }