mirror of
https://github.com/zoriya/EntityFrameworkCore.Projectables.git
synced 2025-12-06 05:56:10 +00:00
Merge pull request #55 from koenbeuk/issue-52
Find overriding implementations for virtual methods
This commit is contained in:
4
.github/workflows/build.yml
vendored
4
.github/workflows/build.yml
vendored
@@ -29,7 +29,7 @@ jobs:
|
||||
- name: Setup .NET
|
||||
uses: actions/setup-dotnet@v1
|
||||
with:
|
||||
dotnet-version: 6.0.x
|
||||
dotnet-version: 7.0.x
|
||||
- name: Restore dependencies
|
||||
run: dotnet restore
|
||||
- name: Build
|
||||
@@ -46,7 +46,7 @@ jobs:
|
||||
- name: Setup .NET
|
||||
uses: actions/setup-dotnet@v1
|
||||
with:
|
||||
dotnet-version: 6.0.x
|
||||
dotnet-version: 7.0.x
|
||||
- name: Pack
|
||||
run: |
|
||||
dotnet pack -v normal -c Debug --include-symbols --include-source -p:PackageVersion=2.0.0-pre-$GITHUB_RUN_ID -o nupkg
|
||||
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
- name: Setup .NET Core
|
||||
uses: actions/setup-dotnet@v1
|
||||
with:
|
||||
dotnet-version: 6.0.x
|
||||
dotnet-version: 7.0.x
|
||||
include-prerelease: True
|
||||
- name: Create Release NuGet package
|
||||
run: |
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
<PropertyGroup>
|
||||
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
|
||||
<LangVersion>9.0</LangVersion>
|
||||
<LangVersion>11.0</LangVersion>
|
||||
<Nullable>enable</Nullable>
|
||||
<EnableNETAnalyzers>true</EnableNETAnalyzers>
|
||||
<NoWarn>CS1591</NoWarn>
|
||||
|
||||
@@ -2,13 +2,13 @@
|
||||
|
||||
<PropertyGroup>
|
||||
<OutputType>Exe</OutputType>
|
||||
<TargetFramework>net6.0</TargetFramework>
|
||||
<TargetFramework>net7.0</TargetFramework>
|
||||
<IsPackable>false</IsPackable>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="BenchmarkDotNet" Version="0.13.2" />
|
||||
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" Version="6.0.0" />
|
||||
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" Version="7.0.0" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Reflection;
|
||||
using System.Reflection.Metadata;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
@@ -21,5 +24,90 @@ namespace EntityFrameworkCore.Projectables.Extensions
|
||||
|
||||
yield return type;
|
||||
}
|
||||
|
||||
public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo)
|
||||
{
|
||||
// We only need to search for virtual instance methods who are not declared on the derivedType
|
||||
if (derivedType == methodInfo.DeclaringType || methodInfo.IsStatic || !methodInfo.IsVirtual)
|
||||
{
|
||||
return methodInfo;
|
||||
}
|
||||
|
||||
if (!derivedType.IsAssignableTo(methodInfo.DeclaringType))
|
||||
{
|
||||
throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo));
|
||||
}
|
||||
|
||||
var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
|
||||
|
||||
foreach (var derivedMethodInfo in derivedMethods)
|
||||
{
|
||||
if (HasCompatibleSignature(methodInfo, derivedMethodInfo))
|
||||
{
|
||||
return derivedMethodInfo;
|
||||
}
|
||||
}
|
||||
|
||||
// No derived methods were found. Return the original methodInfo
|
||||
return methodInfo;
|
||||
|
||||
static bool HasCompatibleSignature(MethodInfo methodInfo, MethodInfo derivedMethodInfo)
|
||||
{
|
||||
if (methodInfo.Name != derivedMethodInfo.Name)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var methodParameters = methodInfo.GetParameters();
|
||||
|
||||
var derivedMethodParameters = derivedMethodInfo.GetParameters();
|
||||
if (methodParameters.Length != derivedMethodParameters.Length)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Match all parameters
|
||||
for (var parameterIndex = 0; parameterIndex < methodParameters.Length; parameterIndex++)
|
||||
{
|
||||
var parameter = methodParameters[parameterIndex];
|
||||
var derivedParameter = derivedMethodParameters[parameterIndex];
|
||||
|
||||
if (parameter.ParameterType.IsGenericParameter)
|
||||
{
|
||||
if (!derivedParameter.ParameterType.IsGenericParameter)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (parameter.ParameterType != derivedParameter.ParameterType)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Match the number of generic type arguments
|
||||
if (methodInfo.IsGenericMethodDefinition)
|
||||
{
|
||||
var methodGenericParameters = methodInfo.GetGenericArguments();
|
||||
|
||||
if (!derivedMethodInfo.IsGenericMethodDefinition)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var derivedGenericArguments = derivedMethodInfo.GetGenericArguments();
|
||||
|
||||
if (methodGenericParameters.Length != derivedGenericArguments.Length)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
using System;
|
||||
using System.Buffers;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics.CodeAnalysis;
|
||||
using System.Linq;
|
||||
@@ -7,6 +8,7 @@ using System.Reflection;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using System.Xml.Linq;
|
||||
using EntityFrameworkCore.Projectables.Extensions;
|
||||
|
||||
namespace EntityFrameworkCore.Projectables.Services
|
||||
{
|
||||
@@ -39,7 +41,10 @@ namespace EntityFrameworkCore.Projectables.Services
|
||||
|
||||
protected override Expression VisitMethodCall(MethodCallExpression node)
|
||||
{
|
||||
if (TryGetReflectedExpression(node.Method, out var reflectedExpression))
|
||||
// Get the overriding methodInfo based on te type of the received of this expression
|
||||
var methodInfo = node.Object?.Type.GetOverridingMethod(node.Method) ?? node.Method;
|
||||
|
||||
if (TryGetReflectedExpression(methodInfo, out var reflectedExpression))
|
||||
{
|
||||
for (var parameterIndex = 0; parameterIndex < reflectedExpression.Parameters.Count; parameterIndex++)
|
||||
{
|
||||
@@ -69,7 +74,12 @@ namespace EntityFrameworkCore.Projectables.Services
|
||||
|
||||
protected override Expression VisitMember(MemberExpression node)
|
||||
{
|
||||
if (TryGetReflectedExpression(node.Member, out var reflectedExpression))
|
||||
var nodeMember = node.Expression switch {
|
||||
{ Type: { } } => node.Expression.Type.GetMember(node.Member.Name, node.Member.MemberType, BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static)[0],
|
||||
_ => node.Member
|
||||
};
|
||||
|
||||
if (TryGetReflectedExpression(nodeMember, out var reflectedExpression))
|
||||
{
|
||||
if (node.Expression is not null)
|
||||
{
|
||||
|
||||
@@ -14,15 +14,15 @@ namespace EntityFrameworkCore.Projectables.Services
|
||||
{
|
||||
public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo)
|
||||
{
|
||||
var reflectedType = projectableMemberInfo.ReflectedType ?? throw new InvalidOperationException("Expected a valid type here");
|
||||
var generatedContainingTypeName = ProjectionExpressionClassNameGenerator.GenerateFullName(reflectedType.Namespace, reflectedType.GetNestedTypePath().Select(x => x.Name), projectableMemberInfo.Name);
|
||||
var declaringType = projectableMemberInfo.DeclaringType ?? throw new InvalidOperationException("Expected a valid type here");
|
||||
var generatedContainingTypeName = ProjectionExpressionClassNameGenerator.GenerateFullName(declaringType.Namespace, declaringType.GetNestedTypePath().Select(x => x.Name), projectableMemberInfo.Name);
|
||||
|
||||
var genericArguments = projectableMemberInfo switch {
|
||||
MethodInfo methodInfo => methodInfo.GetGenericArguments(),
|
||||
_ => null
|
||||
};
|
||||
|
||||
var expressionFactoryMethod = reflectedType.Assembly.GetType(generatedContainingTypeName)
|
||||
var expressionFactoryMethod = declaringType.Assembly.GetType(generatedContainingTypeName)
|
||||
?.GetMethods()
|
||||
?.FirstOrDefault();
|
||||
|
||||
@@ -40,20 +40,20 @@ namespace EntityFrameworkCore.Projectables.Services
|
||||
|
||||
if (useMemberBody is not null)
|
||||
{
|
||||
var exprProperty = reflectedType.GetProperty(useMemberBody, BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic);
|
||||
var exprProperty = declaringType.GetProperty(useMemberBody, BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic);
|
||||
var lambda = exprProperty?.GetValue(null) as LambdaExpression;
|
||||
|
||||
if (lambda is not null)
|
||||
{
|
||||
if (projectableMemberInfo is PropertyInfo property &&
|
||||
lambda.Parameters.Count == 1 &&
|
||||
lambda.Parameters[0].Type == reflectedType && lambda.ReturnType == property.PropertyType)
|
||||
lambda.Parameters[0].Type == declaringType && lambda.ReturnType == property.PropertyType)
|
||||
{
|
||||
return lambda;
|
||||
}
|
||||
else if (projectableMemberInfo is MethodInfo method &&
|
||||
lambda.Parameters.Count == method.GetParameters().Length + 1 &&
|
||||
lambda.Parameters.Last().Type == reflectedType &&
|
||||
lambda.Parameters.Last().Type == declaringType &&
|
||||
!lambda.Parameters.Zip(method.GetParameters(), (a, b) => a.Type != b.ParameterType).Any())
|
||||
{
|
||||
return lambda;
|
||||
@@ -62,8 +62,8 @@ namespace EntityFrameworkCore.Projectables.Services
|
||||
}
|
||||
|
||||
var fullName = string.Join(".", Enumerable.Empty<string>()
|
||||
.Concat(new[] { reflectedType.Namespace })
|
||||
.Concat(reflectedType.GetNestedTypePath().Select(x => x.Name))
|
||||
.Concat(new[] { declaringType.Namespace })
|
||||
.Concat(declaringType.GetNestedTypePath().Select(x => x.Name))
|
||||
.Concat(new[] { projectableMemberInfo.Name }));
|
||||
|
||||
throw new InvalidOperationException($"Unable to resolve generated expression for {fullName}.") {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<TargetFramework>net6.0</TargetFramework>
|
||||
<TargetFramework>net7.0</TargetFramework>
|
||||
<IsPackable>false</IsPackable>
|
||||
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
|
||||
<CompilerGeneratedFilesOutputPath>$(BaseIntermediateOutputPath)Generated</CompilerGeneratedFilesOutputPath>
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
SELECT 1 + 1
|
||||
FROM [Concrete] AS [c]
|
||||
@@ -0,0 +1,2 @@
|
||||
SELECT 1 + 1
|
||||
FROM [Concrete] AS [c]
|
||||
@@ -0,0 +1,2 @@
|
||||
SELECT 1 + 1
|
||||
FROM [Concrete] AS [c]
|
||||
@@ -0,0 +1,2 @@
|
||||
SELECT 1 + 1
|
||||
FROM [Concrete] AS [c]
|
||||
@@ -0,0 +1,95 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.ComponentModel.DataAnnotations.Schema;
|
||||
using System.Linq;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using EntityFrameworkCore.Projectables.FunctionalTests.Helpers;
|
||||
using EntityFrameworkCore.Projectables.Services;
|
||||
using Microsoft.EntityFrameworkCore;
|
||||
using ScenarioTests;
|
||||
using VerifyXunit;
|
||||
using Xunit;
|
||||
|
||||
#nullable disable
|
||||
|
||||
namespace EntityFrameworkCore.Projectables.FunctionalTests
|
||||
{
|
||||
|
||||
[UsesVerify]
|
||||
public class InheritedModelTests
|
||||
{
|
||||
public abstract class Base
|
||||
{
|
||||
public int Id { get; set; }
|
||||
|
||||
[Projectable]
|
||||
public int ComputedProperty => SampleProperty + 1;
|
||||
|
||||
public virtual int SampleProperty => 0;
|
||||
|
||||
[Projectable]
|
||||
public int ComputedMethod() => SampleMethod() + 1;
|
||||
|
||||
public virtual int SampleMethod() => 0;
|
||||
}
|
||||
|
||||
public class Concrete : Base
|
||||
{
|
||||
[Projectable]
|
||||
public override int SampleProperty => 1;
|
||||
|
||||
[Projectable]
|
||||
public override int SampleMethod() => 1;
|
||||
}
|
||||
|
||||
public class MoreConcrete : Concrete
|
||||
{
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public Task ProjectOverOverriddenPropertyImplementation()
|
||||
{
|
||||
using var dbContext = new SampleDbContext<Concrete>();
|
||||
|
||||
var query = dbContext.Set<Concrete>()
|
||||
.Select(x => x.ComputedProperty);
|
||||
|
||||
return Verifier.Verify(query.ToQueryString());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public Task ProjectOverInheritedPropertyImplementation()
|
||||
{
|
||||
using var dbContext = new SampleDbContext<Concrete>();
|
||||
|
||||
var query = dbContext.Set<Concrete>()
|
||||
.Select(x => x.ComputedProperty);
|
||||
|
||||
return Verifier.Verify(query.ToQueryString());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public Task ProjectOverOverriddenMethodImplementation()
|
||||
{
|
||||
using var dbContext = new SampleDbContext<Concrete>();
|
||||
|
||||
var query = dbContext.Set<Concrete>()
|
||||
.Select(x => x.ComputedMethod());
|
||||
|
||||
return Verifier.Verify(query.ToQueryString());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public Task ProjectOverInheritedMethodImplementation()
|
||||
{
|
||||
using var dbContext = new SampleDbContext<Concrete>();
|
||||
|
||||
var query = dbContext.Set<Concrete>()
|
||||
.Select(x => x.ComputedMethod());
|
||||
|
||||
return Verifier.Verify(query.ToQueryString());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<TargetFramework>net6.0</TargetFramework>
|
||||
<TargetFramework>net7.0</TargetFramework>
|
||||
<IsPackable>false</IsPackable>
|
||||
</PropertyGroup>
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<TargetFramework>net6.0</TargetFramework>
|
||||
<TargetFramework>net7.0</TargetFramework>
|
||||
<IsPackable>false</IsPackable>
|
||||
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Reflection.Metadata.Ecma335;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using EntityFrameworkCore.Projectables.Extensions;
|
||||
using Microsoft.EntityFrameworkCore.Metadata.Conventions;
|
||||
using Xunit;
|
||||
|
||||
namespace EntityFrameworkCore.Projectables.Tests.Extensions
|
||||
@@ -18,6 +20,18 @@ namespace EntityFrameworkCore.Projectables.Tests.Extensions
|
||||
}
|
||||
}
|
||||
|
||||
class BaseType
|
||||
{
|
||||
public virtual void VirtualMethod(int arg1) { }
|
||||
public virtual void GenericVirtualMethod<TArg>(TArg arg1) { }
|
||||
}
|
||||
|
||||
class DerivedType : BaseType
|
||||
{
|
||||
public override void VirtualMethod(int arg1) { }
|
||||
public override void GenericVirtualMethod<TArg>(TArg arg1) { }
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetNestedTypePath_OuterType_Returns1Entry()
|
||||
{
|
||||
@@ -28,7 +42,6 @@ namespace EntityFrameworkCore.Projectables.Tests.Extensions
|
||||
Assert.Single(result);
|
||||
}
|
||||
|
||||
|
||||
[Fact]
|
||||
public void GetNestedTypePath_InnerType_Returns2Entries()
|
||||
{
|
||||
@@ -59,5 +72,53 @@ namespace EntityFrameworkCore.Projectables.Tests.Extensions
|
||||
Assert.Equal(typeof(TypeExtensionTests), result.First());
|
||||
Assert.Equal(typeof(InnerType.SubsequentlyInnerType), result.Last());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetOverridingMethod_BaseTypeVirtualMethod_FindsSameMethod()
|
||||
{
|
||||
var type = typeof(BaseType);
|
||||
var method = typeof(BaseType).GetMethod("VirtualMethod")!;
|
||||
|
||||
var result = type.GetOverridingMethod(method);
|
||||
|
||||
Assert.Equal(method, result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetOverridingMethod_DerivedTypeVirtualMethod_FindsOverridingMethod()
|
||||
{
|
||||
var baseType = typeof(BaseType);
|
||||
var baseMethod = baseType.GetMethod("VirtualMethod")!;
|
||||
var derivedType = typeof(DerivedType);
|
||||
var derivedMethod = typeof(DerivedType).GetMethod("VirtualMethod")!;
|
||||
|
||||
var resolvedMethod = derivedType.GetOverridingMethod(baseMethod);
|
||||
|
||||
Assert.Equal(derivedMethod, resolvedMethod);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetOverridingMethod_BaseTypeGenericVirtualMethod_FindsSameMethod()
|
||||
{
|
||||
var type = typeof(BaseType);
|
||||
var method = typeof(BaseType).GetMethod("GenericVirtualMethod")!;
|
||||
|
||||
var result = type.GetOverridingMethod(method);
|
||||
|
||||
Assert.Equal(method, result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetOverridingMethod_DerivedTypeGenericVirtualMethod_FindsOverridingMethod()
|
||||
{
|
||||
var baseType = typeof(BaseType);
|
||||
var baseMethod = baseType.GetMethod("GenericVirtualMethod")!;
|
||||
var derivedType = typeof(DerivedType);
|
||||
var derivedMethod = typeof(DerivedType).GetMethod("GenericVirtualMethod")!;
|
||||
|
||||
var resolvedMethod = derivedType.GetOverridingMethod(baseMethod);
|
||||
|
||||
Assert.Equal(derivedMethod, resolvedMethod);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user