Skip to content

Commit

Permalink
controlling connection
Browse files Browse the repository at this point in the history
  • Loading branch information
Joxos committed Feb 16, 2024
1 parent 2fa5145 commit 55faf4a
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 35 deletions.
31 changes: 15 additions & 16 deletions client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,26 @@
from common.protocol import on_init, is_framed
from common.utils import (
show_status,
compress,
decompress,
STATUS,
send_package,
)
from client.package import unpack_and_process
from common.config import DEFAULT_CODING


class ClientProtocol(asyncio.Protocol):
def __init__(self, package_to_send, on_con_lost):
self.package_to_send = package_to_send
self.on_con_lost = on_con_lost
"""Simple client protocol that can send packages and receive packages."""
def __init__(self,result, is_lost):
self.is_lost = is_lost
self.received_data = ""
self.result = result
on_init(self)

def connection_made(self, transport):
self.transport = transport
self.address = transport.get_extra_info("peername")
show_status(STATUS.CONNECTED, self.address)
transport.write(compress(bytes(self.package_to_send, encoding=DEFAULT_CODING)))
show_status(STATUS.SEND, self.address, self.package_to_send)

def data_received(self, more_data):
try:
Expand All @@ -33,26 +32,26 @@ def data_received(self, more_data):
return
if is_framed(self):
show_status(STATUS.RECV, self.address, self.received_data)
res = unpack_and_process(self.received_data)
show_status(STATUS.RECV, self.address, res)
self.transport.close()
self.result.set_result(unpack_and_process(self.received_data))
show_status(STATUS.RECV, self.address, self.result.result())

def connection_lost(self, exc):
show_status(STATUS.DISCONNECTED, self.address)
self.on_con_lost.set_result(True)
self.is_lost.set_result(True)

async def send_simple_package(package_to_send, server_address,ssl_context=None):
loop = asyncio.get_running_loop()
on_con_lost = loop.create_future()
is_lost = loop.create_future()
result = loop.create_future()

transport, protocol = await loop.create_connection(
lambda: ClientProtocol(package_to_send, on_con_lost),
lambda: ClientProtocol(result,is_lost),
server_address[0],
server_address[1],
ssl=ssl_context,
)

try:
await on_con_lost
finally:
transport.close()
send_package(transport, package_to_send)
await result
transport.close()
return result.result()
21 changes: 21 additions & 0 deletions common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sys import stderr
from loguru import logger
from common.config import *
import ssl

# logger settings
logger.remove()
Expand Down Expand Up @@ -68,3 +69,23 @@ def handle_run_main(main, server_address):
logger.info(
"This might caused by that TLS support is enabled on the server but not on client."
)

def send_package(transport, package):
transport.write(compress(bytes(package, encoding=DEFAULT_CODING)))
show_status(STATUS.SEND, transport.get_extra_info("peername"), package)

def resolve_client_ssl_context(certfile):
context = None
if ENABLE_TLS:
context = ssl.create_default_context()
context.check_hostname = False
try:
context.load_verify_locations(certfile)
except FileNotFoundError:
logger.error("File missing when using TLS.")
return
else:
logger.info("TLS enabled.")
else:
logger.warning("TLS not enabled.")
return context
23 changes: 4 additions & 19 deletions run_client.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,19 @@
"""
client.py: High-performance async client codes.
"""
from common.utils import handle_run_main
from common.utils import handle_run_main,resolve_client_ssl_context
from loguru import logger
from client.package import pack_request_login
from client.config import SERVER_ADDRESS, CRT_PATH
from common.config import ENABLE_TLS
import asyncio
import ssl

from client.client import send_simple_package


async def main():
context = None
if ENABLE_TLS:
context = ssl.create_default_context()
context.check_hostname = False
try:
context.load_verify_locations(CRT_PATH)
except FileNotFoundError:
logger.error("File missing when using TLS.")
return
else:
logger.info("TLS enabled.")
else:
logger.warning("TLS not enabled.")

context = resolve_client_ssl_context(CRT_PATH)
mypackage = pack_request_login("Joxos", "114514")
await send_simple_package(mypackage, SERVER_ADDRESS, context)
result = await send_simple_package(mypackage, SERVER_ADDRESS, context)
print(result)


if __name__ == "__main__":
Expand Down

0 comments on commit 55faf4a

Please sign in to comment.