Merge pull request #57 from rhodon-jargon/interface-support

Add support for interfaces in generic methods
This commit is contained in:
Koen
2023-01-14 19:14:55 -05:00
committed by GitHub
7 changed files with 173 additions and 84 deletions

View File

@@ -25,12 +25,12 @@ namespace EntityFrameworkCore.Projectables.Extensions
yield return type;
}
public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo)
private static bool CanHaveOverridingMethod(this Type derivedType, MethodInfo methodInfo)
{
// We only need to search for virtual instance methods who are not declared on the derivedType
if (derivedType == methodInfo.DeclaringType || methodInfo.IsStatic || !methodInfo.IsVirtual)
{
return methodInfo;
return false;
}
if (!derivedType.IsAssignableTo(methodInfo.DeclaringType))
@@ -38,76 +38,115 @@ namespace EntityFrameworkCore.Projectables.Extensions
throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo));
}
var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
foreach (var derivedMethodInfo in derivedMethods)
{
if (HasCompatibleSignature(methodInfo, derivedMethodInfo))
{
return derivedMethodInfo;
}
}
// No derived methods were found. Return the original methodInfo
return methodInfo;
static bool HasCompatibleSignature(MethodInfo methodInfo, MethodInfo derivedMethodInfo)
{
if (methodInfo.Name != derivedMethodInfo.Name)
{
return false;
}
var methodParameters = methodInfo.GetParameters();
var derivedMethodParameters = derivedMethodInfo.GetParameters();
if (methodParameters.Length != derivedMethodParameters.Length)
{
return false;
}
// Match all parameters
for (var parameterIndex = 0; parameterIndex < methodParameters.Length; parameterIndex++)
{
var parameter = methodParameters[parameterIndex];
var derivedParameter = derivedMethodParameters[parameterIndex];
if (parameter.ParameterType.IsGenericParameter)
{
if (!derivedParameter.ParameterType.IsGenericParameter)
{
return false;
}
}
else
{
if (parameter.ParameterType != derivedParameter.ParameterType)
{
return false;
}
}
}
// Match the number of generic type arguments
if (methodInfo.IsGenericMethodDefinition)
{
var methodGenericParameters = methodInfo.GetGenericArguments();
if (!derivedMethodInfo.IsGenericMethodDefinition)
{
return false;
}
var derivedGenericArguments = derivedMethodInfo.GetGenericArguments();
if (methodGenericParameters.Length != derivedGenericArguments.Length)
{
return false;
}
}
return true;
}
private static int? GetOverridingMethodIndex(this MethodInfo methodInfo, MethodInfo[]? allDerivedMethods)
{
if (allDerivedMethods is { Length: > 0 })
{
var baseDefinition = methodInfo.GetBaseDefinition();
for (var i = 0; i < allDerivedMethods.Length; i++)
{
var derivedMethodInfo = allDerivedMethods[i];
if (derivedMethodInfo.GetBaseDefinition() == baseDefinition)
{
return i;
}
}
}
return null;
}
public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo)
{
if (!derivedType.CanHaveOverridingMethod(methodInfo))
{
return methodInfo;
}
var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
return methodInfo.GetOverridingMethodIndex(derivedMethods) is { } i
? derivedMethods[i]
// No derived methods were found. Return the original methodInfo
: methodInfo;
}
public static PropertyInfo GetOverridingProperty(this Type derivedType, PropertyInfo propertyInfo)
{
var accessor = propertyInfo.GetAccessors()[0];
if (!derivedType.CanHaveOverridingMethod(accessor))
{
return propertyInfo;
}
var derivedProperties = derivedType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
var derivedPropertyMethods = derivedProperties
.Select((Func<PropertyInfo, MethodInfo?>)
(propertyInfo.GetMethod == accessor ? p => p.GetMethod : p => p.SetMethod))
.OfType<MethodInfo>().ToArray();
return accessor.GetOverridingMethodIndex(derivedPropertyMethods) is { } i
? derivedProperties[i]
// No derived methods were found. Return the original methodInfo
: propertyInfo;
}
public static MethodInfo GetImplementingMethod(this Type derivedType, MethodInfo methodInfo)
{
var interfaceType = methodInfo.DeclaringType;
// We only need to search for interface methods
if (interfaceType?.IsInterface != true || derivedType.IsInterface || methodInfo.IsStatic || !methodInfo.IsVirtual)
{
return methodInfo;
}
if (!derivedType.IsAssignableTo(interfaceType))
{
throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo));
}
var interfaceMap = derivedType.GetInterfaceMap(interfaceType);
for (var i = 0; i < interfaceMap.InterfaceMethods.Length; i++)
{
if (interfaceMap.InterfaceMethods[i] == methodInfo)
{
return interfaceMap.TargetMethods[i];
}
}
throw new ApplicationException(
$"The interface map for {derivedType} doesn't contain the implemented method for {methodInfo}!");
}
public static PropertyInfo GetImplementingProperty(this Type derivedType, PropertyInfo propertyInfo)
{
var accessor = propertyInfo.GetAccessors()[0];
var implementingAccessor = derivedType.GetImplementingMethod(accessor);
if (implementingAccessor == accessor)
{
return propertyInfo;
}
var derivedProperties = derivedType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
return derivedProperties.First(propertyInfo.GetMethod == accessor
? p => p.GetMethod == implementingAccessor
: p => p.SetMethod == implementingAccessor);
}
public static MethodInfo GetConcreteMethod(this Type derivedType, MethodInfo methodInfo)
=> methodInfo.DeclaringType?.IsInterface == true
? derivedType.GetImplementingMethod(methodInfo)
: derivedType.GetOverridingMethod(methodInfo);
public static PropertyInfo GetConcreteProperty(this Type derivedType, PropertyInfo propertyInfo)
=> propertyInfo.DeclaringType?.IsInterface == true
? derivedType.GetImplementingProperty(propertyInfo)
: derivedType.GetOverridingProperty(propertyInfo);
}
}

