diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index 378e4ab44dfe..ae18052a3d10 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -15,7 +15,7 @@ import dataclasses import json import sys -from argparse import ArgumentParser +from argparse import ArgumentParser, ArgumentTypeError from enum import Enum from pathlib import Path from typing import Any, Iterable, List, NewType, Optional, Tuple, Union @@ -25,6 +25,20 @@ DataClassType = NewType("DataClassType", Any) +# From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse +def string_to_bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise ArgumentTypeError( + f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)." + ) + + class HfArgumentParser(ArgumentParser): """ This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments. @@ -85,11 +99,20 @@ def _add_dataclass_arguments(self, dtype: DataClassType): if field.default is not dataclasses.MISSING: kwargs["default"] = field.default elif field.type is bool or field.type is Optional[bool]: - if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING): - kwargs["action"] = "store_false" if field.default is True else "store_true" if field.default is True: - field_name = f"--no_{field.name}" - kwargs["dest"] = field.name + self.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **kwargs) + + # Hack because type=bool in argparse does not behave as we want. + kwargs["type"] = string_to_bool + if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING): + # Default value is True if we have no default when of type bool. + default = True if field.default is dataclasses.MISSING else field.default + # This is the value that will get picked if we don't include --field_name in any way + kwargs["default"] = default + # This tells argparse we accept 0 or 1 value after --field_name + kwargs["nargs"] = "?" + # This is the value that will get picked if we do --field_name (without value) + kwargs["const"] = True elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List): kwargs["nargs"] = "+" kwargs["type"] = field.type.__args__[0] diff --git a/tests/test_hf_argparser.py b/tests/test_hf_argparser.py index 523fb01f8c04..db937de93f8c 100644 --- a/tests/test_hf_argparser.py +++ b/tests/test_hf_argparser.py @@ -20,6 +20,7 @@ from typing import List, Optional from transformers import HfArgumentParser, TrainingArguments +from transformers.hf_argparser import string_to_bool def list_field(default=None, metadata=None): @@ -44,6 +45,7 @@ class WithDefaultExample: class WithDefaultBoolExample: foo: bool = False baz: bool = True + opt: Optional[bool] = None class BasicEnum(Enum): @@ -91,7 +93,7 @@ def test_basic(self): expected.add_argument("--foo", type=int, required=True) expected.add_argument("--bar", type=float, required=True) expected.add_argument("--baz", type=str, required=True) - expected.add_argument("--flag", action="store_true") + expected.add_argument("--flag", type=string_to_bool, default=True, const=True, nargs="?") self.argparsersEqual(parser, expected) def test_with_default(self): @@ -106,15 +108,26 @@ def test_with_default_bool(self): parser = HfArgumentParser(WithDefaultBoolExample) expected = argparse.ArgumentParser() - expected.add_argument("--foo", action="store_true") + expected.add_argument("--foo", type=string_to_bool, default=False, const=True, nargs="?") expected.add_argument("--no_baz", action="store_false", dest="baz") + expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?") + expected.add_argument("--opt", type=string_to_bool, default=None) self.argparsersEqual(parser, expected) args = parser.parse_args([]) - self.assertEqual(args, Namespace(foo=False, baz=True)) + self.assertEqual(args, Namespace(foo=False, baz=True, opt=None)) args = parser.parse_args(["--foo", "--no_baz"]) - self.assertEqual(args, Namespace(foo=True, baz=False)) + self.assertEqual(args, Namespace(foo=True, baz=False, opt=None)) + + args = parser.parse_args(["--foo", "--baz"]) + self.assertEqual(args, Namespace(foo=True, baz=True, opt=None)) + + args = parser.parse_args(["--foo", "True", "--baz", "True", "--opt", "True"]) + self.assertEqual(args, Namespace(foo=True, baz=True, opt=True)) + + args = parser.parse_args(["--foo", "False", "--baz", "False", "--opt", "False"]) + self.assertEqual(args, Namespace(foo=False, baz=False, opt=False)) def test_with_enum(self): parser = HfArgumentParser(EnumExample)