Skip to content

Commit

Permalink
Miscellaneous improvements throughout the codebase (#365)
Browse files Browse the repository at this point in the history
Co-authored-by: bswck <[email protected]>
  • Loading branch information
trag1c and bswck authored Aug 31, 2023
1 parent 95452e5 commit e03ce6f
Show file tree
Hide file tree
Showing 34 changed files with 280 additions and 364 deletions.
17 changes: 10 additions & 7 deletions src/cleo/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,23 @@ def find_similar_names(name: str, names: list[str]) -> list[str]:
distance = Levenshtein.distance(name, actual_name)

is_similar = distance <= len(name) / 3
is_sub_string = actual_name.find(name) != -1
substring_index = actual_name.find(name)
is_substring = substring_index != -1

if is_similar or is_sub_string:
if is_similar or is_substring:
distance_by_name[actual_name] = (
distance,
actual_name.find(name) if is_sub_string else float("inf"),
substring_index if is_substring else float("inf"),
)

# Only keep results with a distance below the threshold
distance_by_name = {
k: v for k, v in distance_by_name.items() if v[0] < 2 * threshold
key: value
for key, value in distance_by_name.items()
if value[0] < 2 * threshold
}
# Display results with shortest distance first
return sorted(distance_by_name, key=lambda x: distance_by_name[x])
return sorted(distance_by_name, key=lambda key: distance_by_name[key])


@dataclass
Expand Down Expand Up @@ -101,7 +104,7 @@ def apply(self, secs: float) -> str:


def format_time(secs: float) -> str:
format = next(
time_format = next(
(fmt for fmt in _TIME_FORMATS if secs < fmt.threshold), _TIME_FORMATS[-1]
)
return format.apply(secs)
return time_format.apply(secs)
8 changes: 2 additions & 6 deletions src/cleo/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def _run(self, io: IO) -> int:
break

if index is not None:
del argv[index + 1 : index + 1 + (len(name.split(" ")) - 1)]
del argv[index + 1 : index + 1 + name.count(" ")]

stream = io.input.stream
interactive = io.input.is_interactive()
Expand Down Expand Up @@ -614,11 +614,7 @@ def _get_command_name(self, io: IO) -> str | None:

def extract_namespace(self, name: str, limit: int | None = None) -> str:
parts = name.split(" ")[:-1]

if limit is not None:
return " ".join(parts[:limit])

return " ".join(parts)
return " ".join(parts[:limit])

def _get_default_ui(self) -> UI:
from cleo.ui.progress_bar import ProgressBar
Expand Down
4 changes: 2 additions & 2 deletions src/cleo/color.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
if option not in self.AVAILABLE_OPTIONS:
raise ValueError(
f'"{option}" is not a valid color option. '
f"It must be one of {', '.join(self.AVAILABLE_OPTIONS.keys())}"
f"It must be one of {', '.join(self.AVAILABLE_OPTIONS)}"
)

self._options[option] = self.AVAILABLE_OPTIONS[option]
Expand Down Expand Up @@ -114,7 +114,7 @@ def _parse_color(self, color: str, background: bool) -> str:
if color not in self.COLORS:
raise CleoValueError(
f'"{color}" is an invalid color.'
f" It must be one of {', '.join(self.COLORS.keys())}"
f" It must be one of {', '.join(self.COLORS)}"
)

return str(self.COLORS[color][int(background)])
Expand Down
11 changes: 3 additions & 8 deletions src/cleo/commands/base_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self) -> None:
self.configure()

for i, usage in enumerate(self.usages):
if self.name and usage.find(self.name) != 0:
if self.name and not usage.startswith(self.name):
self.usages[i] = f"{self.name} {usage}"

@property
Expand Down Expand Up @@ -85,7 +85,7 @@ def configure(self) -> None:
"""

def execute(self, io: IO) -> int:
raise NotImplementedError()
raise NotImplementedError

def interact(self, io: IO) -> None:
"""
Expand Down Expand Up @@ -114,12 +114,7 @@ def run(self, io: IO) -> int:

io.input.validate()

status_code = self.execute(io)

if status_code is None:
status_code = 0

return status_code
return self.execute(io) or 0

def merge_application_definition(self, merge_args: bool = True) -> None:
if self._application is None:
Expand Down
27 changes: 10 additions & 17 deletions src/cleo/commands/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,35 +64,29 @@ def execute(self, io: IO) -> int:

def handle(self) -> int:
"""
Executes the command.
Execute the command.
"""
raise NotImplementedError()
raise NotImplementedError

def call(self, name: str, args: str | None = None) -> int:
"""
Call another command.
"""
if args is None:
args = ""

input = StringInput(args)
assert self.application is not None
command = self.application.get(name)

return self.application._run_command(command, self._io.with_input(input))
return self.application._run_command(
command, self._io.with_input(StringInput(args or ""))
)

def call_silent(self, name: str, args: str | None = None) -> int:
"""
Call another command silently.
"""
if args is None:
args = ""

input = StringInput(args)
assert self.application is not None
command = self.application.get(name)

return self.application._run_command(command, NullIO(input))
return self.application._run_command(command, NullIO(StringInput(args or "")))

def argument(self, name: str) -> Any:
"""
Expand Down Expand Up @@ -166,7 +160,7 @@ def choice(
def create_question(
self,
question: str,
type: Literal["choice"] | Literal["confirmation"] | None = None,
type: Literal["choice", "confirmation"] | None = None,
**kwargs: Any,
) -> Question:
"""
Expand All @@ -176,14 +170,13 @@ def create_question(
from cleo.ui.confirmation_question import ConfirmationQuestion
from cleo.ui.question import Question

if not type:
return Question(question, **kwargs)
if type == "confirmation":
return ConfirmationQuestion(question, **kwargs)

if type == "choice":
return ChoiceQuestion(question, **kwargs)

if type == "confirmation":
return ConfirmationQuestion(question, **kwargs)
return Question(question, **kwargs)

def table(
self,
Expand Down
11 changes: 6 additions & 5 deletions src/cleo/commands/completions_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import re
import subprocess

from pathlib import Path
from typing import TYPE_CHECKING
from typing import ClassVar

Expand Down Expand Up @@ -142,7 +143,7 @@ def _get_script_name_and_path(self) -> tuple[str, str]:
# we incorrectly infer `script_name` as `__main__.py`
script_name = self._io.input.script_name or inspect.stack()[-1][1]
script_path = posixpath.realpath(script_name)
script_name = os.path.basename(script_path)
script_name = Path(script_path).name

return script_name, script_path

Expand All @@ -162,7 +163,7 @@ def render_bash(self) -> str:
cmds = []
cmds_opts = []
for cmd in sorted(self.application.all().values(), key=lambda c: c.name or ""):
if cmd.hidden or not cmd.enabled or not cmd.name:
if cmd.hidden or not (cmd.enabled and cmd.name):
continue
command_name = shell_quote(cmd.name) if " " in cmd.name else cmd.name
cmds.append(command_name)
Expand Down Expand Up @@ -207,7 +208,7 @@ def sanitize(s: str) -> str:
cmds = []
cmds_opts = []
for cmd in sorted(self.application.all().values(), key=lambda c: c.name or ""):
if cmd.hidden or not cmd.enabled or not cmd.name:
if cmd.hidden or not (cmd.enabled and cmd.name):
continue
command_name = shell_quote(cmd.name) if " " in cmd.name else cmd.name
cmds.append(self._zsh_describe(command_name, sanitize(cmd.description)))
Expand Down Expand Up @@ -287,11 +288,11 @@ def get_shell_type(self) -> str:
"Please specify your shell type by passing it as the first argument."
)

return os.path.basename(shell)
return Path(shell).name

def _generate_function_name(self, script_name: str, script_path: str) -> str:
sanitized_name = self._sanitize_for_function_name(script_name)
md5_hash = hashlib.md5(script_path.encode()).hexdigest()[0:16]
md5_hash = hashlib.md5(script_path.encode()).hexdigest()[:16]
return f"_{sanitized_name}_{md5_hash}_complete"

def _sanitize_for_function_name(self, name: str) -> str:
Expand Down
11 changes: 5 additions & 6 deletions src/cleo/descriptors/application_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,9 @@ def _sort_commands(
key = self._application.extract_namespace(name, 1) or "_global"
namespaced_commands[key][name] = command

namespaced_commands_lst: dict[str, list[tuple[str, Command]]] = {}
for namespace, commands in namespaced_commands.items():
namespaced_commands_lst[namespace] = sorted(
commands.items(), key=lambda x: x[0]
)
namespaced_commands_list: dict[str, list[tuple[str, Command]]] = {
namespace: sorted(commands.items())
for namespace, commands in namespaced_commands.items()
}

return sorted(namespaced_commands_lst.items(), key=lambda x: x[0])
return sorted(namespaced_commands_list.items())
10 changes: 5 additions & 5 deletions src/cleo/descriptors/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,16 @@ def _write(self, content: str, decorated: bool = True) -> None:
)

def _describe_argument(self, argument: Argument, **options: Any) -> None:
raise NotImplementedError()
raise NotImplementedError

def _describe_option(self, option: Option, **options: Any) -> None:
raise NotImplementedError()
raise NotImplementedError

def _describe_definition(self, definition: Definition, **options: Any) -> None:
raise NotImplementedError()
raise NotImplementedError

def _describe_command(self, command: Command, **options: Any) -> None:
raise NotImplementedError()
raise NotImplementedError

def _describe_application(self, application: Application, **options: Any) -> None:
raise NotImplementedError()
raise NotImplementedError
33 changes: 13 additions & 20 deletions src/cleo/descriptors/text_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,7 @@ def _describe_application(self, application: Application, **options: Any) -> Non

self._describe_definition(Definition(application.definition.options), **options)

self._write("\n")
self._write("\n")
self._write("\n\n")

commands = description.commands
namespaces = description.namespaces
Expand All @@ -181,7 +180,7 @@ def _describe_application(self, application: Application, **options: Any) -> Non
commands[name] = description.command(name)

# calculate max width based on available commands per namespace
all_commands = list(commands.keys())
all_commands = list(commands)
for namespace in namespaces.values():
all_commands += namespace["commands"]

Expand Down Expand Up @@ -223,23 +222,18 @@ def _describe_application(self, application: Application, **options: Any) -> Non
self._write("\n")

def _format_default_value(self, default: Any) -> str:
new_default: Any
if isinstance(default, str):
default = Formatter.escape(default)
elif isinstance(default, list):
new_default = []
for value in default:
if isinstance(value, str):
new_default.append(Formatter.escape(value))

default = new_default
default = [
Formatter.escape(value) for value in default if isinstance(value, str)
]
elif isinstance(default, dict):
new_default = {}
for key, value in default.items():
if isinstance(value, str):
new_default[key] = Formatter.escape(value)

default = new_default
default = {
key: Formatter.escape(value)
for key, value in default.items()
if isinstance(value, str)
}

return json.dumps(default).replace("\\\\", "\\")

Expand All @@ -261,7 +255,7 @@ def _calculate_total_width_for_options(self, options: list[Option]) -> int:
return total_width

def _get_column_width(self, commands: Sequence[Command | str]) -> int:
widths = []
widths: list[int] = []

for command in commands:
if isinstance(command, Command):
Expand All @@ -278,10 +272,9 @@ def _get_column_width(self, commands: Sequence[Command | str]) -> int:
return max(widths) + 2

def _get_command_aliases_text(self, command: Command) -> str:
text = ""
aliases = command.aliases

if aliases:
text = f"[{ '|'.join(aliases) }] "
return f"[{ '|'.join(aliases) }] "

return text
return ""
18 changes: 6 additions & 12 deletions src/cleo/events/event_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self) -> None:

def dispatch(self, event: Event, event_name: str | None = None) -> Event:
if event_name is None:
event_name = event.__class__.__name__
event_name = type(event).__name__

listeners = cast("list[Listener]", self.get_listeners(event_name))

Expand Down Expand Up @@ -58,11 +58,7 @@ def get_listener_priority(self, event_name: str, listener: Listener) -> int | No

def has_listeners(self, event_name: str | None = None) -> bool:
if event_name is not None:
if event_name not in self._listeners:
return False

return bool(self._listeners[event_name])

return bool(self._listeners.get(event_name))
return any(self._listeners.values())

def add_listener(
Expand Down Expand Up @@ -92,10 +88,8 @@ def _sort_listeners(self, event_name: str) -> None:
"""
Sorts the internal list of listeners for the given event by priority.
"""
self._sorted[event_name] = []
prioritized_listeners = self._listeners[event_name]
sorted_listeners = self._sorted[event_name] = []

for _, listeners in sorted(
self._listeners[event_name].items(), key=lambda t: -t[0]
):
for listener in listeners:
self._sorted[event_name].append(listener)
for priority in sorted(prioritized_listeners, reverse=True):
sorted_listeners.extend(prioritized_listeners[priority])
Loading

0 comments on commit e03ce6f

Please sign in to comment.