Skip to content

Support resolving keyed services from DI in RDF and RDG #50093

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ internal static string EmitParameterPreparation(this IEnumerable<EndpointParamet
case EndpointParameterSource.Service:
parameter.EmitServiceParameterPreparation(parameterPreparationBuilder);
break;
case EndpointParameterSource.KeyedService:
parameter.EmitKeyedServiceParameterPreparation(parameterPreparationBuilder);
break;
case EndpointParameterSource.AsParameters:
parameter.EmitAsParametersParameterPreparation(parameterPreparationBuilder, emitterContext);
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,16 @@ internal static void EmitServiceParameterPreparation(this EndpointParameter endp
codeWriter.WriteLine($"var {endpointParameter.EmitHandlerArgument()} = {assigningCode};");
}

internal static void EmitKeyedServiceParameterPreparation(this EndpointParameter endpointParameter, CodeWriter codeWriter)
{
codeWriter.WriteLine(endpointParameter.EmitParameterDiagnosticComment());

var assigningCode = endpointParameter.IsOptional ?
$"httpContext.RequestServices.GetKeyedService<{endpointParameter.Type}>({endpointParameter.AssigningCode});" :
$"httpContext.RequestServices.GetRequiredKeyedService<{endpointParameter.Type}>({endpointParameter.AssigningCode})";
codeWriter.WriteLine($"var {endpointParameter.EmitHandlerArgument()} = {assigningCode};");
}

internal static void EmitAsParametersParameterPreparation(this EndpointParameter endpointParameter, CodeWriter codeWriter, EmitterContext emitterContext)
{
codeWriter.WriteLine(endpointParameter.EmitParameterDiagnosticComment());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ private void ProcessEndpointParameterSource(Endpoint endpoint, ISymbol symbol, I
{
Source = EndpointParameterSource.Service;
}
else if (attributes.TryGetAttribute(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work with derived types?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you referring to if the user is using a derived version of FromKeyedServices attribute? If so, the answer is no but not sure that is a scenario we want to support?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is what I'm referring to. The type is unsealed by design I believe.

{
Source = EndpointParameterSource.KeyedService;
var constructorArgument = keyedServicesAttribute.ConstructorArguments.FirstOrDefault();
AssigningCode = constructorArgument.IsNull ? string.Empty : SymbolDisplay.FormatPrimitive(constructorArgument.Value!, true, true);
}
else if (attributes.HasAttribute(wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_AsParametersAttribute)))
{
Source = EndpointParameterSource.AsParameters;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ internal enum EndpointParameterSource
JsonBodyOrService,
FormBody,
Service,
KeyedService,
// SpecialType refers to HttpContext, HttpRequest, CancellationToken, Stream, etc...
// that are specially checked for in RequestDelegateFactory.CreateArgument()
SpecialType,
Expand Down
22 changes: 22 additions & 0 deletions src/Http/Http.Extensions/src/RequestDelegateFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ public static partial class RequestDelegateFactory
private static readonly MethodInfo ExecuteAwaitedReturnMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteAwaitedReturn), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo GetRequiredServiceMethod = typeof(ServiceProviderServiceExtensions).GetMethod(nameof(ServiceProviderServiceExtensions.GetRequiredService), BindingFlags.Public | BindingFlags.Static, new Type[] { typeof(IServiceProvider) })!;
private static readonly MethodInfo GetServiceMethod = typeof(ServiceProviderServiceExtensions).GetMethod(nameof(ServiceProviderServiceExtensions.GetService), BindingFlags.Public | BindingFlags.Static, new Type[] { typeof(IServiceProvider) })!;
private static readonly MethodInfo GetRequiredKeyedServiceMethod = typeof(ServiceProviderKeyedServiceExtensions).GetMethod(nameof(ServiceProviderKeyedServiceExtensions.GetRequiredKeyedService), BindingFlags.Public | BindingFlags.Static, new Type[] { typeof(IServiceProvider), typeof(object) })!;
private static readonly MethodInfo GetKeyedServiceMethod = typeof(ServiceProviderKeyedServiceExtensions).GetMethod(nameof(ServiceProviderKeyedServiceExtensions.GetKeyedService), BindingFlags.Public | BindingFlags.Static, new Type[] { typeof(IServiceProvider), typeof(object) })!;
private static readonly MethodInfo ResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteResultWriteResponse), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo StringResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteWriteStringResponseAsync), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo StringIsNullOrEmptyMethod = typeof(string).GetMethod(nameof(string.IsNullOrEmpty), BindingFlags.Static | BindingFlags.Public)!;
Expand Down Expand Up @@ -764,6 +766,11 @@ private static Expression CreateArgument(ParameterInfo parameter, RequestDelegat
factoryContext.TrackedParameters.Add(parameter.Name, RequestDelegateFactoryConstants.ServiceAttribute);
return BindParameterFromService(parameter, factoryContext);
}
else if (parameterCustomAttributes.OfType<FromKeyedServicesAttribute>().FirstOrDefault() is { } keyedServicesAttribute)
{
var key = keyedServicesAttribute.Key;
return BindParameterFromKeyedService(parameter, key, factoryContext);
}
else if (parameterCustomAttributes.OfType<AsParametersAttribute>().Any())
{
if (parameter is PropertyAsParameterInfo)
Expand Down Expand Up @@ -1563,6 +1570,21 @@ private static Expression BindParameterFromService(ParameterInfo parameter, Requ
return Expression.Call(GetRequiredServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr);
}

private static Expression BindParameterFromKeyedService(ParameterInfo parameter, object key, RequestDelegateFactoryContext factoryContext)
{
var isOptional = IsOptionalParameter(parameter, factoryContext);

if (isOptional)
{
return Expression.Call(GetKeyedServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr, Expression.Convert(
Expression.Constant(key),
typeof(object)));
}
return Expression.Call(GetRequiredKeyedServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr, Expression.Convert(
Expression.Constant(key),
typeof(object)));
}

private static Expression BindParameterFromValue(ParameterInfo parameter, Expression valueExpression, RequestDelegateFactoryContext factoryContext, string source)
{
if (parameter.ParameterType == typeof(string) || parameter.ParameterType == typeof(string[])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ namespace Microsoft.AspNetCore.Http.Generated
private static readonly JsonOptions FallbackJsonOptions = new();
private static readonly string[] GetVerb = new[] { global::Microsoft.AspNetCore.Http.HttpMethods.Get };

[InterceptsLocation(@"TestMapActions.cs", 25, 13)]
[InterceptsLocation(@"TestMapActions.cs", 26, 5)]
[InterceptsLocation(@"TestMapActions.cs", 26, 13)]
[InterceptsLocation(@"TestMapActions.cs", 27, 5)]
internal static RouteHandlerBuilder MapGet0(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -166,8 +166,8 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 28, 5)]
[InterceptsLocation(@"TestMapActions.cs", 29, 5)]
[InterceptsLocation(@"TestMapActions.cs", 30, 5)]
internal static RouteHandlerBuilder MapGet1(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ namespace Microsoft.AspNetCore.Http.Generated
private static readonly JsonOptions FallbackJsonOptions = new();
private static readonly string[] GetVerb = new[] { global::Microsoft.AspNetCore.Http.HttpMethods.Get };

[InterceptsLocation(@"TestMapActions.cs", 25, 13)]
[InterceptsLocation(@"TestMapActions.cs", 26, 13)]
internal static RouteHandlerBuilder MapGet0(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -167,7 +167,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 26, 5)]
[InterceptsLocation(@"TestMapActions.cs", 27, 5)]
internal static RouteHandlerBuilder MapGet1(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -262,7 +262,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 27, 5)]
[InterceptsLocation(@"TestMapActions.cs", 28, 5)]
internal static RouteHandlerBuilder MapGet2(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -359,7 +359,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 28, 5)]
[InterceptsLocation(@"TestMapActions.cs", 29, 5)]
internal static RouteHandlerBuilder MapGet3(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -454,7 +454,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 29, 5)]
[InterceptsLocation(@"TestMapActions.cs", 30, 5)]
internal static RouteHandlerBuilder MapGet4(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -563,7 +563,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 30, 5)]
[InterceptsLocation(@"TestMapActions.cs", 31, 5)]
internal static RouteHandlerBuilder MapGet5(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -658,7 +658,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 31, 5)]
[InterceptsLocation(@"TestMapActions.cs", 32, 5)]
internal static RouteHandlerBuilder MapGet6(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -755,7 +755,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 32, 5)]
[InterceptsLocation(@"TestMapActions.cs", 33, 5)]
internal static RouteHandlerBuilder MapGet7(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -850,7 +850,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 33, 5)]
[InterceptsLocation(@"TestMapActions.cs", 34, 5)]
internal static RouteHandlerBuilder MapGet8(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -958,7 +958,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 34, 5)]
[InterceptsLocation(@"TestMapActions.cs", 35, 5)]
internal static RouteHandlerBuilder MapGet9(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -1052,7 +1052,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 35, 5)]
[InterceptsLocation(@"TestMapActions.cs", 36, 5)]
internal static RouteHandlerBuilder MapGet10(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -1148,7 +1148,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 36, 5)]
[InterceptsLocation(@"TestMapActions.cs", 37, 5)]
internal static RouteHandlerBuilder MapGet11(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -1242,7 +1242,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 37, 5)]
[InterceptsLocation(@"TestMapActions.cs", 38, 5)]
internal static RouteHandlerBuilder MapGet12(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -1351,7 +1351,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 38, 5)]
[InterceptsLocation(@"TestMapActions.cs", 39, 5)]
internal static RouteHandlerBuilder MapGet13(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -1446,7 +1446,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 39, 5)]
[InterceptsLocation(@"TestMapActions.cs", 40, 5)]
internal static RouteHandlerBuilder MapGet14(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -1554,7 +1554,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 40, 5)]
[InterceptsLocation(@"TestMapActions.cs", 41, 5)]
internal static RouteHandlerBuilder MapGet15(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -1648,7 +1648,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 41, 5)]
[InterceptsLocation(@"TestMapActions.cs", 42, 5)]
internal static RouteHandlerBuilder MapGet16(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -1757,7 +1757,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 42, 5)]
[InterceptsLocation(@"TestMapActions.cs", 43, 5)]
internal static RouteHandlerBuilder MapGet17(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -1852,7 +1852,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 43, 5)]
[InterceptsLocation(@"TestMapActions.cs", 44, 5)]
internal static RouteHandlerBuilder MapGet18(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -1960,7 +1960,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 44, 5)]
[InterceptsLocation(@"TestMapActions.cs", 45, 5)]
internal static RouteHandlerBuilder MapGet19(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ namespace Microsoft.AspNetCore.Http.Generated
private static readonly JsonOptions FallbackJsonOptions = new();
private static readonly string[] PostVerb = new[] { global::Microsoft.AspNetCore.Http.HttpMethods.Post };

[InterceptsLocation(@"TestMapActions.cs", 25, 13)]
[InterceptsLocation(@"TestMapActions.cs", 26, 13)]
internal static RouteHandlerBuilder MapPost0(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -162,7 +162,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 26, 5)]
[InterceptsLocation(@"TestMapActions.cs", 27, 5)]
internal static RouteHandlerBuilder MapPost1(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ namespace Microsoft.AspNetCore.Http.Generated
private static readonly JsonOptions FallbackJsonOptions = new();
private static readonly string[] GetVerb = new[] { global::Microsoft.AspNetCore.Http.HttpMethods.Get };

[InterceptsLocation(@"TestMapActions.cs", 25, 13)]
[InterceptsLocation(@"TestMapActions.cs", 26, 13)]
internal static RouteHandlerBuilder MapGet0(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ namespace Microsoft.AspNetCore.Http.Generated
private static readonly JsonOptions FallbackJsonOptions = new();
private static readonly string[] GetVerb = new[] { global::Microsoft.AspNetCore.Http.HttpMethods.Get };

[InterceptsLocation(@"TestMapActions.cs", 25, 13)]
[InterceptsLocation(@"TestMapActions.cs", 26, 13)]
internal static RouteHandlerBuilder MapGet0(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ namespace Microsoft.AspNetCore.Http.Generated
private static readonly JsonOptions FallbackJsonOptions = new();
private static readonly string[] GetVerb = new[] { global::Microsoft.AspNetCore.Http.HttpMethods.Get };

[InterceptsLocation(@"TestMapActions.cs", 25, 13)]
[InterceptsLocation(@"TestMapActions.cs", 26, 13)]
internal static RouteHandlerBuilder MapGet0(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ namespace Microsoft.AspNetCore.Http.Generated
private static readonly JsonOptions FallbackJsonOptions = new();
private static readonly string[] GetVerb = new[] { global::Microsoft.AspNetCore.Http.HttpMethods.Get };

[InterceptsLocation(@"TestMapActions.cs", 25, 13)]
[InterceptsLocation(@"TestMapActions.cs", 26, 13)]
internal static RouteHandlerBuilder MapGet0(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ namespace Microsoft.AspNetCore.Http.Generated
private static readonly JsonOptions FallbackJsonOptions = new();
private static readonly string[] GetVerb = new[] { global::Microsoft.AspNetCore.Http.HttpMethods.Get };

[InterceptsLocation(@"TestMapActions.cs", 25, 13)]
[InterceptsLocation(@"TestMapActions.cs", 26, 13)]
internal static RouteHandlerBuilder MapGet0(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ namespace Microsoft.AspNetCore.Http.Generated
private static readonly JsonOptions FallbackJsonOptions = new();
private static readonly string[] GetVerb = new[] { global::Microsoft.AspNetCore.Http.HttpMethods.Get };

[InterceptsLocation(@"TestMapActions.cs", 25, 13)]
[InterceptsLocation(@"TestMapActions.cs", 26, 13)]
internal static RouteHandlerBuilder MapGet0(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down
Loading