Skip to content

Commit

Permalink
Add stickers converter
Browse files Browse the repository at this point in the history
  • Loading branch information
Soheab committed Nov 20, 2024
1 parent 7db879b commit f3a528d
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 15 deletions.
100 changes: 85 additions & 15 deletions discord/ext/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
CommandT = TypeVar('CommandT', bound='Command[Any, ..., Any]')
# CHT = TypeVar('CHT', bound='Check')
GroupT = TypeVar('GroupT', bound='Group[Any, ..., Any]')
SpecialDataT = TypeVar('SpecialDataT', discord.Attachment, discord.StickerItem)

if TYPE_CHECKING:
P = ParamSpec('P')
Expand Down Expand Up @@ -252,6 +253,26 @@ async def wrapped(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
return wrapped


async def _convert_stickers(
sticker_type: Type[Union[discord.StickerItem, discord.Sticker, discord.StandardSticker, discord.GuildSticker]],
stickers: _SpecialIterator[discord.StickerItem],
param: Parameter,
/,
) -> Union[discord.StickerItem, discord.Sticker, discord.StandardSticker, discord.GuildSticker]:
if sticker_type is discord.StickerItem:
try:
return next(stickers)
except StopIteration:
raise MissingRequiredSticker(param)

for sticker in stickers:
fetched = await sticker.fetch()
if isinstance(fetched, sticker_type):
return sticker

raise MissingRequiredSticker(param)


class _CaseInsensitiveDict(dict):
def __contains__(self, k):
return super().__contains__(k.casefold())
Expand All @@ -272,15 +293,15 @@ def __setitem__(self, k, v):
super().__setitem__(k.casefold(), v)


class _AttachmentIterator:
def __init__(self, data: List[discord.Attachment]):
self.data: List[discord.Attachment] = data
class _SpecialIterator(Generic[SpecialDataT]):
def __init__(self, data: List[SpecialDataT]):
self.data: List[SpecialDataT] = data
self.index: int = 0

def __iter__(self) -> Self:
return self

def __next__(self) -> discord.Attachment:
def __next__(self) -> SpecialDataT:
try:
value = self.data[self.index]
except IndexError:
Expand Down Expand Up @@ -649,7 +670,14 @@ async def dispatch_error(self, ctx: Context[BotT], error: CommandError, /) -> No
finally:
ctx.bot.dispatch('command_error', ctx, error)

async def transform(self, ctx: Context[BotT], param: Parameter, attachments: _AttachmentIterator, /) -> Any:
async def transform(
self,
ctx: Context[BotT],
param: Parameter,
attachments: _SpecialIterator[discord.Attachment],
stickers: _SpecialIterator[discord.StickerItem],
/,
) -> Any:
converter = param.converter
consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw
view = ctx.view
Expand All @@ -661,6 +689,15 @@ async def transform(self, ctx: Context[BotT], param: Parameter, attachments: _At
# Special case for Greedy[discord.Attachment] to consume the attachments iterator
if converter.converter is discord.Attachment:
return list(attachments)
# Special case for Greedy[discord.StickerItem] to consume the stickers iterator
elif converter.converter in (
discord.StickerItem,
discord.Sticker,
discord.StandardSticker,
discord.GuildSticker,
):
# can only send one sticker at a time
return [await _convert_stickers(converter.converter, stickers, param)]

if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY):
return await self._transform_greedy_pos(ctx, param, param.required, converter.constructed_converter)
Expand All @@ -679,12 +716,27 @@ async def transform(self, ctx: Context[BotT], param: Parameter, attachments: _At
except StopIteration:
raise MissingRequiredAttachment(param)

if self._is_typing_optional(param.annotation) and param.annotation.__args__[0] is discord.Attachment:
if attachments.is_empty():
# I have no idea who would be doing Optional[discord.Attachment] = 1
# but for those cases then 1 should be returned instead of None
return None if param.default is param.empty else param.default
return next(attachments)
# Try to detect Optional[discord.StickerItem] or discord.StickerItem special converter
if converter in (discord.StickerItem, discord.Sticker, discord.StandardSticker, discord.GuildSticker):
return await _convert_stickers(converter, stickers, param)

if self._is_typing_optional(param.annotation):
if param.annotation.__args__[0] is discord.Attachment:
if attachments.is_empty():
# I have no idea who would be doing Optional[discord.Attachment] = 1
# but for those cases then 1 should be returned instead of None
return None if param.default is param.empty else param.default
return next(attachments)
elif param.annotation.__args__[0] in (
discord.StickerItem,
discord.Sticker,
discord.StandardSticker,
discord.GuildSticker,
):
if stickers.is_empty():
return None if param.default is param.empty else param.default

return await _convert_stickers(param.annotation.__args__[0], stickers, param)

if view.eof:
if param.kind == param.VAR_POSITIONAL:
Expand Down Expand Up @@ -834,30 +886,32 @@ async def _parse_arguments(self, ctx: Context[BotT]) -> None:
ctx.kwargs = {}
args = ctx.args
kwargs = ctx.kwargs
attachments = _AttachmentIterator(ctx.message.attachments)

attachments = _SpecialIterator(ctx.message.attachments)
stickers = _SpecialIterator(ctx.message.stickers)

view = ctx.view
iterator = iter(self.params.items())

for name, param in iterator:
ctx.current_parameter = param
if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY):
transformed = await self.transform(ctx, param, attachments)
transformed = await self.transform(ctx, param, attachments, stickers)
args.append(transformed)
elif param.kind == param.KEYWORD_ONLY:
# kwarg only param denotes "consume rest" semantics
if self.rest_is_raw:
ctx.current_argument = argument = view.read_rest()
kwargs[name] = await run_converters(ctx, param.converter, argument, param)
else:
kwargs[name] = await self.transform(ctx, param, attachments)
kwargs[name] = await self.transform(ctx, param, attachments, stickers)
break
elif param.kind == param.VAR_POSITIONAL:
if view.eof and self.require_var_positional:
raise MissingRequiredArgument(param)
while not view.eof:
try:
transformed = await self.transform(ctx, param, attachments)
transformed = await self.transform(ctx, param, attachments, stickers)
args.append(transformed)
except RuntimeError:
break
Expand Down Expand Up @@ -1202,6 +1256,22 @@ def signature(self) -> str:
result.append(f'<{name} (upload a file)>')
continue

if annotation in (discord.StickerItem, discord.Sticker, discord.StandardSticker, discord.GuildSticker):
if annotation is discord.GuildSticker:
sticker_type = 'server sticker'
elif annotation is discord.StandardSticker:
sticker_type = 'standard sticker'
else:
sticker_type = 'sticker'

if optional:
result.append(f'[{name} (send a {sticker_type})]')
elif greedy:
result.append(f'[{name} (send {sticker_type}s)]...')
else:
result.append(f'<{name} (send a {sticker_type})>')
continue

# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
# parameter signature is a literal list of it's values
if origin is Literal:
Expand Down
30 changes: 30 additions & 0 deletions discord/ext/commands/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
'CommandError',
'MissingRequiredArgument',
'MissingRequiredAttachment',
'MissingRequiredSticker',
'BadArgument',
'PrivateMessageOnly',
'NoPrivateMessage',
Expand Down Expand Up @@ -206,6 +207,35 @@ def __init__(self, param: Parameter) -> None:
super().__init__(f'{param.displayed_name or param.name} is a required argument that is missing an attachment.')


class MissingRequiredSticker(UserInputError):
"""Exception raised when parsing a command and a parameter
that requires a sticker is not given.
This inherits from :exc:`UserInputError`
.. versionadded:: 2.5
Attributes
-----------
param: :class:`Parameter`
The argument that is missing a sticker.
"""

def __init__(self, param: Parameter) -> None:
from ...sticker import GuildSticker, StandardSticker

self.param: Parameter = param
converter = param.converter
if converter == GuildSticker:
sticker_type = 'server sticker'
elif converter == StandardSticker:
sticker_type = 'standard sticker'
else:
sticker_type = 'sticker'

super().__init__(f'{param.displayed_name or param.name} is a required argument that is missing a {sticker_type}.')


class TooManyArguments(UserInputError):
"""Exception raised when the command was passed too many arguments and its
:attr:`.Command.ignore_extra` attribute was not set to ``True``.
Expand Down
4 changes: 4 additions & 0 deletions docs/ext/commands/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,9 @@ Exceptions
.. autoexception:: discord.ext.commands.MissingRequiredAttachment
:members:

.. autoexception:: discord.ext.commands.MissingRequiredSticker
:members:

.. autoexception:: discord.ext.commands.ArgumentParsingError
:members:

Expand Down Expand Up @@ -789,6 +792,7 @@ Exception Hierarchy
- :exc:`~.commands.UserInputError`
- :exc:`~.commands.MissingRequiredArgument`
- :exc:`~.commands.MissingRequiredAttachment`
- :exc:`~.commands.MissingRequiredSticker`
- :exc:`~.commands.TooManyArguments`
- :exc:`~.commands.BadArgument`
- :exc:`~.commands.MessageNotFound`
Expand Down
41 changes: 41 additions & 0 deletions docs/ext/commands/commands.rst
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,47 @@ Note that using a :class:`discord.Attachment` converter after a :class:`~ext.com

If an attachment is expected but not given, then :exc:`~ext.commands.MissingRequiredAttachment` is raised to the error handlers.


Stickers
^^^^^^^^^^^^^^^^^^

.. versionadded:: 2.5

Annotating a parameter with any of the following sticker types will automatically get the uploaded sticker on a message and return the corresponding object:

- :class:`~discord.StickerItem`
- :class:`~discord.Sticker`
- :class:`~discord.StandardSticker`
- :class:`~discord.GuildSticker`

Consider the following example:

.. code-block:: python3
import discord
@bot.command()
async def sticker(ctx, sticker: discord.Sticker):
await ctx.send(f'You have uploaded {sticker.name} with format: {sticker.format}!')
When this command is invoked, the user must directly upload a sticker for the command body to be executed. When combined with the :data:`typing.Optional` converter, the user does not have to provide a sticker.

.. code-block:: python3
import typing
import discord
@bot.command()
async def upload(ctx, attachment: typing.Optional[discord.GuildSticker]):
if attachment is None:
await ctx.send('You did not upload anything!')
else:
await ctx.send(f'You have uploaded {sticker.name} with format: {sticker.format} from server: {sticker.guild}!')
If a sticker is expected but not given, then :exc:`~ext.commands.MissingRequiredSticker` is raised to the error handlers.

:class:`~ext.commands.Greedy` is supported too but at the moment, users can only upload one sticker at a time.

.. _ext_commands_flag_converter:

FlagConverter
Expand Down

0 comments on commit f3a528d

Please sign in to comment.