diff --git a/.gitignore b/.gitignore index f2f723c..4e57e37 100644 --- a/.gitignore +++ b/.gitignore @@ -363,4 +363,6 @@ MigrationBackup/ FodyWeavers.xsd # Received verify test results -*.received.* \ No newline at end of file +*.received.* + +.idea diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs index 010fb4c..db66e3d 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs @@ -1,10 +1,10 @@ -using System; -using System.Collections.Generic; +using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Linq.Expressions; using System.Reflection; using EntityFrameworkCore.Projectables.Extensions; +using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Query; namespace EntityFrameworkCore.Projectables.Services @@ -13,11 +13,15 @@ namespace EntityFrameworkCore.Projectables.Services { readonly IProjectionExpressionResolver _resolver; readonly ExpressionArgumentReplacer _expressionArgumentReplacer = new(); + readonly QueryRootReplacer _queryRootReplacer; readonly Dictionary _projectableMemberCache = new(); + private bool _disableRootRewrite = false; + private IEntityType? _entityType; public ProjectableExpressionReplacer(IProjectionExpressionResolver projectionExpressionResolver) { _resolver = projectionExpressionResolver; + _queryRootReplacer = new(_resolver); } bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out LambdaExpression? reflectedExpression) @@ -36,11 +40,42 @@ namespace EntityFrameworkCore.Projectables.Services return reflectedExpression is not null; } + [return: NotNullIfNotNull(nameof(node))] + public override Expression? Visit(Expression? node) + { + var ret = base.Visit(node); + + if (_disableRootRewrite) + { + return ret; + } + + switch (node) + { + // Probably a First() or ToList() + case MethodCallExpression { Arguments.Count: > 0 } call when _entityType != null: + { + var self = _AddProjectableSelect(call.Arguments.First(), _entityType); + return call.Update(null, call.Arguments.Skip(1).Prepend(self)); + } + // Probably a foreach call + case QueryRootExpression root: + return _AddProjectableSelect(root, root.EntityType); + default: + return ret; + } + } + protected override Expression VisitMethodCall(MethodCallExpression node) { // Get the overriding methodInfo based on te type of the received of this expression var methodInfo = node.Object?.Type.GetConcreteMethod(node.Method) ?? node.Method; + if (methodInfo.Name == nameof(Queryable.Select)) + { + _disableRootRewrite = true; + } + if (TryGetReflectedExpression(methodInfo, out var reflectedExpression)) { for (var parameterIndex = 0; parameterIndex < reflectedExpression.Parameters.Count; parameterIndex++) @@ -110,12 +145,16 @@ namespace EntityFrameworkCore.Projectables.Services protected override Expression VisitExtension(Expression node) { - if (node is not QueryRootExpression root) + if (node is QueryRootExpression root) { - return node; + _entityType = root.EntityType; } + return base.VisitExtension(node); + } - var projectableProperties = root.EntityType.ClrType.GetProperties() + private Expression _AddProjectableSelect(Expression node, IEntityType entityType) + { + var projectableProperties = entityType.ClrType.GetProperties() .Where(x => x.IsDefined(typeof(ProjectableAttribute), false)) .Where(x => x.CanWrite) .ToList(); @@ -125,7 +164,7 @@ namespace EntityFrameworkCore.Projectables.Services return node; } - var properties = root.EntityType.GetProperties() + var properties = entityType.GetProperties() .Where(x => !x.IsShadowProperty()) .Select(x => x.GetMemberInfo(false, false)) // Remove projectable properties from the ef properties. Since properties returned here for auto @@ -140,15 +179,15 @@ namespace EntityFrameworkCore.Projectables.Services .GetGenericArguments().First() // Func .GetGenericArguments().Length == 2 // Separate between Func and Func ) - .MakeGenericMethod(root.EntityType.ClrType, root.EntityType.ClrType); - var xParam = Expression.Parameter(root.EntityType.ClrType); + .MakeGenericMethod(entityType.ClrType, entityType.ClrType); + var xParam = Expression.Parameter(entityType.ClrType); return Expression.Call( null, select, node, Expression.Lambda( Expression.MemberInit( - Expression.New(root.EntityType.ClrType), + Expression.New(entityType.ClrType), properties.Select(x => Expression.Bind(x, Expression.MakeMemberAccess(xParam, x))) .Concat(projectableProperties .Select(x => Expression.Bind(x, _ReplaceParam(_resolver.FindGeneratedExpression(x), xParam)))