diff --git a/.lintrunner.toml b/.lintrunner.toml index f5d3a16be..23a76193a 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -13,3 +13,21 @@ init_command = [ '--dry-run={{DRYRUN}}', 'black==22.3.0', ] +is_formatter = true + +[[linter]] +code = 'USORT' +include_patterns = ['**/*.py'] +command = [ + 'python3', + 'tools/lint/usort_linter.py', + '--', + '@{{PATHSFILE}}' +] +init_command = [ + 'python3', + 'tools/lint/pip_init.py', + '--dry-run={{DRYRUN}}', + 'usort==1.0.2', +] +is_formatter = true diff --git a/requirements.txt b/requirements.txt index 15dfcdb72..f909c4784 100644 --- a/requirements.txt +++ b/requirements.txt @@ -38,4 +38,5 @@ tqdm typing-inspect typing_extensions urllib3 +usort websocket-client diff --git a/tools/lint/black_linter.py b/tools/lint/black_linter.py index cbf83d896..cfdc3d4e8 100644 --- a/tools/lint/black_linter.py +++ b/tools/lint/black_linter.py @@ -13,38 +13,9 @@ import subprocess import sys import time -from enum import Enum -from typing import Any, BinaryIO, List, NamedTuple, Optional +from typing import BinaryIO, List - -IS_WINDOWS: bool = os.name == "nt" - - -def eprint(*args: Any, **kwargs: Any) -> None: - print(*args, file=sys.stderr, flush=True, **kwargs) - - -class LintSeverity(str, Enum): - ERROR = "error" - WARNING = "warning" - ADVICE = "advice" - DISABLED = "disabled" - - -class LintMessage(NamedTuple): - path: Optional[str] - line: Optional[int] - char: Optional[int] - code: str - severity: LintSeverity - name: str - original: Optional[str] - replacement: Optional[str] - description: Optional[str] - - -def as_posix(name: str) -> str: - return name.replace("\\", "/") if IS_WINDOWS else name +from utils import as_posix, IS_WINDOWS, LintMessage, LintSeverity def _run_command( @@ -218,8 +189,8 @@ def main() -> None: thread_name_prefix="Thread", ) as executor: futures = { - executor.submit(check_file, x, args.retries, args.timeout): x - for x in args.filenames + executor.submit(check_file, filename, args.retries, args.timeout): filename + for filename in args.filenames } for future in concurrent.futures.as_completed(futures): try: diff --git a/tools/lint/usort_linter.py b/tools/lint/usort_linter.py new file mode 100644 index 000000000..1201c64cf --- /dev/null +++ b/tools/lint/usort_linter.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import concurrent.futures +import json +import os +import subprocess +from typing import List + +from usort import config as usort_config, usort + +from utils import as_posix, LintMessage, LintSeverity + + +def check_file( + filename: str, +) -> List[LintMessage]: + try: + top_of_file_cat = usort_config.Category("top_of_file") + known = usort_config.known_factory() + # cinder magic imports must be on top (after future imports) + known["__strict__"] = top_of_file_cat + known["__static__"] = top_of_file_cat + + config = usort_config.Config( + categories=( + ( + usort_config.CAT_FUTURE, + top_of_file_cat, + usort_config.CAT_STANDARD_LIBRARY, + usort_config.CAT_THIRD_PARTY, + usort_config.CAT_FIRST_PARTY, + ) + ), + known=known, + ) + + with open(filename, mode="rb") as f: + original = f.read() + result = usort(original, config) + if result.error: + raise result.error + + except subprocess.TimeoutExpired: + return [ + LintMessage( + path=filename, + line=None, + char=None, + code="USORT", + severity=LintSeverity.ERROR, + name="timeout", + original=None, + replacement=None, + description=( + "usort timed out while trying to process a file. " + "Please report an issue in pytorch/torchrec." + ), + ) + ] + except (OSError, subprocess.CalledProcessError) as err: + return [ + LintMessage( + path=filename, + line=None, + char=None, + code="USORT", + severity=LintSeverity.ADVICE, + name="command-failed", + original=None, + replacement=None, + description=( + f"Failed due to {err.__class__.__name__}:\n{err}" + if not isinstance(err, subprocess.CalledProcessError) + else ( + "COMMAND (exit code {returncode})\n" + "{command}\n\n" + "STDERR\n{stderr}\n\n" + "STDOUT\n{stdout}" + ).format( + returncode=err.returncode, + command=" ".join(as_posix(x) for x in err.cmd), + stderr=err.stderr.decode("utf-8").strip() or "(empty)", + stdout=err.stdout.decode("utf-8").strip() or "(empty)", + ) + ), + ) + ] + + replacement = result.output + if original == replacement: + return [] + + return [ + LintMessage( + path=filename, + line=None, + char=None, + code="USORT", + severity=LintSeverity.WARNING, + name="format", + original=original.decode("utf-8"), + replacement=replacement.decode("utf-8"), + description="Run `lintrunner -a` to apply this patch.", + ) + ] + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Format files with usort.", + fromfile_prefix_chars="@", + ) + parser.add_argument( + "filenames", + nargs="+", + help="paths to lint", + ) + args = parser.parse_args() + + with concurrent.futures.ThreadPoolExecutor( + max_workers=os.cpu_count(), + thread_name_prefix="Thread", + ) as executor: + futures = { + executor.submit(check_file, filename): filename + for filename in args.filenames + } + for future in concurrent.futures.as_completed(futures): + try: + for lint_message in future.result(): + print(json.dumps(lint_message._asdict()), flush=True) + except Exception: + raise RuntimeError(f"Failed at {futures[future]}") + + +if __name__ == "__main__": + main() diff --git a/tools/lint/utils.py b/tools/lint/utils.py new file mode 100644 index 000000000..9f3e4e8ae --- /dev/null +++ b/tools/lint/utils.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import os +from enum import Enum +from typing import NamedTuple, Optional + +IS_WINDOWS: bool = os.name == "nt" + + +class LintSeverity(str, Enum): + ERROR = "error" + WARNING = "warning" + ADVICE = "advice" + DISABLED = "disabled" + + +class LintMessage(NamedTuple): + path: Optional[str] + line: Optional[int] + char: Optional[int] + code: str + severity: LintSeverity + name: str + original: Optional[str] + replacement: Optional[str] + description: Optional[str] + + +def as_posix(name: str) -> str: + return name.replace("\\", "/") if IS_WINDOWS else name