Skip to content

Commit

Permalink
[#16] refactor: create new reddit clients (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
nickatnight authored Feb 27, 2023
1 parent f91830d commit a3465e5
Show file tree
Hide file tree
Showing 10 changed files with 111 additions and 50 deletions.
Empty file added backend/src/clients/__init__.py
Empty file.
Empty file.
20 changes: 20 additions & 0 deletions backend/src/clients/reddit/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import platform

import asyncpraw

from src.core.config import settings
from src.interfaces.client import IClient


class RedditResource(IClient[asyncpraw.Reddit]):
@classmethod
def configure(cls) -> asyncpraw.Reddit:
platform_name = platform.uname()
reddit_config = {
"client_id": settings.CLIENT_ID,
"client_secret": settings.CLIENT_SECRET,
"username": settings.USERNAME,
"password": settings.PASSWORD,
"user_agent": f"{platform_name}/{settings.VERSION} ({settings.BOT_NAME} {settings.DEVELOPER});",
}
return asyncpraw.Reddit(**reddit_config)
20 changes: 20 additions & 0 deletions backend/src/clients/reddit/inbox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import AsyncIterator

from asyncpraw.models import Message

from src.clients.reddit.base import RedditResource


class InboxClient(RedditResource):
def __init__(self) -> None:
self.reddit = self.configure()

def stream(self) -> AsyncIterator[Message]:
"""stream incoming messages"""
s: AsyncIterator[Message] = self.reddit.inbox.stream()

return s

async def close(self) -> None:
"""Close requester"""
_ = await self.reddit.close()
16 changes: 7 additions & 9 deletions backend/src/core/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from src.core.config import settings
from src.core.engine import GameEngine
from src.core.enums import SupportedSubs
from src.core.enums import SupportedSubs, UserBlacklist
from src.db.session import SessionLocal
from src.repositories import GameRepository, PlayerRepository, SubRedditRepository
from src.services import GameService, PlayerService, SubRedditService
Expand Down Expand Up @@ -40,13 +40,6 @@ async def tag_init(subreddit_name: str = SupportedSubs.TAG_YOURE_IT_BOT) -> None
game=GameService(repo=game_repo),
subreddit=SubRedditService(repo=subreddit_repo),
),
reddit_config={ # type: ignore
"client_id": settings.CLIENT_ID,
"client_secret": settings.CLIENT_SECRET,
"username": settings.USERNAME,
"password": settings.PASSWORD,
"user_agent": f"{platform_name}/{settings.VERSION} ({settings.BOT_NAME} {settings.DEVELOPER});",
},
)

logger.info(f"Game of Tag has started for SubReddit[r/{subreddit_name}]")
Expand All @@ -56,8 +49,13 @@ async def tag_init(subreddit_name: str = SupportedSubs.TAG_YOURE_IT_BOT) -> None
if __name__ == "__main__":
loop = asyncio.get_event_loop()
tasks = []
supported_subs = (
SupportedSubs.all()
if settings.USERNAME == UserBlacklist.TAG_YOURE_IT_BOT
else SupportedSubs.test()
)

for sub in SupportedSubs.all():
for sub in supported_subs:
task = loop.create_task(tag_init(subreddit_name=sub))
tasks.append(task)

Expand Down
49 changes: 17 additions & 32 deletions backend/src/core/engine.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import copy
import logging
from typing import Optional, Union
from uuid import UUID

import asyncpraw
from aiohttp import ClientSession

from src.core.config import settings
from src.core.typed import RedditClientConfigTyped
from src.services.stream.comment import CommentStreamService
from src.services.stream.inbox import InboxStreamService
from src.services.tag import TagService
Expand All @@ -31,37 +26,27 @@ def __init__(
self,
tag_service: TagService,
stream_service: Union[InboxStreamService, CommentStreamService],
reddit_config: RedditClientConfigTyped,
):
self.tag_service = tag_service
self.stream_service = stream_service
self.reddit_config = reddit_config

async def run(self) -> None:
logger.info(
f"Starting session with [{self.stream_service.__class__.__name__}] stream class..."
f"Starting session with [{self.stream_service.__class__.__name__}] stream class for u/{settings.USERNAME}..."
)
async with ClientSession(trust_env=True) as session:
config: RedditClientConfigTyped = copy.deepcopy(self.reddit_config)
config.update(
{
"requestor_kwargs": {
"session": session
}, # must successfully close the session when finished
}
)
game_id: Optional[Union[UUID, str]] = None
tag_service: TagService = self.tag_service

async with asyncpraw.Reddit(**self.reddit_config) as reddit:
logger.info(f"Streaming mentions for u/{settings.USERNAME}")

