Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mypy: Add music_assistant.common #1428

Merged
merged 12 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions music_assistant/client/music.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ async def browse(
path: str | None = None,
limit: int | None = None,
offset: int | None = None,
) -> list[MediaItemType]:
) -> list[MediaItemType | ItemMapping]:
"""Browse Music providers."""
return [
media_from_dict(obj)
Expand All @@ -454,7 +454,7 @@ async def browse(

async def recently_played(
self, limit: int = 10, media_types: list[MediaType] | None = None
) -> list[MediaItemType]:
) -> list[MediaItemType | ItemMapping]:
"""Return a list of the last played items."""
return [
media_from_dict(item)
Expand All @@ -466,7 +466,7 @@ async def recently_played(
async def get_item_by_uri(
self,
uri: str,
) -> MediaItemType:
) -> MediaItemType | ItemMapping:
"""Get single music item providing a mediaitem uri."""
return media_from_dict(await self.client.send_command("music/item_by_uri", uri=uri))

Expand All @@ -478,7 +478,7 @@ async def get_item(
force_refresh: bool = False,
lazy: bool = True,
add_to_library: bool = False,
) -> MediaItemType:
) -> MediaItemType | ItemMapping:
"""Get single music item by id and media type."""
return media_from_dict(
await self.client.send_command(
Expand Down Expand Up @@ -534,7 +534,7 @@ async def add_item_to_library(self, item: str | MediaItemType) -> MediaItemType:
async def refresh_item(
self,
media_item: MediaItemType,
) -> MediaItemType | None:
) -> MediaItemType | ItemMapping | None:
"""Try to refresh a mediaitem by requesting it's full object or search for substitutes."""
if result := await self.client.send_command("music/refresh_item", media_item=media_item):
return media_from_dict(result)
Expand Down
2 changes: 1 addition & 1 deletion music_assistant/common/helpers/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def now_timestamp() -> float:
return now().timestamp()


def future_timestamp(**kwargs) -> float:
def future_timestamp(**kwargs: float) -> float:
"""Return current timestamp + timedelta."""
return (now() + datetime.timedelta(**kwargs)).timestamp()

Expand Down
2 changes: 1 addition & 1 deletion music_assistant/common/helpers/global_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# global cache - we use this on a few places (as limited as possible)
# where we have no other options
_global_cache_lock = asyncio.Lock()
_global_cache = {}
_global_cache: dict[str, Any] = {}


def get_global_cache_value(key: str, default: Any = None) -> Any:
Expand Down
12 changes: 6 additions & 6 deletions music_assistant/common/helpers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import base64
from _collections_abc import dict_keys, dict_values
from types import MethodType
from typing import Any
from typing import Any, TypeVar

import aiofiles
import orjson
from mashumaro.mixins.orjson import DataClassORJSONMixin

JSON_ENCODE_EXCEPTIONS = (TypeError, ValueError)
JSON_DECODE_EXCEPTIONS = (orjson.JSONDecodeError,)
Expand Down Expand Up @@ -59,12 +60,11 @@ def json_dumps(data: Any, indent: bool = False) -> str:

json_loads = orjson.loads

TargetT = TypeVar("TargetT", bound=DataClassORJSONMixin)

async def load_json_file(path: str, target_class: type | None = None) -> dict:

async def load_json_file(path: str, target_class: type[TargetT]) -> TargetT:
"""Load JSON from file."""
async with aiofiles.open(path, "r") as _file:
content = await _file.read()
if target_class:
# support for a mashumaro model
return target_class.from_json(content)
return json_loads(content)
return target_class.from_json(content)
2 changes: 1 addition & 1 deletion music_assistant/common/helpers/uri.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
base62_length22_id_pattern = re.compile(r"^[a-zA-Z0-9]{22}$")


def valid_base62_length22(item_id) -> bool:
def valid_base62_length22(item_id: str) -> bool:
"""Validate Spotify style ID."""
return bool(base62_length22_id_pattern.match(item_id))

Expand Down
42 changes: 18 additions & 24 deletions music_assistant/common/helpers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
import re
import socket
from collections.abc import Callable
from collections.abc import Set as AbstractSet
from typing import Any, TypeVar
from urllib.parse import urlparse
from uuid import UUID

# pylint: disable=invalid-name
T = TypeVar("T")
_UNDEF: dict = {}
CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable)
CALLBACK_TYPE = Callable[[], None]
# pylint: enable=invalid-name

Expand Down Expand Up @@ -50,7 +49,7 @@ def try_parse_float(possible_float: Any, default: float | None = 0.0) -> float |
return default


def try_parse_bool(possible_bool: Any) -> str:
def try_parse_bool(possible_bool: Any) -> bool:
"""Try to parse a bool."""
if isinstance(possible_bool, bool):
return possible_bool
Expand Down Expand Up @@ -79,7 +78,7 @@ def create_sort_name(input_str: str) -> str:
return input_str.strip()


def parse_title_and_version(title: str, track_version: str | None = None):
def parse_title_and_version(title: str, track_version: str | None = None) -> tuple[str, str]:
"""Try to parse clean track title and version from the title."""
version = ""
for splitter in [" (", " [", " - ", " (", " [", "-"]:
Expand Down Expand Up @@ -135,7 +134,7 @@ def clean_title(title: str) -> str:
return title.strip()


def get_version_substitute(version_str: str):
def get_version_substitute(version_str: str) -> str:
"""Transform provider version str to universal version type."""
version_str = version_str.lower()
# substitute edit and edition with version
Expand Down Expand Up @@ -169,7 +168,7 @@ def strip_url(line: str) -> str:
).rstrip()


