From 54d68ab00108f01b1b003dc775c1e28b7d41ac60 Mon Sep 17 00:00:00 2001 From: Jason Date: Mon, 3 Feb 2020 15:55:16 -0800 Subject: [PATCH] Logging Refactor (#305) * add files from #292 Signed-off-by: Jason * update logger.py Signed-off-by: Jason * import and bug fixes Signed-off-by: Jason * update exp_logging to use new logger Signed-off-by: Jason * style fix Signed-off-by: Jason * style fix Signed-off-by: Jason * fix deprecated unittest Signed-off-by: Jason * isort Signed-off-by: Jason * update headeR Signed-off-by: Jason * remove unused imports Signed-off-by: Jason --- nemo/__init__.py | 4 +- nemo/constants.py | 50 +++++ nemo/utils/env_var_parsing.py | 208 +++++++++++++++++++ nemo/utils/exp_logging.py | 58 +++--- nemo/utils/formatters/__init__.py | 0 nemo/utils/formatters/base.py | 128 ++++++++++++ nemo/utils/formatters/colors.py | 121 +++++++++++ nemo/utils/formatters/utils.py | 45 +++++ nemo/utils/metaclasses.py | 29 +++ nemo/utils/nemo_logging.py | 313 +++++++++++++++++++++++++++++ requirements/requirements.txt | 1 + requirements/requirements_nlp.txt | 1 - requirements/requirements_test.txt | 1 + tests/test_deprecated.py | 64 ++++-- 14 files changed, 974 insertions(+), 49 deletions(-) create mode 100644 nemo/constants.py create mode 100644 nemo/utils/env_var_parsing.py create mode 100644 nemo/utils/formatters/__init__.py create mode 100644 nemo/utils/formatters/base.py create mode 100644 nemo/utils/formatters/colors.py create mode 100644 nemo/utils/formatters/utils.py create mode 100644 nemo/utils/metaclasses.py create mode 100644 nemo/utils/nemo_logging.py diff --git a/nemo/__init__.py b/nemo/__init__.py index a7c91e793960..56702dd8700a 100644 --- a/nemo/__init__.py +++ b/nemo/__init__.py @@ -33,9 +33,9 @@ ) if "NEMO_PACKAGE_BUILDING" not in os.environ: - import logging + from nemo.utils.nemo_logging import Logger as _Logger - logging = logging.getLogger(__name__) + logging = _Logger() from nemo import backends from nemo import core diff --git a/nemo/constants.py b/nemo/constants.py new file mode 100644 index 000000000000..6cd3a1f60ff8 --- /dev/null +++ b/nemo/constants.py @@ -0,0 +1,50 @@ +# Copyright (C) NVIDIA CORPORATION. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.**** + +import numpy as np + +# Supported Numpy DTypes: `np.sctypes` +ACCEPTED_INT_NUMBER_FORMATS = ( + int, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.int, + np.int8, + np.int16, + np.int32, + np.int64, +) + +ACCEPTED_FLOAT_NUMBER_FORMATS = ( + float, + np.float, + np.float16, + np.float32, + np.float64, + np.float128, +) + +ACCEPTED_STR_NUMBER_FORMATS = ( + str, + np.str, +) + +ACCEPTED_NUMBER_FORMATS = ACCEPTED_INT_NUMBER_FORMATS + ACCEPTED_FLOAT_NUMBER_FORMATS + ACCEPTED_STR_NUMBER_FORMATS + +# NEMO_ENV_VARNAME_DEBUG_VERBOSITY = "NEMO_DEBUG_VERBOSITY" +NEMO_ENV_VARNAME_ENABLE_COLORING = "NEMO_ENABLE_COLORING" +NEMO_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR = "NEMO_REDIRECT_LOGS_TO_STDERR" +# NEMO_ENV_VARNAME_SAVE_LOGS_TO_DIR = "NEMO_SAVE_LOGS_TO_DIR" diff --git a/nemo/utils/env_var_parsing.py b/nemo/utils/env_var_parsing.py new file mode 100644 index 000000000000..063c9d14db8a --- /dev/null +++ b/nemo/utils/env_var_parsing.py @@ -0,0 +1,208 @@ +# The MIT Licence (MIT) +# +# Copyright (c) 2016 YunoJuno Ltd +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +# Vendored dependency from : https://github.com/yunojuno/python-env-utils/blob/master/env_utils/utils.py +# +# ========================================================================================================= +# +# Modified by NVIDIA +# +# Copyright (C) NVIDIA CORPORATION. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.**** + +import decimal +import json +import os + +from dateutil import parser + +__all__ = [ + "get_env", + "get_envbool", + "get_envint", + "get_envfloat", + "get_envdecimal", + "get_envdate", + "get_envdatetime", + "get_envlist", + "get_envdict", + "CoercionError", + "RequiredSettingMissingError", +] + + +class CoercionError(Exception): + """Custom error raised when a value cannot be coerced.""" + + def __init__(self, key, value, func): + msg = "Unable to coerce '{}={}' using {}.".format(key, value, func.__name__) + super(CoercionError, self).__init__(msg) + + +class RequiredSettingMissingError(Exception): + """Custom error raised when a required env var is missing.""" + + def __init__(self, key): + msg = "Required env var '{}' is missing.".format(key) + super(RequiredSettingMissingError, self).__init__(msg) + + +def _get_env(key, default=None, coerce=lambda x: x, required=False): + """ + Return env var coerced into a type other than string. + This function extends the standard os.getenv function to enable + the coercion of values into data types other than string (all env + vars are strings by default). + Args: + key: string, the name of the env var to look up + Kwargs: + default: the default value to return if the env var does not exist. NB the + default value is **not** coerced, and is assumed to be of the correct type. + coerce: a function that is used to coerce the value returned into + another type + required: bool, if True, then a RequiredSettingMissingError error is raised + if the env var does not exist. + Returns the env var, passed through the coerce function + """ + try: + value = os.environ[key] + except KeyError: + if required is True: + raise RequiredSettingMissingError(key) + else: + return default + + try: + return coerce(value) + except Exception: + raise CoercionError(key, value, coerce) + + +# standard type coercion functions +def _bool(value): + if isinstance(value, bool): + return value + + return not (value is None or value.lower() in ("false", "0", "no", "n", "f", "none")) + + +def _int(value): + return int(value) + + +def _float(value): + return float(value) + + +def _decimal(value): + return decimal.Decimal(value) + + +def _dict(value): + return json.loads(value) + + +def _datetime(value): + return parser.parse(value) + + +def _date(value): + return parser.parse(value).date() + + +def get_env(key, *default, **kwargs): + """ + Return env var. + This is the parent function of all other get_foo functions, + and is responsible for unpacking args/kwargs into the values + that _get_env expects (it is the root function that actually + interacts with environ). + Args: + key: string, the env var name to look up. + default: (optional) the value to use if the env var does not + exist. If this value is not supplied, then the env var is + considered to be required, and a RequiredSettingMissingError + error will be raised if it does not exist. + Kwargs: + coerce: a func that may be supplied to coerce the value into + something else. This is used by the default get_foo functions + to cast strings to builtin types, but could be a function that + returns a custom class. + Returns the env var, coerced if required, and a default if supplied. + """ + assert len(default) in (0, 1), "Too many args supplied." + func = kwargs.get('coerce', lambda x: x) + required = len(default) == 0 + default = default[0] if not required else None + return _get_env(key, default=default, coerce=func, required=required) + + +def get_envbool(key, *default): + """Return env var cast as boolean.""" + return get_env(key, *default, coerce=_bool) + + +def get_envint(key, *default): + """Return env var cast as integer.""" + return get_env(key, *default, coerce=_int) + + +def get_envfloat(key, *default): + """Return env var cast as float.""" + return get_env(key, *default, coerce=_float) + + +def get_envdecimal(key, *default): + """Return env var cast as Decimal.""" + return get_env(key, *default, coerce=_decimal) + + +def get_envdate(key, *default): + """Return env var as a date.""" + return get_env(key, *default, coerce=_date) + + +def get_envdatetime(key, *default): + """Return env var as a datetime.""" + return get_env(key, *default, coerce=_datetime) + + +def get_envlist(key, *default, **kwargs): + """Return env var as a list.""" + separator = kwargs.get('separator', ' ') + return get_env(key, *default, coerce=lambda x: x.split(separator)) + + +def get_envdict(key, *default): + """Return env var as a dict.""" + return get_env(key, *default, coerce=_dict) diff --git a/nemo/utils/exp_logging.py b/nemo/utils/exp_logging.py index 69868bc8c365..3af7e8d93139 100644 --- a/nemo/utils/exp_logging.py +++ b/nemo/utils/exp_logging.py @@ -4,37 +4,37 @@ import subprocess import sys import time -import warnings from shutil import copyfile import nemo +from nemo.utils.decorators import deprecated +@deprecated(version=0.11, explanation="Please use nemo.logging instead") def get_logger(unused): - warnings.warn("This function will be deprecated in the future. You " "can just use nemo.logging instead") return nemo.logging -class ContextFilter(logging.Filter): - """ - This is a filter which injects contextual information into the log. - Use it when we want to inject worker number into the log message. +# class ContextFilter(logging.Filter): +# """ +# This is a filter which injects contextual information into the log. +# Use it when we want to inject worker number into the log message. - Usage: - logger = get_logger(name) - tmp = logging.Formatter( - 'WORKER %(local_rank)s: %(asctime)s - %(levelname)s - %(message)s') - logger.addFilter(ContextFilter(self.local_rank)) +# Usage: +# logger = get_logger(name) +# tmp = logging.Formatter( +# 'WORKER %(local_rank)s: %(asctime)s - %(levelname)s - %(message)s') +# logger.addFilter(ContextFilter(self.local_rank)) - """ +# """ - def __init__(self, local_rank): - super().__init__() - self.local_rank = local_rank +# def __init__(self, local_rank): +# super().__init__() +# self.local_rank = local_rank - def filter(self, record): - record.local_rank = self.local_rank - return True +# def filter(self, record): +# record.local_rank = self.local_rank +# return True class ExpManager: @@ -155,21 +155,21 @@ def __init__( def create_logger(self, level=logging.INFO, log_file=True): logger = nemo.logging - tmp = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + # tmp = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') - if self.global_rank == 0: - logger.setLevel(level) - ch = logging.StreamHandler() - ch.setLevel(level) - ch.setFormatter(tmp) - logger.addHandler(ch) + # if self.global_rank == 0: + # logger.setLevel(level) + # ch = logging.StreamHandler() + # ch.setLevel(level) + # ch.setFormatter(tmp) + # logger.addHandler(ch) if log_file: self.log_file = f'{self.work_dir}/log_globalrank-{self.global_rank}_' f'localrank-{self.local_rank}.txt' - fh = logging.FileHandler(self.log_file) - fh.setLevel(level) - fh.setFormatter(tmp) - logger.addHandler(fh) + logger.add_file_handler(self.log_file) + # fh = logging.FileHandler(self.log_file) + # fh.setLevel(level) + # fh.setFormatter(tmp) self.logger = logger return logger diff --git a/nemo/utils/formatters/__init__.py b/nemo/utils/formatters/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nemo/utils/formatters/base.py b/nemo/utils/formatters/base.py new file mode 100644 index 000000000000..6b844877b185 --- /dev/null +++ b/nemo/utils/formatters/base.py @@ -0,0 +1,128 @@ +# Copyright (C) NVIDIA CORPORATION. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.**** + +import logging + +from nemo.utils.formatters.colors import Fore as ForegroundColors +from nemo.utils.formatters.utils import check_color_support, to_unicode + +__all__ = ["BaseNeMoFormatter"] + + +class BaseFormatter(logging.Formatter): + """ + Log formatter used in Tornado. Key features of this formatter are: + * Color support when logging to a terminal that supports it. + * Timestamps on every log line. + * Robust against str/bytes encoding problems. + """ + + DEFAULT_FORMAT = "%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s" + + DEFAULT_DATE_FORMAT = "%Y-%m-%d %H:%M:%S" + + DEFAULT_COLORS = { + logging.DEBUG: ForegroundColors.CYAN, + logging.INFO: ForegroundColors.GREEN, + logging.WARNING: ForegroundColors.YELLOW, + logging.ERROR: ForegroundColors.MAGENTA, + logging.CRITICAL: ForegroundColors.RED, + } + + def __init__(self, color=True, fmt=None, datefmt=None, colors=None): + r""" + :arg bool color: Enables color support. + :arg string fmt: Log message format. + It will be applied to the attributes dict of log records. The + text between ``%(color)s`` and ``%(end_color)s`` will be colored + depending on the level if color support is on. + :arg dict colors: color mappings from logging level to terminal color + code + :arg string datefmt: Datetime format. + Used for formatting ``(asctime)`` placeholder in ``prefix_fmt``. + .. versionchanged:: 3.2 + Added ``fmt`` and ``datefmt`` arguments. + """ + + if fmt is None: + fmt = self.DEFAULT_FORMAT + + if datefmt is None: + datefmt = self.DEFAULT_DATE_FORMAT + + if colors is None: + colors = self.DEFAULT_COLORS + + logging.Formatter.__init__(self, datefmt=datefmt) + + self._fmt = fmt + self._colors = {} + self._normal = "" + + if color and check_color_support(): + self._colors = colors + self._normal = ForegroundColors.RESET + + def format(self, record): + try: + message = record.getMessage() + assert isinstance(message, str) # guaranteed by logging + # Encoding notes: The logging module prefers to work with character + # strings, but only enforces that log messages are instances of + # basestring. In python 2, non-ascii bytestrings will make + # their way through the logging framework until they blow up with + # an unhelpful decoding error (with this formatter it happens + # when we attach the prefix, but there are other opportunities for + # exceptions further along in the framework). + # + # If a byte string makes it this far, convert it to unicode to + # ensure it will make it out to the logs. Use repr() as a fallback + # to ensure that all byte strings can be converted successfully, + # but don't do it by default so we don't add extra quotes to ascii + # bytestrings. This is a bit of a hacky place to do this, but + # it's worth it since the encoding errors that would otherwise + # result are so useless (and tornado is fond of using utf8-encoded + # byte strings wherever possible). + record.message = to_unicode(message) + + except Exception as e: + record.message = "Bad message (%r): %r" % (e, record.__dict__) + + record.asctime = self.formatTime(record, self.datefmt) + + if record.levelno in self._colors: + record.color = self._colors[record.levelno] + record.end_color = self._normal + else: + record.color = record.end_color = "" + + formatted = self._fmt % record.__dict__ + + if record.exc_info: + if not record.exc_text: + record.exc_text = self.formatException(record.exc_info) + + if record.exc_text: + # exc_text contains multiple lines. We need to _safe_unicode + # each line separately so that non-utf8 bytes don't cause + # all the newlines to turn into '\n'. + lines = [formatted.rstrip()] + lines.extend(to_unicode(ln) for ln in record.exc_text.split("\n")) + + formatted = "\n".join(lines) + return formatted.replace("\n", "\n ") + + +class BaseNeMoFormatter(BaseFormatter): + DEFAULT_FORMAT = "%(color)s[NeMo %(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s" diff --git a/nemo/utils/formatters/colors.py b/nemo/utils/formatters/colors.py new file mode 100644 index 000000000000..ec7a5d56f1c6 --- /dev/null +++ b/nemo/utils/formatters/colors.py @@ -0,0 +1,121 @@ +# Source: https://github.com/tartley/colorama/blob/master/colorama/ansi.py +# Copyright: Jonathan Hartley 2013. BSD 3-Clause license. +# +# Copyright (C) NVIDIA CORPORATION. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.**** + +CSI = "\033[" +OSC = "\033]" +BEL = "\007" + + +def code_to_chars(code): + return CSI + str(code) + "m" + + +def set_title(title): + return OSC + "2;" + title + BEL + + +def clear_screen(mode=2): + return CSI + str(mode) + "J" + + +def clear_line(mode=2): + return CSI + str(mode) + "K" + + +class AnsiCodes(object): + def __init__(self): + # the subclasses declare class attributes which are numbers. + # Upon instantiation we define instance attributes, which are the same + # as the class attributes but wrapped with the ANSI escape sequence + for name in dir(self): + if not name.startswith("_"): + value = getattr(self, name) + setattr(self, name, code_to_chars(value)) + + +class AnsiCursor(object): + def UP(self, n=1): + return CSI + str(n) + "A" + + def DOWN(self, n=1): + return CSI + str(n) + "B" + + def FORWARD(self, n=1): + return CSI + str(n) + "C" + + def BACK(self, n=1): + return CSI + str(n) + "D" + + def POS(self, x=1, y=1): + return CSI + str(y) + ";" + str(x) + "H" + + +class AnsiFore(AnsiCodes): + BLACK = 30 + RED = 31 + GREEN = 32 + YELLOW = 33 + BLUE = 34 + MAGENTA = 35 + CYAN = 36 + WHITE = 37 + RESET = 39 + + # These are fairly well supported, but not part of the standard. + LIGHTBLACK_EX = 90 + LIGHTRED_EX = 91 + LIGHTGREEN_EX = 92 + LIGHTYELLOW_EX = 93 + LIGHTBLUE_EX = 94 + LIGHTMAGENTA_EX = 95 + LIGHTCYAN_EX = 96 + LIGHTWHITE_EX = 97 + + +class AnsiBack(AnsiCodes): + BLACK = 40 + RED = 41 + GREEN = 42 + YELLOW = 43 + BLUE = 44 + MAGENTA = 45 + CYAN = 46 + WHITE = 47 + RESET = 49 + + # These are fairly well supported, but not part of the standard. + LIGHTBLACK_EX = 100 + LIGHTRED_EX = 101 + LIGHTGREEN_EX = 102 + LIGHTYELLOW_EX = 103 + LIGHTBLUE_EX = 104 + LIGHTMAGENTA_EX = 105 + LIGHTCYAN_EX = 106 + LIGHTWHITE_EX = 107 + + +class AnsiStyle(AnsiCodes): + BRIGHT = 1 + DIM = 2 + NORMAL = 22 + RESET_ALL = 0 + + +Fore = AnsiFore() +Back = AnsiBack() +Style = AnsiStyle() +Cursor = AnsiCursor() diff --git a/nemo/utils/formatters/utils.py b/nemo/utils/formatters/utils.py new file mode 100644 index 000000000000..954470a1810e --- /dev/null +++ b/nemo/utils/formatters/utils.py @@ -0,0 +1,45 @@ +# Copyright (C) NVIDIA CORPORATION. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.**** + +import sys + +from nemo.constants import NEMO_ENV_VARNAME_ENABLE_COLORING +from nemo.utils.env_var_parsing import get_envbool + +__all__ = ["check_color_support", "to_unicode"] + + +def check_color_support(): + # Colors can be forced with an env variable + if not sys.platform.lower().startswith("win") and get_envbool(NEMO_ENV_VARNAME_ENABLE_COLORING, False): + return True + + +def to_unicode(value): + """ + Converts a string argument to a unicode string. + If the argument is already a unicode string or None, it is returned + unchanged. Otherwise it must be a byte string and is decoded as utf8. + """ + try: + if isinstance(value, (str, type(None))): + return value + + if not isinstance(value, bytes): + raise TypeError("Expected bytes, unicode, or None; got %r" % type(value)) + + return value.decode("utf-8") + + except UnicodeDecodeError: + return repr(value) diff --git a/nemo/utils/metaclasses.py b/nemo/utils/metaclasses.py new file mode 100644 index 000000000000..0f584aa76cad --- /dev/null +++ b/nemo/utils/metaclasses.py @@ -0,0 +1,29 @@ +# Copyright (C) NVIDIA CORPORATION. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.**** + +__all__ = [ + "SingletonMetaClass", +] + + +class SingletonMetaClass(type): + + _instances = {} + + def __call__(cls, *args, **kwargs): + + if cls not in cls._instances: + cls._instances[cls] = super(SingletonMetaClass, cls).__call__(*args, **kwargs) + + return cls._instances[cls] diff --git a/nemo/utils/nemo_logging.py b/nemo/utils/nemo_logging.py new file mode 100644 index 000000000000..4e49028d0b6c --- /dev/null +++ b/nemo/utils/nemo_logging.py @@ -0,0 +1,313 @@ +# Copyright (C) NVIDIA CORPORATION. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.**** + +import logging as _logging +import sys +import threading +import warnings +from contextlib import contextmanager + +# from nemo.constants import NEMO_ENV_VARNAME_SAVE_LOGS_TO_DIR +from nemo.constants import NEMO_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR +from nemo.utils.env_var_parsing import get_envbool, get_envint +from nemo.utils.formatters.base import BaseNeMoFormatter +from nemo.utils.metaclasses import SingletonMetaClass + +__all__ = [ + "Logger", +] + + +class Logger(metaclass=SingletonMetaClass): + + # Level 0 + NOTSET = _logging.NOTSET + + # Level 10 + DEBUG = _logging.DEBUG + + # Level 20 + INFO = _logging.INFO + + # Level 30 + WARNING = _logging.WARNING + + # Level 40 + ERROR = _logging.ERROR + + # Level 50 + CRITICAL = _logging.CRITICAL + + _level_names = { + 0: "NOTSET", + 10: "DEBUG", + 20: "INFO", + 30: "WARNING", + 40: "ERROR", + 50: "CRITICAL", + } + + def __init__(self): + + self._logger = None + + # Multi-GPU runs run in separate processes, thread locks shouldn't be needed + self._logger_lock = threading.Lock() + + self._handlers = dict() + + self.old_warnings_showwarning = None + + self._define_logger() + + def _define_logger(self): + + # Use double-checked locking to avoid taking lock unnecessarily. + if self._logger is not None: + return self._logger + + with self._logger_lock: + try: + self._logger = _logging.getLogger("nemo_logger") + # By default, silence all loggers except the logger for rank 0 + self.remove_stream_handlers() + if get_envint("RANK", 0) == 0: + self.add_stream_handlers() + + finally: + self.set_verbosity(verbosity_level=Logger.INFO) + + self._logger.propagate = False + + def remove_stream_handlers(self): + if self._logger is None: + raise RuntimeError("Impossible to set handlers if the Logger is not predefined") + + # ======== Remove Handler if already existing ======== + + try: + self._logger.removeHandler(self._handlers["stream_stdout"]) + except KeyError: + pass + + try: + self._logger.removeHandler(self._handlers["stream_stderr"]) + except KeyError: + pass + + def add_stream_handlers(self): + if self._logger is None: + raise RuntimeError("Impossible to set handlers if the Logger is not predefined") + + # Add the output handler. + if get_envbool(NEMO_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR, False): + self._handlers["stream_stdout"] = _logging.StreamHandler(sys.stderr) + + else: + self._handlers["stream_stdout"] = _logging.StreamHandler(sys.stdout) + self._handlers["stream_stdout"].addFilter(lambda record: record.levelno <= _logging.INFO) + + self._handlers["stream_stderr"] = _logging.StreamHandler(sys.stderr) + self._handlers["stream_stderr"].addFilter(lambda record: record.levelno > _logging.INFO) + + formatter = BaseNeMoFormatter + + self._handlers["stream_stdout"].setFormatter(formatter()) + self._logger.addHandler(self._handlers["stream_stdout"]) + + try: + self._handlers["stream_stderr"].setFormatter(formatter()) + self._logger.addHandler(self._handlers["stream_stderr"]) + except KeyError: + pass + + def reset_stream_handler(self): + self.remove_stream_handlers() + self.add_stream_handlers() + + def add_file_handler(self, log_file): + if self._logger is None: + raise RuntimeError("Impossible to set handlers if the Logger is not predefined") + + self._handlers["file"] = _logging.FileHandler(log_file) + + formatter = BaseNeMoFormatter + self._handlers["file"].setFormatter(formatter()) + self._logger.addHandler(self._handlers["file"]) + + def getEffectiveLevel(self): + """Return how much logging output will be produced.""" + if self._logger is not None: + return self._logger.getEffectiveLevel() + + def get_verbosity(self): + return self.getEffectiveLevel() + + def setLevel(self, verbosity_level): + """Sets the threshold for what messages will be logged.""" + if self._logger is not None: + self._logger.setLevel(verbosity_level) + + for handler in self._logger.handlers: + handler.setLevel(verbosity_level) + + def set_verbosity(self, verbosity_level): + self.setLevel(verbosity_level) + + @contextmanager + def patch_stderr_handler(self, stream): + """ Useful for unittests + """ + if self._logger is not None: + try: + old_stream = self._handlers["stream_stderr"].stream + if old_stream is None: + raise ValueError + + # Port backwards set_stream() from python 3.7 + self._handlers["stream_stderr"].acquire() + try: + self._handlers["stream_stderr"].flush() + self._handlers["stream_stderr"].stream = stream + finally: + self._handlers["stream_stderr"].release() + + yield stream + except (KeyError, ValueError): + raise RuntimeError("Impossible to patch logging handlers if handler does not exist") + finally: + # Port backwards set_stream() from python 3.7 + self._handlers["stream_stderr"].acquire() + try: + self._handlers["stream_stderr"].flush() + self._handlers["stream_stderr"].stream = old_stream + finally: + self._handlers["stream_stderr"].release() + + else: + raise RuntimeError("Impossible to patch logging handlers if handler does not exist") + + @contextmanager + def temp_verbosity(self, verbosity_level): + """Sets the a temporary threshold for what messages will be logged.""" + + if self._logger is not None: + + old_verbosity = self.get_verbosity() + + try: + self.set_verbosity(verbosity_level) + yield + + finally: + self.set_verbosity(old_verbosity) + + else: + try: + yield + + finally: + pass + + def captureWarnings(self, capture): + """ + If capture is true, redirect all warnings to the logging package. + If capture is False, ensure that warnings are not redirected to logging + but to their original destinations. + """ + + if self._logger is not None: + + if capture and self.old_warnings_showwarning is None: + # Backup Method + self.old_warnings_showwarning = warnings.showwarning + warnings.showwarning = self._showwarning + + elif not capture and self.old_warnings_showwarning is not None: + # Restore Method + warnings.showwarning = self.old_warnings_showwarning + self.old_warnings_showwarning = None + + def _showwarning(self, message, category, filename, lineno, line=None): + """ + Implementation of showwarnings which redirects to logging. + It will call warnings.formatwarning and will log the resulting string + with level logging.WARNING. + """ + s = warnings.formatwarning(message, category, filename, lineno, line) + self.warning("%s", s) + + def debug(self, msg, *args, **kwargs): + """ + Log 'msg % args' with severity 'DEBUG'. + + To pass exception information, use the keyword argument exc_info with + a true value, e.g. + + logger.debug("Houston, we have a %s", "thorny problem", exc_info=1) + """ + if self._logger is not None and self._logger.isEnabledFor(Logger.DEBUG): + self._logger._log(Logger.DEBUG, msg, args, **kwargs) + + def info(self, msg, *args, **kwargs): + """ + Log 'msg % args' with severity 'INFO'. + + To pass exception information, use the keyword argument exc_info with + a true value, e.g. + + logger.info("Houston, we have a %s", "interesting problem", exc_info=1) + """ + if self._logger is not None and self._logger.isEnabledFor(Logger.INFO): + self._logger._log(Logger.INFO, msg, args, **kwargs) + + def warning(self, msg, *args, **kwargs): + """ + Log 'msg % args' with severity 'WARNING'. + + To pass exception information, use the keyword argument exc_info with + a true value, e.g. + + logger.warning("Houston, we have a %s", "bit of a problem", exc_info=1) + """ + if self._logger is not None and self._logger.isEnabledFor(Logger.WARNING): + self._logger._log(Logger.WARNING, msg, args, **kwargs) + + def error(self, msg, *args, **kwargs): + """ + Log 'msg % args' with severity 'ERROR'. + + To pass exception information, use the keyword argument exc_info with + a true value, e.g. + + logger.error("Houston, we have a %s", "major problem", exc_info=1) + """ + if self._logger is not None and self._logger.isEnabledFor(Logger.ERROR): + self._logger._log(Logger.ERROR, msg, args, **kwargs) + + def critical(self, msg, *args, **kwargs): + """ + Log 'msg % args' with severity 'CRITICAL'. + + To pass exception information, use the keyword argument exc_info with + a true value, e.g. + + logger.critical("Houston, we have a %s", "major disaster", exc_info=1) + """ + if self._logger is not None and self._logger.isEnabledFor(Logger.CRITICAL): + self._logger._log(Logger.CRITICAL, msg, args, **kwargs) + + +# # Necessary to catch the correct caller +# _logging._srcfile = os.path.normcase(inspect.getfile(Logger.__class__)) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index d460a58a55b2..87e9c5b4fd50 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,6 +1,7 @@ onnx onnxruntime pandas +python-dateutil tensorboardX tensorboard torch diff --git a/requirements/requirements_nlp.txt b/requirements/requirements_nlp.txt index 47133a09fb62..ebdb41560653 100644 --- a/requirements/requirements_nlp.txt +++ b/requirements/requirements_nlp.txt @@ -1,7 +1,6 @@ boto3 h5py matplotlib -python-dateutil<2.8.1,>=2.1 sentencepiece torchtext transformers diff --git a/requirements/requirements_test.txt b/requirements/requirements_test.txt index 192a3a3fddc7..493b8268cfd1 100644 --- a/requirements/requirements_test.txt +++ b/requirements/requirements_test.txt @@ -1,5 +1,6 @@ parameterized pytest +pytest-runner black isort[requirements] wrapt diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 9f41fa21340a..45089c7d8b70 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -17,15 +17,20 @@ # limitations under the License. # ============================================================================= - +import re from io import StringIO from unittest.mock import patch +from nemo import logging from nemo.utils.decorators import deprecated from tests.common_setup import NeMoUnitTest class DeprecatedTest(NeMoUnitTest): + NEMO_ERR_MSG_FORMAT = re.compile( + r"\[NeMo W [0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2} deprecated:[0-9]*\] " + ) + def test_say_whee_deprecated(self): """ Tests whether both std and err streams return the right values when function is deprecated.""" @@ -36,14 +41,22 @@ def say_whee(): # Mock up both std and stderr streams. with patch('sys.stdout', new=StringIO()) as std_out: - with patch('sys.stderr', new=StringIO()) as std_err: + with logging.patch_stderr_handler(StringIO()) as std_err: say_whee() # Check std output. self.assertEqual(std_out.getvalue().strip(), "Whee!") # Check error output. - self.assertEqual(std_err.getvalue().strip(), 'Function ``say_whee`` is deprecated.') + # Error ouput now has NeMoBaseFormatter so attempt to strip formatting from error message + # Error formatting always in enclosed in '[' and ']' blocks so remove them + err_msg = std_err.getvalue().strip() + match = self.NEMO_ERR_MSG_FORMAT.match(err_msg) + if match: + err_msg = err_msg[match.end() :] + self.assertEqual(err_msg, 'Function ``say_whee`` is deprecated.') + else: + raise ValueError("Test case could not find a match, did the format of nemo loggin messages change?") def test_say_wow_twice_deprecated(self): """ Tests whether both std and err streams return the right values @@ -55,18 +68,24 @@ def say_wow(): # Mock up both std and stderr streams - first call with patch('sys.stdout', new=StringIO()) as std_out: - with patch('sys.stderr', new=StringIO()) as std_err: + with logging.patch_stderr_handler(StringIO()) as std_err: say_wow() # Check std output. self.assertEqual(std_out.getvalue().strip(), "Woooow!") # Check error output. - self.assertEqual(std_err.getvalue().strip(), 'Function ``say_wow`` is deprecated.') + err_msg = std_err.getvalue().strip() + match = self.NEMO_ERR_MSG_FORMAT.match(err_msg) + if match: + err_msg = err_msg[match.end() :] + self.assertEqual(err_msg, 'Function ``say_wow`` is deprecated.') + else: + raise ValueError("Test case could not find a match, did the format of nemo loggin messages change?") # Second call. with patch('sys.stdout', new=StringIO()) as std_out: - with patch('sys.stderr', new=StringIO()) as std_err: + with logging.patch_stderr_handler(StringIO()) as std_err: say_wow() # Check std output. @@ -79,23 +98,30 @@ def test_say_whoopie_deprecated_version(self): """ Tests whether both std and err streams return the right values when function is deprecated and version is provided. """ - @deprecated(version=0.1) + version = 0.1 + + @deprecated(version=version) def say_whoopie(): print("Whoopie!") # Mock up both std and stderr streams. with patch('sys.stdout', new=StringIO()) as std_out: - with patch('sys.stderr', new=StringIO()) as std_err: + with logging.patch_stderr_handler(StringIO()) as std_err: say_whoopie() # Check std output. self.assertEqual(std_out.getvalue().strip(), "Whoopie!") - # Check error output. - self.assertEqual( - std_err.getvalue().strip(), - "Function ``say_whoopie`` is deprecated. It is going to be removed in the 0.1 version.", - ) + err_msg = std_err.getvalue().strip() + match = self.NEMO_ERR_MSG_FORMAT.match(err_msg) + if match: + err_msg = err_msg[match.end() :] + self.assertEqual( + err_msg, + f"Function ``say_whoopie`` is deprecated. It is going to be removed in the {version} version.", + ) + else: + raise ValueError("Test case could not find a match, did the format of nemo loggin messages change?") def test_say_kowabunga_deprecated_explanation(self): """ Tests whether both std and err streams return the right values @@ -107,13 +133,17 @@ def say_kowabunga(): # Mock up both std and stderr streams. with patch('sys.stdout', new=StringIO()) as std_out: - with patch('sys.stderr', new=StringIO()) as std_err: + with logging.patch_stderr_handler(StringIO()) as std_err: say_kowabunga() # Check std output. self.assertEqual(std_out.getvalue().strip(), "Kowabunga!") # Check error output. - self.assertEqual( - std_err.getvalue().strip(), 'Function ``say_kowabunga`` is deprecated. Please use ``print_ihaa`` instead.' - ) + err_msg = std_err.getvalue().strip() + match = self.NEMO_ERR_MSG_FORMAT.match(err_msg) + if match: + err_msg = err_msg[match.end() :] + self.assertEqual(err_msg, 'Function ``say_kowabunga`` is deprecated. Please use ``print_ihaa`` instead.') + else: + raise ValueError("Test case could not find a match, did the format of nemo loggin messages change?")