2
2
using System . Collections . Generic ;
3
3
using System . Linq ;
4
4
using System . Reflection ;
5
+ using System . Threading ;
5
6
using MediatR . Pipeline ;
6
7
using Microsoft . Extensions . DependencyInjection ;
7
8
using Microsoft . Extensions . DependencyInjection . Extensions ;
8
9
9
10
namespace MediatR . Registration ;
10
11
11
12
public static class ServiceRegistrar
12
- {
13
- public static void AddMediatRClasses ( IServiceCollection services , MediatRServiceConfiguration configuration )
14
- {
13
+ {
14
+ private static int MaxGenericTypeParameters ;
15
+ private static int MaxTypesClosing ;
16
+ private static int MaxGenericTypeRegistrations ;
17
+ private static int RegistrationTimeout ;
18
+
19
+ public static void SetGenericRequestHandlerRegistrationLimitations ( MediatRServiceConfiguration configuration )
20
+ {
21
+ MaxGenericTypeParameters = configuration . MaxGenericTypeParameters ;
22
+ MaxTypesClosing = configuration . MaxTypesClosing ;
23
+ MaxGenericTypeRegistrations = configuration . MaxGenericTypeRegistrations ;
24
+ RegistrationTimeout = configuration . RegistrationTimeout ;
25
+ }
26
+
27
+ public static void AddMediatRClassesWithTimeout ( IServiceCollection services , MediatRServiceConfiguration configuration )
28
+ {
29
+ using ( var cts = new CancellationTokenSource ( RegistrationTimeout ) )
30
+ {
31
+ try
32
+ {
33
+ AddMediatRClasses ( services , configuration , cts . Token ) ;
34
+ }
35
+ catch ( OperationCanceledException )
36
+ {
37
+ throw new TimeoutException ( "The generic handler registration process timed out." ) ;
38
+ }
39
+ }
40
+ }
41
+
42
+ public static void AddMediatRClasses ( IServiceCollection services , MediatRServiceConfiguration configuration , CancellationToken cancellationToken = default )
43
+ {
44
+
15
45
var assembliesToScan = configuration . AssembliesToRegister . Distinct ( ) . ToArray ( ) ;
16
46
17
- ConnectImplementationsToTypesClosing ( typeof ( IRequestHandler < , > ) , services , assembliesToScan , false , configuration ) ;
18
- ConnectImplementationsToTypesClosing ( typeof ( IRequestHandler < > ) , services , assembliesToScan , false , configuration ) ;
47
+ ConnectImplementationsToTypesClosing ( typeof ( IRequestHandler < , > ) , services , assembliesToScan , false , configuration , cancellationToken ) ;
48
+ ConnectImplementationsToTypesClosing ( typeof ( IRequestHandler < > ) , services , assembliesToScan , false , configuration , cancellationToken ) ;
19
49
ConnectImplementationsToTypesClosing ( typeof ( INotificationHandler < > ) , services , assembliesToScan , true , configuration ) ;
20
50
ConnectImplementationsToTypesClosing ( typeof ( IStreamRequestHandler < , > ) , services , assembliesToScan , false , configuration ) ;
21
51
ConnectImplementationsToTypesClosing ( typeof ( IRequestExceptionHandler < , , > ) , services , assembliesToScan , true , configuration ) ;
@@ -63,7 +93,8 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa
63
93
IServiceCollection services ,
64
94
IEnumerable < Assembly > assembliesToScan ,
65
95
bool addIfAlreadyExists ,
66
- MediatRServiceConfiguration configuration )
96
+ MediatRServiceConfiguration configuration ,
97
+ CancellationToken cancellationToken = default )
67
98
{
68
99
var concretions = new List < Type > ( ) ;
69
100
var interfaces = new List < Type > ( ) ;
@@ -72,9 +103,10 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa
72
103
73
104
var types = assembliesToScan
74
105
. SelectMany ( a => a . DefinedTypes )
106
+ . Where ( t => ! t . ContainsGenericParameters || configuration . RegisterGenericHandlers )
75
107
. Where ( t => t . IsConcrete ( ) && t . FindInterfacesThatClose ( openRequestInterface ) . Any ( ) )
76
108
. Where ( configuration . TypeEvaluator )
77
- . ToList ( ) ;
109
+ . ToList ( ) ;
78
110
79
111
foreach ( var type in types )
80
112
{
@@ -131,7 +163,7 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa
131
163
foreach ( var @interface in genericInterfaces )
132
164
{
133
165
var exactMatches = genericConcretions . Where ( x => x . CanBeCastTo ( @interface ) ) . ToList ( ) ;
134
- AddAllConcretionsThatClose ( @interface , exactMatches , services , assembliesToScan ) ;
166
+ AddAllConcretionsThatClose ( @interface , exactMatches , services , assembliesToScan , cancellationToken ) ;
135
167
}
136
168
}
137
169
@@ -174,7 +206,7 @@ private static void AddConcretionsThatCouldBeClosed(Type @interface, List<Type>
174
206
175
207
private static ( Type Service , Type Implementation ) GetConcreteRegistrationTypes ( Type openRequestHandlerInterface , Type concreteGenericTRequest , Type openRequestHandlerImplementation )
176
208
{
177
- var closingType = concreteGenericTRequest . GetGenericArguments ( ) . First ( ) ;
209
+ var closingTypes = concreteGenericTRequest . GetGenericArguments ( ) ;
178
210
179
211
var concreteTResponse = concreteGenericTRequest . GetInterfaces ( )
180
212
. FirstOrDefault ( x => x . IsGenericType && x . GetGenericTypeDefinition ( ) == typeof ( IRequest < > ) )
@@ -187,33 +219,90 @@ private static (Type Service, Type Implementation) GetConcreteRegistrationTypes(
187
219
typeDefinition . MakeGenericType ( concreteGenericTRequest , concreteTResponse ) :
188
220
typeDefinition . MakeGenericType ( concreteGenericTRequest ) ;
189
221
190
- return ( serviceType , openRequestHandlerImplementation . MakeGenericType ( closingType ) ) ;
222
+ return ( serviceType , openRequestHandlerImplementation . MakeGenericType ( closingTypes ) ) ;
191
223
}
192
224
193
- private static List < Type > ? GetConcreteRequestTypes ( Type openRequestHandlerInterface , Type openRequestHandlerImplementation , IEnumerable < Assembly > assembliesToScan )
225
+ private static List < Type > ? GetConcreteRequestTypes ( Type openRequestHandlerInterface , Type openRequestHandlerImplementation , IEnumerable < Assembly > assembliesToScan , CancellationToken cancellationToken )
194
226
{
195
- var constraints = openRequestHandlerImplementation . GetGenericArguments ( ) . First ( ) . GetGenericParameterConstraints ( ) ;
196
-
197
- var typesThatCanClose = assembliesToScan
198
- . SelectMany ( assembly => assembly . GetTypes ( ) )
199
- . Where ( type => type . IsClass && ! type . IsAbstract && constraints . All ( constraint => constraint . IsAssignableFrom ( type ) ) )
200
- . ToList ( ) ;
227
+ //request generic type constraints
228
+ var constraintsForEachParameter = openRequestHandlerImplementation
229
+ . GetGenericArguments ( )
230
+ . Select ( x => x . GetGenericParameterConstraints ( ) )
231
+ . ToList ( ) ;
232
+
233
+ if ( constraintsForEachParameter . Count > 2 && constraintsForEachParameter . Any ( constraints => ! constraints . Where ( x => x . IsInterface || x . IsClass ) . Any ( ) ) )
234
+ throw new ArgumentException ( $ "Error registering the generic handler type: { openRequestHandlerImplementation . FullName } . When registering generic requests with more than two type parameters, each type parameter must have at least one constraint of type interface or class.") ;
235
+
236
+ var typesThatCanCloseForEachParameter = constraintsForEachParameter
237
+ . Select ( constraints => assembliesToScan
238
+ . SelectMany ( assembly => assembly . GetTypes ( ) )
239
+ . Where ( type => type . IsClass && ! type . IsAbstract && constraints . All ( constraint => constraint . IsAssignableFrom ( type ) ) ) . ToList ( )
240
+ ) . ToList ( ) ;
201
241
202
242
var requestType = openRequestHandlerInterface . GenericTypeArguments . First ( ) ;
203
243
204
244
if ( requestType . IsGenericParameter )
205
245
return null ;
206
246
207
247
var requestGenericTypeDefinition = requestType . GetGenericTypeDefinition ( ) ;
248
+
249
+ var combinations = GenerateCombinations ( requestType , typesThatCanCloseForEachParameter , 0 , cancellationToken ) ;
250
+
251
+ return combinations . Select ( types => requestGenericTypeDefinition . MakeGenericType ( types . ToArray ( ) ) ) . ToList ( ) ;
252
+ }
253
+
254
+ // Method to generate combinations recursively
255
+ public static List < List < Type > > GenerateCombinations ( Type requestType , List < List < Type > > lists , int depth = 0 , CancellationToken cancellationToken = default )
256
+ {
257
+ if ( depth == 0 )
258
+ {
259
+ // Initial checks
260
+ if ( MaxGenericTypeParameters > 0 && lists . Count > MaxGenericTypeParameters )
261
+ throw new ArgumentException ( $ "Error registering the generic type: { requestType . FullName } . The number of generic type parameters exceeds the maximum allowed ({ MaxGenericTypeParameters } ).") ;
262
+
263
+ foreach ( var list in lists )
264
+ {
265
+ if ( MaxTypesClosing > 0 && list . Count > MaxTypesClosing )
266
+ throw new ArgumentException ( $ "Error registering the generic type: { requestType . FullName } . One of the generic type parameter's count of types that can close exceeds the maximum length allowed ({ MaxTypesClosing } ).") ;
267
+ }
268
+
269
+ // Calculate the total number of combinations
270
+ long totalCombinations = 1 ;
271
+ foreach ( var list in lists )
272
+ {
273
+ totalCombinations *= list . Count ;
274
+ if ( MaxGenericTypeParameters > 0 && totalCombinations > MaxGenericTypeRegistrations )
275
+ throw new ArgumentException ( $ "Error registering the generic type: { requestType . FullName } . The total number of generic type registrations exceeds the maximum allowed ({ MaxGenericTypeRegistrations } ).") ;
276
+ }
277
+ }
278
+
279
+ if ( depth >= lists . Count )
280
+ return new List < List < Type > > { new List < Type > ( ) } ;
281
+
282
+ cancellationToken . ThrowIfCancellationRequested ( ) ;
208
283
209
- return typesThatCanClose . Select ( type => requestGenericTypeDefinition . MakeGenericType ( type ) ) . ToList ( ) ;
284
+ var currentList = lists [ depth ] ;
285
+ var childCombinations = GenerateCombinations ( requestType , lists , depth + 1 , cancellationToken ) ;
286
+ var combinations = new List < List < Type > > ( ) ;
287
+
288
+ foreach ( var item in currentList )
289
+ {
290
+ foreach ( var childCombination in childCombinations )
291
+ {
292
+ var currentCombination = new List < Type > { item } ;
293
+ currentCombination . AddRange ( childCombination ) ;
294
+ combinations . Add ( currentCombination ) ;
295
+ }
296
+ }
297
+
298
+ return combinations ;
210
299
}
211
300
212
- private static void AddAllConcretionsThatClose ( Type openRequestInterface , List < Type > concretions , IServiceCollection services , IEnumerable < Assembly > assembliesToScan )
301
+ private static void AddAllConcretionsThatClose ( Type openRequestInterface , List < Type > concretions , IServiceCollection services , IEnumerable < Assembly > assembliesToScan , CancellationToken cancellationToken )
213
302
{
214
303
foreach ( var concretion in concretions )
215
- {
216
- var concreteRequests = GetConcreteRequestTypes ( openRequestInterface , concretion , assembliesToScan ) ;
304
+ {
305
+ var concreteRequests = GetConcreteRequestTypes ( openRequestInterface , concretion , assembliesToScan , cancellationToken ) ;
217
306
218
307
if ( concreteRequests is null )
219
308
continue ;
@@ -223,6 +312,7 @@ private static void AddAllConcretionsThatClose(Type openRequestInterface, List<T
223
312
224
313
foreach ( var ( Service , Implementation ) in registrationTypes )
225
314
{
315
+ cancellationToken . ThrowIfCancellationRequested ( ) ;
226
316
services . AddTransient ( Service , Implementation ) ;
227
317
}
228
318
}
0 commit comments