Skip to content

Commit

Permalink
logging and error handling added
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex committed Dec 2, 2024
1 parent f2da463 commit a42f32a
Show file tree
Hide file tree
Showing 5 changed files with 604 additions and 159 deletions.
32 changes: 11 additions & 21 deletions python/nistoar/midas/dbio/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from nistoar.pdr.utils.prov import Action
from .. import MIDASException
from .status import RecordStatus
from .websocket import WebSocketServer
from .notifier import Notifier
from nistoar.pdr.utils.prov import ANONYMOUS_USER

DAP_PROJECTS = "dap"
Expand Down Expand Up @@ -383,14 +383,6 @@ def searched(self, cst: CST):
else:
and_conditions[key] = value

#print("======== and_conditions =========")
#for key, value in and_conditions.items():
# print(f"{key}: {value}")
#print("======== or_conditions =========")
#for key, values in or_conditions.items():
# for value in values:
# print(f"{key}: {value}")

rec = self._data

and_met = True
Expand All @@ -417,8 +409,6 @@ def searched(self, cst: CST):
if rec.get(key) == value:
or_met = True
break
#print(or_met)
#print(and_met)
if (not or_conditions or or_met) and and_met:
return True
return False
Expand Down Expand Up @@ -827,14 +817,14 @@ class DBClient(ABC):
the only allowed shoulder will be the default, ``grp0``.
"""

def __init__(self, config: Mapping, projcoll: str, websocket_server: WebSocketServer, nativeclient=None, foruser: str = ANONYMOUS):
def __init__(self, config: Mapping, projcoll: str, notification_server: Notifier = None, nativeclient=None, foruser: str = ANONYMOUS):
self._cfg = config
self._native = nativeclient
self._projcoll = projcoll
self._who = foruser
self._whogrps = None
self._dbgroups = DBGroups(self)
self.websocket_server = websocket_server
self.notification_server = notification_server


@property
Expand Down Expand Up @@ -897,17 +887,17 @@ def create_record(self, name: str, shoulder: str = None, foruser: str = None) ->
rec['name'] = name
rec = ProjectRecord(self._projcoll, rec, self)
rec.save()
# Send a message through the WebSocket
message = f"{name}"
self._send_message_async(message)
if self.notification_server:
message = "New record created : "f"{name}"
self._notify(message)
return rec

def _send_message_async(self, message):
def _notify(self, message):
loop = asyncio.get_event_loop()
if loop.is_running():
asyncio.create_task(self.websocket_server.send_message_to_clients(message))
asyncio.create_task(self.notification_server.send_message_to_clients(message))
else:
asyncio.run(self.websocket_server.send_message_to_clients(message))
asyncio.run(self.notification_server.send_message_to_clients(message))

def _default_shoulder(self):
out = self._cfg.get("default_shoulder")
Expand Down Expand Up @@ -1298,7 +1288,7 @@ class DBClientFactory(ABC):
an abstract class for creating client connections to the database
"""

def __init__(self, config, websocket_server: WebSocketServer):
def __init__(self, config, notification_server: Notifier = None):
"""
initialize the factory with its configuration. The configuration provided here serves as
the default parameters for the cient as these can be overridden by the configuration parameters
Expand All @@ -1308,7 +1298,7 @@ def __init__(self, config, websocket_server: WebSocketServer):
depend on the type of project being access (e.g. "dmp" vs. "dap").
"""
self._cfg = config
self.websocket_server = websocket_server
self.notification_server = notification_server

@abstractmethod
def create_client(self, servicetype: str, config: Mapping = {}, foruser: str = ANONYMOUS):
Expand Down
14 changes: 7 additions & 7 deletions python/nistoar/midas/dbio/inmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections.abc import Mapping, MutableMapping, Set
from typing import Iterator, List
from . import base
from .websocket import WebSocketServer
from .notifier import Notifier

from nistoar.base.config import merge_config

Expand All @@ -16,9 +16,9 @@ class InMemoryDBClient(base.DBClient):
an in-memory DBClient implementation
"""

def __init__(self, dbdata: Mapping, config: Mapping, projcoll: str, websocket=WebSocketServer ,foruser: str = base.ANONYMOUS):
def __init__(self, dbdata: Mapping, config: Mapping, projcoll: str, notification_server: Notifier = None ,foruser: str = base.ANONYMOUS):
self._db = dbdata
super(InMemoryDBClient, self).__init__(config, projcoll, websocket, self._db, foruser)
super(InMemoryDBClient, self).__init__(config, projcoll, notification_server, self._db, foruser)

def _next_recnum(self, shoulder):
if shoulder not in self._db['nextnum']:
Expand Down Expand Up @@ -139,7 +139,7 @@ class InMemoryDBClientFactory(base.DBClientFactory):
clients it creates.
"""