View File

@@ -42,7 +42,7 @@ namespace EntityFrameworkCore.Projectables.Services
protected override Expression VisitMethodCall(MethodCallExpression node)
{
// Get the overriding methodInfo based on te type of the received of this expression
var methodInfo = node.Object?.Type.GetOverridingMethod(node.Method) ?? node.Method;
var methodInfo = node.Object?.Type.GetConcreteMethod(node.Method) ?? node.Method;
if (TryGetReflectedExpression(methodInfo, out var reflectedExpression))
{
@@ -74,16 +74,25 @@ namespace EntityFrameworkCore.Projectables.Services
protected override Expression VisitMember(MemberExpression node)
{
var nodeMember = node.Expression switch {
{ Type: { } } => node.Expression.Type.GetMember(node.Member.Name, node.Member.MemberType, BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static)[0],
var nodeExpression = node.Expression switch {
UnaryExpression { NodeType: ExpressionType.Convert, Type: { IsInterface: true } type, Operand: { } operand }
when type.IsAssignableFrom(operand.Type)
// This is an interface member. Operand contains the concrete (or at least more concrete) expression,
// from which we can try to find the concrete member.
=> operand,
_ => node.Expression
};
var nodeMember = node.Member switch {
PropertyInfo property when nodeExpression is not null
=> nodeExpression.Type.GetConcreteProperty(property),
_ => node.Member
};
if (TryGetReflectedExpression(nodeMember, out var reflectedExpression))
{
if (node.Expression is not null)
if (nodeExpression is not null)
{
_expressionArgumentReplacer.ParameterArgumentMapping.Add(reflectedExpression.Parameters[0], node.Expression);
_expressionArgumentReplacer.ParameterArgumentMapping.Add(reflectedExpression.Parameters[0], nodeExpression);
var updatedBody = _expressionArgumentReplacer.Visit(reflectedExpression.Body);
_expressionArgumentReplacer.ParameterArgumentMapping.Clear();

View File

@@ -1,2 +1,2 @@
SELECT 2
FROM [Concrete] AS [c]
FROM [MoreConcrete] AS [m]

View File

@@ -1,2 +1,2 @@
SELECT 2
FROM [Concrete] AS [c]
FROM [MoreConcrete] AS [m]

View File

@@ -20,7 +20,13 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests
[UsesVerify]
public class InheritedModelTests
{
public abstract class Base
public interface IBase
{
int ComputedProperty { get; }
int ComputedMethod();
}
public abstract class Base : IBase
{
public int Id { get; set; }
@@ -62,9 +68,9 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests
[Fact]
public Task ProjectOverInheritedPropertyImplementation()
{
using var dbContext = new SampleDbContext<Concrete>();
using var dbContext = new SampleDbContext<MoreConcrete>();
var query = dbContext.Set<Concrete>()
var query = dbContext.Set<MoreConcrete>()
.Select(x => x.ComputedProperty);
return Verifier.Verify(query.ToQueryString());
@@ -84,12 +90,43 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests
[Fact]
public Task ProjectOverInheritedMethodImplementation()
{
using var dbContext = new SampleDbContext<Concrete>();
using var dbContext = new SampleDbContext<MoreConcrete>();
var query = dbContext.Set<Concrete>()
var query = dbContext.Set<MoreConcrete>()
.Select(x => x.ComputedMethod());
return Verifier.Verify(query.ToQueryString());
}
[Fact]
public Task ProjectOverImplementedProperty()
{
using var dbContext = new SampleDbContext<Concrete>();
var query = dbContext.Set<Concrete>().SelectComputedProperty();
return Verifier.Verify(query.ToQueryString());
}
[Fact]
public Task ProjectOverImplementedMethod()
{
using var dbContext = new SampleDbContext<Concrete>();
var query = dbContext.Set<Concrete>().SelectComputedMethod();
return Verifier.Verify(query.ToQueryString());
}
}
public static class ModelExtensions
{
public static IQueryable<int> SelectComputedProperty<TConcrete>(this IQueryable<TConcrete> concretes)
where TConcrete : InheritedModelTests.IBase
=> concretes.Select(x => x.ComputedProperty);
public static IQueryable<int> SelectComputedMethod<TConcrete>(this IQueryable<TConcrete> concretes)
where TConcrete : InheritedModelTests.IBase
=> concretes.Select(x => x.ComputedMethod());
}
}