Fix interface property lookup in generic method

This commit is contained in:
Rhodon
2023-05-04 14:09:11 +02:00
parent 3b98b1ef36
commit aab4834002
2 changed files with 66 additions and 32 deletions
@@ -41,23 +41,8 @@ namespace EntityFrameworkCore.Projectables.Extensions
return true; return true;
} }
private static int? GetOverridingMethodIndex(this MethodInfo methodInfo, MethodInfo[]? allDerivedMethods) private static bool IsOverridingMethodOf(this MethodInfo methodInfo, MethodInfo baseDefinition)
{ => methodInfo.GetBaseDefinition() == baseDefinition;
if (allDerivedMethods is { Length: > 0 })
{
var baseDefinition = methodInfo.GetBaseDefinition();
for (var i = 0; i < allDerivedMethods.Length; i++)
{
var derivedMethodInfo = allDerivedMethods[i];
if (derivedMethodInfo.GetBaseDefinition() == baseDefinition)
{
return i;
}
}
}
return null;
}
public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo) public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo)
{ {
@@ -68,31 +53,38 @@ namespace EntityFrameworkCore.Projectables.Extensions
var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
return methodInfo.GetOverridingMethodIndex(derivedMethods) is { } i MethodInfo? overridingMethod = null;
? derivedMethods[i] if (derivedMethods is { Length: > 0 })
// No derived methods were found. Return the original methodInfo {
: methodInfo; var baseDefinition = methodInfo.GetBaseDefinition();
overridingMethod = derivedMethods.FirstOrDefault(derivedMethodInfo
=> derivedMethodInfo.IsOverridingMethodOf(baseDefinition));
}
return overridingMethod ?? methodInfo; // If no derived methods were found, return the original methodInfo
} }
public static PropertyInfo GetOverridingProperty(this Type derivedType, PropertyInfo propertyInfo) public static PropertyInfo GetOverridingProperty(this Type derivedType, PropertyInfo propertyInfo)
{ {
var accessor = propertyInfo.GetAccessors()[0]; var accessor = propertyInfo.GetAccessors(true).FirstOrDefault(derivedType.CanHaveOverridingMethod);
if (accessor is null)
if (!derivedType.CanHaveOverridingMethod(accessor))
{ {
return propertyInfo; return propertyInfo;
} }
var isGetAccessor = propertyInfo.GetMethod == accessor;
var derivedProperties = derivedType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); var derivedProperties = derivedType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
var derivedPropertyMethods = derivedProperties
.Select((Func<PropertyInfo, MethodInfo?>)
(propertyInfo.GetMethod == accessor ? p => p.GetMethod : p => p.SetMethod))
.OfType<MethodInfo>().ToArray();
return accessor.GetOverridingMethodIndex(derivedPropertyMethods) is { } i PropertyInfo? overridingProperty = null;
? derivedProperties[i] if (derivedProperties is { Length: > 0 })
// No derived methods were found. Return the original methodInfo {
: propertyInfo; var baseDefinition = accessor.GetBaseDefinition();
overridingProperty = derivedProperties.FirstOrDefault(p
=> (isGetAccessor ? p.GetMethod : p.SetMethod)?.IsOverridingMethodOf(baseDefinition) == true);
}
return overridingProperty ?? propertyInfo; // If no derived methods were found, return the original methodInfo
} }
public static MethodInfo GetImplementingMethod(this Type derivedType, MethodInfo methodInfo) public static MethodInfo GetImplementingMethod(this Type derivedType, MethodInfo methodInfo)
@@ -20,8 +20,20 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests
[UsesVerify] [UsesVerify]
public class InheritedModelTests public class InheritedModelTests
{ {
public interface IBaseProvider<TBase>
{
ICollection<TBase> Bases { get; set; }
}
public class BaseProvider : IBaseProvider<Concrete>
{
public int Id { get; set; }
public ICollection<Concrete> Bases { get; set; }
}
public interface IBase public interface IBase
{ {
int Id { get; }
int ComputedProperty { get; } int ComputedProperty { get; }
int ComputedMethod(); int ComputedMethod();
} }
@@ -117,6 +129,26 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests
return Verifier.Verify(query.ToQueryString()); return Verifier.Verify(query.ToQueryString());
} }
[Fact]
public Task ProjectOverProvider()
{
using var dbContext = new SampleDbContext<BaseProvider>();
var query = dbContext.Set<BaseProvider>().AllBases<BaseProvider, Concrete>();
return Verifier.Verify(query.ToQueryString());
}
[Fact]
public Task ProjectOverExtensionMethod()
{
using var dbContext = new SampleDbContext<Concrete>();
var query = dbContext.Set<Concrete>().Select(c => c.ComputedPropertyPlusMethod());
return Verifier.Verify(query.ToQueryString());
}
} }
public static class ModelExtensions public static class ModelExtensions
@@ -128,5 +160,15 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests
public static IQueryable<int> SelectComputedMethod<TConcrete>(this IQueryable<TConcrete> concretes) public static IQueryable<int> SelectComputedMethod<TConcrete>(this IQueryable<TConcrete> concretes)
where TConcrete : InheritedModelTests.IBase where TConcrete : InheritedModelTests.IBase
=> concretes.Select(x => x.ComputedMethod()); => concretes.Select(x => x.ComputedMethod());
public static IQueryable<int> AllBases<TProvider, TBase>(this IQueryable<TProvider> concretes)
where TProvider : InheritedModelTests.IBaseProvider<TBase>
where TBase : InheritedModelTests.IBase
=> concretes.SelectMany(x => x.Bases).Select(x => x.Id);
[Projectable]
public static int ComputedPropertyPlusMethod<TConcrete>(this TConcrete concrete)
where TConcrete : InheritedModelTests.IBase
=> concrete.ComputedProperty + concrete.ComputedMethod();
} }
} }