async for mention in self.stream_service.stream(reddit):
# pass
pre_flight_check: bool = await self.stream_service.pre_flight_check(
tag_service, mention
)
if pre_flight_check:
game_id = await self.stream_service.process(tag_service, mention, game_id)

await mention.mark_read()
game_id: Optional[Union[UUID, str]] = None
tag_service: TagService = self.tag_service

# TODO: move to decorator?
try:
async for mention in self.stream_service.client.stream():
# pass
pre_flight_check: bool = await self.stream_service.pre_flight_check(
tag_service, mention
)
if pre_flight_check:
game_id = await self.stream_service.process(tag_service, mention, game_id)

await mention.mark_read()
finally:
await self.stream_service.client.close()
22 changes: 20 additions & 2 deletions backend/src/core/enums.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from enum import Enum
from typing import List


class BaseEnum(str, Enum):
Expand All @@ -10,7 +11,7 @@ class SimpleEnum:
"""A simple enum to get list of class member variables"""

@classmethod
def all(cls):
def all(cls) -> List[str]:
return [value for name, value in vars(cls).items() if name.isupper()]


Expand All @@ -24,13 +25,30 @@ class SupportedSubs(SimpleEnum):

TAG_YOURE_IT_BOT = "TagYoureItBot"
# DOGECOIN = "dogecoin"
TEST = "test"

@classmethod
def test(cls) -> List[str]:
return list(filter(lambda i: i == cls.TEST, cls.all()))


class TagEnum:
"""Various phrases to watch as input"""

KEY = "!tag"
ENABLE_PHRASE = "i want to play tag again"
DISABLE_PHRASE = "i dont want to play tag"


class UserBlackList(SimpleEnum):
class UserBlacklist(SimpleEnum):
"""Do not tag these users"""

TAG_YOURE_IT_BOT = "TagYoureItBot"
TAG_YOURE_IT_BOT_TEST = "TagYoureItBotTest"


class RestrictedReadMail(SimpleEnum):
"""Manually read any mail from these users"""

MOD_NEWS_LETTER = "ModNewsletter"
REDDIT = "reddit"
13 changes: 13 additions & 0 deletions backend/src/interfaces/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from abc import ABCMeta, abstractmethod
from typing import Generic, TypeVar


T = TypeVar("T")


class IClient(Generic[T], metaclass=ABCMeta):
@classmethod
@abstractmethod
def configure(cls) -> T:
"""Configures a new client."""
raise NotImplementedError
6 changes: 4 additions & 2 deletions backend/src/services/stream/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
ModelType = TypeVar("ModelType")
SchemaType = TypeVar("SchemaType")
PrawType = TypeVar("PrawType", bound=AsyncPRAWBase)
ClientType = TypeVar("ClientType")


class AbstractStream(Generic[PrawType], metaclass=ABCMeta):
"""interface to stream Reddit Comments, Messaging, etc from a particular Subreddit"""

def __init__(self, subreddit_name: str) -> None:
self.subreddit_name: str = subreddit_name
def __init__(self, subreddit_name: str, client: Optional[ClientType] = None) -> None:
self.subreddit_name = subreddit_name
self.client = client

@abstractmethod
async def pre_flight_check(self, tag_service: Any, obj: PrawType) -> bool:
Expand Down
15 changes: 10 additions & 5 deletions backend/src/services/stream/inbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from asyncpraw.models import Message, Redditor
from asyncpraw.models import Subreddit as PrawSubReddit

from src.clients.reddit.inbox import InboxClient
from src.core.config import settings
from src.core.const import TAG_TIME_HUMAN_READABLE, ReplyEnum
from src.core.enums import TagEnum, UserBlackList
from src.core.enums import RestrictedReadMail, TagEnum
from src.core.utils import is_tag_time_expired
from src.models.game import Game
from src.models.player import Player
Expand All @@ -21,6 +22,10 @@


class InboxStreamService(AbstractStream[Message]):
def __init__(self, subreddit_name: str, client: Optional[InboxClient] = None) -> None:
self.subreddit_name = subreddit_name
self.client = client or InboxClient()

async def pre_flight_check(self, tag_service: TagService, obj: Message) -> bool:
author = obj.author
await author.load() # Re-fetches the object
Expand All @@ -30,11 +35,11 @@ async def pre_flight_check(self, tag_service: TagService, obj: Message) -> bool:

# direct messages which may involve user engagement take precedence
if obj.was_comment is False:
logger.info(f"Subject of Message[{obj.subject}")
logger.info(f"Subject of Message[{obj.subject}]")

# automatically read mod new letter mail
if author_name in UserBlackList.all():
await obj.mark_read()
# skip mail from blacklist users
if author_name in RestrictedReadMail.all():
logger.info(f"NEW MAIL from: [{author_name}]...skipping")
return False

# disable PM replies when not in production
Expand Down

0 comments on commit a3465e5

Please sign in to comment.