19
19
days ,
20
20
match_title ,
21
21
match_user_id ,
22
+ match_other_recipients ,
22
23
truncate ,
23
24
format_channel_name ,
24
25
)
@@ -34,6 +35,7 @@ def __init__(
34
35
manager : "ThreadManager" ,
35
36
recipient : typing .Union [discord .Member , discord .User , int ],
36
37
channel : typing .Union [discord .DMChannel , discord .TextChannel ] = None ,
38
+ other_recipients : typing .List [typing .Union [discord .Member , discord .User ]] = [],
37
39
):
38
40
self .manager = manager
39
41
self .bot = manager .bot
@@ -45,6 +47,7 @@ def __init__(
45
47
raise CommandError ("Recipient cannot be a bot." )
46
48
self ._id = recipient .id
47
49
self ._recipient = recipient
50
+ self ._other_recipients = other_recipients
48
51
self ._channel = channel
49
52
self .genesis_message = None
50
53
self ._ready_event = asyncio .Event ()
@@ -54,7 +57,7 @@ def __init__(
54
57
self ._cancelled = False
55
58
56
59
def __repr__ (self ):
57
- return f'Thread(recipient="{ self .recipient or self .id } ", channel={ self .channel .id } )'
60
+ return f'Thread(recipient="{ self .recipient or self .id } ", channel={ self .channel .id } , other_recipienets= { len ( self . _other_recipients ) } )'
58
61
59
62
async def wait_until_ready (self ) -> None :
60
63
"""Blocks execution until the thread is fully set up."""
@@ -80,6 +83,10 @@ def channel(self) -> typing.Union[discord.TextChannel, discord.DMChannel]:
80
83
def recipient (self ) -> typing .Optional [typing .Union [discord .User , discord .Member ]]:
81
84
return self ._recipient
82
85
86
+ @property
87
+ def recipients (self ) -> typing .List [typing .Union [discord .User , discord .Member ]]:
88
+ return [self ._recipient ] + self ._other_recipients
89
+
83
90
@property
84
91
def ready (self ) -> bool :
85
92
return self ._ready_event .is_set ()
@@ -103,6 +110,23 @@ def cancelled(self, flag: bool):
103
110
for i in self .wait_tasks :
104
111
i .cancel ()
105
112
113
+ @classmethod
114
+ async def from_channel (
115
+ cls , manager : "ThreadManager" , channel : discord .TextChannel
116
+ ) -> "Thread" :
117
+ recipient_id = match_user_id (
118
+ channel .topic
119
+ ) # there is a chance it grabs from another recipient's main thread
120
+ recipient = manager .bot .get_user (recipient_id ) or await manager .bot .fetch_user (
121
+ recipient_id
122
+ )
123
+
124
+ other_recipients = match_other_recipients (channel .topic )
125
+ for n , uid in enumerate (other_recipients ):
126
+ other_recipients [n ] = manager .bot .get_user (uid ) or await manager .bot .fetch_user (uid )
127
+
128
+ return cls (manager , recipient or recipient_id , channel , other_recipients )
129
+
106
130
async def setup (self , * , creator = None , category = None , initial_message = None ):
107
131
"""Create the thread channel and other io related initialisation tasks"""
108
132
self .bot .dispatch ("thread_initiate" , self , creator , category , initial_message )
@@ -619,23 +643,30 @@ async def find_linked_messages(
619
643
except ValueError :
620
644
raise ValueError ("Malformed thread message." )
621
645
622
- async for msg in self .recipient .history ():
623
- if either_direction :
624
- if msg .id == joint_id :
625
- return message1 , msg
646
+ messages = [message1 ]
647
+ for user in self .recipients :
648
+ async for msg in user .history ():
649
+ if either_direction :
650
+ if msg .id == joint_id :
651
+ return message1 , msg
626
652
627
- if not (msg .embeds and msg .embeds [0 ].author .url ):
628
- continue
629
- try :
630
- if int (msg .embeds [0 ].author .url .split ("#" )[- 1 ]) == joint_id :
631
- return message1 , msg
632
- except ValueError :
633
- continue
634
- raise ValueError ("DM message not found. Plain messages are not supported." )
653
+ if not (msg .embeds and msg .embeds [0 ].author .url ):
654
+ continue
655
+ try :
656
+ if int (msg .embeds [0 ].author .url .split ("#" )[- 1 ]) == joint_id :
657
+ messages .append (msg )
658
+ break
659
+ except ValueError :
660
+ continue
661
+
662
+ if len (messages ) > 1 :
663
+ return messages
664
+
665
+ raise ValueError ("DM message not found." )
635
666
636
667
async def edit_message (self , message_id : typing .Optional [int ], message : str ) -> None :
637
668
try :
638
- message1 , message2 = await self .find_linked_messages (message_id )
669
+ message1 , * message2 = await self .find_linked_messages (message_id )
639
670
except ValueError :
640
671
logger .warning ("Failed to edit message." , exc_info = True )
641
672
raise
@@ -644,10 +675,11 @@ async def edit_message(self, message_id: typing.Optional[int], message: str) ->
644
675
embed1 .description = message
645
676
646
677
tasks = [self .bot .api .edit_message (message1 .id , message ), message1 .edit (embed = embed1 )]
647
- if message2 is not None :
648
- embed2 = message2 .embeds [0 ]
649
- embed2 .description = message
650
- tasks += [message2 .edit (embed = embed2 )]
678
+ if message2 is not [None ]:
679
+ for m2 in message2 :
680
+ embed2 = message2 .embeds [0 ]
681
+ embed2 .description = message
682
+ tasks += [m2 .edit (embed = embed2 )]
651
683
elif message1 .embeds [0 ].author .name .startswith ("Persistent Note" ):
652
684
tasks += [self .bot .api .edit_note (message1 .id , message )]
653
685
@@ -657,14 +689,16 @@ async def delete_message(
657
689
self , message : typing .Union [int , discord .Message ] = None , note : bool = True
658
690
) -> None :
659
691
if isinstance (message , discord .Message ):
660
- message1 , message2 = await self .find_linked_messages (message1 = message , note = note )
692
+ message1 , * message2 = await self .find_linked_messages (message1 = message , note = note )
661
693
else :
662
- message1 , message2 = await self .find_linked_messages (message , note = note )
694
+ message1 , * message2 = await self .find_linked_messages (message , note = note )
695
+ print (message1 , message2 )
663
696
tasks = []
664
697
if not isinstance (message , discord .Message ):
665
698
tasks += [message1 .delete ()]
666
- elif message2 is not None :
667
- tasks += [message2 .delete ()]
699
+ elif message2 is not [None ]:
700
+ for m2 in message2 :
701
+ tasks += [m2 .delete ()]
668
702
elif message1 .embeds [0 ].author .name .startswith ("Persistent Note" ):
669
703
tasks += [self .bot .api .delete_note (message1 .id )]
670
704
if tasks :
@@ -750,16 +784,18 @@ async def reply(
750
784
)
751
785
)
752
786
787
+ user_msg_tasks = []
753
788
tasks = []
754
789
755
- try :
756
- user_msg = await self .send (
757
- message ,
758
- destination = self .recipient ,
759
- from_mod = True ,
760
- anonymous = anonymous ,
761
- plain = plain ,
790
+ for user in self .recipients :
791
+ user_msg_tasks .append (
792
+ self .send (
793
+ message , destination = user , from_mod = True , anonymous = anonymous , plain = plain ,
794
+ )
762
795
)
796
+
797
+ try :
798
+ user_msg = await asyncio .gather (* user_msg_tasks )
763
799
except Exception as e :
764
800
logger .error ("Message delivery failed:" , exc_info = True )
765
801
if isinstance (e , discord .Forbidden ):
@@ -1063,9 +1099,23 @@ def get_notifications(self) -> str:
1063
1099
1064
1100
return " " .join (mentions )
1065
1101
1066
- async def set_title (self , title ) -> None :
1102
+ async def set_title (self , title : str ) -> None :
1103
+ user_id = match_user_id (self .channel .topic )
1104
+ ids = "," .join (i .id for i in self ._other_recipients )
1105
+
1106
+ await self .channel .edit (
1107
+ topic = f"Title: { title } \n User ID: { user_id } \n Other Recipients: { ids } "
1108
+ )
1109
+
1110
+ async def add_user (self , user : typing .Union [discord .Member , discord .User ]) -> None :
1111
+ title = match_title (self .channel .topic )
1067
1112
user_id = match_user_id (self .channel .topic )
1068
- await self .channel .edit (topic = f"Title: { title } \n User ID: { user_id } " )
1113
+ self ._other_recipients .append (user )
1114
+
1115
+ ids = "," .join (str (i .id ) for i in self ._other_recipients )
1116
+ await self .channel .edit (
1117
+ topic = f"Title: { title } \n User ID: { user_id } \n Other Recipients: { ids } "
1118
+ )
1069
1119
1070
1120
1071
1121
class ThreadManager :
@@ -1127,11 +1177,13 @@ async def find(
1127
1177
await thread .close (closer = self .bot .user , silent = True , delete_channel = False )
1128
1178
thread = None
1129
1179
else :
1130
- channel = discord .utils .get (
1131
- self .bot .modmail_guild .text_channels , topic = f"User ID: { recipient_id } "
1180
+ channel = discord .utils .find (
1181
+ lambda x : str (recipient_id ) in x .topic if x .topic else False ,
1182
+ self .bot .modmail_guild .text_channels ,
1132
1183
)
1184
+
1133
1185
if channel :
1134
- thread = Thread (self , recipient or recipient_id , channel )
1186
+ thread = await Thread . from_channel (self , channel )
1135
1187
if thread .recipient :
1136
1188
# only save if data is valid
1137
1189
self .cache [recipient_id ] = thread
@@ -1161,10 +1213,14 @@ async def _find_from_channel(self, channel):
1161
1213
except discord .NotFound :
1162
1214
recipient = None
1163
1215
1216
+ other_recipients = match_other_recipients (channel .topic )
1217
+ for n , uid in enumerate (other_recipients ):
1218
+ other_recipients [n ] = self .bot .get_user (uid ) or await self .bot .fetch_user (uid )
1219
+
1164
1220
if recipient is None :
1165
- thread = Thread (self , user_id , channel )
1221
+ thread = Thread (self , user_id , channel , other_recipients )
1166
1222
else :
1167
- self .cache [user_id ] = thread = Thread (self , recipient , channel )
1223
+ self .cache [user_id ] = thread = Thread (self , recipient , channel , other_recipients )
1168
1224
thread .ready = True
1169
1225
1170
1226
return thread
0 commit comments