From 423879185e8f0eb1cc1a8c397d5aec391fbb2f50 Mon Sep 17 00:00:00 2001 From: Rhodon Date: Wed, 11 Jan 2023 14:08:44 +0100 Subject: [PATCH] Add support for concrete interface implementations Change overriding member logic to use MethodInfo.GetBaseDefinition() --- .../Extensions/TypeExtensions.cs | 183 +++++++++++------- .../Services/ProjectableExpressionReplacer.cs | 17 +- ....ProjectOverImplementedMethod.verified.txt | 2 + ...rojectOverImplementedProperty.verified.txt | 2 + ...InheritedMethodImplementation.verified.txt | 2 +- ...heritedPropertyImplementation.verified.txt | 2 +- .../InheritedModelTests.cs | 47 ++++- 7 files changed, 171 insertions(+), 84 deletions(-) create mode 100644 tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverImplementedMethod.verified.txt create mode 100644 tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverImplementedProperty.verified.txt diff --git a/src/EntityFrameworkCore.Projectables/Extensions/TypeExtensions.cs b/src/EntityFrameworkCore.Projectables/Extensions/TypeExtensions.cs index 1bec521..c9e1bef 100644 --- a/src/EntityFrameworkCore.Projectables/Extensions/TypeExtensions.cs +++ b/src/EntityFrameworkCore.Projectables/Extensions/TypeExtensions.cs @@ -25,12 +25,12 @@ namespace EntityFrameworkCore.Projectables.Extensions yield return type; } - public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo) + private static bool CanHaveOverridingMethod(this Type derivedType, MethodInfo methodInfo) { // We only need to search for virtual instance methods who are not declared on the derivedType if (derivedType == methodInfo.DeclaringType || methodInfo.IsStatic || !methodInfo.IsVirtual) { - return methodInfo; + return false; } if (!derivedType.IsAssignableTo(methodInfo.DeclaringType)) @@ -38,76 +38,115 @@ namespace EntityFrameworkCore.Projectables.Extensions throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo)); } - var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); - - foreach (var derivedMethodInfo in derivedMethods) - { - if (HasCompatibleSignature(methodInfo, derivedMethodInfo)) - { - return derivedMethodInfo; - } - } - - // No derived methods were found. Return the original methodInfo - return methodInfo; - - static bool HasCompatibleSignature(MethodInfo methodInfo, MethodInfo derivedMethodInfo) - { - if (methodInfo.Name != derivedMethodInfo.Name) - { - return false; - } - - var methodParameters = methodInfo.GetParameters(); - - var derivedMethodParameters = derivedMethodInfo.GetParameters(); - if (methodParameters.Length != derivedMethodParameters.Length) - { - return false; - } - - // Match all parameters - for (var parameterIndex = 0; parameterIndex < methodParameters.Length; parameterIndex++) - { - var parameter = methodParameters[parameterIndex]; - var derivedParameter = derivedMethodParameters[parameterIndex]; - - if (parameter.ParameterType.IsGenericParameter) - { - if (!derivedParameter.ParameterType.IsGenericParameter) - { - return false; - } - } - else - { - if (parameter.ParameterType != derivedParameter.ParameterType) - { - return false; - } - } - } - - // Match the number of generic type arguments - if (methodInfo.IsGenericMethodDefinition) - { - var methodGenericParameters = methodInfo.GetGenericArguments(); - - if (!derivedMethodInfo.IsGenericMethodDefinition) - { - return false; - } - - var derivedGenericArguments = derivedMethodInfo.GetGenericArguments(); - - if (methodGenericParameters.Length != derivedGenericArguments.Length) - { - return false; - } - } - - return true; - } + return true; } + + private static int? GetOverridingMethodIndex(this MethodInfo methodInfo, MethodInfo[]? allDerivedMethods) + { + if (allDerivedMethods is { Length: > 0 }) + { + var baseDefinition = methodInfo.GetBaseDefinition(); + for (var i = 0; i < allDerivedMethods.Length; i++) + { + var derivedMethodInfo = allDerivedMethods[i]; + if (derivedMethodInfo.GetBaseDefinition() == baseDefinition) + { + return i; + } + } + } + + return null; + } + + public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo) + { + if (!derivedType.CanHaveOverridingMethod(methodInfo)) + { + return methodInfo; + } + + var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + + return methodInfo.GetOverridingMethodIndex(derivedMethods) is { } i + ? derivedMethods[i] + // No derived methods were found. Return the original methodInfo + : methodInfo; + } + + public static PropertyInfo GetOverridingProperty(this Type derivedType, PropertyInfo propertyInfo) + { + var accessor = propertyInfo.GetAccessors()[0]; + + if (!derivedType.CanHaveOverridingMethod(accessor)) + { + return propertyInfo; + } + + var derivedProperties = derivedType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + var derivedPropertyMethods = derivedProperties + .Select((Func) + (propertyInfo.GetMethod == accessor ? p => p.GetMethod : p => p.SetMethod)) + .OfType().ToArray(); + + return accessor.GetOverridingMethodIndex(derivedPropertyMethods) is { } i + ? derivedProperties[i] + // No derived methods were found. Return the original methodInfo + : propertyInfo; + } + + public static MethodInfo GetImplementingMethod(this Type derivedType, MethodInfo methodInfo) + { + var interfaceType = methodInfo.DeclaringType; + // We only need to search for interface methods + if (interfaceType?.IsInterface != true || derivedType.IsInterface || methodInfo.IsStatic || !methodInfo.IsVirtual) + { + return methodInfo; + } + + if (!derivedType.IsAssignableTo(interfaceType)) + { + throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo)); + } + + var interfaceMap = derivedType.GetInterfaceMap(interfaceType); + for (var i = 0; i < interfaceMap.InterfaceMethods.Length; i++) + { + if (interfaceMap.InterfaceMethods[i] == methodInfo) + { + return interfaceMap.TargetMethods[i]; + } + } + + throw new ApplicationException( + $"The interface map for {derivedType} doesn't contain the implemented method for {methodInfo}!"); + } + + public static PropertyInfo GetImplementingProperty(this Type derivedType, PropertyInfo propertyInfo) + { + var accessor = propertyInfo.GetAccessors()[0]; + + var implementingAccessor = derivedType.GetImplementingMethod(accessor); + if (implementingAccessor == accessor) + { + return propertyInfo; + } + + var derivedProperties = derivedType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + + return derivedProperties.First(propertyInfo.GetMethod == accessor + ? p => p.GetMethod == implementingAccessor + : p => p.SetMethod == implementingAccessor); + } + + public static MethodInfo GetConcreteMethod(this Type derivedType, MethodInfo methodInfo) + => methodInfo.DeclaringType?.IsInterface == true + ? derivedType.GetImplementingMethod(methodInfo) + : derivedType.GetOverridingMethod(methodInfo); + + public static PropertyInfo GetConcreteProperty(this Type derivedType, PropertyInfo propertyInfo) + => propertyInfo.DeclaringType?.IsInterface == true + ? derivedType.GetImplementingProperty(propertyInfo) + : derivedType.GetOverridingProperty(propertyInfo); } } diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs index 1dfdf77..84f5283 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs @@ -42,7 +42,7 @@ namespace EntityFrameworkCore.Projectables.Services protected override Expression VisitMethodCall(MethodCallExpression node) { // Get the overriding methodInfo based on te type of the received of this expression - var methodInfo = node.Object?.Type.GetOverridingMethod(node.Method) ?? node.Method; + var methodInfo = node.Object?.Type.GetConcreteMethod(node.Method) ?? node.Method; if (TryGetReflectedExpression(methodInfo, out var reflectedExpression)) { @@ -74,16 +74,23 @@ namespace EntityFrameworkCore.Projectables.Services protected override Expression VisitMember(MemberExpression node) { - var nodeMember = node.Expression switch { - { Type: { } } => node.Expression.Type.GetMember(node.Member.Name, node.Member.MemberType, BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static)[0], + var nodeExpression = node.Expression switch { + UnaryExpression { NodeType: ExpressionType.Convert, Type: { IsInterface: true } type, Operand: { } operand } + when type.IsAssignableFrom(operand.Type) + => operand, + _ => node.Expression + }; + var nodeMember = node.Member switch { + PropertyInfo property when nodeExpression is not null + => nodeExpression.Type.GetConcreteProperty(property), _ => node.Member }; if (TryGetReflectedExpression(nodeMember, out var reflectedExpression)) { - if (node.Expression is not null) + if (nodeExpression is not null) { - _expressionArgumentReplacer.ParameterArgumentMapping.Add(reflectedExpression.Parameters[0], node.Expression); + _expressionArgumentReplacer.ParameterArgumentMapping.Add(reflectedExpression.Parameters[0], nodeExpression); var updatedBody = _expressionArgumentReplacer.Visit(reflectedExpression.Body); _expressionArgumentReplacer.ParameterArgumentMapping.Clear(); diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverImplementedMethod.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverImplementedMethod.verified.txt new file mode 100644 index 0000000..9ba8adc --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverImplementedMethod.verified.txt @@ -0,0 +1,2 @@ +SELECT 2 +FROM [Concrete] AS [c] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverImplementedProperty.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverImplementedProperty.verified.txt new file mode 100644 index 0000000..9ba8adc --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverImplementedProperty.verified.txt @@ -0,0 +1,2 @@ +SELECT 2 +FROM [Concrete] AS [c] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverInheritedMethodImplementation.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverInheritedMethodImplementation.verified.txt index 9ba8adc..63ade64 100644 --- a/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverInheritedMethodImplementation.verified.txt +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverInheritedMethodImplementation.verified.txt @@ -1,2 +1,2 @@ SELECT 2 -FROM [Concrete] AS [c] \ No newline at end of file +FROM [MoreConcrete] AS [m] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverInheritedPropertyImplementation.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverInheritedPropertyImplementation.verified.txt index 9ba8adc..63ade64 100644 --- a/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverInheritedPropertyImplementation.verified.txt +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverInheritedPropertyImplementation.verified.txt @@ -1,2 +1,2 @@ SELECT 2 -FROM [Concrete] AS [c] \ No newline at end of file +FROM [MoreConcrete] AS [m] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.cs b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.cs index 8d8acc5..cea0103 100644 --- a/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.cs +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.cs @@ -20,7 +20,13 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests [UsesVerify] public class InheritedModelTests { - public abstract class Base + public interface IBase + { + int ComputedProperty { get; } + int ComputedMethod(); + } + + public abstract class Base : IBase { public int Id { get; set; } @@ -62,9 +68,9 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests [Fact] public Task ProjectOverInheritedPropertyImplementation() { - using var dbContext = new SampleDbContext(); + using var dbContext = new SampleDbContext(); - var query = dbContext.Set() + var query = dbContext.Set() .Select(x => x.ComputedProperty); return Verifier.Verify(query.ToQueryString()); @@ -84,12 +90,43 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests [Fact] public Task ProjectOverInheritedMethodImplementation() { - using var dbContext = new SampleDbContext(); + using var dbContext = new SampleDbContext(); - var query = dbContext.Set() + var query = dbContext.Set() .Select(x => x.ComputedMethod()); return Verifier.Verify(query.ToQueryString()); } + + [Fact] + public Task ProjectOverImplementedProperty() + { + using var dbContext = new SampleDbContext(); + + var query = dbContext.Set().SelectComputedProperty(); + + return Verifier.Verify(query.ToQueryString()); + } + + [Fact] + public Task ProjectOverImplementedMethod() + { + using var dbContext = new SampleDbContext(); + + var query = dbContext.Set().SelectComputedMethod(); + + return Verifier.Verify(query.ToQueryString()); + } + } + + public static class ModelExtensions + { + public static IQueryable SelectComputedProperty(this IQueryable concretes) + where TConcrete : InheritedModelTests.IBase + => concretes.Select(x => x.ComputedProperty); + + public static IQueryable SelectComputedMethod(this IQueryable concretes) + where TConcrete : InheritedModelTests.IBase + => concretes.Select(x => x.ComputedMethod()); } }