diff --git a/samples/BasicSample/Program.cs b/samples/BasicSample/Program.cs index c2be8b4..b16b119 100644 --- a/samples/BasicSample/Program.cs +++ b/samples/BasicSample/Program.cs @@ -134,11 +134,31 @@ namespace BasicSample { Console.WriteLine($"User name: {u.FullName}"); } + + foreach (var u in dbContext.Users.ToList()) + { + Console.WriteLine($"User name: {u.FullName}"); + } + + foreach (var u in dbContext.Users.OrderBy(x => x.FullName)) + { + Console.WriteLine($"User name: {u.FullName}"); + } + } + + { + foreach (var u in dbContext.Users.Where(x => x.TotalSpent >= 1)) + { + Console.WriteLine($"User name: {u.FullName}"); + } } { var result = dbContext.Users.FirstOrDefault(); Console.WriteLine($"Our first user {result.FullName} has spent {result.TotalSpent}"); + + result = dbContext.Users.FirstOrDefault(x => x.TotalSpent > 1); + Console.WriteLine($"Our first user {result.FullName} has spent {result.TotalSpent}"); } { diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs index f9e8212..f1c4f65 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using System.Collections; +using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Linq.Expressions; @@ -17,9 +18,26 @@ namespace EntityFrameworkCore.Projectables.Services private bool _disableRootRewrite; private IEntityType? _entityType; + private readonly MethodInfo _select; + private readonly MethodInfo _where; + public ProjectableExpressionReplacer(IProjectionExpressionResolver projectionExpressionResolver) { _resolver = projectionExpressionResolver; + _select = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public) + .Where(x => x.Name == nameof(Queryable.Select)) + .First(x => + x.GetParameters().Last().ParameterType // Expression> + .GetGenericArguments().First() // Func + .GetGenericArguments().Length == 2 // Separate between Func and Func + ); + _where = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public) + .Where(x => x.Name == nameof(Queryable.Where)) + .First(x => + x.GetParameters().Last().ParameterType // Expression> + .GetGenericArguments().First() // Func + .GetGenericArguments().Length == 2 // Separate between Func and Func + ); } bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out LambdaExpression? reflectedExpression) @@ -45,6 +63,7 @@ namespace EntityFrameworkCore.Projectables.Services if (_disableRootRewrite) { + // This boolean is enabled when a "Select" is encountered return ret; } @@ -53,10 +72,62 @@ namespace EntityFrameworkCore.Projectables.Services // Probably a First() or ToList() case MethodCallExpression { Arguments.Count: > 0, Object: null } call when _entityType != null: { + // if return type != IQueryable { + // if return type is IEnuberable { + // // case of a ToList() + // return (ret.arg[0]).Select(...).ToList() or the other method + // } else { + // // case of a Max() + // return ret; + // } + // } else if retrun type == entitytype { + // // case of a first() + // return obj.MyMap(x => new Obj {}); + // } + + + if (call.Method.ReturnType.IsAssignableTo(typeof(IQueryable))) + { + // Generic case where the return type is still a IQueryable + return _AddProjectableSelect(call, _entityType); + } + + if (call.Method.ReturnType == _entityType.ClrType) + { + // case of a .First(), .SingleAsync() + if (call.Arguments.Count != 1 && true /* Add && arg.count == 1 exist */) + { + // .First(x => whereCondition), since we need to add a select after the last condition but + // before the query become executed by EF (before the .First()), we rewrite the .First(where) + // as .Where(where).Select(x => ...).First() + + var where = Expression.Call(null, _where.MakeGenericMethod(_entityType.ClrType), call.Arguments); + // The call instance is based on the wrong polymorphied method. + var first = call.Method.DeclaringType?.GetMethods() + .FirstOrDefault(x => x.Name == call.Method.Name && x.GetParameters().Length == 1); + if (first == null) + { + // Unknown case that should not happen. + return call; + } + + return Expression.Call(null, first.MakeGenericMethod(_entityType.ClrType), _AddProjectableSelect(where, _entityType)); + } + + // .First() without arguments is the same case as bellow so we let it fallthrough + } + else if (!call.Method.ReturnType.IsAssignableTo(typeof(IEnumerable))) + { + // case of something like a .Max(), .Sum() + return call; + } + + // return type is IEnumerable or EntityType (in case of fallthrough from a .First()) + + // case of something like .ToList(), .ToArrayAsync() var self = _AddProjectableSelect(call.Arguments.First(), _entityType); return call.Update(null, call.Arguments.Skip(1).Prepend(self)); } - // Probably a foreach call case QueryRootExpression root: return _AddProjectableSelect(root, root.EntityType); default: @@ -170,14 +241,7 @@ namespace EntityFrameworkCore.Projectables.Services .Where(x => projectableProperties.All(y => x.Name != y.Name && x.Name != $"<{y.Name}>k__BackingField")); // Replace db.Entities to db.Entities.Select(x => new Entity { Property1 = x.Property1, Rewritted = rewrittedProperty }) - var select = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public) - .Where(x => x.Name == nameof(Queryable.Select)) - .First(x => - x.GetParameters().Last().ParameterType // Expression> - .GetGenericArguments().First() // Func - .GetGenericArguments().Length == 2 // Separate between Func and Func - ) - .MakeGenericMethod(entityType.ClrType, entityType.ClrType); + var select = _select.MakeGenericMethod(entityType.ClrType, entityType.ClrType); var xParam = Expression.Parameter(entityType.ClrType); return Expression.Call( null,