def __init__(self, config: Mapping, websocket_server:WebSocketServer, _dbdata = None):
def __init__(self, config: Mapping, notification_server: Notifier = None, _dbdata = None):
"""
Create the factory with the given configuration.
Expand All @@ -148,8 +148,8 @@ def __init__(self, config: Mapping, websocket_server:WebSocketServer, _dbdata =
of the in-memory data structure required to use this input.) If
not provided, an empty database is created.
"""
super(InMemoryDBClientFactory, self).__init__(config, websocket_server)
self.websocket_server = websocket_server
super(InMemoryDBClientFactory, self).__init__(config, notification_server)
self.notification_server = notification_server
self._db = {
base.DAP_PROJECTS: {},
base.DMP_PROJECTS: {},
Expand All @@ -166,5 +166,5 @@ def create_client(self, servicetype: str, config: Mapping={}, foruser: str = bas
cfg = merge_config(config, deepcopy(self._cfg))
if servicetype not in self._db:
self._db[servicetype] = {}
return InMemoryDBClient(self._db, cfg, servicetype, self.websocket_server, foruser)
return InMemoryDBClient(self._db, cfg, servicetype, self.notification_server, foruser)

Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
# websocket_server.py
# notifier.py
import asyncio
import websockets
from concurrent.futures import ThreadPoolExecutor
import copy
import logging

class WebSocketServer:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class Notifier:
def __init__(self, host="localhost", port=8765):
self.host = host
self.port = port
self.server = None
self.clients = set() # Initialize the clients set

async def start(self):
self.server = await websockets.serve(self.websocket_handler, self.host, self.port)
#print(f"WebSocket server started on ws://{self.host}:{self.port}")
try:
self.server = await websockets.serve(self.websocket_handler, self.host, self.port)
logger.info(f"WebSocket server started on ws://{self.host}:{self.port}")
except Exception as e:
logger.error(f"Failed to start WebSocket server: {e}")


async def websocket_handler(self, websocket):
Expand All @@ -22,22 +29,26 @@ async def websocket_handler(self, websocket):
try:
async for message in websocket:
await self.send_message_to_clients(message)
except Exception as e:
logger.error(f"Error in websocket_handler: {e}")
finally:
# Remove the client from the set when they disconnect
self.clients.remove(websocket)

async def send_message_to_clients(self, message):
for client in self.clients:
print(client)
if self.clients:
for client in self.clients:
asyncio.create_task(client.send(message))
try:
asyncio.create_task(client.send(message))
except Exception as e:
logger.error(f"Failed to send message to client: {e}")

async def stop(self):
if self.server:
self.server.close()
await self.server.wait_closed()
self.server = None
logger.info("WebSocket server stopped")

async def wait_closed(self):
if self.server:
Expand Down
130 changes: 6 additions & 124 deletions python/tests/nistoar/midas/dbio/test_inmem.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import os, json, pdb, logging,asyncio
import os, json, pdb, logging
from pathlib import Path
import unittest as test
import websockets

from nistoar.midas.dbio import inmem, base
from nistoar.pdr.utils.prov import Action, Agent
from nistoar.midas.dbio.websocket import WebSocketServer

testuser = Agent("dbio", Agent.AUTO, "tester", "test")
testdir = Path(__file__).parents[0]
Expand All @@ -28,49 +26,12 @@
with open(dmp_path, 'r') as file:
dmp = json.load(file)


class TestInMemoryDBClientFactory(test.TestCase):

@classmethod
def initialize_websocket_server(cls):
websocket_server = WebSocketServer()
try:
cls.loop = asyncio.get_event_loop()
if cls.loop.is_closed():
raise RuntimeError
except RuntimeError:
cls.loop = asyncio.new_event_loop()
asyncio.set_event_loop(cls.loop)
cls.loop.run_until_complete(websocket_server.start())
#print("WebSocketServer initialized:", websocket_server)
return websocket_server

@classmethod
def setUpClass(cls):
cls.websocket_server = cls.initialize_websocket_server()

@classmethod
def tearDownClass(cls):
# Ensure the WebSocket server is properly closed
cls.loop.run_until_complete(cls.websocket_server.stop())
cls.loop.run_until_complete(cls.websocket_server.wait_closed())

# Cancel all lingering tasks
asyncio.set_event_loop(cls.loop) # Set the event loop as the current event loop
tasks = asyncio.all_tasks(loop=cls.loop)
for task in tasks:
task.cancel()
cls.loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))

