mirror of
https://github.com/zoriya/EntityFrameworkCore.Projectables.git
synced 2025-12-06 05:56:10 +00:00
Add support for concrete interface implementations
Change overriding member logic to use MethodInfo.GetBaseDefinition()
This commit is contained in:
@@ -25,12 +25,12 @@ namespace EntityFrameworkCore.Projectables.Extensions
|
|||||||
yield return type;
|
yield return type;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo)
|
private static bool CanHaveOverridingMethod(this Type derivedType, MethodInfo methodInfo)
|
||||||
{
|
{
|
||||||
// We only need to search for virtual instance methods who are not declared on the derivedType
|
// We only need to search for virtual instance methods who are not declared on the derivedType
|
||||||
if (derivedType == methodInfo.DeclaringType || methodInfo.IsStatic || !methodInfo.IsVirtual)
|
if (derivedType == methodInfo.DeclaringType || methodInfo.IsStatic || !methodInfo.IsVirtual)
|
||||||
{
|
{
|
||||||
return methodInfo;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!derivedType.IsAssignableTo(methodInfo.DeclaringType))
|
if (!derivedType.IsAssignableTo(methodInfo.DeclaringType))
|
||||||
@@ -38,76 +38,115 @@ namespace EntityFrameworkCore.Projectables.Extensions
|
|||||||
throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo));
|
throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo));
|
||||||
}
|
}
|
||||||
|
|
||||||
var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
|
return true;
|
||||||
|
|
||||||
foreach (var derivedMethodInfo in derivedMethods)
|
|
||||||
{
|
|
||||||
if (HasCompatibleSignature(methodInfo, derivedMethodInfo))
|
|
||||||
{
|
|
||||||
return derivedMethodInfo;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// No derived methods were found. Return the original methodInfo
|
|
||||||
return methodInfo;
|
|
||||||
|
|
||||||
static bool HasCompatibleSignature(MethodInfo methodInfo, MethodInfo derivedMethodInfo)
|
|
||||||
{
|
|
||||||
if (methodInfo.Name != derivedMethodInfo.Name)
|
|
||||||
{
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
var methodParameters = methodInfo.GetParameters();
|
|
||||||
|
|
||||||
var derivedMethodParameters = derivedMethodInfo.GetParameters();
|
|
||||||
if (methodParameters.Length != derivedMethodParameters.Length)
|
|
||||||
{
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Match all parameters
|
|
||||||
for (var parameterIndex = 0; parameterIndex < methodParameters.Length; parameterIndex++)
|
|
||||||
{
|
|
||||||
var parameter = methodParameters[parameterIndex];
|
|
||||||
var derivedParameter = derivedMethodParameters[parameterIndex];
|
|
||||||
|
|
||||||
if (parameter.ParameterType.IsGenericParameter)
|
|
||||||
{
|
|
||||||
if (!derivedParameter.ParameterType.IsGenericParameter)
|
|
||||||
{
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
if (parameter.ParameterType != derivedParameter.ParameterType)
|
|
||||||
{
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Match the number of generic type arguments
|
|
||||||
if (methodInfo.IsGenericMethodDefinition)
|
|
||||||
{
|
|
||||||
var methodGenericParameters = methodInfo.GetGenericArguments();
|
|
||||||
|
|
||||||
if (!derivedMethodInfo.IsGenericMethodDefinition)
|
|
||||||
{
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
var derivedGenericArguments = derivedMethodInfo.GetGenericArguments();
|
|
||||||
|
|
||||||
if (methodGenericParameters.Length != derivedGenericArguments.Length)
|
|
||||||
{
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static int? GetOverridingMethodIndex(this MethodInfo methodInfo, MethodInfo[]? allDerivedMethods)
|
||||||
|
{
|
||||||
|
if (allDerivedMethods is { Length: > 0 })
|
||||||
|
{
|
||||||
|
var baseDefinition = methodInfo.GetBaseDefinition();
|
||||||
|
for (var i = 0; i < allDerivedMethods.Length; i++)
|
||||||
|
{
|
||||||
|
var derivedMethodInfo = allDerivedMethods[i];
|
||||||
|
if (derivedMethodInfo.GetBaseDefinition() == baseDefinition)
|
||||||
|
{
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo)
|
||||||
|
{
|
||||||
|
if (!derivedType.CanHaveOverridingMethod(methodInfo))
|
||||||
|
{
|
||||||
|
return methodInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
|
||||||
|
|
||||||
|
return methodInfo.GetOverridingMethodIndex(derivedMethods) is { } i
|
||||||
|
? derivedMethods[i]
|
||||||
|
// No derived methods were found. Return the original methodInfo
|
||||||
|
: methodInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PropertyInfo GetOverridingProperty(this Type derivedType, PropertyInfo propertyInfo)
|
||||||
|
{
|
||||||
|
var accessor = propertyInfo.GetAccessors()[0];
|
||||||
|
|
||||||
|
if (!derivedType.CanHaveOverridingMethod(accessor))
|
||||||
|
{
|
||||||
|
return propertyInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
var derivedProperties = derivedType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
|
||||||
|
var derivedPropertyMethods = derivedProperties
|
||||||
|
.Select((Func<PropertyInfo, MethodInfo?>)
|
||||||
|
(propertyInfo.GetMethod == accessor ? p => p.GetMethod : p => p.SetMethod))
|
||||||
|
.OfType<MethodInfo>().ToArray();
|
||||||
|
|
||||||
|
return accessor.GetOverridingMethodIndex(derivedPropertyMethods) is { } i
|
||||||
|
? derivedProperties[i]
|
||||||
|
// No derived methods were found. Return the original methodInfo
|
||||||
|
: propertyInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static MethodInfo GetImplementingMethod(this Type derivedType, MethodInfo methodInfo)
|
||||||
|
{
|
||||||
|
var interfaceType = methodInfo.DeclaringType;
|
||||||
|
// We only need to search for interface methods
|
||||||
|
if (interfaceType?.IsInterface != true || derivedType.IsInterface || methodInfo.IsStatic || !methodInfo.IsVirtual)
|
||||||
|
{
|
||||||
|
return methodInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!derivedType.IsAssignableTo(interfaceType))
|
||||||
|
{
|
||||||
|
throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo));
|
||||||
|
}
|
||||||
|
|
||||||
|
var interfaceMap = derivedType.GetInterfaceMap(interfaceType);
|
||||||
|
for (var i = 0; i < interfaceMap.InterfaceMethods.Length; i++)
|
||||||
|
{
|
||||||
|
if (interfaceMap.InterfaceMethods[i] == methodInfo)
|
||||||
|
{
|
||||||
|
return interfaceMap.TargetMethods[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new ApplicationException(
|
||||||
|
$"The interface map for {derivedType} doesn't contain the implemented method for {methodInfo}!");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PropertyInfo GetImplementingProperty(this Type derivedType, PropertyInfo propertyInfo)
|
||||||
|
{
|
||||||
|
var accessor = propertyInfo.GetAccessors()[0];
|
||||||
|
|
||||||
|
var implementingAccessor = derivedType.GetImplementingMethod(accessor);
|
||||||
|
if (implementingAccessor == accessor)
|
||||||
|
{
|
||||||
|
return propertyInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
var derivedProperties = derivedType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
|
||||||
|
|
||||||
|
return derivedProperties.First(propertyInfo.GetMethod == accessor
|
||||||
|
? p => p.GetMethod == implementingAccessor
|
||||||
|
: p => p.SetMethod == implementingAccessor);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static MethodInfo GetConcreteMethod(this Type derivedType, MethodInfo methodInfo)
|
||||||
|
=> methodInfo.DeclaringType?.IsInterface == true
|
||||||
|
? derivedType.GetImplementingMethod(methodInfo)
|
||||||
|
: derivedType.GetOverridingMethod(methodInfo);
|
||||||
|
|
||||||
|
public static PropertyInfo GetConcreteProperty(this Type derivedType, PropertyInfo propertyInfo)
|
||||||
|
=> propertyInfo.DeclaringType?.IsInterface == true
|
||||||
|
? derivedType.GetImplementingProperty(propertyInfo)
|
||||||
|
: derivedType.GetOverridingProperty(propertyInfo);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ namespace EntityFrameworkCore.Projectables.Services
|
|||||||
protected override Expression VisitMethodCall(MethodCallExpression node)
|
protected override Expression VisitMethodCall(MethodCallExpression node)
|
||||||
{
|
{
|
||||||
// Get the overriding methodInfo based on te type of the received of this expression
|
// Get the overriding methodInfo based on te type of the received of this expression
|
||||||
var methodInfo = node.Object?.Type.GetOverridingMethod(node.Method) ?? node.Method;
|
var methodInfo = node.Object?.Type.GetConcreteMethod(node.Method) ?? node.Method;
|
||||||
|
|
||||||
if (TryGetReflectedExpression(methodInfo, out var reflectedExpression))
|
if (TryGetReflectedExpression(methodInfo, out var reflectedExpression))
|
||||||
{
|
{
|
||||||
@@ -74,16 +74,23 @@ namespace EntityFrameworkCore.Projectables.Services
|
|||||||
|
|
||||||
protected override Expression VisitMember(MemberExpression node)
|
protected override Expression VisitMember(MemberExpression node)
|
||||||
{
|
{
|
||||||
var nodeMember = node.Expression switch {
|
var nodeExpression = node.Expression switch {
|
||||||
{ Type: { } } => node.Expression.Type.GetMember(node.Member.Name, node.Member.MemberType, BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static)[0],
|
UnaryExpression { NodeType: ExpressionType.Convert, Type: { IsInterface: true } type, Operand: { } operand }
|
||||||
|
when type.IsAssignableFrom(operand.Type)
|
||||||
|
=> operand,
|
||||||
|
_ => node.Expression
|
||||||
|
};
|
||||||
|
var nodeMember = node.Member switch {
|
||||||
|
PropertyInfo property when nodeExpression is not null
|
||||||
|
=> nodeExpression.Type.GetConcreteProperty(property),
|
||||||
_ => node.Member
|
_ => node.Member
|
||||||
};
|
};
|
||||||
|
|
||||||
if (TryGetReflectedExpression(nodeMember, out var reflectedExpression))
|
if (TryGetReflectedExpression(nodeMember, out var reflectedExpression))
|
||||||
{
|
{
|
||||||
if (node.Expression is not null)
|
if (nodeExpression is not null)
|
||||||
{
|
{
|
||||||
_expressionArgumentReplacer.ParameterArgumentMapping.Add(reflectedExpression.Parameters[0], node.Expression);
|
_expressionArgumentReplacer.ParameterArgumentMapping.Add(reflectedExpression.Parameters[0], nodeExpression);
|
||||||
var updatedBody = _expressionArgumentReplacer.Visit(reflectedExpression.Body);
|
var updatedBody = _expressionArgumentReplacer.Visit(reflectedExpression.Body);
|
||||||
_expressionArgumentReplacer.ParameterArgumentMapping.Clear();
|
_expressionArgumentReplacer.ParameterArgumentMapping.Clear();
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,2 @@
|
|||||||
|
SELECT 2
|
||||||
|
FROM [Concrete] AS [c]
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
SELECT 2
|
||||||
|
FROM [Concrete] AS [c]
|
||||||
@@ -1,2 +1,2 @@
|
|||||||
SELECT 2
|
SELECT 2
|
||||||
FROM [Concrete] AS [c]
|
FROM [MoreConcrete] AS [m]
|
||||||
@@ -1,2 +1,2 @@
|
|||||||
SELECT 2
|
SELECT 2
|
||||||
FROM [Concrete] AS [c]
|
FROM [MoreConcrete] AS [m]
|
||||||
@@ -20,7 +20,13 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests
|
|||||||
[UsesVerify]
|
[UsesVerify]
|
||||||
public class InheritedModelTests
|
public class InheritedModelTests
|
||||||
{
|
{
|
||||||
public abstract class Base
|
public interface IBase
|
||||||
|
{
|
||||||
|
int ComputedProperty { get; }
|
||||||
|
int ComputedMethod();
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract class Base : IBase
|
||||||
{
|
{
|
||||||
public int Id { get; set; }
|
public int Id { get; set; }
|
||||||
|
|
||||||
@@ -62,9 +68,9 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests
|
|||||||
[Fact]
|
[Fact]
|
||||||
public Task ProjectOverInheritedPropertyImplementation()
|
public Task ProjectOverInheritedPropertyImplementation()
|
||||||
{
|
{
|
||||||
using var dbContext = new SampleDbContext<Concrete>();
|
using var dbContext = new SampleDbContext<MoreConcrete>();
|
||||||
|
|
||||||
var query = dbContext.Set<Concrete>()
|
var query = dbContext.Set<MoreConcrete>()
|
||||||
.Select(x => x.ComputedProperty);
|
.Select(x => x.ComputedProperty);
|
||||||
|
|
||||||
return Verifier.Verify(query.ToQueryString());
|
return Verifier.Verify(query.ToQueryString());
|
||||||
@@ -84,12 +90,43 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests
|
|||||||
[Fact]
|
[Fact]
|
||||||
public Task ProjectOverInheritedMethodImplementation()
|
public Task ProjectOverInheritedMethodImplementation()
|
||||||
{
|
{
|
||||||
using var dbContext = new SampleDbContext<Concrete>();
|
using var dbContext = new SampleDbContext<MoreConcrete>();
|
||||||
|
|
||||||
var query = dbContext.Set<Concrete>()
|
var query = dbContext.Set<MoreConcrete>()
|
||||||
.Select(x => x.ComputedMethod());
|
.Select(x => x.ComputedMethod());
|
||||||
|
|
||||||
return Verifier.Verify(query.ToQueryString());
|
return Verifier.Verify(query.ToQueryString());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public Task ProjectOverImplementedProperty()
|
||||||
|
{
|
||||||
|
using var dbContext = new SampleDbContext<Concrete>();
|
||||||
|
|
||||||
|
var query = dbContext.Set<Concrete>().SelectComputedProperty();
|
||||||
|
|
||||||
|
return Verifier.Verify(query.ToQueryString());
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public Task ProjectOverImplementedMethod()
|
||||||
|
{
|
||||||
|
using var dbContext = new SampleDbContext<Concrete>();
|
||||||
|
|
||||||
|
var query = dbContext.Set<Concrete>().SelectComputedMethod();
|
||||||
|
|
||||||
|
return Verifier.Verify(query.ToQueryString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class ModelExtensions
|
||||||
|
{
|
||||||
|
public static IQueryable<int> SelectComputedProperty<TConcrete>(this IQueryable<TConcrete> concretes)
|
||||||
|
where TConcrete : InheritedModelTests.IBase
|
||||||
|
=> concretes.Select(x => x.ComputedProperty);
|
||||||
|
|
||||||
|
public static IQueryable<int> SelectComputedMethod<TConcrete>(this IQueryable<TConcrete> concretes)
|
||||||
|
where TConcrete : InheritedModelTests.IBase
|
||||||
|
=> concretes.Select(x => x.ComputedMethod());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user