From fead7bb41aa6b560c40260c2ec45cfca50128031 Mon Sep 17 00:00:00 2001 From: Koen Date: Mon, 12 Dec 2022 00:15:54 +0000 Subject: [PATCH] Find overriding implementations for virtual methods --- Directory.Build.props | 2 +- .../Extensions/TypeExtensions.cs | 88 +++++++++++++++++ .../Services/ProjectableExpressionReplacer.cs | 14 ++- .../Services/ProjectionExpressionResolver.cs | 16 ++-- ...InheritedMethodImplementation.verified.txt | 2 + ...heritedPropertyImplementation.verified.txt | 2 + ...verriddenMethodImplementation.verified.txt | 2 + ...rriddenPropertyImplementation.verified.txt | 2 + .../InheritedModelTests.cs | 95 +++++++++++++++++++ .../Extensions/TypeExtensionTests.cs | 63 +++++++++++- 10 files changed, 274 insertions(+), 12 deletions(-) create mode 100644 tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverInheritedMethodImplementation.verified.txt create mode 100644 tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverInheritedPropertyImplementation.verified.txt create mode 100644 tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverOverriddenMethodImplementation.verified.txt create mode 100644 tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverOverriddenPropertyImplementation.verified.txt create mode 100644 tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.cs diff --git a/Directory.Build.props b/Directory.Build.props index 0bac34f..d79998e 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -2,7 +2,7 @@ true - 9.0 + 11.0 enable true CS1591 diff --git a/src/EntityFrameworkCore.Projectables/Extensions/TypeExtensions.cs b/src/EntityFrameworkCore.Projectables/Extensions/TypeExtensions.cs index 148ed23..1bec521 100644 --- a/src/EntityFrameworkCore.Projectables/Extensions/TypeExtensions.cs +++ b/src/EntityFrameworkCore.Projectables/Extensions/TypeExtensions.cs @@ -1,7 +1,10 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Reflection; +using System.Reflection.Metadata; using System.Runtime.CompilerServices; +using System.Security.Cryptography.X509Certificates; using System.Text; using System.Threading.Tasks; @@ -21,5 +24,90 @@ namespace EntityFrameworkCore.Projectables.Extensions yield return type; } + + public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo) + { + // We only need to search for virtual instance methods who are not declared on the derivedType + if (derivedType == methodInfo.DeclaringType || methodInfo.IsStatic || !methodInfo.IsVirtual) + { + return methodInfo; + } + + if (!derivedType.IsAssignableTo(methodInfo.DeclaringType)) + { + throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo)); + } + + var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + + foreach (var derivedMethodInfo in derivedMethods) + { + if (HasCompatibleSignature(methodInfo, derivedMethodInfo)) + { + return derivedMethodInfo; + } + } + + // No derived methods were found. Return the original methodInfo + return methodInfo; + + static bool HasCompatibleSignature(MethodInfo methodInfo, MethodInfo derivedMethodInfo) + { + if (methodInfo.Name != derivedMethodInfo.Name) + { + return false; + } + + var methodParameters = methodInfo.GetParameters(); + + var derivedMethodParameters = derivedMethodInfo.GetParameters(); + if (methodParameters.Length != derivedMethodParameters.Length) + { + return false; + } + + // Match all parameters + for (var parameterIndex = 0; parameterIndex < methodParameters.Length; parameterIndex++) + { + var parameter = methodParameters[parameterIndex]; + var derivedParameter = derivedMethodParameters[parameterIndex]; + + if (parameter.ParameterType.IsGenericParameter) + { + if (!derivedParameter.ParameterType.IsGenericParameter) + { + return false; + } + } + else + { + if (parameter.ParameterType != derivedParameter.ParameterType) + { + return false; + } + } + } + + // Match the number of generic type arguments + if (methodInfo.IsGenericMethodDefinition) + { + var methodGenericParameters = methodInfo.GetGenericArguments(); + + if (!derivedMethodInfo.IsGenericMethodDefinition) + { + return false; + } + + var derivedGenericArguments = derivedMethodInfo.GetGenericArguments(); + + if (methodGenericParameters.Length != derivedGenericArguments.Length) + { + return false; + } + } + + return true; + } + } } } diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs index 4a826fb..1dfdf77 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs @@ -1,4 +1,5 @@ using System; +using System.Buffers; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; @@ -7,6 +8,7 @@ using System.Reflection; using System.Text; using System.Threading.Tasks; using System.Xml.Linq; +using EntityFrameworkCore.Projectables.Extensions; namespace EntityFrameworkCore.Projectables.Services { @@ -39,7 +41,10 @@ namespace EntityFrameworkCore.Projectables.Services protected override Expression VisitMethodCall(MethodCallExpression node) { - if (TryGetReflectedExpression(node.Method, out var reflectedExpression)) + // Get the overriding methodInfo based on te type of the received of this expression + var methodInfo = node.Object?.Type.GetOverridingMethod(node.Method) ?? node.Method; + + if (TryGetReflectedExpression(methodInfo, out var reflectedExpression)) { for (var parameterIndex = 0; parameterIndex < reflectedExpression.Parameters.Count; parameterIndex++) { @@ -69,7 +74,12 @@ namespace EntityFrameworkCore.Projectables.Services protected override Expression VisitMember(MemberExpression node) { - if (TryGetReflectedExpression(node.Member, out var reflectedExpression)) + var nodeMember = node.Expression switch { + { Type: { } } => node.Expression.Type.GetMember(node.Member.Name, node.Member.MemberType, BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static)[0], + _ => node.Member + }; + + if (TryGetReflectedExpression(nodeMember, out var reflectedExpression)) { if (node.Expression is not null) { diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs index ef82ed7..7cd974f 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs @@ -14,15 +14,15 @@ namespace EntityFrameworkCore.Projectables.Services { public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo) { - var reflectedType = projectableMemberInfo.ReflectedType ?? throw new InvalidOperationException("Expected a valid type here"); - var generatedContainingTypeName = ProjectionExpressionClassNameGenerator.GenerateFullName(reflectedType.Namespace, reflectedType.GetNestedTypePath().Select(x => x.Name), projectableMemberInfo.Name); + var declaringType = projectableMemberInfo.DeclaringType ?? throw new InvalidOperationException("Expected a valid type here"); + var generatedContainingTypeName = ProjectionExpressionClassNameGenerator.GenerateFullName(declaringType.Namespace, declaringType.GetNestedTypePath().Select(x => x.Name), projectableMemberInfo.Name); var genericArguments = projectableMemberInfo switch { MethodInfo methodInfo => methodInfo.GetGenericArguments(), _ => null }; - var expressionFactoryMethod = reflectedType.Assembly.GetType(generatedContainingTypeName) + var expressionFactoryMethod = declaringType.Assembly.GetType(generatedContainingTypeName) ?.GetMethods() ?.FirstOrDefault(); @@ -40,20 +40,20 @@ namespace EntityFrameworkCore.Projectables.Services if (useMemberBody is not null) { - var exprProperty = reflectedType.GetProperty(useMemberBody, BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic); + var exprProperty = declaringType.GetProperty(useMemberBody, BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic); var lambda = exprProperty?.GetValue(null) as LambdaExpression; if (lambda is not null) { if (projectableMemberInfo is PropertyInfo property && lambda.Parameters.Count == 1 && - lambda.Parameters[0].Type == reflectedType && lambda.ReturnType == property.PropertyType) + lambda.Parameters[0].Type == declaringType && lambda.ReturnType == property.PropertyType) { return lambda; } else if (projectableMemberInfo is MethodInfo method && lambda.Parameters.Count == method.GetParameters().Length + 1 && - lambda.Parameters.Last().Type == reflectedType && + lambda.Parameters.Last().Type == declaringType && !lambda.Parameters.Zip(method.GetParameters(), (a, b) => a.Type != b.ParameterType).Any()) { return lambda; @@ -62,8 +62,8 @@ namespace EntityFrameworkCore.Projectables.Services } var fullName = string.Join(".", Enumerable.Empty() - .Concat(new[] { reflectedType.Namespace }) - .Concat(reflectedType.GetNestedTypePath().Select(x => x.Name)) + .Concat(new[] { declaringType.Namespace }) + .Concat(declaringType.GetNestedTypePath().Select(x => x.Name)) .Concat(new[] { projectableMemberInfo.Name })); throw new InvalidOperationException($"Unable to resolve generated expression for {fullName}.") { diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverInheritedMethodImplementation.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverInheritedMethodImplementation.verified.txt new file mode 100644 index 0000000..6b36b12 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverInheritedMethodImplementation.verified.txt @@ -0,0 +1,2 @@ +SELECT 1 + 1 +FROM [Concrete] AS [c] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverInheritedPropertyImplementation.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverInheritedPropertyImplementation.verified.txt new file mode 100644 index 0000000..6b36b12 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverInheritedPropertyImplementation.verified.txt @@ -0,0 +1,2 @@ +SELECT 1 + 1 +FROM [Concrete] AS [c] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverOverriddenMethodImplementation.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverOverriddenMethodImplementation.verified.txt new file mode 100644 index 0000000..6b36b12 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverOverriddenMethodImplementation.verified.txt @@ -0,0 +1,2 @@ +SELECT 1 + 1 +FROM [Concrete] AS [c] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverOverriddenPropertyImplementation.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverOverriddenPropertyImplementation.verified.txt new file mode 100644 index 0000000..6b36b12 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.ProjectOverOverriddenPropertyImplementation.verified.txt @@ -0,0 +1,2 @@ +SELECT 1 + 1 +FROM [Concrete] AS [c] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.cs b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.cs new file mode 100644 index 0000000..8d8acc5 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.cs @@ -0,0 +1,95 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations.Schema; +using System.Linq; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using System.Threading.Tasks; +using EntityFrameworkCore.Projectables.FunctionalTests.Helpers; +using EntityFrameworkCore.Projectables.Services; +using Microsoft.EntityFrameworkCore; +using ScenarioTests; +using VerifyXunit; +using Xunit; + +#nullable disable + +namespace EntityFrameworkCore.Projectables.FunctionalTests +{ + + [UsesVerify] + public class InheritedModelTests + { + public abstract class Base + { + public int Id { get; set; } + + [Projectable] + public int ComputedProperty => SampleProperty + 1; + + public virtual int SampleProperty => 0; + + [Projectable] + public int ComputedMethod() => SampleMethod() + 1; + + public virtual int SampleMethod() => 0; + } + + public class Concrete : Base + { + [Projectable] + public override int SampleProperty => 1; + + [Projectable] + public override int SampleMethod() => 1; + } + + public class MoreConcrete : Concrete + { + } + + [Fact] + public Task ProjectOverOverriddenPropertyImplementation() + { + using var dbContext = new SampleDbContext(); + + var query = dbContext.Set() + .Select(x => x.ComputedProperty); + + return Verifier.Verify(query.ToQueryString()); + } + + [Fact] + public Task ProjectOverInheritedPropertyImplementation() + { + using var dbContext = new SampleDbContext(); + + var query = dbContext.Set() + .Select(x => x.ComputedProperty); + + return Verifier.Verify(query.ToQueryString()); + } + + [Fact] + public Task ProjectOverOverriddenMethodImplementation() + { + using var dbContext = new SampleDbContext(); + + var query = dbContext.Set() + .Select(x => x.ComputedMethod()); + + return Verifier.Verify(query.ToQueryString()); + } + + [Fact] + public Task ProjectOverInheritedMethodImplementation() + { + using var dbContext = new SampleDbContext(); + + var query = dbContext.Set() + .Select(x => x.ComputedMethod()); + + return Verifier.Verify(query.ToQueryString()); + } + } +} diff --git a/tests/EntityFrameworkCore.Projectables.Tests/Extensions/TypeExtensionTests.cs b/tests/EntityFrameworkCore.Projectables.Tests/Extensions/TypeExtensionTests.cs index 90df2d7..a565381 100644 --- a/tests/EntityFrameworkCore.Projectables.Tests/Extensions/TypeExtensionTests.cs +++ b/tests/EntityFrameworkCore.Projectables.Tests/Extensions/TypeExtensionTests.cs @@ -1,9 +1,11 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Reflection.Metadata.Ecma335; using System.Text; using System.Threading.Tasks; using EntityFrameworkCore.Projectables.Extensions; +using Microsoft.EntityFrameworkCore.Metadata.Conventions; using Xunit; namespace EntityFrameworkCore.Projectables.Tests.Extensions @@ -18,6 +20,18 @@ namespace EntityFrameworkCore.Projectables.Tests.Extensions } } + class BaseType + { + public virtual void VirtualMethod(int arg1) { } + public virtual void GenericVirtualMethod(TArg arg1) { } + } + + class DerivedType : BaseType + { + public override void VirtualMethod(int arg1) { } + public override void GenericVirtualMethod(TArg arg1) { } + } + [Fact] public void GetNestedTypePath_OuterType_Returns1Entry() { @@ -28,7 +42,6 @@ namespace EntityFrameworkCore.Projectables.Tests.Extensions Assert.Single(result); } - [Fact] public void GetNestedTypePath_InnerType_Returns2Entries() { @@ -59,5 +72,53 @@ namespace EntityFrameworkCore.Projectables.Tests.Extensions Assert.Equal(typeof(TypeExtensionTests), result.First()); Assert.Equal(typeof(InnerType.SubsequentlyInnerType), result.Last()); } + + [Fact] + public void GetOverridingMethod_BaseTypeVirtualMethod_FindsSameMethod() + { + var type = typeof(BaseType); + var method = typeof(BaseType).GetMethod("VirtualMethod")!; + + var result = type.GetOverridingMethod(method); + + Assert.Equal(method, result); + } + + [Fact] + public void GetOverridingMethod_DerivedTypeVirtualMethod_FindsOverridingMethod() + { + var baseType = typeof(BaseType); + var baseMethod = baseType.GetMethod("VirtualMethod")!; + var derivedType = typeof(DerivedType); + var derivedMethod = typeof(DerivedType).GetMethod("VirtualMethod")!; + + var resolvedMethod = derivedType.GetOverridingMethod(baseMethod); + + Assert.Equal(derivedMethod, resolvedMethod); + } + + [Fact] + public void GetOverridingMethod_BaseTypeGenericVirtualMethod_FindsSameMethod() + { + var type = typeof(BaseType); + var method = typeof(BaseType).GetMethod("GenericVirtualMethod")!; + + var result = type.GetOverridingMethod(method); + + Assert.Equal(method, result); + } + + [Fact] + public void GetOverridingMethod_DerivedTypeGenericVirtualMethod_FindsOverridingMethod() + { + var baseType = typeof(BaseType); + var baseMethod = baseType.GetMethod("GenericVirtualMethod")!; + var derivedType = typeof(DerivedType); + var derivedMethod = typeof(DerivedType).GetMethod("GenericVirtualMethod")!; + + var resolvedMethod = derivedType.GetOverridingMethod(baseMethod); + + Assert.Equal(derivedMethod, resolvedMethod); + } } }