Skip to content

Commit cac76be

Browse files
authored
Merge pull request #1048 from zachpainter77/master
Add Support For Generic Handlers With Multiple Generic Type Parameters
2 parents 3b8bf44 + 811ce54 commit cac76be

File tree

7 files changed

+1129
-488
lines changed

7 files changed

+1129
-488
lines changed

src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs

+480-455
Large diffs are not rendered by default.

src/MediatR/MicrosoftExtensionsDI/ServiceCollectionExtensions.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ public static IServiceCollection AddMediatR(this IServiceCollection services,
4747
throw new ArgumentException("No assemblies found to scan. Supply at least one assembly to scan for handlers.");
4848
}
4949

50-
ServiceRegistrar.AddMediatRClasses(services, configuration);
50+
ServiceRegistrar.SetGenericRequestHandlerRegistrationLimitations(configuration);
51+
52+
ServiceRegistrar.AddMediatRClassesWithTimeout(services, configuration);
5153

5254
ServiceRegistrar.AddRequiredServices(services, configuration);
5355

src/MediatR/Registration/ServiceRegistrar.cs

+111-21
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,50 @@
22
using System.Collections.Generic;
33
using System.Linq;
44
using System.Reflection;
5+
using System.Threading;
56
using MediatR.Pipeline;
67
using Microsoft.Extensions.DependencyInjection;
78
using Microsoft.Extensions.DependencyInjection.Extensions;
89

910
namespace MediatR.Registration;
1011