def strip_dotcom(line: str):
def strip_dotcom(line: str) -> str:
"""Strip scheme-less netloc from line."""
return dot_com_pattern.sub("", line)

Expand Down Expand Up @@ -227,17 +226,17 @@ def clean_stream_title(line: str) -> str:
return line


async def get_ip():
async def get_ip() -> str:
"""Get primary IP-address for this host."""

def _get_ip():
def _get_ip() -> str:
"""Get primary IP-address for this host."""
# pylint: disable=broad-except,no-member
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
# doesn't even have to be reachable
sock.connect(("10.255.255.255", 1))
_ip = sock.getsockname()[0]
_ip = str(sock.getsockname()[0])
except Exception:
_ip = "127.0.0.1"
finally:
Expand Down Expand Up @@ -273,7 +272,7 @@ async def select_free_port(range_start: int, range_end: int) -> int:
async def get_ip_from_host(dns_name: str) -> str | None:
"""Resolve (first) IP-address for given dns name."""

def _resolve():
def _resolve() -> str | None:
try:
return socket.gethostbyname(dns_name)
except Exception: # pylint: disable=broad-except
Expand All @@ -283,7 +282,7 @@ def _resolve():
return await asyncio.to_thread(_resolve)


async def get_ip_pton(ip_string: str | None = None):
async def get_ip_pton(ip_string: str | None = None) -> bytes:
"""Return socket pton for local ip."""
if ip_string is None:
ip_string = await get_ip()
Expand All @@ -294,7 +293,7 @@ async def get_ip_pton(ip_string: str | None = None):
return await asyncio.to_thread(socket.inet_pton, socket.AF_INET6, ip_string)


def get_folder_size(folderpath):
def get_folder_size(folderpath: str) -> float:
"""Return folder size in gb."""
total_size = 0
# pylint: disable=unused-variable
Expand All @@ -306,7 +305,9 @@ def get_folder_size(folderpath):
return total_size / float(1 << 30)


def merge_dict(base_dict: dict, new_dict: dict, allow_overwite=False):
def merge_dict(
base_dict: dict[Any, Any], new_dict: dict[Any, Any], allow_overwite: bool = False
) -> dict[Any, Any]:
"""Merge dict without overwriting existing values."""
final_dict = base_dict.copy()
for key, value in new_dict.items():
Expand All @@ -321,12 +322,12 @@ def merge_dict(base_dict: dict, new_dict: dict, allow_overwite=False):
return final_dict


def merge_tuples(base: tuple, new: tuple) -> tuple:
def merge_tuples(base: tuple[Any, ...], new: tuple[Any, ...]) -> tuple[Any, ...]:
"""Merge 2 tuples."""
return tuple(x for x in base if x not in new) + tuple(new)


def merge_lists(base: list, new: list) -> list:
def merge_lists(base: list[Any], new: list[Any]) -> list[Any]:
"""Merge 2 lists."""
return [x for x in base if x not in new] + list(new)

Expand All @@ -335,7 +336,7 @@ def get_changed_keys(
dict1: dict[str, Any],
dict2: dict[str, Any],
ignore_keys: list[str] | None = None,
) -> set[str]:
) -> AbstractSet[str]:
"""Compare 2 dicts and return set of changed keys."""
return get_changed_values(dict1, dict2, ignore_keys).keys()

Expand Down Expand Up @@ -369,7 +370,7 @@ def get_changed_values(
return changed_values


def empty_queue(q: asyncio.Queue) -> None:
def empty_queue(q: asyncio.Queue[T]) -> None:
"""Empty an asyncio Queue."""
for _ in range(q.qsize()):
try:
Expand All @@ -386,10 +387,3 @@ def is_valid_uuid(uuid_to_test: str) -> bool:
except ValueError:
return False
return str(uuid_obj) == uuid_to_test


class classproperty(property): # noqa: N801
"""Implement class property for python3.11+."""

def __get__(self, cls, owner): # noqa: D105
return classmethod(self.fget).__get__(None, owner)()
2 changes: 1 addition & 1 deletion music_assistant/common/models/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class ServerInfoMessage(DataClassORJSONMixin):
)


def parse_message(raw: dict) -> MessageType:
def parse_message(raw: dict[Any, Any]) -> MessageType:
"""Parse Message from raw dict object."""
if "event" in raw:
return EventMessage.from_dict(raw)
Expand Down
Loading