diff --git a/client/client.py b/client/client.py index 8bfee69..fae4e92 100644 --- a/client/client.py +++ b/client/client.py @@ -2,27 +2,26 @@ from common.protocol import on_init, is_framed from common.utils import ( show_status, - compress, decompress, STATUS, + send_package, ) from client.package import unpack_and_process from common.config import DEFAULT_CODING class ClientProtocol(asyncio.Protocol): - def __init__(self, package_to_send, on_con_lost): - self.package_to_send = package_to_send - self.on_con_lost = on_con_lost + """Simple client protocol that can send packages and receive packages.""" + def __init__(self,result, is_lost): + self.is_lost = is_lost self.received_data = "" + self.result = result on_init(self) def connection_made(self, transport): self.transport = transport self.address = transport.get_extra_info("peername") show_status(STATUS.CONNECTED, self.address) - transport.write(compress(bytes(self.package_to_send, encoding=DEFAULT_CODING))) - show_status(STATUS.SEND, self.address, self.package_to_send) def data_received(self, more_data): try: @@ -33,26 +32,26 @@ def data_received(self, more_data): return if is_framed(self): show_status(STATUS.RECV, self.address, self.received_data) - res = unpack_and_process(self.received_data) - show_status(STATUS.RECV, self.address, res) - self.transport.close() + self.result.set_result(unpack_and_process(self.received_data)) + show_status(STATUS.RECV, self.address, self.result.result()) def connection_lost(self, exc): show_status(STATUS.DISCONNECTED, self.address) - self.on_con_lost.set_result(True) + self.is_lost.set_result(True) async def send_simple_package(package_to_send, server_address,ssl_context=None): loop = asyncio.get_running_loop() - on_con_lost = loop.create_future() + is_lost = loop.create_future() + result = loop.create_future() transport, protocol = await loop.create_connection( - lambda: ClientProtocol(package_to_send, on_con_lost), + lambda: ClientProtocol(result,is_lost), server_address[0], server_address[1], ssl=ssl_context, ) - try: - await on_con_lost - finally: - transport.close() \ No newline at end of file + send_package(transport, package_to_send) + await result + transport.close() + return result.result() \ No newline at end of file diff --git a/common/utils.py b/common/utils.py index 8843555..d8ac6a5 100644 --- a/common/utils.py +++ b/common/utils.py @@ -6,6 +6,7 @@ from sys import stderr from loguru import logger from common.config import * +import ssl # logger settings logger.remove() @@ -68,3 +69,23 @@ def handle_run_main(main, server_address): logger.info( "This might caused by that TLS support is enabled on the server but not on client." ) + +def send_package(transport, package): + transport.write(compress(bytes(package, encoding=DEFAULT_CODING))) + show_status(STATUS.SEND, transport.get_extra_info("peername"), package) + +def resolve_client_ssl_context(certfile): + context = None + if ENABLE_TLS: + context = ssl.create_default_context() + context.check_hostname = False + try: + context.load_verify_locations(certfile) + except FileNotFoundError: + logger.error("File missing when using TLS.") + return + else: + logger.info("TLS enabled.") + else: + logger.warning("TLS not enabled.") + return context \ No newline at end of file diff --git a/run_client.py b/run_client.py index d530542..8450d9b 100644 --- a/run_client.py +++ b/run_client.py @@ -1,34 +1,19 @@ """ client.py: High-performance async client codes. """ -from common.utils import handle_run_main +from common.utils import handle_run_main,resolve_client_ssl_context from loguru import logger from client.package import pack_request_login from client.config import SERVER_ADDRESS, CRT_PATH -from common.config import ENABLE_TLS -import asyncio -import ssl from client.client import send_simple_package async def main(): - context = None - if ENABLE_TLS: - context = ssl.create_default_context() - context.check_hostname = False - try: - context.load_verify_locations(CRT_PATH) - except FileNotFoundError: - logger.error("File missing when using TLS.") - return - else: - logger.info("TLS enabled.") - else: - logger.warning("TLS not enabled.") - + context = resolve_client_ssl_context(CRT_PATH) mypackage = pack_request_login("Joxos", "114514") - await send_simple_package(mypackage, SERVER_ADDRESS, context) + result = await send_simple_package(mypackage, SERVER_ADDRESS, context) + print(result) if __name__ == "__main__":