Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
small fixes for tango (#5350)
Browse files Browse the repository at this point in the history
* small fixes for tango

* fix
  • Loading branch information
epwalsh authored Aug 10, 2021
1 parent 2e11a15 commit 90bf33b
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 4 deletions.
12 changes: 10 additions & 2 deletions allennlp/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging
import os
import sys
import warnings


if os.environ.get("ALLENNLP_DEBUG"):
LEVEL = logging.DEBUG
Expand All @@ -27,10 +29,16 @@ def _transformers_log_filter(record):

logging.getLogger("transformers.file_utils").addFilter(_transformers_log_filter)

from allennlp.commands import main # noqa


def run():
# We issue a seperate warning from the tango command and ignore this one so that
# users won't see a Tango warning when they're not using the Tango command.
warnings.filterwarnings(
"ignore", category=UserWarning, message="AllenNLP Tango", module=r"allennlp\.tango"
)

from allennlp.commands import main # noqa

main(prog="allennlp")


Expand Down
5 changes: 5 additions & 0 deletions allennlp/commands/tango.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from os import PathLike
from pathlib import Path
from typing import Union, Dict, Any, List, Optional
import warnings

from overrides import overrides

Expand Down Expand Up @@ -69,6 +70,10 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument


def run_tango_from_args(args: argparse.Namespace):
warnings.warn(
"AllenNLP Tango is an experimental API and parts of it might change or disappear "
"every time we release a new version."
)
run_tango_from_file(
tango_filename=args.config_path,
serialization_dir=args.serialization_dir,
Expand Down
3 changes: 1 addition & 2 deletions allennlp/tango/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch

from allennlp.common import Lazy, Tqdm
from allennlp.common.checks import check_for_gpu
from allennlp.common.util import sanitize
from allennlp.models import Model
from allennlp.nn.util import move_to_device
Expand Down Expand Up @@ -58,10 +57,10 @@ def run( # type: ignore
concrete_data_loader = data_loader.construct(instances=dataset.splits[split])

if torch.cuda.device_count() > 0:
model = model.cuda()
cuda_device = torch.device(0)
else:
cuda_device = torch.device("cpu")
check_for_gpu(cuda_device)

generator_tqdm = Tqdm.tqdm(iter(concrete_data_loader))

Expand Down
2 changes: 2 additions & 0 deletions allennlp/tango/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,8 @@ def _run_with_work_dir(self, cache: StepCache, **kwargs) -> T:
if self.work_dir_for_run is not None:
raise ValueError("You can only run a Step's run() method once at a time.")

logger.info("Starting run for step %s of type %s", self.name, self.__class__)

if self.DETERMINISTIC:
random.seed(784507111)

Expand Down

0 comments on commit 90bf33b

Please sign in to comment.