Skip to content

Commit

Permalink
Allow --arg Value for booleans in HfArgumentParser (huggingface#9823)
Browse files Browse the repository at this point in the history
* Allow --arg Value for booleans in HfArgumentParser

* Update last test

* Better error message
  • Loading branch information
sgugger authored and Qbiwan committed Jan 31, 2021
1 parent 2a577d4 commit 2035586
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 9 deletions.
33 changes: 28 additions & 5 deletions src/transformers/hf_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import dataclasses
import json
import sys
from argparse import ArgumentParser
from argparse import ArgumentParser, ArgumentTypeError
from enum import Enum
from pathlib import Path
from typing import Any, Iterable, List, NewType, Optional, Tuple, Union
Expand All @@ -25,6 +25,20 @@
DataClassType = NewType("DataClassType", Any)


# From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
def string_to_bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise ArgumentTypeError(
f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)."
)


class HfArgumentParser(ArgumentParser):
"""
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
Expand Down Expand Up @@ -85,11 +99,20 @@ def _add_dataclass_arguments(self, dtype: DataClassType):
if field.default is not dataclasses.MISSING:
kwargs["default"] = field.default
elif field.type is bool or field.type is Optional[bool]:
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
kwargs["action"] = "store_false" if field.default is True else "store_true"
if field.default is True:
field_name = f"--no_{field.name}"
kwargs["dest"] = field.name
self.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **kwargs)

# Hack because type=bool in argparse does not behave as we want.
kwargs["type"] = string_to_bool
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
# Default value is True if we have no default when of type bool.
default = True if field.default is dataclasses.MISSING else field.default
# This is the value that will get picked if we don't include --field_name in any way
kwargs["default"] = default
# This tells argparse we accept 0 or 1 value after --field_name
kwargs["nargs"] = "?"
# This is the value that will get picked if we do --field_name (without value)
kwargs["const"] = True
elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List):
kwargs["nargs"] = "+"
kwargs["type"] = field.type.__args__[0]
Expand Down
21 changes: 17 additions & 4 deletions tests/test_hf_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import List, Optional

from transformers import HfArgumentParser, TrainingArguments
from transformers.hf_argparser import string_to_bool


def list_field(default=None, metadata=None):
Expand All @@ -44,6 +45,7 @@ class WithDefaultExample:
class WithDefaultBoolExample:
foo: bool = False
baz: bool = True
opt: Optional[bool] = None


class BasicEnum(Enum):
Expand Down Expand Up @@ -91,7 +93,7 @@ def test_basic(self):
expected.add_argument("--foo", type=int, required=True)
expected.add_argument("--bar", type=float, required=True)
expected.add_argument("--baz", type=str, required=True)
expected.add_argument("--flag", action="store_true")
expected.add_argument("--flag", type=string_to_bool, default=True, const=True, nargs="?")
self.argparsersEqual(parser, expected)

def test_with_default(self):
Expand All @@ -106,15 +108,26 @@ def test_with_default_bool(self):
parser = HfArgumentParser(WithDefaultBoolExample)

expected = argparse.ArgumentParser()
expected.add_argument("--foo", action="store_true")
expected.add_argument("--foo", type=string_to_bool, default=False, const=True, nargs="?")
expected.add_argument("--no_baz", action="store_false", dest="baz")
expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?")
expected.add_argument("--opt", type=string_to_bool, default=None)
self.argparsersEqual(parser, expected)

args = parser.parse_args([])
self.assertEqual(args, Namespace(foo=False, baz=True))
self.assertEqual(args, Namespace(foo=False, baz=True, opt=None))

args = parser.parse_args(["--foo", "--no_baz"])
self.assertEqual(args, Namespace(foo=True, baz=False))
self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))

args = parser.parse_args(["--foo", "--baz"])
self.assertEqual(args, Namespace(foo=True, baz=True, opt=None))

args = parser.parse_args(["--foo", "True", "--baz", "True", "--opt", "True"])
self.assertEqual(args, Namespace(foo=True, baz=True, opt=True))

args = parser.parse_args(["--foo", "False", "--baz", "False", "--opt", "False"])
self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))

def test_with_enum(self):
parser = HfArgumentParser(EnumExample)
Expand Down

0 comments on commit 2035586

Please sign in to comment.