From 401785c9d2a8cddb688eb76cd256cee700a5f9fe Mon Sep 17 00:00:00 2001 From: Cordila <49218334+Cordila@users.noreply.github.com> Date: Sun, 12 Mar 2023 08:20:31 +0530 Subject: [PATCH 1/2] Import/export --- cogs/modmail.py | 47 +++++++++++++++++++++++++++++++++++++++++++++++ core/clients.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/cogs/modmail.py b/cogs/modmail.py index 484eb7aac8..607ca1458c 100644 --- a/cogs/modmail.py +++ b/cogs/modmail.py @@ -1,5 +1,6 @@ import asyncio import re +import os from datetime import datetime, timezone from itertools import zip_longest from typing import Optional, Union, List, Tuple, Literal @@ -2186,6 +2187,52 @@ async def isenable(self, ctx): return await ctx.send(embed=embed) + @commands.command(name="export") + @checks.has_permissions(PermissionLevel.ADMINISTRATOR) + async def export_backup(self, ctx, collection_name): + """ + Export a backup of a collection in the form of a json file. + + {prefix}export + """ + success_message, file = await self.bot.api.export_backups(collection_name) + await ctx.send(success_message) + await ctx.author.send(file=file) + + @commands.command(name="import") + @checks.has_permissions(PermissionLevel.ADMINISTRATOR) + async def import_backup(self,ctx): + """ + Import a backup from a json file. + + This will overwrite all data in the collection. + + {prefix}import + + """ + if len(ctx.message.attachments) == 1: + attachment = ctx.message.attachments[0] + await attachment.save(attachment.filename) + file = discord.File(attachment.filename) + collection_name = os.path.splitext(attachment.filename)[0] + await ctx.send(f"This will overwrite all data in the {collection_name} collection. Are you sure you want to continue? (yes/no)") + try: + msg = await self.bot.wait_for("message",timeout=30,check=lambda m: m.author == ctx.author and m.channel.id == ctx.channel.id) + if msg.content.lower() == "yes": + success_message = await self.bot.api.import_backups(collection_name, file) + await ctx.send(success_message) + os.remove(attachment.filename) + else: + return await ctx.send("Cancelled.") + + except asyncio.TimeoutError: + return await ctx.send("You took too long to respond. Please try again.") + + else: + return await ctx.send("Please attach 1 json file.") + + + async def setup(bot): await bot.add_cog(Modmail(bot)) diff --git a/core/clients.py b/core/clients.py index eebe3bcff6..403268a32a 100644 --- a/core/clients.py +++ b/core/clients.py @@ -1,12 +1,13 @@ import secrets import sys -from json import JSONDecodeError +import json from typing import Any, Dict, Union, Optional import discord from discord import Member, DMChannel, TextChannel, Message from discord.ext import commands +from bson import ObjectId from aiohttp import ClientResponseError, ClientResponse from motor.motor_asyncio import AsyncIOMotorClient from pymongo.errors import ConfigurationError @@ -16,6 +17,12 @@ logger = getLogger(__name__) +class CustomJSONEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, ObjectId): + return str(obj) + return super().default(obj) + class GitHub: """ The client for interacting with GitHub API. @@ -751,6 +758,27 @@ async def get_user_info(self) -> Optional[dict]: } } + async def export_backups(self, collection_name: str): + coll = self.db[collection_name] + documents = [] + async for document in coll.find(): + documents.append(document) + with open(f"{collection_name}.json", "w") as f: + json.dump(documents, f, cls=CustomJSONEncoder) + with open(f"{collection_name}.json", "rb") as f: + file = discord.File(f, f"{collection_name}.json") + success_message = f"Exported {len(documents)} documents from {collection_name} to JSON. Check your DMs for the file." + return success_message, file + + async def import_backups(self, collection_name: str, file: discord.File): + contents = await self.bot.loop.run_in_executor(None, file.fp.read) + documents = json.loads(contents.decode('utf-8')) + coll = self.db[collection_name] + await coll.delete_many({}) + result = await coll.insert_many(documents) + success_message = f"Imported {len(result.inserted_ids)} documents from {file.filename} into {collection_name}." + return success_message + class PluginDatabaseClient: def __init__(self, bot): From 59204d4ac64ccb890576eae7d981a805f7321054 Mon Sep 17 00:00:00 2001 From: Cordila <49218334+Cordila@users.noreply.github.com> Date: Sun, 12 Mar 2023 13:00:26 +0530 Subject: [PATCH 2/2] Black format --- bot.py | 3 --- cogs/modmail.py | 14 +++++++++----- cogs/plugins.py | 2 -- cogs/utility.py | 1 - core/clients.py | 7 +++++-- core/config.py | 1 - core/models.py | 1 - core/thread.py | 2 -- 8 files changed, 14 insertions(+), 17 deletions(-) diff --git a/bot.py b/bot.py index b23b2449b2..79b323ef94 100644 --- a/bot.py +++ b/bot.py @@ -657,7 +657,6 @@ async def get_or_fetch_user(self, id: int) -> discord.User: return self.get_user(id) or await self.fetch_user(id) async def retrieve_emoji(self) -> typing.Tuple[str, str]: - sent_emoji = self.config["sent_emoji"] blocked_emoji = self.config["blocked_emoji"] @@ -731,7 +730,6 @@ def check_manual_blocked_roles(self, author: discord.Member) -> bool: if isinstance(author, discord.Member): for r in author.roles: if str(r.id) in self.blocked_roles: - blocked_reason = self.blocked_roles.get(str(r.id)) or "" try: @@ -790,7 +788,6 @@ async def is_blocked( channel: discord.TextChannel = None, send_message: bool = False, ) -> bool: - member = self.guild.get_member(author.id) if member is None: # try to find in other guilds diff --git a/cogs/modmail.py b/cogs/modmail.py index 607ca1458c..6658d62c81 100644 --- a/cogs/modmail.py +++ b/cogs/modmail.py @@ -2201,7 +2201,7 @@ async def export_backup(self, ctx, collection_name): @commands.command(name="import") @checks.has_permissions(PermissionLevel.ADMINISTRATOR) - async def import_backup(self,ctx): + async def import_backup(self, ctx): """ Import a backup from a json file. @@ -2215,9 +2215,15 @@ async def import_backup(self,ctx): await attachment.save(attachment.filename) file = discord.File(attachment.filename) collection_name = os.path.splitext(attachment.filename)[0] - await ctx.send(f"This will overwrite all data in the {collection_name} collection. Are you sure you want to continue? (yes/no)") + await ctx.send( + f"This will overwrite all data in the {collection_name} collection. Are you sure you want to continue? (yes/no)" + ) try: - msg = await self.bot.wait_for("message",timeout=30,check=lambda m: m.author == ctx.author and m.channel.id == ctx.channel.id) + msg = await self.bot.wait_for( + "message", + timeout=30, + check=lambda m: m.author == ctx.author and m.channel.id == ctx.channel.id, + ) if msg.content.lower() == "yes": success_message = await self.bot.api.import_backups(collection_name, file) await ctx.send(success_message) @@ -2232,7 +2238,5 @@ async def import_backup(self,ctx): return await ctx.send("Please attach 1 json file.") - - async def setup(bot): await bot.add_cog(Modmail(bot)) diff --git a/cogs/plugins.py b/cogs/plugins.py index 2bfac509af..37020f7e0c 100644 --- a/cogs/plugins.py +++ b/cogs/plugins.py @@ -265,7 +265,6 @@ async def load_plugin(self, plugin): raise InvalidPluginError("Cannot load extension, plugin invalid.") from exc async def parse_user_input(self, ctx, plugin_name, check_version=False): - if not self.bot.config["enable_plugins"]: embed = discord.Embed( description="Plugins are disabled, enable them by setting `ENABLE_PLUGINS=true`", @@ -380,7 +379,6 @@ async def plugins_add(self, ctx, *, plugin_name: str): await self.bot.config.update() if self.bot.config.get("enable_plugins"): - invalidate_caches() try: diff --git a/cogs/utility.py b/cogs/utility.py index ac642d9eb3..26356289df 100644 --- a/cogs/utility.py +++ b/cogs/utility.py @@ -599,7 +599,6 @@ async def status(self, ctx, *, status_type: str.lower): return await ctx.send(embed=embed) async def set_presence(self, *, status=None, activity_type=None, activity_message=None): - if status is None: status = self.bot.config.get("status") diff --git a/core/clients.py b/core/clients.py index 403268a32a..fc2e776196 100644 --- a/core/clients.py +++ b/core/clients.py @@ -23,6 +23,7 @@ def default(self, obj): return str(obj) return super().default(obj) + class GitHub: """ The client for interacting with GitHub API. @@ -772,11 +773,13 @@ async def export_backups(self, collection_name: str): async def import_backups(self, collection_name: str, file: discord.File): contents = await self.bot.loop.run_in_executor(None, file.fp.read) - documents = json.loads(contents.decode('utf-8')) + documents = json.loads(contents.decode("utf-8")) coll = self.db[collection_name] await coll.delete_many({}) result = await coll.insert_many(documents) - success_message = f"Imported {len(result.inserted_ids)} documents from {file.filename} into {collection_name}." + success_message = ( + f"Imported {len(result.inserted_ids)} documents from {file.filename} into {collection_name}." + ) return success_message diff --git a/core/config.py b/core/config.py index 56db40c7b0..fb030b1b0f 100644 --- a/core/config.py +++ b/core/config.py @@ -21,7 +21,6 @@ class ConfigManager: - public_keys = { # activity "twitch_url": "https://www.twitch.tv/discordmodmail/", diff --git a/core/models.py b/core/models.py index 2eab1ceebb..4c6f956e82 100644 --- a/core/models.py +++ b/core/models.py @@ -202,7 +202,6 @@ async def convert(self, ctx, argument): try: return await super().convert(ctx, argument) except commands.ChannelNotFound: - if guild: categories = {c.name.casefold(): c for c in guild.categories} else: diff --git a/core/thread.py b/core/thread.py index 53cdd1d202..98756c46f5 100644 --- a/core/thread.py +++ b/core/thread.py @@ -716,7 +716,6 @@ async def delete_message( async def find_linked_message_from_dm( self, message, either_direction=False, get_thread_channel=False ) -> typing.List[discord.Message]: - joint_id = None if either_direction: joint_id = get_joint_id(message) @@ -909,7 +908,6 @@ async def send( persistent_note: bool = False, thread_creation: bool = False, ) -> None: - if not note and from_mod: self.bot.loop.create_task(self._restart_close_timer()) # Start or restart thread auto close