# Close the event loop
cls.loop.close()



def setUp(self):
self.cfg = {"goob": "gurn"}
self.fact = inmem.InMemoryDBClientFactory(
self.cfg,self.websocket_server, {"nextnum": {"hank": 2}})
self.cfg, {"nextnum": {"hank": 2}})

def test_ctor(self):
self.assertEqual(self.fact._cfg, self.cfg)
Expand All @@ -94,44 +55,11 @@ def test_create_client(self):

class TestInMemoryDBClient(test.TestCase):

@classmethod
def initialize_websocket_server(cls):
websocket_server = WebSocketServer()
try:
cls.loop = asyncio.get_event_loop()
except RuntimeError:
cls.loop = asyncio.new_event_loop()
asyncio.set_event_loop(cls.loop)
cls.loop.run_until_complete(websocket_server.start())
#print("WebSocketServer initialized:", websocket_server)
return websocket_server

@classmethod
def setUpClass(cls):
cls.websocket_server = cls.initialize_websocket_server()

@classmethod
def tearDownClass(cls):
# Ensure the WebSocket server is properly closed
cls.loop.run_until_complete(cls.websocket_server.stop())
cls.loop.run_until_complete(cls.websocket_server.wait_closed())

# Cancel all lingering tasks
asyncio.set_event_loop(cls.loop)
tasks = asyncio.all_tasks(loop=cls.loop)
for task in tasks:
task.cancel()
cls.loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))

# Close the event loop
cls.loop.close()

def setUp(self):
self.cfg = {"default_shoulder": "mds3"}
self.user = "nist0:ava1"
self.cli = inmem.InMemoryDBClientFactory({},self.websocket_server).create_client(
self.cli = inmem.InMemoryDBClientFactory({}).create_client(
base.DMP_PROJECTS, self.cfg, self.user)


def test_next_recnum(self):
self.assertEqual(self.cli._next_recnum("goob"), 1)
Expand Down Expand Up @@ -342,6 +270,8 @@ def test_adv_select_records(self):
self.cli._db[base.DMP_PROJECTS][id] = rec.to_dict()




id = "pdr0:0006"
rec = base.ProjectRecord(
base.DMP_PROJECTS, {"id": id, "name": "test 2", "status": {
Expand Down Expand Up @@ -513,53 +443,5 @@ def test_record_action(self):
self.assertEqual(acts[1]['type'], Action.COMMENT)


class TestWebSocketServer(test.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.websocket_server = WebSocketServer()
self.loop = asyncio.get_event_loop()
await self.websocket_server.start()

# Initialize the InMemoryDBClientFactory with the websocket_server
self.cfg = {"default_shoulder": "mds3"}
self.user = "nist0:ava1"
self.cli = inmem.InMemoryDBClientFactory({},self.websocket_server).create_client(
base.DMP_PROJECTS, self.cfg, self.user)

async def asyncTearDown(self):
await self.websocket_server.stop()
await self.websocket_server.wait_closed()


async def test_create_records_websocket(self):
messages = []

async def receive_messages(uri):
try:
async with websockets.connect(uri) as websocket:
while True:
message = await websocket.recv()
#print(f"Received message: {message}")
messages.append(message)
#print(f"Messages: {messages}")
# Break the loop after receiving the first message for this test
except Exception as e:
print(f"Failed to connect to WebSocket server: {e}")

# Start the WebSocket client to receive messages
uri = 'ws://localhost:8765'
receive_task = asyncio.create_task(receive_messages(uri))
await asyncio.sleep(2)

#await self.websocket_server.send_message_to_clients("Connection established")
# Inject some data into the database
rec = self.cli.create_record("mine1")
await asyncio.sleep(2)

#print(f"Messages: {messages}")
self.assertEqual(len(messages), 1)
self.assertEqual(messages[0], "mine1")



if __name__ == '__main__':
test.main()
test.main()
Loading

0 comments on commit a42f32a

Please sign in to comment.