diff --git a/.gitignore b/.gitignore index f2f723c..4e57e37 100644 --- a/.gitignore +++ b/.gitignore @@ -363,4 +363,6 @@ MigrationBackup/ FodyWeavers.xsd # Received verify test results -*.received.* \ No newline at end of file +*.received.* + +.idea diff --git a/Directory.Build.props b/Directory.Build.props index d79998e..932ca59 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -23,11 +23,15 @@ - net6.0 + net7.0;net6.0 6.0.0 6.0.0 $(EFCoreVersion) - - \ No newline at end of file + + 7.0.0 + $(EFCoreVersion) + + + diff --git a/samples/BasicSample/BasicSample.csproj b/samples/BasicSample/BasicSample.csproj index 6e97dbc..937f9ed 100644 --- a/samples/BasicSample/BasicSample.csproj +++ b/samples/BasicSample/BasicSample.csproj @@ -2,7 +2,7 @@ Exe - net6.0 + net7.0 disable true $(BaseIntermediateOutputPath)Generated @@ -10,9 +10,9 @@ - - - + + + diff --git a/samples/BasicSample/Program.cs b/samples/BasicSample/Program.cs index 0ca49e5..08c25fc 100644 --- a/samples/BasicSample/Program.cs +++ b/samples/BasicSample/Program.cs @@ -1,16 +1,11 @@ -using EntityFrameworkCore.Projectables; -using EntityFrameworkCore.Projectables.Extensions; -using Microsoft.Data.Sqlite; -using Microsoft.EntityFrameworkCore; -using Microsoft.Extensions.Caching.Memory; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; -using System; -using System.Collections; +using System; using System.Collections.Generic; using System.ComponentModel.DataAnnotations.Schema; -using System.Diagnostics; using System.Linq; +using EntityFrameworkCore.Projectables; +using Microsoft.Data.Sqlite; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; namespace BasicSample { @@ -22,12 +17,13 @@ namespace BasicSample public ICollection Orders { get; set; } - [Projectable] - public string FullName - => FirstName + " " + LastName; + [Projectable(UseMemberBody = nameof(_FullName))] + public string FullName { get; set; } + private string _FullName => FirstName + " " + LastName; - [Projectable] - public double TotalSpent => Orders.Sum(x => x.PriceSum); + [Projectable(UseMemberBody = nameof(_TotalSpent))] + public double TotalSpent { get; set; } + private double _TotalSpent => Orders.Sum(x => x.PriceSum); [Projectable] public Order MostValuableOrder @@ -86,7 +82,7 @@ namespace BasicSample class Program { - static void Main(string[] args) + public static void Main(string[] args) { using var dbConnection = new SqliteConnection("Filename=:memory:"); dbConnection.Open(); @@ -95,6 +91,8 @@ namespace BasicSample .AddDbContext((provider, options) => { options .UseSqlite(dbConnection) + // .LogTo(Console.WriteLine) + .EnableSensitiveDataLogging() .UseProjectables(); }) .BuildServiceProvider(); @@ -105,9 +103,9 @@ namespace BasicSample var product1 = new Product { Name = "Red pen", Price = 1.5 }; var product2 = new Product { Name = "Blue pen", Price = 2.1 }; - var user = new User { - FirstName = "Jon", - LastName = "Doe", + var user = new User { + FirstName = "Jon", + LastName = "Doe", Orders = new List { new Order { Items = new List { @@ -130,6 +128,42 @@ namespace BasicSample dbContext.SaveChanges(); // What did our user spent in total + + { + foreach (var u in dbContext.Users) + { + Console.WriteLine($"User name: {u.FullName}"); + } + + foreach (var u in dbContext.Users.ToList()) + { + Console.WriteLine($"User name: {u.FullName}"); + } + + foreach (var u in dbContext.Users.OrderBy(x => x.FullName)) + { + Console.WriteLine($"User name: {u.FullName}"); + } + } + + { + foreach (var u in dbContext.Users.Where(x => x.TotalSpent >= 1)) + { + Console.WriteLine($"User name: {u.FullName}"); + } + } + + { + var result = dbContext.Users.FirstOrDefault(); + Console.WriteLine($"Our first user {result.FullName} has spent {result.TotalSpent}"); + + result = dbContext.Users.FirstOrDefault(x => x.TotalSpent > 1); + Console.WriteLine($"Our first user {result.FullName} has spent {result.TotalSpent}"); + + var spent = dbContext.Users.Sum(x => x.TotalSpent); + Console.WriteLine($"Our users combined spent: {spent}"); + } + { var query = dbContext.Users .Select(x => new { diff --git a/samples/ReadmeSample/ReadmeSample.csproj b/samples/ReadmeSample/ReadmeSample.csproj index a1f59ea..2da4034 100644 --- a/samples/ReadmeSample/ReadmeSample.csproj +++ b/samples/ReadmeSample/ReadmeSample.csproj @@ -10,9 +10,9 @@ - - - + + + diff --git a/shell.nix b/shell.nix new file mode 100644 index 0000000..3ee0422 --- /dev/null +++ b/shell.nix @@ -0,0 +1,12 @@ +{pkgs ? import {}}: let + dotnet = with pkgs.dotnetCorePackages; + combinePackages [ + sdk_7_0 + aspnetcore_7_0 + ]; +in + pkgs.mkShell { + packages = [dotnet]; + + DOTNET_ROOT = "${dotnet}"; + } diff --git a/src/EntityFrameworkCore.Projectables/EntityFrameworkCore.Projectables.csproj b/src/EntityFrameworkCore.Projectables/EntityFrameworkCore.Projectables.csproj index 079b595..9aa456d 100644 --- a/src/EntityFrameworkCore.Projectables/EntityFrameworkCore.Projectables.csproj +++ b/src/EntityFrameworkCore.Projectables/EntityFrameworkCore.Projectables.csproj @@ -1,7 +1,7 @@  - net6.0 + net7.0;net6.0 README.md diff --git a/src/EntityFrameworkCore.Projectables/Extensions/ExpressionExtensions.cs b/src/EntityFrameworkCore.Projectables/Extensions/ExpressionExtensions.cs index 565e784..49da220 100644 --- a/src/EntityFrameworkCore.Projectables/Extensions/ExpressionExtensions.cs +++ b/src/EntityFrameworkCore.Projectables/Extensions/ExpressionExtensions.cs @@ -18,6 +18,6 @@ namespace EntityFrameworkCore.Projectables.Extensions /// Replaces all calls to properties and methods that are marked with the Projectable attribute with their respective expression tree /// public static Expression ExpandProjectables(this Expression expression) - => new ProjectableExpressionReplacer(new ProjectionExpressionResolver()).Visit(expression); + => new ProjectableExpressionReplacer(new ProjectionExpressionResolver()).Replace(expression); } } diff --git a/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/CustomQueryCompiler.cs b/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/CustomQueryCompiler.cs index 5a300a2..2a85c9c 100644 --- a/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/CustomQueryCompiler.cs +++ b/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/CustomQueryCompiler.cs @@ -34,6 +34,6 @@ namespace EntityFrameworkCore.Projectables.Infrastructure.Internal => _decoratedQueryCompiler.ExecuteAsync(Expand(query), cancellationToken); Expression Expand(Expression expression) - => _projectableExpressionReplacer.Visit(expression); + => _projectableExpressionReplacer.Replace(expression); } } diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs index 6b478f9..c664532 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs @@ -1,14 +1,12 @@ -using System; -using System.Buffers; +using System.Collections; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Linq.Expressions; using System.Reflection; -using System.Text; -using System.Threading.Tasks; -using System.Xml.Linq; using EntityFrameworkCore.Projectables.Extensions; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Query; namespace EntityFrameworkCore.Projectables.Services { @@ -17,10 +15,29 @@ namespace EntityFrameworkCore.Projectables.Services readonly IProjectionExpressionResolver _resolver; readonly ExpressionArgumentReplacer _expressionArgumentReplacer = new(); readonly Dictionary _projectableMemberCache = new(); + private bool _disableRootRewrite; + private IEntityType? _entityType; + + private readonly MethodInfo _select; + private readonly MethodInfo _where; public ProjectableExpressionReplacer(IProjectionExpressionResolver projectionExpressionResolver) { _resolver = projectionExpressionResolver; + _select = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public) + .Where(x => x.Name == nameof(Queryable.Select)) + .First(x => + x.GetParameters().Last().ParameterType // Expression> + .GetGenericArguments().First() // Func + .GetGenericArguments().Length == 2 // Separate between Func and Func + ); + _where = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public) + .Where(x => x.Name == nameof(Queryable.Where)) + .First(x => + x.GetParameters().Last().ParameterType // Expression> + .GetGenericArguments().First() // Func + .GetGenericArguments().Length == 2 // Separate between Func and Func + ); } bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out LambdaExpression? reflectedExpression) @@ -28,10 +45,10 @@ namespace EntityFrameworkCore.Projectables.Services if (!_projectableMemberCache.TryGetValue(memberInfo, out reflectedExpression)) { var projectableAttribute = memberInfo.GetCustomAttribute(false); - - reflectedExpression = projectableAttribute is not null + + reflectedExpression = projectableAttribute is not null ? _resolver.FindGeneratedExpression(memberInfo) - : (LambdaExpression?)null; + : null; _projectableMemberCache.Add(memberInfo, reflectedExpression); } @@ -39,6 +56,85 @@ namespace EntityFrameworkCore.Projectables.Services return reflectedExpression is not null; } + [return: NotNullIfNotNull(nameof(node))] + public Expression? Replace(Expression? node) + { + var ret = Visit(node); + + if (_disableRootRewrite) + { + // This boolean is enabled when a "Select" is encountered + return ret; + } + + switch (ret) + { + // Probably a First() or ToList() + case MethodCallExpression { Arguments.Count: > 0, Object: null } call when _entityType != null: + { + // if return type != IQueryable { + // if return type is IEnuberable { + // // case of a ToList() + // return (ret.arg[0]).Select(...).ToList() or the other method + // } else { + // // case of a Max() + // return ret; + // } + // } else if retrun type == entitytype { + // // case of a first() + // return obj.MyMap(x => new Obj {}); + // } + + + if (call.Method.ReturnType.IsAssignableTo(typeof(IQueryable))) + { + // Generic case where the return type is still a IQueryable + return _AddProjectableSelect(call, _entityType); + } + + if (call.Method.ReturnType == _entityType.ClrType) + { + // case of a .First(), .SingleAsync() + if (call.Arguments.Count != 1 && true /* Add && arg.count == 1 exist */) + { + // .First(x => whereCondition), since we need to add a select after the last condition but + // before the query become executed by EF (before the .First()), we rewrite the .First(where) + // as .Where(where).Select(x => ...).First() + + var where = Expression.Call(null, _where.MakeGenericMethod(_entityType.ClrType), call.Arguments); + // The call instance is based on the wrong polymorphied method. + var first = call.Method.DeclaringType?.GetMethods() + .FirstOrDefault(x => x.Name == call.Method.Name && x.GetParameters().Length == 1); + if (first == null) + { + // Unknown case that should not happen. + return call; + } + + return Expression.Call(null, first.MakeGenericMethod(_entityType.ClrType), _AddProjectableSelect(where, _entityType)); + } + + // .First() without arguments is the same case as bellow so we let it fallthrough + } + else if (!call.Method.ReturnType.IsAssignableTo(typeof(IEnumerable))) + { + // case of something like a .Max(), .Sum() + return call; + } + + // return type is IEnumerable or EntityType (in case of fallthrough from a .First()) + + // case of something like .ToList(), .ToArrayAsync() + var self = _AddProjectableSelect(call.Arguments.First(), _entityType); + return call.Update(null, call.Arguments.Skip(1).Prepend(self)); + } + case QueryRootExpression root when _entityType != null: + return _AddProjectableSelect(root, _entityType); + default: + return ret; + } + } + protected override Expression VisitMethodCall(MethodCallExpression node) { // Replace MethodGroup arguments with their reflected expressions. @@ -58,13 +154,18 @@ namespace EntityFrameworkCore.Projectables.Services // Get the overriding methodInfo based on te type of the received of this expression var methodInfo = node.Object?.Type.GetConcreteMethod(node.Method) ?? node.Method; + if (methodInfo.Name == nameof(Queryable.Select)) + { + _disableRootRewrite = true; + } + if (TryGetReflectedExpression(methodInfo, out var reflectedExpression)) { for (var parameterIndex = 0; parameterIndex < reflectedExpression.Parameters.Count; parameterIndex++) { var parameterExpession = reflectedExpression.Parameters[parameterIndex]; var mappedArgumentExpression = (parameterIndex, node.Object) switch { - (0, not null) => node.Object, + (0, not null) => node.Object, (_, not null) => node.Arguments[parameterIndex - 1], (_, null) => node.Arguments.Count > parameterIndex ? node.Arguments[parameterIndex] : null }; @@ -74,11 +175,11 @@ namespace EntityFrameworkCore.Projectables.Services _expressionArgumentReplacer.ParameterArgumentMapping.Add(parameterExpession, mappedArgumentExpression); } } - + var updatedBody = _expressionArgumentReplacer.Visit(reflectedExpression.Body); _expressionArgumentReplacer.ParameterArgumentMapping.Clear(); - return Visit( + return base.Visit( updatedBody ); } @@ -110,13 +211,13 @@ namespace EntityFrameworkCore.Projectables.Services var updatedBody = _expressionArgumentReplacer.Visit(reflectedExpression.Body); _expressionArgumentReplacer.ParameterArgumentMapping.Clear(); - return Visit( + return base.Visit( updatedBody ); } else { - return Visit( + return base.Visit( reflectedExpression.Body ); } @@ -124,5 +225,66 @@ namespace EntityFrameworkCore.Projectables.Services return base.VisitMember(node); } + + protected override Expression VisitExtension(Expression node) + { +#if NET7_0_OR_GREATER + if (node is EntityQueryRootExpression root) +#else + if (node is QueryRootExpression root) +#endif + { + _entityType = root.EntityType; + } + return base.VisitExtension(node); + } + + private Expression _AddProjectableSelect(Expression node, IEntityType entityType) + { + var projectableProperties = entityType.ClrType.GetProperties() + .Where(x => x.IsDefined(typeof(ProjectableAttribute), false)) + .Where(x => x.CanWrite) + .ToList(); + + if (!projectableProperties.Any()) + { + return node; + } + + var properties = entityType.GetProperties() + .Where(x => !x.IsShadowProperty()) + .Select(x => x.GetMemberInfo(false, false)) + // Remove projectable properties from the ef properties. Since properties returned here for auto + // properties (like `public string Test {get;set;}`) are generated fields, we also need to take them into account. + .Where(x => projectableProperties.All(y => x.Name != y.Name && x.Name != $"<{y.Name}>k__BackingField")); + + // Replace db.Entities to db.Entities.Select(x => new Entity { Property1 = x.Property1, Rewritted = rewrittedProperty }) + var select = _select.MakeGenericMethod(entityType.ClrType, entityType.ClrType); + var xParam = Expression.Parameter(entityType.ClrType); + return Expression.Call( + null, + select, + node, + Expression.Lambda( + Expression.MemberInit( + Expression.New(entityType.ClrType), + properties.Select(x => Expression.Bind(x, Expression.MakeMemberAccess(xParam, x))) + .Concat(projectableProperties + .Select(x => Expression.Bind(x, _GetAccessor(x, xParam))) + ) + ), + xParam + ) + ); + } + + private Expression _GetAccessor(PropertyInfo property, ParameterExpression para) + { + var lambda = _resolver.FindGeneratedExpression(property); + _expressionArgumentReplacer.ParameterArgumentMapping.Add(lambda.Parameters[0], para); + var updatedBody = _expressionArgumentReplacer.Visit(lambda.Body); + _expressionArgumentReplacer.ParameterArgumentMapping.Clear(); + return base.Visit(updatedBody); + } } } diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs index f913049..b9062dd 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs @@ -1,13 +1,8 @@ using System; -using System.Collections.Concurrent; -using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Reflection; -using System.Text; -using System.Threading.Tasks; using EntityFrameworkCore.Projectables.Extensions; -using Microsoft.EntityFrameworkCore.Storage.ValueConversion.Internal; namespace EntityFrameworkCore.Projectables.Services { diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/EntityFrameworkCore.Projectables.FunctionalTests.csproj b/tests/EntityFrameworkCore.Projectables.FunctionalTests/EntityFrameworkCore.Projectables.FunctionalTests.csproj index 2e4c691..0c895e5 100644 --- a/tests/EntityFrameworkCore.Projectables.FunctionalTests/EntityFrameworkCore.Projectables.FunctionalTests.csproj +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/EntityFrameworkCore.Projectables.FunctionalTests.csproj @@ -8,7 +8,7 @@ - + diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.UseMemberPropertyQueryRootExpression.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.UseMemberPropertyQueryRootExpression.verified.txt new file mode 100644 index 0000000..e8c699d --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.UseMemberPropertyQueryRootExpression.verified.txt @@ -0,0 +1,2 @@ +SELECT [e].[Id], [e].[Id] * 5 +FROM [Entity] AS [e] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.cs b/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.cs new file mode 100644 index 0000000..bef2283 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/QueryRootTests.cs @@ -0,0 +1,39 @@ +using System.ComponentModel.DataAnnotations.Schema; +using System.Threading.Tasks; +using EntityFrameworkCore.Projectables.FunctionalTests.Helpers; +using Microsoft.EntityFrameworkCore; +using VerifyXunit; +using Xunit; + +namespace EntityFrameworkCore.Projectables.FunctionalTests +{ + [UsesVerify] + public class QueryRootTests + { + public record Entity + { + public int Id { get; set; } + + [Projectable(UseMemberBody = nameof(Computed2))] + public int Computed1 => Id; + + private int Computed2 => Id * 2; + + [Projectable(UseMemberBody = nameof(_ComputedWithBaking))] + [NotMapped] + public int ComputedWithBacking { get; set; } + + private int _ComputedWithBaking => Id * 5; + } + + [Fact] + public Task UseMemberPropertyQueryRootExpression() + { + using var dbContext = new SampleDbContext(); + + var query = dbContext.Set(); + + return Verifier.Verify(query.ToQueryString()); + } + } +} diff --git a/tests/EntityFrameworkCore.Projectables.Tests/Services/ProjectableExpressionReplacerTests.cs b/tests/EntityFrameworkCore.Projectables.Tests/Services/ProjectableExpressionReplacerTests.cs index 133baeb..a9ed5dd 100644 --- a/tests/EntityFrameworkCore.Projectables.Tests/Services/ProjectableExpressionReplacerTests.cs +++ b/tests/EntityFrameworkCore.Projectables.Tests/Services/ProjectableExpressionReplacerTests.cs @@ -61,7 +61,7 @@ namespace EntityFrameworkCore.Projectables.Tests.Services ); var subject = new ProjectableExpressionReplacer(resolver); - var actual = subject.Visit(input); + var actual = subject.Replace(input); Assert.Equal(expected.ToString(), actual.ToString()); } @@ -77,7 +77,7 @@ namespace EntityFrameworkCore.Projectables.Tests.Services ); var subject = new ProjectableExpressionReplacer(resolver); - var actual = subject.Visit(input); + var actual = subject.Replace(input); Assert.Equal(expected.ToString(), actual.ToString()); } @@ -93,7 +93,7 @@ namespace EntityFrameworkCore.Projectables.Tests.Services ); var subject = new ProjectableExpressionReplacer(resolver); - var actual = subject.Visit(input); + var actual = subject.Replace(input); Assert.Equal(expected.ToString(), actual.ToString()); } @@ -109,7 +109,7 @@ namespace EntityFrameworkCore.Projectables.Tests.Services ); var subject = new ProjectableExpressionReplacer(resolver); - var actual = subject.Visit(input); + var actual = subject.Replace(input); Assert.Equal(expected.ToString(), actual.ToString()); } @@ -125,7 +125,7 @@ namespace EntityFrameworkCore.Projectables.Tests.Services ); var subject = new ProjectableExpressionReplacer(resolver); - var actual = subject.Visit(input); + var actual = subject.Replace(input); Assert.Equal(expected.ToString(), actual.ToString()); } @@ -141,7 +141,7 @@ namespace EntityFrameworkCore.Projectables.Tests.Services ); var subject = new ProjectableExpressionReplacer(resolver); - var actual = subject.Visit(input); + var actual = subject.Replace(input); Assert.Equal(expected.ToString(), actual.ToString()); } @@ -157,7 +157,7 @@ namespace EntityFrameworkCore.Projectables.Tests.Services ); var subject = new ProjectableExpressionReplacer(resolver); - var actual = subject.Visit(input); + var actual = subject.Replace(input); Assert.Equal(expected.ToString(), actual.ToString()); }