Skip to content

Commit

Permalink
Enable RUF018 rule for walrus assignments in asserts (#18886)
Browse files Browse the repository at this point in the history
(cherry picked from commit 018a308)
  • Loading branch information
awaelchli authored and Borda committed Nov 2, 2023
1 parent cc49cee commit 98a2f54
Show file tree
Hide file tree
Showing 15 changed files with 33 additions and 29 deletions.
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ repos:
- flake8-simplify
- flake8-return

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.1.3"
hooks:
- id: ruff
args: ["--fix", "--preview"]

- repo: https://github.com/psf/black
rev: 23.9.1
hooks:
Expand Down Expand Up @@ -120,9 +126,3 @@ repos:
- id: prettier
# https://prettier.io/docs/en/options.html#print-width
args: ["--print-width=120"]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.0.292"
hooks:
- id: ruff
args: ["--fix"]
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ select = [
"E", "W", # see: https://pypi.org/project/pycodestyle
"F", # see: https://pypi.org/project/pyflakes
"S", # see: https://pypi.org/project/flake8-bandit
"RUF018", # see: https://docs.astral.sh/ruff/rules/assignment-in-assert
]
extend-select = [
"I", # see: isort
Expand All @@ -64,6 +65,7 @@ extend-select = [
ignore = [
"E731", # Do not assign a lambda expression, use a def
"S108",
"E203", # conflicts with black
]
# Exclude a variety of commonly ignored directories.
exclude = [
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/app/cli/connect/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def connect_app(app_name_or_id: str):

for command_name, metadata in retriever.api_commands.items():
if "cls_path" in metadata:
target_file = os.path.join(commands_folder, f"{command_name.replace(' ','_')}.py")
target_file = os.path.join(commands_folder, f"{command_name.replace(' ', '_')}.py")
_download_command(
command_name,
metadata["cls_path"],
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/app/source_code/tar.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _get_split_size(
max_size = max_split_count * (1 << 31) # max size per part limited by Requests or urllib as shown in ref above
if total_size > max_size:
raise click.ClickException(
f"The size of the datastore to be uploaded is bigger than our {max_size/(1 << 40):.2f} TBytes limit"
f"The size of the datastore to be uploaded is bigger than our {max_size / (1 << 40):.2f} TBytes limit"
)

split_size = minimum_split_size
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/callbacks/lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ def _add_prefix(
def _add_suffix(self, name: str, param_groups: List[Dict], param_group_index: int, use_names: bool = True) -> str:
if len(param_groups) > 1:
if not use_names:
return f"{name}/pg{param_group_index+1}"
pg_name = param_groups[param_group_index].get("name", f"pg{param_group_index+1}")
return f"{name}/pg{param_group_index + 1}"
pg_name = param_groups[param_group_index].get("name", f"pg{param_group_index + 1}")
return f"{name}/{pg_name}"
if use_names:
pg_name = param_groups[param_group_index].get("name")
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,8 +627,8 @@ def get_automatic(
if len(optimizers) > 1 or len(lr_schedulers) > 1:
raise MisconfigurationException(
f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model` expects at most one optimizer "
f"and one lr_scheduler to be 'AUTOMATIC', but found {optimizers+lr_schedulers}. In this case the user "
"is expected to link the argument groups and implement `configure_optimizers`, see "
f"and one lr_scheduler to be 'AUTOMATIC', but found {optimizers + lr_schedulers}. In this case the "
"user is expected to link the argument groups and implement `configure_optimizers`, see "
"https://lightning.ai/docs/pytorch/stable/common/lightning_cli.html"
"#optimizers-and-learning-rate-schedulers"
)
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_app/core/test_lightning_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def test_configure_api():
asyncio.set_event_loop(loop)
results = loop.run_until_complete(asyncio.gather(*coros))
response_time = time() - t0
print(f"RPS: {N/response_time}")
print(f"RPS: {N / response_time}")
assert response_time < 10
assert len(results) == N
assert all(r.get("detail", None) == ("HERE" if i % 5 == 0 else None) for i, r in enumerate(results))
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_app/storage/test_copier.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_copier_handles_exception(stat_mock, dir_mock, monkeypatch):
copy_request_queue.put(request)
copier.run_once()
response = copy_response_queue.get()
assert type(response.exception) == OSError
assert type(response.exception) is OSError
assert response.exception.args[0] == "Something went wrong"


Expand Down
6 changes: 3 additions & 3 deletions tests/tests_app/utilities/test_cli_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def test_arrow_time_callback():
assert _arrow_time_callback(Mock(), Mock(), "2022-08-23 12:34:00.000") == arrow.Arrow(2022, 8, 23, 12, 34)

# Just check humanized format is parsed
assert type(_arrow_time_callback(Mock(), Mock(), "48 hours ago")) == arrow.Arrow
assert type(_arrow_time_callback(Mock(), Mock(), "48 hours ago")) is arrow.Arrow

assert type(_arrow_time_callback(Mock(), Mock(), "60 minutes ago")) == arrow.Arrow
assert type(_arrow_time_callback(Mock(), Mock(), "60 minutes ago")) is arrow.Arrow

assert type(_arrow_time_callback(Mock(), Mock(), "120 seconds ago")) == arrow.Arrow
assert type(_arrow_time_callback(Mock(), Mock(), "120 seconds ago")) is arrow.Arrow

# Check raising errors
with pytest.raises(Exception, match="cannot parse time Mon"):
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_app/utilities/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_4xx_exceptions_caught_in_subcommands(self, mock_api_handled_group, mock

mock_subcommand.invoke.assert_called
assert result.exit_code == 1
assert type(result.exception) == ClickException
assert type(result.exception) is ClickException
assert api_error_msg == str(result.exception)

def test_original_thrown_if_cannot_decode_body(self, mock_api_handled_group, mock_subcommand):
Expand All @@ -81,4 +81,4 @@ def test_original_thrown_if_cannot_decode_body(self, mock_api_handled_group, moc

mock_subcommand.invoke.assert_called
assert result.exit_code == 1
assert type(result.exception) == ApiException
assert type(result.exception) is ApiException
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def _test_two_groups(strategy, left_collective, right_collective):


@skip_distributed_unavailable
@pytest.mark.flaky(reruns=5)
@RunIf(skip_windows=True) # unhandled timeouts
@pytest.mark.xfail(raises=TimeoutError, strict=False)
def test_two_groups():
Expand All @@ -285,6 +286,7 @@ def _test_default_process_group(strategy, *collectives):


@skip_distributed_unavailable
@pytest.mark.flaky(reruns=5)
@RunIf(skip_windows=True) # unhandled timeouts
def test_default_process_group():
collective_launch(_test_default_process_group, [torch.device("cpu")] * 3, num_groups=2)
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_fabric/utilities/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_lazy_load_module(tmp_path):
model1.load_state_dict(checkpoint)

assert isinstance(checkpoint["weight"], _NotYetLoadedTensor)
assert type(model0.weight.data) == torch.Tensor
assert type(model0.weight.data) is torch.Tensor
assert torch.equal(model0.weight, model1.weight)
assert torch.equal(model0.bias, model1.bias)

Expand Down
12 changes: 6 additions & 6 deletions tests/tests_fabric/utilities/test_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ def line_number():
output = stderr.getvalue()
expected_lines = [
f"test_warnings.py:{base_line}: test1",
f"test_warnings.py:{base_line+1}: test2",
f"test_warnings.py:{base_line+3}: test3",
f"test_warnings.py:{base_line+4}: test4",
f"test_warnings.py:{base_line+6}: test5",
f"test_warnings.py:{base_line+9}: test6",
f"test_warnings.py:{base_line+10}: test7",
f"test_warnings.py:{base_line + 1}: test2",
f"test_warnings.py:{base_line + 3}: test3",
f"test_warnings.py:{base_line + 4}: test4",
f"test_warnings.py:{base_line + 6}: test5",
f"test_warnings.py:{base_line + 9}: test6",
f"test_warnings.py:{base_line + 10}: test7",
]

for ln in expected_lines:
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_pytorch/profilers/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def test_simple_profiler_summary(tmpdir, extended):
f" {'Total time (s)':<15}\t| {'Percentage %':<15}\t|"
)
output_string_len = len(header_string.expandtabs())
sep_lines = f"{sep}{'-'* output_string_len}"
sep_lines = f"{sep}{'-' * output_string_len}"
expected_text = (
f"Profiler Report{sep}"
f"{sep_lines}"
Expand All @@ -236,7 +236,7 @@ def test_simple_profiler_summary(tmpdir, extended):
f"{sep}| {'Action':<{max_action_len}s}\t| {'Mean duration (s)':<15}\t| {'Total time (s)':<15}\t|"
)
output_string_len = len(header_string.expandtabs())
sep_lines = f"{sep}{'-'* output_string_len}"
sep_lines = f"{sep}{'-' * output_string_len}"
expected_text = (
f"Profiler Report{sep}"
f"{sep_lines}"
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/trainer/properties/test_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_trainer_loggers_setters():
logger2 = CustomLogger()

trainer = Trainer()
assert type(trainer.logger) == TensorBoardLogger
assert type(trainer.logger) is TensorBoardLogger
assert trainer.loggers == [trainer.logger]

# Test setters for trainer.logger
Expand Down

0 comments on commit 98a2f54

Please sign in to comment.