Use a second path for rewritting query path

This commit is contained in:
2023-10-10 17:45:27 +02:00
parent e5eae5bf5a
commit 669b02a9f9
2 changed files with 51 additions and 10 deletions

4
.gitignore vendored
View File

@@ -363,4 +363,6 @@ MigrationBackup/
FodyWeavers.xsd FodyWeavers.xsd
# Received verify test results # Received verify test results
*.received.* *.received.*
.idea

View File

@@ -1,10 +1,10 @@
using System; using System.Collections.Generic;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.Linq; using System.Linq;
using System.Linq.Expressions; using System.Linq.Expressions;
using System.Reflection; using System.Reflection;
using EntityFrameworkCore.Projectables.Extensions; using EntityFrameworkCore.Projectables.Extensions;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query; using Microsoft.EntityFrameworkCore.Query;
namespace EntityFrameworkCore.Projectables.Services namespace EntityFrameworkCore.Projectables.Services
@@ -13,11 +13,15 @@ namespace EntityFrameworkCore.Projectables.Services
{ {
readonly IProjectionExpressionResolver _resolver; readonly IProjectionExpressionResolver _resolver;
readonly ExpressionArgumentReplacer _expressionArgumentReplacer = new(); readonly ExpressionArgumentReplacer _expressionArgumentReplacer = new();
readonly QueryRootReplacer _queryRootReplacer;
readonly Dictionary<MemberInfo, LambdaExpression?> _projectableMemberCache = new(); readonly Dictionary<MemberInfo, LambdaExpression?> _projectableMemberCache = new();
private bool _disableRootRewrite = false;
private IEntityType? _entityType;
public ProjectableExpressionReplacer(IProjectionExpressionResolver projectionExpressionResolver) public ProjectableExpressionReplacer(IProjectionExpressionResolver projectionExpressionResolver)
{ {
_resolver = projectionExpressionResolver; _resolver = projectionExpressionResolver;
_queryRootReplacer = new(_resolver);
} }
bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out LambdaExpression? reflectedExpression) bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out LambdaExpression? reflectedExpression)
@@ -36,11 +40,42 @@ namespace EntityFrameworkCore.Projectables.Services
return reflectedExpression is not null; return reflectedExpression is not null;
} }
[return: NotNullIfNotNull(nameof(node))]
public override Expression? Visit(Expression? node)
{
var ret = base.Visit(node);
if (_disableRootRewrite)
{
return ret;
}
switch (node)
{
// Probably a First() or ToList()
case MethodCallExpression { Arguments.Count: > 0 } call when _entityType != null:
{
var self = _AddProjectableSelect(call.Arguments.First(), _entityType);
return call.Update(null, call.Arguments.Skip(1).Prepend(self));
}
// Probably a foreach call
case QueryRootExpression root:
return _AddProjectableSelect(root, root.EntityType);
default:
return ret;
}
}
protected override Expression VisitMethodCall(MethodCallExpression node) protected override Expression VisitMethodCall(MethodCallExpression node)
{ {
// Get the overriding methodInfo based on te type of the received of this expression // Get the overriding methodInfo based on te type of the received of this expression
var methodInfo = node.Object?.Type.GetConcreteMethod(node.Method) ?? node.Method; var methodInfo = node.Object?.Type.GetConcreteMethod(node.Method) ?? node.Method;
if (methodInfo.Name == nameof(Queryable.Select))
{
_disableRootRewrite = true;
}
if (TryGetReflectedExpression(methodInfo, out var reflectedExpression)) if (TryGetReflectedExpression(methodInfo, out var reflectedExpression))
{ {
for (var parameterIndex = 0; parameterIndex < reflectedExpression.Parameters.Count; parameterIndex++) for (var parameterIndex = 0; parameterIndex < reflectedExpression.Parameters.Count; parameterIndex++)
@@ -110,12 +145,16 @@ namespace EntityFrameworkCore.Projectables.Services
protected override Expression VisitExtension(Expression node) protected override Expression VisitExtension(Expression node)
{ {
if (node is not QueryRootExpression root) if (node is QueryRootExpression root)
{ {
return node; _entityType = root.EntityType;
} }
return base.VisitExtension(node);
}
var projectableProperties = root.EntityType.ClrType.GetProperties() private Expression _AddProjectableSelect(Expression node, IEntityType entityType)
{
var projectableProperties = entityType.ClrType.GetProperties()
.Where(x => x.IsDefined(typeof(ProjectableAttribute), false)) .Where(x => x.IsDefined(typeof(ProjectableAttribute), false))
.Where(x => x.CanWrite) .Where(x => x.CanWrite)
.ToList(); .ToList();
@@ -125,7 +164,7 @@ namespace EntityFrameworkCore.Projectables.Services
return node; return node;
} }
var properties = root.EntityType.GetProperties() var properties = entityType.GetProperties()
.Where(x => !x.IsShadowProperty()) .Where(x => !x.IsShadowProperty())
.Select(x => x.GetMemberInfo(false, false)) .Select(x => x.GetMemberInfo(false, false))
// Remove projectable properties from the ef properties. Since properties returned here for auto // Remove projectable properties from the ef properties. Since properties returned here for auto
@@ -140,15 +179,15 @@ namespace EntityFrameworkCore.Projectables.Services
.GetGenericArguments().First() // Func<T, Ret> .GetGenericArguments().First() // Func<T, Ret>
.GetGenericArguments().Length == 2 // Separate between Func<T, Ret> and Func<T, int, Ret> .GetGenericArguments().Length == 2 // Separate between Func<T, Ret> and Func<T, int, Ret>
) )
.MakeGenericMethod(root.EntityType.ClrType, root.EntityType.ClrType); .MakeGenericMethod(entityType.ClrType, entityType.ClrType);
var xParam = Expression.Parameter(root.EntityType.ClrType); var xParam = Expression.Parameter(entityType.ClrType);
return Expression.Call( return Expression.Call(
null, null,
select, select,
node, node,
Expression.Lambda( Expression.Lambda(
Expression.MemberInit( Expression.MemberInit(
Expression.New(root.EntityType.ClrType), Expression.New(entityType.ClrType),
properties.Select(x => Expression.Bind(x, Expression.MakeMemberAccess(xParam, x))) properties.Select(x => Expression.Bind(x, Expression.MakeMemberAccess(xParam, x)))
.Concat(projectableProperties .Concat(projectableProperties
.Select(x => Expression.Bind(x, _ReplaceParam(_resolver.FindGeneratedExpression(x), xParam))) .Select(x => Expression.Bind(x, _ReplaceParam(_resolver.FindGeneratedExpression(x), xParam)))