mirror of
https://github.com/zoriya/EntityFrameworkCore.Projectables.git
synced 2025-12-06 05:56:10 +00:00
Add support for concrete interface implementations
Change overriding member logic to use MethodInfo.GetBaseDefinition()
This commit is contained in:
@@ -25,12 +25,12 @@ namespace EntityFrameworkCore.Projectables.Extensions
|
||||
yield return type;
|
||||
}
|
||||
|
||||
public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo)
|
||||
private static bool CanHaveOverridingMethod(this Type derivedType, MethodInfo methodInfo)
|
||||
{
|
||||
// We only need to search for virtual instance methods who are not declared on the derivedType
|
||||
if (derivedType == methodInfo.DeclaringType || methodInfo.IsStatic || !methodInfo.IsVirtual)
|
||||
{
|
||||
return methodInfo;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!derivedType.IsAssignableTo(methodInfo.DeclaringType))
|
||||
@@ -38,76 +38,115 @@ namespace EntityFrameworkCore.Projectables.Extensions
|
||||
throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo));
|
||||
}
|
||||
|
||||
var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
|
||||
|
||||
foreach (var derivedMethodInfo in derivedMethods)
|
||||
{
|
||||
if (HasCompatibleSignature(methodInfo, derivedMethodInfo))
|
||||
{
|
||||
return derivedMethodInfo;
|
||||
}
|
||||
}
|
||||
|
||||
// No derived methods were found. Return the original methodInfo
|
||||
return methodInfo;
|
||||
|
||||
static bool HasCompatibleSignature(MethodInfo methodInfo, MethodInfo derivedMethodInfo)
|
||||
{
|
||||
if (methodInfo.Name != derivedMethodInfo.Name)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var methodParameters = methodInfo.GetParameters();
|
||||
|
||||
var derivedMethodParameters = derivedMethodInfo.GetParameters();
|
||||
if (methodParameters.Length != derivedMethodParameters.Length)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Match all parameters
|
||||
for (var parameterIndex = 0; parameterIndex < methodParameters.Length; parameterIndex++)
|
||||
{
|
||||
var parameter = methodParameters[parameterIndex];
|
||||
var derivedParameter = derivedMethodParameters[parameterIndex];
|
||||
|
||||
if (parameter.ParameterType.IsGenericParameter)
|
||||
{
|
||||
if (!derivedParameter.ParameterType.IsGenericParameter)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (parameter.ParameterType != derivedParameter.ParameterType)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Match the number of generic type arguments
|
||||
if (methodInfo.IsGenericMethodDefinition)
|
||||
{
|
||||
var methodGenericParameters = methodInfo.GetGenericArguments();
|
||||
|
||||
if (!derivedMethodInfo.IsGenericMethodDefinition)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var derivedGenericArguments = derivedMethodInfo.GetGenericArguments();
|
||||
|
||||
if (methodGenericParameters.Length != derivedGenericArguments.Length)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private static int? GetOverridingMethodIndex(this MethodInfo methodInfo, MethodInfo[]? allDerivedMethods)
|
||||
{
|
||||
if (allDerivedMethods is { Length: > 0 })
|
||||
{
|
||||
var baseDefinition = methodInfo.GetBaseDefinition();
|
||||
for (var i = 0; i < allDerivedMethods.Length; i++)
|
||||
{
|
||||
var derivedMethodInfo = allDerivedMethods[i];
|
||||
if (derivedMethodInfo.GetBaseDefinition() == baseDefinition)
|
||||
{
|
||||
return i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo)
|
||||
{
|
||||
if (!derivedType.CanHaveOverridingMethod(methodInfo))
|
||||
{
|
||||
return methodInfo;
|
||||
}
|
||||
|
||||
var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
|
||||
|
||||
return methodInfo.GetOverridingMethodIndex(derivedMethods) is { } i
|
||||
? derivedMethods[i]
|
||||
// No derived methods were found. Return the original methodInfo
|
||||
: methodInfo;
|
||||
}
|
||||
|
||||
public static PropertyInfo GetOverridingProperty(this Type derivedType, PropertyInfo propertyInfo)
|
||||
{
|
||||
var accessor = propertyInfo.GetAccessors()[0];
|
||||
|
||||
if (!derivedType.CanHaveOverridingMethod(accessor))
|
||||
{
|
||||
return propertyInfo;
|
||||
}
|
||||
|
||||
var derivedProperties = derivedType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
|
||||
var derivedPropertyMethods = derivedProperties
|
||||
.Select((Func<PropertyInfo, MethodInfo?>)
|
||||
(propertyInfo.GetMethod == accessor ? p => p.GetMethod : p => p.SetMethod))
|
||||
.OfType<MethodInfo>().ToArray();
|
||||
|
||||
return accessor.GetOverridingMethodIndex(derivedPropertyMethods) is { } i
|
||||
? derivedProperties[i]
|
||||
// No derived methods were found. Return the original methodInfo
|
||||
: propertyInfo;
|
||||
}
|
||||
|
||||
public static MethodInfo GetImplementingMethod(this Type derivedType, MethodInfo methodInfo)
|
||||
{
|
||||
var interfaceType = methodInfo.DeclaringType;
|
||||
// We only need to search for interface methods
|
||||
if (interfaceType?.IsInterface != true || derivedType.IsInterface || methodInfo.IsStatic || !methodInfo.IsVirtual)
|
||||
{
|
||||
return methodInfo;
|
||||
}
|
||||
|
||||
if (!derivedType.IsAssignableTo(interfaceType))
|
||||
{
|
||||
throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo));
|
||||
}
|
||||
|
||||
var interfaceMap = derivedType.GetInterfaceMap(interfaceType);
|
||||
for (var i = 0; i < interfaceMap.InterfaceMethods.Length; i++)
|
||||
{
|
||||
if (interfaceMap.InterfaceMethods[i] == methodInfo)
|
||||
{
|
||||
return interfaceMap.TargetMethods[i];
|
||||
}
|
||||
}
|
||||
|
||||
throw new ApplicationException(
|
||||
$"The interface map for {derivedType} doesn't contain the implemented method for {methodInfo}!");
|
||||
}
|
||||
|
||||
public static PropertyInfo GetImplementingProperty(this Type derivedType, PropertyInfo propertyInfo)
|
||||
{
|
||||
var accessor = propertyInfo.GetAccessors()[0];
|
||||
|
||||
var implementingAccessor = derivedType.GetImplementingMethod(accessor);
|
||||
if (implementingAccessor == accessor)
|
||||
{
|
||||
return propertyInfo;
|
||||
}
|
||||
|
||||
var derivedProperties = derivedType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
|
||||
|
||||
return derivedProperties.First(propertyInfo.GetMethod == accessor
|
||||
? p => p.GetMethod == implementingAccessor
|
||||
: p => p.SetMethod == implementingAccessor);
|
||||
}
|
||||
|
||||
public static MethodInfo GetConcreteMethod(this Type derivedType, MethodInfo methodInfo)
|
||||
=> methodInfo.DeclaringType?.IsInterface == true
|
||||
? derivedType.GetImplementingMethod(methodInfo)
|
||||
: derivedType.GetOverridingMethod(methodInfo);
|
||||
|
||||
public static PropertyInfo GetConcreteProperty(this Type derivedType, PropertyInfo propertyInfo)
|
||||
=> propertyInfo.DeclaringType?.IsInterface == true
|
||||
? derivedType.GetImplementingProperty(propertyInfo)
|
||||
: derivedType.GetOverridingProperty(propertyInfo);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ namespace EntityFrameworkCore.Projectables.Services
|
||||
protected override Expression VisitMethodCall(MethodCallExpression node)
|
||||
{
|
||||
// Get the overriding methodInfo based on te type of the received of this expression
|
||||
var methodInfo = node.Object?.Type.GetOverridingMethod(node.Method) ?? node.Method;
|
||||
var methodInfo = node.Object?.Type.GetConcreteMethod(node.Method) ?? node.Method;
|
||||
|
||||
if (TryGetReflectedExpression(methodInfo, out var reflectedExpression))
|
||||
{
|
||||
@@ -74,16 +74,23 @@ namespace EntityFrameworkCore.Projectables.Services
|
||||
|
||||
protected override Expression VisitMember(MemberExpression node)
|
||||
{
|
||||
var nodeMember = node.Expression switch {
|
||||
{ Type: { } } => node.Expression.Type.GetMember(node.Member.Name, node.Member.MemberType, BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static)[0],
|
||||
var nodeExpression = node.Expression switch {
|
||||
UnaryExpression { NodeType: ExpressionType.Convert, Type: { IsInterface: true } type, Operand: { } operand }
|
||||
when type.IsAssignableFrom(operand.Type)
|
||||
=> operand,
|
||||
_ => node.Expression
|
||||
};
|
||||
var nodeMember = node.Member switch {
|
||||
PropertyInfo property when nodeExpression is not null
|
||||
=> nodeExpression.Type.GetConcreteProperty(property),
|
||||
_ => node.Member
|
||||
};
|
||||
|
||||
if (TryGetReflectedExpression(nodeMember, out var reflectedExpression))
|
||||
{
|
||||
if (node.Expression is not null)
|
||||
if (nodeExpression is not null)
|
||||
{
|
||||
_expressionArgumentReplacer.ParameterArgumentMapping.Add(reflectedExpression.Parameters[0], node.Expression);
|
||||
_expressionArgumentReplacer.ParameterArgumentMapping.Add(reflectedExpression.Parameters[0], nodeExpression);
|
||||
var updatedBody = _expressionArgumentReplacer.Visit(reflectedExpression.Body);
|
||||
_expressionArgumentReplacer.ParameterArgumentMapping.Clear();
|
||||
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
SELECT 2
|
||||
FROM [Concrete] AS [c]
|
||||
@@ -0,0 +1,2 @@
|
||||
SELECT 2
|
||||
FROM [Concrete] AS [c]
|
||||
@@ -1,2 +1,2 @@
|
||||
SELECT 2
|
||||
FROM [Concrete] AS [c]
|
||||
FROM [MoreConcrete] AS [m]
|
||||
@@ -1,2 +1,2 @@
|
||||
SELECT 2
|
||||
FROM [Concrete] AS [c]
|
||||
FROM [MoreConcrete] AS [m]
|
||||
@@ -20,7 +20,13 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests
|
||||
[UsesVerify]
|
||||
public class InheritedModelTests
|
||||
{
|
||||
public abstract class Base
|
||||
public interface IBase
|
||||
{
|
||||
int ComputedProperty { get; }
|
||||
int ComputedMethod();
|
||||
}
|
||||
|
||||
public abstract class Base : IBase
|
||||
{
|
||||
public int Id { get; set; }
|
||||
|
||||
@@ -62,9 +68,9 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests
|
||||
[Fact]
|
||||
public Task ProjectOverInheritedPropertyImplementation()
|
||||
{
|
||||
using var dbContext = new SampleDbContext<Concrete>();
|
||||
using var dbContext = new SampleDbContext<MoreConcrete>();
|
||||
|
||||
var query = dbContext.Set<Concrete>()
|
||||
var query = dbContext.Set<MoreConcrete>()
|
||||
.Select(x => x.ComputedProperty);
|
||||
|
||||
return Verifier.Verify(query.ToQueryString());
|
||||
@@ -84,12 +90,43 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests
|
||||
[Fact]
|
||||
public Task ProjectOverInheritedMethodImplementation()
|
||||
{
|
||||
using var dbContext = new SampleDbContext<Concrete>();
|
||||
using var dbContext = new SampleDbContext<MoreConcrete>();
|
||||
|
||||
var query = dbContext.Set<Concrete>()
|
||||
var query = dbContext.Set<MoreConcrete>()
|
||||
.Select(x => x.ComputedMethod());
|
||||
|
||||
return Verifier.Verify(query.ToQueryString());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public Task ProjectOverImplementedProperty()
|
||||
{
|
||||
using var dbContext = new SampleDbContext<Concrete>();
|
||||
|
||||
var query = dbContext.Set<Concrete>().SelectComputedProperty();
|
||||
|
||||
return Verifier.Verify(query.ToQueryString());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public Task ProjectOverImplementedMethod()
|
||||
{
|
||||
using var dbContext = new SampleDbContext<Concrete>();
|
||||
|
||||
var query = dbContext.Set<Concrete>().SelectComputedMethod();
|
||||
|
||||
return Verifier.Verify(query.ToQueryString());
|
||||
}
|
||||
}
|
||||
|
||||
public static class ModelExtensions
|
||||
{
|
||||
public static IQueryable<int> SelectComputedProperty<TConcrete>(this IQueryable<TConcrete> concretes)
|
||||
where TConcrete : InheritedModelTests.IBase
|
||||
=> concretes.Select(x => x.ComputedProperty);
|
||||
|
||||
public static IQueryable<int> SelectComputedMethod<TConcrete>(this IQueryable<TConcrete> concretes)
|
||||
where TConcrete : InheritedModelTests.IBase
|
||||
=> concretes.Select(x => x.ComputedMethod());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user