Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable RUF018 rule for walrus assignments in asserts #18886

Merged
merged 9 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ repos:
- flake8-simplify
- flake8-return

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.1.3"
hooks:
- id: ruff
args: ["--fix", "--preview"]

- repo: https://github.com/psf/black
rev: 23.9.1
hooks:
Expand Down Expand Up @@ -120,9 +126,3 @@ repos:
- id: prettier
# https://prettier.io/docs/en/options.html#print-width
args: ["--print-width=120"]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.0.292"
hooks:
- id: ruff
args: ["--fix"]
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ select = [
"E", "W", # see: https://pypi.org/project/pycodestyle
"F", # see: https://pypi.org/project/pyflakes
"S", # see: https://pypi.org/project/flake8-bandit
"RUF018", # see: https://docs.astral.sh/ruff/rules/assignment-in-assert
]
extend-select = [
"I", # see: isort
Expand All @@ -64,6 +65,7 @@ extend-select = [
ignore = [
"E731", # Do not assign a lambda expression, use a def
"S108",
"E203", # conflicts with black
]
# Exclude a variety of commonly ignored directories.
exclude = [
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/app/cli/connect/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def connect_app(app_name_or_id: str):

for command_name, metadata in retriever.api_commands.items():
if "cls_path" in metadata:
target_file = os.path.join(commands_folder, f"{command_name.replace(' ','_')}.py")
target_file = os.path.join(commands_folder, f"{command_name.replace(' ', '_')}.py")
_download_command(
command_name,
metadata["cls_path"],
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/app/source_code/tar.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _get_split_size(
max_size = max_split_count * (1 << 31) # max size per part limited by Requests or urllib as shown in ref above
if total_size > max_size:
raise click.ClickException(
f"The size of the datastore to be uploaded is bigger than our {max_size/(1 << 40):.2f} TBytes limit"
f"The size of the datastore to be uploaded is bigger than our {max_size / (1 << 40):.2f} TBytes limit"
)

split_size = minimum_split_size
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/callbacks/lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ def _add_prefix(
def _add_suffix(self, name: str, param_groups: List[Dict], param_group_index: int, use_names: bool = True) -> str:
if len(param_groups) > 1:
if not use_names:
return f"{name}/pg{param_group_index+1}"
pg_name = param_groups[param_group_index].get("name", f"pg{param_group_index+1}")
return f"{name}/pg{param_group_index + 1}"
pg_name = param_groups[param_group_index].get("name", f"pg{param_group_index + 1}")
return f"{name}/{pg_name}"
if use_names:
pg_name = param_groups[param_group_index].get("name")
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,8 +627,8 @@ def get_automatic(
if len(optimizers) > 1 or len(lr_schedulers) > 1:
raise MisconfigurationException(
f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model` expects at most one optimizer "
f"and one lr_scheduler to be 'AUTOMATIC', but found {optimizers+lr_schedulers}. In this case the user "
"is expected to link the argument groups and implement `configure_optimizers`, see "
f"and one lr_scheduler to be 'AUTOMATIC', but found {optimizers + lr_schedulers}. In this case the "
"user is expected to link the argument groups and implement `configure_optimizers`, see "
"https://lightning.ai/docs/pytorch/stable/common/lightning_cli.html"
"#optimizers-and-learning-rate-schedulers"
)
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_app/core/test_lightning_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def test_configure_api():
asyncio.set_event_loop(loop)
results = loop.run_until_complete(asyncio.gather(*coros))
response_time = time() - t0
print(f"RPS: {N/response_time}")
print(f"RPS: {N / response_time}")
assert response_time < 10
assert len(results) == N
assert all(r.get("detail", None) == ("HERE" if i % 5 == 0 else None) for i, r in enumerate(results))
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_app/storage/test_copier.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_copier_handles_exception(stat_mock, dir_mock, monkeypatch):
copy_request_queue.put(request)
copier.run_once()
response = copy_response_queue.get()
assert type(response.exception) == OSError
assert type(response.exception) is OSError
assert response.exception.args[0] == "Something went wrong"


Expand Down
6 changes: 3 additions & 3 deletions tests/tests_app/utilities/test_cli_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def test_arrow_time_callback():
assert _arrow_time_callback(Mock(), Mock(), "2022-08-23 12:34:00.000") == arrow.Arrow(2022, 8, 23, 12, 34)

# Just check humanized format is parsed
assert type(_arrow_time_callback(Mock(), Mock(), "48 hours ago")) == arrow.Arrow
assert type(_arrow_time_callback(Mock(), Mock(), "48 hours ago")) is arrow.Arrow

assert type(_arrow_time_callback(Mock(), Mock(), "60 minutes ago")) == arrow.Arrow
assert type(_arrow_time_callback(Mock(), Mock(), "60 minutes ago")) is arrow.Arrow

assert type(_arrow_time_callback(Mock(), Mock(), "120 seconds ago")) == arrow.Arrow
assert type(_arrow_time_callback(Mock(), Mock(), "120 seconds ago")) is arrow.Arrow

# Check raising errors
with pytest.raises(Exception, match="cannot parse time Mon"):
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_app/utilities/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_4xx_exceptions_caught_in_subcommands(self, mock_api_handled_group, mock

mock_subcommand.invoke.assert_called
assert result.exit_code == 1
assert type(result.exception) == ClickException
assert type(result.exception) is ClickException
assert api_error_msg == str(result.exception)

def test_original_thrown_if_cannot_decode_body(self, mock_api_handled_group, mock_subcommand):
Expand All @@ -81,4 +81,4 @@ def test_original_thrown_if_cannot_decode_body(self, mock_api_handled_group, moc

mock_subcommand.invoke.assert_called
assert result.exit_code == 1
assert type(result.exception) == ApiException
assert type(result.exception) is ApiException
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def _test_two_groups(strategy, left_collective, right_collective):


@skip_distributed_unavailable
@pytest.mark.flaky(reruns=5)
@RunIf(skip_windows=True) # unhandled timeouts
@pytest.mark.xfail(raises=TimeoutError, strict=False)
def test_two_groups():
Expand All @@ -286,6 +287,7 @@ def _test_default_process_group(strategy, *collectives):


@skip_distributed_unavailable
@pytest.mark.flaky(reruns=5)
@RunIf(skip_windows=True) # unhandled timeouts
def test_default_process_group():
collective_launch(_test_default_process_group, [torch.device("cpu")] * 3, num_groups=2)
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_fabric/utilities/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_lazy_load_module(tmp_path):
model1.load_state_dict(checkpoint)

assert isinstance(checkpoint["weight"], _NotYetLoadedTensor)
assert type(model0.weight.data) == torch.Tensor
assert type(model0.weight.data) is torch.Tensor
assert torch.equal(model0.weight, model1.weight)
assert torch.equal(model0.bias, model1.bias)

Expand Down
12 changes: 6 additions & 6 deletions tests/tests_fabric/utilities/test_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ def line_number():
output = stderr.getvalue()
expected_lines = [
f"test_warnings.py:{base_line}: test1",
f"test_warnings.py:{base_line+1}: test2",
f"test_warnings.py:{base_line+3}: test3",
f"test_warnings.py:{base_line+4}: test4",
f"test_warnings.py:{base_line+6}: test5",
f"test_warnings.py:{base_line+9}: test6",
f"test_warnings.py:{base_line+10}: test7",
f"test_warnings.py:{base_line + 1}: test2",
f"test_warnings.py:{base_line + 3}: test3",
f"test_warnings.py:{base_line + 4}: test4",
f"test_warnings.py:{base_line + 6}: test5",
f"test_warnings.py:{base_line + 9}: test6",
f"test_warnings.py:{base_line + 10}: test7",
]

for ln in expected_lines:
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_pytorch/profilers/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def test_simple_profiler_summary(tmpdir, extended):
f" {'Total time (s)':<15}\t| {'Percentage %':<15}\t|"
)
output_string_len = len(header_string.expandtabs())
sep_lines = f"{sep}{'-'* output_string_len}"
sep_lines = f"{sep}{'-' * output_string_len}"
expected_text = (
f"Profiler Report{sep}"
f"{sep_lines}"
Expand All @@ -236,7 +236,7 @@ def test_simple_profiler_summary(tmpdir, extended):
f"{sep}| {'Action':<{max_action_len}s}\t| {'Mean duration (s)':<15}\t| {'Total time (s)':<15}\t|"
)
output_string_len = len(header_string.expandtabs())
sep_lines = f"{sep}{'-'* output_string_len}"
sep_lines = f"{sep}{'-' * output_string_len}"
expected_text = (
f"Profiler Report{sep}"
f"{sep_lines}"
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/trainer/properties/test_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_trainer_loggers_setters():
logger2 = CustomLogger()

trainer = Trainer()
assert type(trainer.logger) == TensorBoardLogger
assert type(trainer.logger) is TensorBoardLogger
assert trainer.loggers == [trainer.logger]

# Test setters for trainer.logger
Expand Down