1112
public static class ServiceRegistrar
12-
{
13-
public static void AddMediatRClasses(IServiceCollection services, MediatRServiceConfiguration configuration)
14-
{
13+
{
14+
private static int MaxGenericTypeParameters;
15+
private static int MaxTypesClosing;
16+
private static int MaxGenericTypeRegistrations;
17+
private static int RegistrationTimeout;
18+
19+
public static void SetGenericRequestHandlerRegistrationLimitations(MediatRServiceConfiguration configuration)
20+
{
21+
MaxGenericTypeParameters = configuration.MaxGenericTypeParameters;
22+
MaxTypesClosing = configuration.MaxTypesClosing;
23+
MaxGenericTypeRegistrations = configuration.MaxGenericTypeRegistrations;
24+
RegistrationTimeout = configuration.RegistrationTimeout;
25+
}
26+
27+
public static void AddMediatRClassesWithTimeout(IServiceCollection services, MediatRServiceConfiguration configuration)
28+
{
29+
using(var cts = new CancellationTokenSource(RegistrationTimeout))
30+
{
31+
try
32+
{
33+
AddMediatRClasses(services, configuration, cts.Token);
34+
}
35+
catch (OperationCanceledException)
36+
{
37+
throw new TimeoutException("The generic handler registration process timed out.");
38+
}
39+
}
40+
}
41+
42+
public static void AddMediatRClasses(IServiceCollection services, MediatRServiceConfiguration configuration, CancellationToken cancellationToken = default)
43+
{
44+
1545
var assembliesToScan = configuration.AssembliesToRegister.Distinct().ToArray();
1646

17-
ConnectImplementationsToTypesClosing(typeof(IRequestHandler<,>), services, assembliesToScan, false, configuration);
18-
ConnectImplementationsToTypesClosing(typeof(IRequestHandler<>), services, assembliesToScan, false, configuration);
47+
ConnectImplementationsToTypesClosing(typeof(IRequestHandler<,>), services, assembliesToScan, false, configuration, cancellationToken);
48+
ConnectImplementationsToTypesClosing(typeof(IRequestHandler<>), services, assembliesToScan, false, configuration, cancellationToken);
1949
ConnectImplementationsToTypesClosing(typeof(INotificationHandler<>), services, assembliesToScan, true, configuration);
2050
ConnectImplementationsToTypesClosing(typeof(IStreamRequestHandler<,>), services, assembliesToScan, false, configuration);
2151
ConnectImplementationsToTypesClosing(typeof(IRequestExceptionHandler<,,>), services, assembliesToScan, true, configuration);
@@ -63,7 +93,8 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa
6393
IServiceCollection services,
6494
IEnumerable<Assembly> assembliesToScan,
6595
bool addIfAlreadyExists,
66-
MediatRServiceConfiguration configuration)
96+
MediatRServiceConfiguration configuration,
97+
CancellationToken cancellationToken = default)
6798
{
6899
var concretions = new List<Type>();
69100
var interfaces = new List<Type>();
@@ -72,9 +103,10 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa
72103

73104
var types = assembliesToScan
74105
.SelectMany(a => a.DefinedTypes)
106+
.Where(t => !t.ContainsGenericParameters || configuration.RegisterGenericHandlers)
75107
.Where(t => t.IsConcrete() && t.FindInterfacesThatClose(openRequestInterface).Any())
76108
.Where(configuration.TypeEvaluator)
77-
.ToList();
109+
.ToList();
78110

79111
foreach (var type in types)
80112
{
@@ -131,7 +163,7 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa
131163
foreach (var @interface in genericInterfaces)
132164
{
133165
var exactMatches = genericConcretions.Where(x => x.CanBeCastTo(@interface)).ToList();
134-
AddAllConcretionsThatClose(@interface, exactMatches, services, assembliesToScan);
166+
AddAllConcretionsThatClose(@interface, exactMatches, services, assembliesToScan, cancellationToken);
135167
}
136168
}
137169

@@ -174,7 +206,7 @@ private static void AddConcretionsThatCouldBeClosed(Type @interface, List<Type>
174206

175207
private static (Type Service, Type Implementation) GetConcreteRegistrationTypes(Type openRequestHandlerInterface, Type concreteGenericTRequest, Type openRequestHandlerImplementation)
176208
{
177-
var closingType = concreteGenericTRequest.GetGenericArguments().First();
209+
var closingTypes = concreteGenericTRequest.GetGenericArguments();
178210

179211
var concreteTResponse = concreteGenericTRequest.GetInterfaces()
180212
.FirstOrDefault(x => x.IsGenericType && x.GetGenericTypeDefinition() == typeof(IRequest<>))
@@ -187,33 +219,90 @@ private static (Type Service, Type Implementation) GetConcreteRegistrationTypes(
187219
typeDefinition.MakeGenericType(concreteGenericTRequest, concreteTResponse) :
188220
typeDefinition.MakeGenericType(concreteGenericTRequest);
189221

190-
return (serviceType, openRequestHandlerImplementation.MakeGenericType(closingType));
222+
return (serviceType, openRequestHandlerImplementation.MakeGenericType(closingTypes));
191223
}
192224

193-
private static List<Type>? GetConcreteRequestTypes(Type openRequestHandlerInterface, Type openRequestHandlerImplementation, IEnumerable<Assembly> assembliesToScan)
225+
private static List<Type>? GetConcreteRequestTypes(Type openRequestHandlerInterface, Type openRequestHandlerImplementation, IEnumerable<Assembly> assembliesToScan, CancellationToken cancellationToken)
194226
{
195-
var constraints = openRequestHandlerImplementation.GetGenericArguments().First().GetGenericParameterConstraints();
196-
197-
var typesThatCanClose = assembliesToScan
198-
.SelectMany(assembly => assembly.GetTypes())
199-
.Where(type => type.IsClass && !type.IsAbstract && constraints.All(constraint => constraint.IsAssignableFrom(type)))
200-
.ToList();
227+
//request generic type constraints
228+
var constraintsForEachParameter = openRequestHandlerImplementation
229+
.GetGenericArguments()
230+
.Select(x => x.GetGenericParameterConstraints())
231+
.ToList();
232+
233+
if (constraintsForEachParameter.Count > 2 && constraintsForEachParameter.Any(constraints => !constraints.Where(x => x.IsInterface || x.IsClass).Any()))
234+
throw new ArgumentException($"Error registering the generic handler type: {openRequestHandlerImplementation.FullName}. When registering generic requests with more than two type parameters, each type parameter must have at least one constraint of type interface or class.");
235+
236+
var typesThatCanCloseForEachParameter = constraintsForEachParameter
237+
.Select(constraints => assembliesToScan
238+
.SelectMany(assembly => assembly.GetTypes())
239+
.Where(type => type.IsClass && !type.IsAbstract && constraints.All(constraint => constraint.IsAssignableFrom(type))).ToList()
240+
).ToList();
201241

202242
var requestType = openRequestHandlerInterface.GenericTypeArguments.First();
203243

204244
if (requestType.IsGenericParameter)
205245
return null;
206246

207247
var requestGenericTypeDefinition = requestType.GetGenericTypeDefinition();
248+
249+
var combinations = GenerateCombinations(requestType, typesThatCanCloseForEachParameter, 0, cancellationToken);
250+
251+
return combinations.Select(types => requestGenericTypeDefinition.MakeGenericType(types.ToArray())).ToList();
252+
}
253+
254+
// Method to generate combinations recursively
255+
public static List<List<Type>> GenerateCombinations(Type requestType, List<List<Type>> lists, int depth = 0, CancellationToken cancellationToken = default)
256+
{
257+
if (depth == 0)
258+
{
259+
// Initial checks
260+
if (MaxGenericTypeParameters > 0 && lists.Count > MaxGenericTypeParameters)
261+
throw new ArgumentException($"Error registering the generic type: {requestType.FullName}. The number of generic type parameters exceeds the maximum allowed ({MaxGenericTypeParameters}).");
262+
263+
foreach (var list in lists)
264+
{
265+
if (MaxTypesClosing > 0 && list.Count > MaxTypesClosing)
266+
throw new ArgumentException($"Error registering the generic type: {requestType.FullName}. One of the generic type parameter's count of types that can close exceeds the maximum length allowed ({MaxTypesClosing}).");
267+
}
268+
269+
// Calculate the total number of combinations
270+
long totalCombinations = 1;
271+
foreach (var list in lists)
272+
{
273+
totalCombinations *= list.Count;
274+
if (MaxGenericTypeParameters > 0 && totalCombinations > MaxGenericTypeRegistrations)
275+
throw new ArgumentException($"Error registering the generic type: {requestType.FullName}. The total number of generic type registrations exceeds the maximum allowed ({MaxGenericTypeRegistrations}).");
276+
}
277+
}
278+
279+
if (depth >= lists.Count)
280+
return new List<List<Type>> { new List<Type>() };
281+
282+
cancellationToken.ThrowIfCancellationRequested();
208283

209-
return typesThatCanClose.Select(type => requestGenericTypeDefinition.MakeGenericType(type)).ToList();
284+
var currentList = lists[depth];
285+
var childCombinations = GenerateCombinations(requestType, lists, depth + 1, cancellationToken);
286+
var combinations = new List<List<Type>>();
287+
288+
foreach (var item in currentList)
289+
{
290+
foreach (var childCombination in childCombinations)
291+
{
292+
var currentCombination = new List<Type> { item };
293+
currentCombination.AddRange(childCombination);
294+
combinations.Add(currentCombination);
295+
}
296+
}
297+
298+
return combinations;
210299
}
211300

212-
private static void AddAllConcretionsThatClose(Type openRequestInterface, List<Type> concretions, IServiceCollection services, IEnumerable<Assembly> assembliesToScan)
301+
private static void AddAllConcretionsThatClose(Type openRequestInterface, List<Type> concretions, IServiceCollection services, IEnumerable<Assembly> assembliesToScan, CancellationToken cancellationToken)
213302
{
214303
foreach (var concretion in concretions)
215-
{
216-
var concreteRequests = GetConcreteRequestTypes(openRequestInterface, concretion, assembliesToScan);
304+
{
305+
var concreteRequests = GetConcreteRequestTypes(openRequestInterface, concretion, assembliesToScan, cancellationToken);
217306

218307
if (concreteRequests is null)
219308
continue;
@@ -223,6 +312,7 @@ private static void AddAllConcretionsThatClose(Type openRequestInterface, List<T
223312

224313
foreach (var (Service, Implementation) in registrationTypes)
225314
{
315+
cancellationToken.ThrowIfCancellationRequested();
226316
services.AddTransient(Service, Implementation);
227317
}
228318
}

0 commit comments

Comments
 (0)