Merge pull request #55 from koenbeuk/issue-52

Find overriding implementations for virtual methods
This commit is contained in:
Koen
2022-12-12 00:55:13 +00:00
committed by GitHub
16 changed files with 282 additions and 20 deletions

View File

@@ -29,7 +29,7 @@ jobs:
- name: Setup .NET
uses: actions/setup-dotnet@v1
with:
dotnet-version: 6.0.x
dotnet-version: 7.0.x
- name: Restore dependencies
run: dotnet restore
- name: Build
@@ -46,7 +46,7 @@ jobs:
- name: Setup .NET
uses: actions/setup-dotnet@v1
with:
dotnet-version: 6.0.x
dotnet-version: 7.0.x
- name: Pack
run: |
dotnet pack -v normal -c Debug --include-symbols --include-source -p:PackageVersion=2.0.0-pre-$GITHUB_RUN_ID -o nupkg

View File

@@ -26,7 +26,7 @@ jobs:
- name: Setup .NET Core
uses: actions/setup-dotnet@v1
with:
dotnet-version: 6.0.x
dotnet-version: 7.0.x
include-prerelease: True
- name: Create Release NuGet package
run: |

View File

@@ -2,7 +2,7 @@
<PropertyGroup>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<LangVersion>9.0</LangVersion>
<LangVersion>11.0</LangVersion>
<Nullable>enable</Nullable>
<EnableNETAnalyzers>true</EnableNETAnalyzers>
<NoWarn>CS1591</NoWarn>

View File

@@ -2,13 +2,13 @@
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net6.0</TargetFramework>
<TargetFramework>net7.0</TargetFramework>
<IsPackable>false</IsPackable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="BenchmarkDotNet" Version="0.13.2" />
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" Version="6.0.0" />
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" Version="7.0.0" />
</ItemGroup>
<ItemGroup>

View File

@@ -1,7 +1,10 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Reflection.Metadata;
using System.Runtime.CompilerServices;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading.Tasks;
@@ -21,5 +24,90 @@ namespace EntityFrameworkCore.Projectables.Extensions
yield return type;
}
public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo)
{
// We only need to search for virtual instance methods who are not declared on the derivedType
if (derivedType == methodInfo.DeclaringType || methodInfo.IsStatic || !methodInfo.IsVirtual)
{
return methodInfo;
}
if (!derivedType.IsAssignableTo(methodInfo.DeclaringType))
{
throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo));
}
var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
foreach (var derivedMethodInfo in derivedMethods)
{
if (HasCompatibleSignature(methodInfo, derivedMethodInfo))
{
return derivedMethodInfo;
}
}
// No derived methods were found. Return the original methodInfo
return methodInfo;
static bool HasCompatibleSignature(MethodInfo methodInfo, MethodInfo derivedMethodInfo)
{
if (methodInfo.Name != derivedMethodInfo.Name)
{
return false;
}
var methodParameters = methodInfo.GetParameters();
var derivedMethodParameters = derivedMethodInfo.GetParameters();
if (methodParameters.Length != derivedMethodParameters.Length)
{
return false;
}
// Match all parameters
for (var parameterIndex = 0; parameterIndex < methodParameters.Length; parameterIndex++)
{
var parameter = methodParameters[parameterIndex];
var derivedParameter = derivedMethodParameters[parameterIndex];
if (parameter.ParameterType.IsGenericParameter)
{
if (!derivedParameter.ParameterType.IsGenericParameter)
{
return false;
}
}
else
{
if (parameter.ParameterType != derivedParameter.ParameterType)
{
return false;
}
}
}
// Match the number of generic type arguments
if (methodInfo.IsGenericMethodDefinition)
{
var methodGenericParameters = methodInfo.GetGenericArguments();
if (!derivedMethodInfo.IsGenericMethodDefinition)
{
return false;
}
var derivedGenericArguments = derivedMethodInfo.GetGenericArguments();
if (methodGenericParameters.Length != derivedGenericArguments.Length)
{
return false;
}
}
return true;
}
}
}
}

View File

