Skip to content

Commit

Permalink
Merge branch 'main' into zillion-dead-end-priority
Browse files Browse the repository at this point in the history
  • Loading branch information
beauxq authored Jan 19, 2025
2 parents 20738b6 + 9e353eb commit 015a8f7
Show file tree
Hide file tree
Showing 25 changed files with 255 additions and 324 deletions.
16 changes: 14 additions & 2 deletions .github/pyright-config.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
{
"include": [
"type_check.py",
"../BizHawkClient.py",
"../Patch.py",
"../test/general/test_groups.py",
"../test/general/test_helpers.py",
"../test/general/test_memory.py",
"../test/general/test_names.py",
"../test/multiworld/__init__.py",
"../test/multiworld/test_multiworlds.py",
"../test/netutils/__init__.py",
"../test/programs/__init__.py",
"../test/programs/test_multi_server.py",
"../test/utils/__init__.py",
"../test/webhost/test_descriptions.py",
"../worlds/AutoSNIClient.py",
"../Patch.py"
"type_check.py"
],

"exclude": [
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/strict-type-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:

- name: "Install dependencies"
run: |
python -m pip install --upgrade pip pyright==1.1.358
python -m pip install --upgrade pip pyright==1.1.392.post0
python ModuleUpdate.py --append "WebHostLib/requirements.txt" --force --yes
- name: "pyright: strict check on specific files"
Expand Down
7 changes: 7 additions & 0 deletions CommonClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,13 @@ async def send_connect(self, **kwargs: typing.Any) -> None:
await self.send_msgs([payload])
await self.send_msgs([{"cmd": "Get", "keys": ["_read_race_mode"]}])

async def check_locations(self, locations: typing.Collection[int]) -> set[int]:
"""Send new location checks to the server. Returns the set of actually new locations that were sent."""
locations = set(locations) & self.missing_locations
if locations:
await self.send_msgs([{"cmd": 'LocationChecks', "locations": tuple(locations)}])
return locations

async def console_input(self) -> str:
if self.ui:
self.ui.focus_textinput()
Expand Down
4 changes: 4 additions & 0 deletions LinksAwakeningClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,10 @@ async def server_auth(self, password_requested: bool = False):

while self.client.auth == None:
await asyncio.sleep(0.1)

# Just return if we're closing
if self.exit_event.is_set():
return
self.auth = self.client.auth
await self.send_connect()

Expand Down
1 change: 1 addition & 0 deletions MultiServer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1992,6 +1992,7 @@ async def process_client_cmd(ctx: Context, client: Client, args: dict):
args["cmd"] = "SetReply"
value = ctx.stored_data.get(args["key"], args.get("default", 0))
args["original_value"] = copy.copy(value)
args["slot"] = client.slot
for operation in args["operations"]:
func = modify_functions[operation["operation"]]
value = func(value, operation["value"])
Expand Down
4 changes: 2 additions & 2 deletions data/lua/connector_oot.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1816,7 +1816,7 @@ end

-- Main control handling: main loop and socket receive

function receive()
function APreceive()
l, e = ootSocket:receive()
-- Handle incoming message
if e == 'closed' then
Expand Down Expand Up @@ -1874,7 +1874,7 @@ function main()
end
if (curstate == STATE_OK) or (curstate == STATE_INITIAL_CONNECTION_MADE) or (curstate == STATE_TENTATIVELY_CONNECTED) then
if (frame % 30 == 0) then
receive()
APreceive()
end
elseif (curstate == STATE_UNINITIALIZED) then
if (frame % 60 == 0) then
Expand Down
1 change: 1 addition & 0 deletions docs/network protocol.md
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ Sent to clients in response to a [Set](#Set) package if want_reply was set to tr
| key | str | The key that was updated. |
| value | any | The new value for the key. |
| original_value | any | The value the key had before it was updated. Not present on "_read" prefixed special keys. |
| slot | int | The slot that originally sent the Set package causing this change. |

Additional arguments added to the [Set](#Set) package that triggered this [SetReply](#SetReply) will also be passed along.

Expand Down
11 changes: 7 additions & 4 deletions test/general/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import unittest
from typing import Callable, Dict, Optional

from typing_extensions import override

from BaseClasses import CollectionState, MultiWorld, Region


class TestHelpers(unittest.TestCase):
multiworld: MultiWorld
player: int = 1

@override
def setUp(self) -> None:
self.multiworld = MultiWorld(self.player)
self.multiworld.game[self.player] = "helper_test_game"
Expand Down Expand Up @@ -38,15 +41,15 @@ def test_region_helpers(self) -> None:
"TestRegion1": {"TestRegion2": "connection"},
"TestRegion2": {"TestRegion1": None},
}

reg_exit_set: Dict[str, set[str]] = {
"TestRegion1": {"TestRegion3"}
}

exit_rules: Dict[str, Callable[[CollectionState], bool]] = {
"TestRegion1": lambda state: state.has("test_item", self.player)
}

self.multiworld.regions += [Region(region, self.player, self.multiworld, regions[region]) for region in regions]

with self.subTest("Test Location Creation Helper"):
Expand All @@ -73,7 +76,7 @@ def test_region_helpers(self) -> None:
entrance_name = exit_name if exit_name else f"{parent} -> {exit_reg}"
self.assertEqual(exit_rules[exit_reg],
self.multiworld.get_entrance(entrance_name, self.player).access_rule)

for region in reg_exit_set:
current_region = self.multiworld.get_region(region, self.player)
current_region.add_exits(reg_exit_set[region])
Expand Down
2 changes: 1 addition & 1 deletion test/general/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class TestWorldMemory(unittest.TestCase):
def test_leak(self):
def test_leak(self) -> None:
"""Tests that worlds don't leak references to MultiWorld or themselves with default options."""
import gc
import weakref
Expand Down
4 changes: 2 additions & 2 deletions test/general/test_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@


class TestNames(unittest.TestCase):
def test_item_names_format(self):
def test_item_names_format(self) -> None:
"""Item names must not be all numeric in order to differentiate between ID and name in !hint"""
for gamename, world_type in AutoWorldRegister.world_types.items():
with self.subTest(game=gamename):
for item_name in world_type.item_name_to_id:
self.assertFalse(item_name.isnumeric(),
f"Item name \"{item_name}\" is invalid. It must not be numeric.")

def test_location_name_format(self):
def test_location_name_format(self) -> None:
"""Location names must not be all numeric in order to differentiate between ID and name in !hint_location"""
for gamename, world_type in AutoWorldRegister.world_types.items():
with self.subTest(game=gamename):
Expand Down
26 changes: 13 additions & 13 deletions worlds/_bizhawk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import enum
import json
import sys
import typing
from typing import Any, Sequence


BIZHAWK_SOCKET_PORT_RANGE_START = 43055
Expand Down Expand Up @@ -44,10 +44,10 @@ class SyncError(Exception):


class BizHawkContext:
streams: typing.Optional[typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter]]
streams: tuple[asyncio.StreamReader, asyncio.StreamWriter] | None
connection_status: ConnectionStatus
_lock: asyncio.Lock
_port: typing.Optional[int]
_port: int | None

def __init__(self) -> None:
self.streams = None
Expand Down Expand Up @@ -122,12 +122,12 @@ async def get_script_version(ctx: BizHawkContext) -> int:
return int(await ctx._send_message("VERSION"))


async def send_requests(ctx: BizHawkContext, req_list: typing.List[typing.Dict[str, typing.Any]]) -> typing.List[typing.Dict[str, typing.Any]]:
async def send_requests(ctx: BizHawkContext, req_list: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Sends a list of requests to the BizHawk connector and returns their responses.
It's likely you want to use the wrapper functions instead of this."""
responses = json.loads(await ctx._send_message(json.dumps(req_list)))
errors: typing.List[ConnectorError] = []
errors: list[ConnectorError] = []

for response in responses:
if response["type"] == "ERROR":
Expand Down Expand Up @@ -180,7 +180,7 @@ async def get_system(ctx: BizHawkContext) -> str:
return res["value"]


async def get_cores(ctx: BizHawkContext) -> typing.Dict[str, str]:
async def get_cores(ctx: BizHawkContext) -> dict[str, str]:
"""Gets the preferred cores for systems with multiple cores. Only systems with multiple available cores have
entries."""
res = (await send_requests(ctx, [{"type": "PREFERRED_CORES"}]))[0]
Expand Down Expand Up @@ -233,8 +233,8 @@ async def set_message_interval(ctx: BizHawkContext, value: float) -> None:
raise SyncError(f"Expected response of type SET_MESSAGE_INTERVAL_RESPONSE but got {res['type']}")


async def guarded_read(ctx: BizHawkContext, read_list: typing.Sequence[typing.Tuple[int, int, str]],
guard_list: typing.Sequence[typing.Tuple[int, typing.Sequence[int], str]]) -> typing.Optional[typing.List[bytes]]:
async def guarded_read(ctx: BizHawkContext, read_list: Sequence[tuple[int, int, str]],
guard_list: Sequence[tuple[int, Sequence[int], str]]) -> list[bytes] | None:
"""Reads an array of bytes at 1 or more addresses if and only if every byte in guard_list matches its expected
value.
Expand Down Expand Up @@ -262,7 +262,7 @@ async def guarded_read(ctx: BizHawkContext, read_list: typing.Sequence[typing.Tu
"domain": domain
} for address, size, domain in read_list])

ret: typing.List[bytes] = []
ret: list[bytes] = []
for item in res:
if item["type"] == "GUARD_RESPONSE":
if not item["value"]:
Expand All @@ -276,7 +276,7 @@ async def guarded_read(ctx: BizHawkContext, read_list: typing.Sequence[typing.Tu
return ret


async def read(ctx: BizHawkContext, read_list: typing.Sequence[typing.Tuple[int, int, str]]) -> typing.List[bytes]:
async def read(ctx: BizHawkContext, read_list: Sequence[tuple[int, int, str]]) -> list[bytes]:
"""Reads data at 1 or more addresses.
Items in `read_list` should be organized `(address, size, domain)` where
Expand All @@ -288,8 +288,8 @@ async def read(ctx: BizHawkContext, read_list: typing.Sequence[typing.Tuple[int,
return await guarded_read(ctx, read_list, [])


async def guarded_write(ctx: BizHawkContext, write_list: typing.Sequence[typing.Tuple[int, typing.Sequence[int], str]],
guard_list: typing.Sequence[typing.Tuple[int, typing.Sequence[int], str]]) -> bool:
async def guarded_write(ctx: BizHawkContext, write_list: Sequence[tuple[int, Sequence[int], str]],
guard_list: Sequence[tuple[int, Sequence[int], str]]) -> bool:
"""Writes data to 1 or more addresses if and only if every byte in guard_list matches its expected value.
Items in `write_list` should be organized `(address, value, domain)` where
Expand Down Expand Up @@ -326,7 +326,7 @@ async def guarded_write(ctx: BizHawkContext, write_list: typing.Sequence[typing.
return True


async def write(ctx: BizHawkContext, write_list: typing.Sequence[typing.Tuple[int, typing.Sequence[int], str]]) -> None:
async def write(ctx: BizHawkContext, write_list: Sequence[tuple[int, Sequence[int], str]]) -> None:
"""Writes data to 1 or more addresses.
Items in write_list should be organized `(address, value, domain)` where
Expand Down
12 changes: 6 additions & 6 deletions worlds/_bizhawk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

import abc
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, ClassVar

from worlds.LauncherComponents import Component, SuffixIdentifier, Type, components, launch_subprocess

Expand All @@ -24,9 +24,9 @@ def launch_client(*args) -> None:


class AutoBizHawkClientRegister(abc.ABCMeta):
game_handlers: ClassVar[Dict[Tuple[str, ...], Dict[str, BizHawkClient]]] = {}
game_handlers: ClassVar[dict[tuple[str, ...], dict[str, BizHawkClient]]] = {}

def __new__(cls, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any]) -> AutoBizHawkClientRegister:
def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any]) -> AutoBizHawkClientRegister:
new_class = super().__new__(cls, name, bases, namespace)

# Register handler
Expand Down Expand Up @@ -54,7 +54,7 @@ def __new__(cls, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any])
return new_class

@staticmethod
async def get_handler(ctx: "BizHawkClientContext", system: str) -> Optional[BizHawkClient]:
async def get_handler(ctx: "BizHawkClientContext", system: str) -> BizHawkClient | None:
for systems, handlers in AutoBizHawkClientRegister.game_handlers.items():
if system in systems:
for handler in handlers.values():
Expand All @@ -65,13 +65,13 @@ async def get_handler(ctx: "BizHawkClientContext", system: str) -> Optional[BizH


class BizHawkClient(abc.ABC, metaclass=AutoBizHawkClientRegister):
system: ClassVar[Union[str, Tuple[str, ...]]]
system: ClassVar[str | tuple[str, ...]]
"""The system(s) that the game this client is for runs on"""

game: ClassVar[str]
"""The game this client is for"""

patch_suffix: ClassVar[Optional[Union[str, Tuple[str, ...]]]]
patch_suffix: ClassVar[str | tuple[str, ...] | None]
"""The file extension(s) this client is meant to open and patch (e.g. ".apz3")"""

@abc.abstractmethod
Expand Down
12 changes: 6 additions & 6 deletions worlds/_bizhawk/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import asyncio
import enum
import subprocess
from typing import Any, Dict, Optional
from typing import Any

from CommonClient import CommonContext, ClientCommandProcessor, get_base_parser, server_loop, logger, gui_enabled
import Patch
Expand Down Expand Up @@ -43,15 +43,15 @@ class BizHawkClientContext(CommonContext):
command_processor = BizHawkClientCommandProcessor
auth_status: AuthStatus
password_requested: bool
client_handler: Optional[BizHawkClient]
slot_data: Optional[Dict[str, Any]] = None
rom_hash: Optional[str] = None
client_handler: BizHawkClient | None
slot_data: dict[str, Any] | None = None
rom_hash: str | None = None
bizhawk_ctx: BizHawkContext

watcher_timeout: float
"""The maximum amount of time the game watcher loop will wait for an update from the server before executing"""

def __init__(self, server_address: Optional[str], password: Optional[str]):
def __init__(self, server_address: str | None, password: str | None):
super().__init__(server_address, password)
self.auth_status = AuthStatus.NOT_AUTHENTICATED
self.password_requested = False
Expand Down Expand Up @@ -241,7 +241,7 @@ def _patch_and_run_game(patch_file: str):
return {}


def launch(*launch_args) -> None:
def launch(*launch_args: str) -> None:
async def main():
parser = get_base_parser()
parser.add_argument("patch_file", default="", type=str, nargs="?", help="Path to an Archipelago patch file")
Expand Down
2 changes: 1 addition & 1 deletion worlds/alttp/Client.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def new_check(location_id):
snes_logger.info(f"Discarding recent {len(new_locations)} checks as ROM Status has changed.")
return False
else:
await ctx.send_msgs([{"cmd": 'LocationChecks', "locations": new_locations}])
await ctx.check_locations(new_locations)
await snes_flush_writes(ctx)
return True

Expand Down
Loading

0 comments on commit 015a8f7

Please sign in to comment.