Handle cases like .First(where) and .Sum

This commit is contained in:
2023-10-11 01:49:38 +02:00
parent aa7cf4f9c6
commit 4fb6013aed
2 changed files with 94 additions and 10 deletions

View File

@@ -134,11 +134,31 @@ namespace BasicSample
{
Console.WriteLine($"User name: {u.FullName}");
}
foreach (var u in dbContext.Users.ToList())
{
Console.WriteLine($"User name: {u.FullName}");
}
foreach (var u in dbContext.Users.OrderBy(x => x.FullName))
{
Console.WriteLine($"User name: {u.FullName}");
}
}
{
foreach (var u in dbContext.Users.Where(x => x.TotalSpent >= 1))
{
Console.WriteLine($"User name: {u.FullName}");
}
}
{
var result = dbContext.Users.FirstOrDefault();
Console.WriteLine($"Our first user {result.FullName} has spent {result.TotalSpent}");
result = dbContext.Users.FirstOrDefault(x => x.TotalSpent > 1);
Console.WriteLine($"Our first user {result.FullName} has spent {result.TotalSpent}");
}
{

View File

@@ -1,4 +1,5 @@
using System.Collections.Generic;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
@@ -17,9 +18,26 @@ namespace EntityFrameworkCore.Projectables.Services
private bool _disableRootRewrite;
private IEntityType? _entityType;
private readonly MethodInfo _select;
private readonly MethodInfo _where;
public ProjectableExpressionReplacer(IProjectionExpressionResolver projectionExpressionResolver)
{
_resolver = projectionExpressionResolver;
_select = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public)
.Where(x => x.Name == nameof(Queryable.Select))
.First(x =>
x.GetParameters().Last().ParameterType // Expression<Func<T, Ret>>
.GetGenericArguments().First() // Func<T, Ret>
.GetGenericArguments().Length == 2 // Separate between Func<T, Ret> and Func<T, int, Ret>
);
_where = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public)
.Where(x => x.Name == nameof(Queryable.Where))
.First(x =>
x.GetParameters().Last().ParameterType // Expression<Func<T, Ret>>
.GetGenericArguments().First() // Func<T, Ret>
.GetGenericArguments().Length == 2 // Separate between Func<T, Ret> and Func<T, int, Ret>
);
}
bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out LambdaExpression? reflectedExpression)
@@ -45,6 +63,7 @@ namespace EntityFrameworkCore.Projectables.Services
if (_disableRootRewrite)
{
// This boolean is enabled when a "Select" is encountered
return ret;
}
@@ -53,10 +72,62 @@ namespace EntityFrameworkCore.Projectables.Services
// Probably a First() or ToList()
case MethodCallExpression { Arguments.Count: > 0, Object: null } call when _entityType != null:
{
// if return type != IQueryable {
// if return type is IEnuberable {
// // case of a ToList()
// return (ret.arg[0]).Select(...).ToList() or the other method
// } else {
// // case of a Max()
// return ret;
// }
// } else if retrun type == entitytype {
// // case of a first()
// return obj.MyMap(x => new Obj {});
// }
if (call.Method.ReturnType.IsAssignableTo(typeof(IQueryable)))
{
// Generic case where the return type is still a IQueryable<T>
return _AddProjectableSelect(call, _entityType);
}
if (call.Method.ReturnType == _entityType.ClrType)
{
// case of a .First(), .SingleAsync()
if (call.Arguments.Count != 1 && true /* Add && arg.count == 1 exist */)
{
// .First(x => whereCondition), since we need to add a select after the last condition but
// before the query become executed by EF (before the .First()), we rewrite the .First(where)
// as .Where(where).Select(x => ...).First()
var where = Expression.Call(null, _where.MakeGenericMethod(_entityType.ClrType), call.Arguments);
// The call instance is based on the wrong polymorphied method.
var first = call.Method.DeclaringType?.GetMethods()
.FirstOrDefault(x => x.Name == call.Method.Name && x.GetParameters().Length == 1);
if (first == null)
{
// Unknown case that should not happen.
return call;
}
return Expression.Call(null, first.MakeGenericMethod(_entityType.ClrType), _AddProjectableSelect(where, _entityType));
}
// .First() without arguments is the same case as bellow so we let it fallthrough
}
else if (!call.Method.ReturnType.IsAssignableTo(typeof(IEnumerable)))
{
// case of something like a .Max(), .Sum()
return call;
}
// return type is IEnumerable<EntityType> or EntityType (in case of fallthrough from a .First())
// case of something like .ToList(), .ToArrayAsync()
var self = _AddProjectableSelect(call.Arguments.First(), _entityType);
return call.Update(null, call.Arguments.Skip(1).Prepend(self));
}
// Probably a foreach call
case QueryRootExpression root:
return _AddProjectableSelect(root, root.EntityType);
default:
@@ -170,14 +241,7 @@ namespace EntityFrameworkCore.Projectables.Services
.Where(x => projectableProperties.All(y => x.Name != y.Name && x.Name != $"<{y.Name}>k__BackingField"));
// Replace db.Entities to db.Entities.Select(x => new Entity { Property1 = x.Property1, Rewritted = rewrittedProperty })
var select = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public)
.Where(x => x.Name == nameof(Queryable.Select))
.First(x =>
x.GetParameters().Last().ParameterType // Expression<Func<T, Ret>>
.GetGenericArguments().First() // Func<T, Ret>
.GetGenericArguments().Length == 2 // Separate between Func<T, Ret> and Func<T, int, Ret>
)
.MakeGenericMethod(entityType.ClrType, entityType.ClrType);
var select = _select.MakeGenericMethod(entityType.ClrType, entityType.ClrType);
var xParam = Expression.Parameter(entityType.ClrType);
return Expression.Call(
null,