Skip to content

Commit

Permalink
fix: prevent queueing the same task
Browse files Browse the repository at this point in the history
  • Loading branch information
huenique committed Nov 9, 2021
1 parent 5fd42d7 commit e217c44
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 15 deletions.
3 changes: 2 additions & 1 deletion dayong/components/privilege_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ async def get_user_info(
username="",
nickname="",
message="",
)
),
"message_id",
)
info = result.first()
if isinstance(info, (AnonMessage,)):
Expand Down
11 changes: 7 additions & 4 deletions dayong/components/task_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""
import hikari
import tanjun
from sqlalchemy.exc import NoResultFound
from sqlalchemy.exc import NoResultFound, ProgrammingError

from dayong.abc import Database
from dayong.core.settings import CONTENT_PROVIDER
Expand Down Expand Up @@ -45,10 +45,11 @@ async def start_task(context: tanjun.abc.Context, source: str, db: Database):
channel_name=channel.name if channel.name else "", task_name=source, run=True
)

await db.create_table()

try:
await db.create_table()
result = await db.get_row(task_model, "task_name")
if bool(result.one().run) is False:
if bool(result.one().run) is True:
raise PermissionError
else:
await db.update_row(task_model, "task_name")
Expand Down Expand Up @@ -79,6 +80,8 @@ async def stop_task(context: tanjun.abc.Context, source: str, db: Database):
run=False,
)

# We can also update the row here, but for simplicity, it's best to just perform a
# delete query.
await db.remove_row(task_model, "task_name")


Expand Down Expand Up @@ -131,7 +134,7 @@ async def share_content(
try:
await stop_task(ctx, source, db)
await ctx.respond(f"Stopped content delivery for `{source}`")
except NoResultFound:
except (NoResultFound, ProgrammingError):
await ctx.respond("That task isn't running 🤔")
else:
await ctx.respond(
Expand Down
14 changes: 4 additions & 10 deletions dayong/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
Data model operations which include retrieval and update commands.
"""
import asyncio
from typing import Any

import tanjun
Expand Down Expand Up @@ -42,13 +41,10 @@ async def update(instance: Any, update: Any) -> Any:
async def connect(
self, config: DayongConfig = tanjun.injected(type=DayongConfig)
) -> None:
loop = asyncio.get_running_loop()
self._conn = await loop.run_in_executor(
None,
create_async_engine,
self._conn = create_async_engine(
config.database_uri
if config.database_uri
else DayongDynamicLoader().load().database_uri,
else DayongDynamicLoader().load().database_uri
)

async def create_table(self) -> None:
Expand All @@ -57,8 +53,7 @@ async def create_table(self) -> None:

async def add_row(self, table_model: SQLModel) -> None:
async with AsyncSession(self._conn) as session:
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, session.add, table_model)
session.add(table_model)
await session.commit()

async def remove_row(self, table_model: SQLModel, attribute: str) -> None:
Expand Down Expand Up @@ -93,7 +88,6 @@ async def get_all_row(self, table_model: type[SQLModel]) -> ScalarResult[Any]:
return await session.exec(select(table_model)) # type: ignore

async def update_row(self, table_model: SQLModel, attribute: str) -> None:
loop = asyncio.get_running_loop()
model = type(table_model)
table = table_model.__dict__

Expand All @@ -105,6 +99,6 @@ async def update_row(self, table_model: SQLModel, attribute: str) -> None:
)
task = row.one()
task = await self.update(task, table)
await loop.run_in_executor(None, session.add, task)
session.add(task)
await session.commit()
await session.refresh(task)

0 comments on commit e217c44

Please sign in to comment.