@@ -1,4 +1,5 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
@@ -7,6 +8,7 @@ using System.Reflection;
using System.Text;
using System.Threading.Tasks;
using System.Xml.Linq;
using EntityFrameworkCore.Projectables.Extensions;
namespace EntityFrameworkCore.Projectables.Services
{
@@ -39,7 +41,10 @@ namespace EntityFrameworkCore.Projectables.Services
protected override Expression VisitMethodCall(MethodCallExpression node)
{
if (TryGetReflectedExpression(node.Method, out var reflectedExpression))
// Get the overriding methodInfo based on te type of the received of this expression
var methodInfo = node.Object?.Type.GetOverridingMethod(node.Method) ?? node.Method;
if (TryGetReflectedExpression(methodInfo, out var reflectedExpression))
{
for (var parameterIndex = 0; parameterIndex < reflectedExpression.Parameters.Count; parameterIndex++)
{
@@ -69,7 +74,12 @@ namespace EntityFrameworkCore.Projectables.Services
protected override Expression VisitMember(MemberExpression node)
{
if (TryGetReflectedExpression(node.Member, out var reflectedExpression))
var nodeMember = node.Expression switch {
{ Type: { } } => node.Expression.Type.GetMember(node.Member.Name, node.Member.MemberType, BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static)[0],
_ => node.Member
};
if (TryGetReflectedExpression(nodeMember, out var reflectedExpression))
{
if (node.Expression is not null)
{

View File

@@ -14,15 +14,15 @@ namespace EntityFrameworkCore.Projectables.Services
{
public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo)
{
var reflectedType = projectableMemberInfo.ReflectedType ?? throw new InvalidOperationException("Expected a valid type here");
var generatedContainingTypeName = ProjectionExpressionClassNameGenerator.GenerateFullName(reflectedType.Namespace, reflectedType.GetNestedTypePath().Select(x => x.Name), projectableMemberInfo.Name);
var declaringType = projectableMemberInfo.DeclaringType ?? throw new InvalidOperationException("Expected a valid type here");
var generatedContainingTypeName = ProjectionExpressionClassNameGenerator.GenerateFullName(declaringType.Namespace, declaringType.GetNestedTypePath().Select(x => x.Name), projectableMemberInfo.Name);
var genericArguments = projectableMemberInfo switch {
MethodInfo methodInfo => methodInfo.GetGenericArguments(),
_ => null
};
var expressionFactoryMethod = reflectedType.Assembly.GetType(generatedContainingTypeName)
var expressionFactoryMethod = declaringType.Assembly.GetType(generatedContainingTypeName)
?.GetMethods()
?.FirstOrDefault();
@@ -40,20 +40,20 @@ namespace EntityFrameworkCore.Projectables.Services
if (useMemberBody is not null)
{
var exprProperty = reflectedType.GetProperty(useMemberBody, BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic);
var exprProperty = declaringType.GetProperty(useMemberBody, BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic);
var lambda = exprProperty?.GetValue(null) as LambdaExpression;
if (lambda is not null)
{
if (projectableMemberInfo is PropertyInfo property &&
lambda.Parameters.Count == 1 &&
lambda.Parameters[0].Type == reflectedType && lambda.ReturnType == property.PropertyType)
lambda.Parameters[0].Type == declaringType && lambda.ReturnType == property.PropertyType)
{
return lambda;
}
else if (projectableMemberInfo is MethodInfo method &&
lambda.Parameters.Count == method.GetParameters().Length + 1 &&
lambda.Parameters.Last().Type == reflectedType &&
lambda.Parameters.Last().Type == declaringType &&
!lambda.Parameters.Zip(method.GetParameters(), (a, b) => a.Type != b.ParameterType).Any())
{
return lambda;
@@ -62,8 +62,8 @@ namespace EntityFrameworkCore.Projectables.Services
}
var fullName = string.Join(".", Enumerable.Empty<string>()
.Concat(new[] { reflectedType.Namespace })
.Concat(reflectedType.GetNestedTypePath().Select(x => x.Name))
.Concat(new[] { declaringType.Namespace })
.Concat(declaringType.GetNestedTypePath().Select(x => x.Name))
.Concat(new[] { projectableMemberInfo.Name }));
throw new InvalidOperationException($"Unable to resolve generated expression for {fullName}.") {

View File

@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net6.0</TargetFramework>
<TargetFramework>net7.0</TargetFramework>
<IsPackable>false</IsPackable>
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
<CompilerGeneratedFilesOutputPath>$(BaseIntermediateOutputPath)Generated</CompilerGeneratedFilesOutputPath>

View File

@@ -0,0 +1,95 @@
using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations.Schema;
using System.Linq;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading.Tasks;
using EntityFrameworkCore.Projectables.FunctionalTests.Helpers;
using EntityFrameworkCore.Projectables.Services;
using Microsoft.EntityFrameworkCore;
using ScenarioTests;
using VerifyXunit;
using Xunit;
#nullable disable
namespace EntityFrameworkCore.Projectables.FunctionalTests
{
[UsesVerify]
public class InheritedModelTests
{
public abstract class Base
{
public int Id { get; set; }
[Projectable]
public int ComputedProperty => SampleProperty + 1;
public virtual int SampleProperty => 0;
[Projectable]
public int ComputedMethod() => SampleMethod() + 1;
public virtual int SampleMethod() => 0;
}
public class Concrete : Base
{
[Projectable]
public override int SampleProperty => 1;
[Projectable]
public override int SampleMethod() => 1;
}
public class MoreConcrete : Concrete
{
}
[Fact]
public Task ProjectOverOverriddenPropertyImplementation()
{
using var dbContext = new SampleDbContext<Concrete>();
var query = dbContext.Set<Concrete>()
.Select(x => x.ComputedProperty);
return Verifier.Verify(query.ToQueryString());
}
[Fact]
public Task ProjectOverInheritedPropertyImplementation()
{
using var dbContext = new SampleDbContext<Concrete>();
var query = dbContext.Set<Concrete>()
.Select(x => x.ComputedProperty);
return Verifier.Verify(query.ToQueryString());
}
[Fact]
public Task ProjectOverOverriddenMethodImplementation()
{
using var dbContext = new SampleDbContext<Concrete>();
var query = dbContext.Set<Concrete>()
.Select(x => x.ComputedMethod());
return Verifier.Verify(query.ToQueryString());
}
[Fact]
public Task ProjectOverInheritedMethodImplementation()
{
using var dbContext = new SampleDbContext<Concrete>();
var query = dbContext.Set<Concrete>()
.Select(x => x.ComputedMethod());
return Verifier.Verify(query.ToQueryString());
}
}
}

View File

@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net6.0</TargetFramework>
<TargetFramework>net7.0</TargetFramework>
<IsPackable>false</IsPackable>
</PropertyGroup>

View File

@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net6.0</TargetFramework>
<TargetFramework>net7.0</TargetFramework>
<IsPackable>false</IsPackable>
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>

View File

@@ -1,9 +1,11 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection.Metadata.Ecma335;
using System.Text;
using System.Threading.Tasks;
using EntityFrameworkCore.Projectables.Extensions;
using Microsoft.EntityFrameworkCore.Metadata.Conventions;
using Xunit;
namespace EntityFrameworkCore.Projectables.Tests.Extensions
@@ -18,6 +20,18 @@ namespace EntityFrameworkCore.Projectables.Tests.Extensions
}
}
class BaseType
{
public virtual void VirtualMethod(int arg1) { }
public virtual void GenericVirtualMethod<TArg>(TArg arg1) { }
}
class DerivedType : BaseType
{
public override void VirtualMethod(int arg1) { }
public override void GenericVirtualMethod<TArg>(TArg arg1) { }
}
[Fact]
public void GetNestedTypePath_OuterType_Returns1Entry()
{
@@ -28,7 +42,6 @@ namespace EntityFrameworkCore.Projectables.Tests.Extensions
Assert.Single(result);
}
[Fact]
public void GetNestedTypePath_InnerType_Returns2Entries()
{
@@ -59,5 +72,53 @@ namespace EntityFrameworkCore.Projectables.Tests.Extensions
Assert.Equal(typeof(TypeExtensionTests), result.First());
Assert.Equal(typeof(InnerType.SubsequentlyInnerType), result.Last());
}
[Fact]
public void GetOverridingMethod_BaseTypeVirtualMethod_FindsSameMethod()
{
var type = typeof(BaseType);
var method = typeof(BaseType).GetMethod("VirtualMethod")!;
var result = type.GetOverridingMethod(method);
Assert.Equal(method, result);
}
[Fact]
public void GetOverridingMethod_DerivedTypeVirtualMethod_FindsOverridingMethod()
{
var baseType = typeof(BaseType);
var baseMethod = baseType.GetMethod("VirtualMethod")!;
var derivedType = typeof(DerivedType);
var derivedMethod = typeof(DerivedType).GetMethod("VirtualMethod")!;
var resolvedMethod = derivedType.GetOverridingMethod(baseMethod);
Assert.Equal(derivedMethod, resolvedMethod);
}
[Fact]
public void GetOverridingMethod_BaseTypeGenericVirtualMethod_FindsSameMethod()
{
var type = typeof(BaseType);
var method = typeof(BaseType).GetMethod("GenericVirtualMethod")!;
var result = type.GetOverridingMethod(method);
Assert.Equal(method, result);
}
[Fact]
public void GetOverridingMethod_DerivedTypeGenericVirtualMethod_FindsOverridingMethod()
{
var baseType = typeof(BaseType);
var baseMethod = baseType.GetMethod("GenericVirtualMethod")!;
var derivedType = typeof(DerivedType);
var derivedMethod = typeof(DerivedType).GetMethod("GenericVirtualMethod")!;
var resolvedMethod = derivedType.GetOverridingMethod(baseMethod);
Assert.Equal(derivedMethod, resolvedMethod);
}
}
}