Merge pull request #86 from zoriya/master

Add query root rewrite support
This commit is contained in:
Koen
2023-10-24 02:54:50 +01:00
committed by GitHub
15 changed files with 309 additions and 59 deletions

2
.gitignore vendored
View File

@@ -364,3 +364,5 @@ FodyWeavers.xsd
# Received verify test results
*.received.*
.idea

View File

@@ -23,11 +23,15 @@
</PropertyGroup>
<PropertyGroup>
<TargetFrameworkVersion>net6.0</TargetFrameworkVersion>
<TargetFrameworkVersion>net7.0;net6.0</TargetFrameworkVersion>
<MicrosoftExtensionsVersion>6.0.0</MicrosoftExtensionsVersion>
<EFCoreVersion>6.0.0</EFCoreVersion>
<TestEFCoreVersion>$(EFCoreVersion)</TestEFCoreVersion>
</PropertyGroup>
<PropertyGroup Condition="'$(TargetFramework)' == 'net7.0'">
<EFCoreVersion>7.0.0</EFCoreVersion>
<TestEFCoreVersion>$(EFCoreVersion)</TestEFCoreVersion>
</PropertyGroup>
</Project>

View File

@@ -2,7 +2,7 @@
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net6.0</TargetFramework>
<TargetFramework>net7.0</TargetFramework>
<Nullable>disable</Nullable>
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
<CompilerGeneratedFilesOutputPath>$(BaseIntermediateOutputPath)Generated</CompilerGeneratedFilesOutputPath>
@@ -10,9 +10,9 @@
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.EntityFrameworkCore.Sqlite" Version="6.0.0" />
<PackageReference Include="Microsoft.Extensions.Logging" Version="6.0.0" />
<PackageReference Include="Microsoft.Extensions.Logging.Console" Version="6.0.0" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Sqlite" Version="7.0.0" />
<PackageReference Include="Microsoft.Extensions.Logging" Version="7.0.0" />
<PackageReference Include="Microsoft.Extensions.Logging.Console" Version="7.0.0" />
</ItemGroup>
<ItemGroup>

View File

@@ -1,16 +1,11 @@
using EntityFrameworkCore.Projectables;
using EntityFrameworkCore.Projectables.Extensions;
using Microsoft.Data.Sqlite;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using System;
using System.Collections;
using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations.Schema;
using System.Diagnostics;
using System.Linq;
using EntityFrameworkCore.Projectables;
using Microsoft.Data.Sqlite;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection;
namespace BasicSample
{
@@ -22,12 +17,13 @@ namespace BasicSample
public ICollection<Order> Orders { get; set; }
[Projectable]
public string FullName
=> FirstName + " " + LastName;
[Projectable(UseMemberBody = nameof(_FullName))]
public string FullName { get; set; }
private string _FullName => FirstName + " " + LastName;
[Projectable]
public double TotalSpent => Orders.Sum(x => x.PriceSum);
[Projectable(UseMemberBody = nameof(_TotalSpent))]
public double TotalSpent { get; set; }
private double _TotalSpent => Orders.Sum(x => x.PriceSum);
[Projectable]
public Order MostValuableOrder
@@ -86,7 +82,7 @@ namespace BasicSample
class Program
{
static void Main(string[] args)
public static void Main(string[] args)
{
using var dbConnection = new SqliteConnection("Filename=:memory:");
dbConnection.Open();
@@ -95,6 +91,8 @@ namespace BasicSample
.AddDbContext<ApplicationDbContext>((provider, options) => {
options
.UseSqlite(dbConnection)
// .LogTo(Console.WriteLine)
.EnableSensitiveDataLogging()
.UseProjectables();
})
.BuildServiceProvider();
@@ -130,6 +128,42 @@ namespace BasicSample
dbContext.SaveChanges();
// What did our user spent in total
{
foreach (var u in dbContext.Users)
{
Console.WriteLine($"User name: {u.FullName}");
}
foreach (var u in dbContext.Users.ToList())
{
Console.WriteLine($"User name: {u.FullName}");
}
foreach (var u in dbContext.Users.OrderBy(x => x.FullName))
{
Console.WriteLine($"User name: {u.FullName}");
}
}
{
foreach (var u in dbContext.Users.Where(x => x.TotalSpent >= 1))
{
Console.WriteLine($"User name: {u.FullName}");
}
}
{
var result = dbContext.Users.FirstOrDefault();
Console.WriteLine($"Our first user {result.FullName} has spent {result.TotalSpent}");
result = dbContext.Users.FirstOrDefault(x => x.TotalSpent > 1);
Console.WriteLine($"Our first user {result.FullName} has spent {result.TotalSpent}");
var spent = dbContext.Users.Sum(x => x.TotalSpent);
Console.WriteLine($"Our users combined spent: {spent}");
}
{
var query = dbContext.Users
.Select(x => new {

View File

@@ -10,9 +10,9 @@
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" Version="6.0.0" />
<PackageReference Include="Microsoft.Extensions.Logging" Version="6.0.0" />
<PackageReference Include="Microsoft.Extensions.Logging.Console" Version="6.0.0" />
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" Version="7.0.0" />
<PackageReference Include="Microsoft.Extensions.Logging" Version="7.0.0" />
<PackageReference Include="Microsoft.Extensions.Logging.Console" Version="7.0.0" />
</ItemGroup>
<ItemGroup>

12
shell.nix Normal file
View File

@@ -0,0 +1,12 @@
{pkgs ? import <nixpkgs> {}}: let
dotnet = with pkgs.dotnetCorePackages;
combinePackages [
sdk_7_0
aspnetcore_7_0
];
in
pkgs.mkShell {
packages = [dotnet];
DOTNET_ROOT = "${dotnet}";
}

View File

@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net6.0</TargetFramework>
<TargetFrameworks>net7.0;net6.0</TargetFrameworks>
<PackageReadmeFile>README.md</PackageReadmeFile>
</PropertyGroup>

View File

@@ -18,6 +18,6 @@ namespace EntityFrameworkCore.Projectables.Extensions
/// Replaces all calls to properties and methods that are marked with the <C>Projectable</C> attribute with their respective expression tree
/// </summary>
public static Expression ExpandProjectables(this Expression expression)
=> new ProjectableExpressionReplacer(new ProjectionExpressionResolver()).Visit(expression);
=> new ProjectableExpressionReplacer(new ProjectionExpressionResolver()).Replace(expression);
}
}

View File

@@ -34,6 +34,6 @@ namespace EntityFrameworkCore.Projectables.Infrastructure.Internal
=> _decoratedQueryCompiler.ExecuteAsync<TResult>(Expand(query), cancellationToken);
Expression Expand(Expression expression)
=> _projectableExpressionReplacer.Visit(expression);
=> _projectableExpressionReplacer.Replace(expression);
}
}

View File

@@ -1,14 +1,12 @@
using System;
using System.Buffers;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
using System.Xml.Linq;
using EntityFrameworkCore.Projectables.Extensions;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query;
namespace EntityFrameworkCore.Projectables.Services
{
@@ -17,10 +15,29 @@ namespace EntityFrameworkCore.Projectables.Services
readonly IProjectionExpressionResolver _resolver;
readonly ExpressionArgumentReplacer _expressionArgumentReplacer = new();
readonly Dictionary<MemberInfo, LambdaExpression?> _projectableMemberCache = new();
private bool _disableRootRewrite;
private IEntityType? _entityType;
private readonly MethodInfo _select;
private readonly MethodInfo _where;
public ProjectableExpressionReplacer(IProjectionExpressionResolver projectionExpressionResolver)
{
_resolver = projectionExpressionResolver;
_select = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public)
.Where(x => x.Name == nameof(Queryable.Select))
.First(x =>
x.GetParameters().Last().ParameterType // Expression<Func<T, Ret>>
.GetGenericArguments().First() // Func<T, Ret>
.GetGenericArguments().Length == 2 // Separate between Func<T, Ret> and Func<T, int, Ret>
);
_where = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public)
.Where(x => x.Name == nameof(Queryable.Where))
.First(x =>
x.GetParameters().Last().ParameterType // Expression<Func<T, Ret>>
.GetGenericArguments().First() // Func<T, Ret>
.GetGenericArguments().Length == 2 // Separate between Func<T, Ret> and Func<T, int, Ret>
);
}
bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out LambdaExpression? reflectedExpression)
@@ -31,7 +48,7 @@ namespace EntityFrameworkCore.Projectables.Services
reflectedExpression = projectableAttribute is not null
? _resolver.FindGeneratedExpression(memberInfo)
: (LambdaExpression?)null;
: null;
_projectableMemberCache.Add(memberInfo, reflectedExpression);
}
@@ -39,6 +56,85 @@ namespace EntityFrameworkCore.Projectables.Services
return reflectedExpression is not null;
}
[return: NotNullIfNotNull(nameof(node))]
public Expression? Replace(Expression? node)
{
var ret = Visit(node);
if (_disableRootRewrite)
{
// This boolean is enabled when a "Select" is encountered
return ret;
}
switch (ret)
{
// Probably a First() or ToList()
case MethodCallExpression { Arguments.Count: > 0, Object: null } call when _entityType != null:
{
// if return type != IQueryable {
// if return type is IEnuberable {
// // case of a ToList()
// return (ret.arg[0]).Select(...).ToList() or the other method
// } else {
// // case of a Max()
// return ret;
// }
// } else if retrun type == entitytype {
// // case of a first()
// return obj.MyMap(x => new Obj {});
// }
if (call.Method.ReturnType.IsAssignableTo(typeof(IQueryable)))
{
// Generic case where the return type is still a IQueryable<T>
return _AddProjectableSelect(call, _entityType);
}
if (call.Method.ReturnType == _entityType.ClrType)
{
// case of a .First(), .SingleAsync()
if (call.Arguments.Count != 1 && true /* Add && arg.count == 1 exist */)
{
// .First(x => whereCondition), since we need to add a select after the last condition but
// before the query become executed by EF (before the .First()), we rewrite the .First(where)
// as .Where(where).Select(x => ...).First()
var where = Expression.Call(null, _where.MakeGenericMethod(_entityType.ClrType), call.Arguments);
// The call instance is based on the wrong polymorphied method.
var first = call.Method.DeclaringType?.GetMethods()
.FirstOrDefault(x => x.Name == call.Method.Name && x.GetParameters().Length == 1);
if (first == null)
{
// Unknown case that should not happen.
return call;
}
return Expression.Call(null, first.MakeGenericMethod(_entityType.ClrType), _AddProjectableSelect(where, _entityType));
}
// .First() without arguments is the same case as bellow so we let it fallthrough
}
else if (!call.Method.ReturnType.IsAssignableTo(typeof(IEnumerable)))
{
// case of something like a .Max(), .Sum()
return call;
}
// return type is IEnumerable<EntityType> or EntityType (in case of fallthrough from a .First())
// case of something like .ToList(), .ToArrayAsync()
var self = _AddProjectableSelect(call.Arguments.First(), _entityType);
return call.Update(null, call.Arguments.Skip(1).Prepend(self));
}
case QueryRootExpression root when _entityType != null:
return _AddProjectableSelect(root, _entityType);
default:
return ret;
}
}
protected override Expression VisitMethodCall(MethodCallExpression node)
{
// Replace MethodGroup arguments with their reflected expressions.
@@ -58,6 +154,11 @@ namespace EntityFrameworkCore.Projectables.Services
// Get the overriding methodInfo based on te type of the received of this expression
var methodInfo = node.Object?.Type.GetConcreteMethod(node.Method) ?? node.Method;
if (methodInfo.Name == nameof(Queryable.Select))
{
_disableRootRewrite = true;
}
if (TryGetReflectedExpression(methodInfo, out var reflectedExpression))
{
for (var parameterIndex = 0; parameterIndex < reflectedExpression.Parameters.Count; parameterIndex++)
@@ -78,7 +179,7 @@ namespace EntityFrameworkCore.Projectables.Services
var updatedBody = _expressionArgumentReplacer.Visit(reflectedExpression.Body);
_expressionArgumentReplacer.ParameterArgumentMapping.Clear();
return Visit(
return base.Visit(
updatedBody
);
}
@@ -110,13 +211,13 @@ namespace EntityFrameworkCore.Projectables.Services
var updatedBody = _expressionArgumentReplacer.Visit(reflectedExpression.Body);
_expressionArgumentReplacer.ParameterArgumentMapping.Clear();
return Visit(
return base.Visit(
updatedBody
);
}
else
{
return Visit(
return base.Visit(
reflectedExpression.Body
);
}
@@ -124,5 +225,66 @@ namespace EntityFrameworkCore.Projectables.Services
return base.VisitMember(node);
}
protected override Expression VisitExtension(Expression node)
{
#if NET7_0_OR_GREATER
if (node is EntityQueryRootExpression root)
#else
if (node is QueryRootExpression root)
#endif
{
_entityType = root.EntityType;
}
return base.VisitExtension(node);
}
private Expression _AddProjectableSelect(Expression node, IEntityType entityType)
{
var projectableProperties = entityType.ClrType.GetProperties()
.Where(x => x.IsDefined(typeof(ProjectableAttribute), false))
.Where(x => x.CanWrite)
.ToList();
if (!projectableProperties.Any())
{
return node;
}
var properties = entityType.GetProperties()
.Where(x => !x.IsShadowProperty())
.Select(x => x.GetMemberInfo(false, false))
// Remove projectable properties from the ef properties. Since properties returned here for auto
// properties (like `public string Test {get;set;}`) are generated fields, we also need to take them into account.
.Where(x => projectableProperties.All(y => x.Name != y.Name && x.Name != $"<{y.Name}>k__BackingField"));
// Replace db.Entities to db.Entities.Select(x => new Entity { Property1 = x.Property1, Rewritted = rewrittedProperty })
var select = _select.MakeGenericMethod(entityType.ClrType, entityType.ClrType);
var xParam = Expression.Parameter(entityType.ClrType);
return Expression.Call(
null,
select,
node,
Expression.Lambda(
Expression.MemberInit(
Expression.New(entityType.ClrType),
properties.Select(x => Expression.Bind(x, Expression.MakeMemberAccess(xParam, x)))
.Concat(projectableProperties
.Select(x => Expression.Bind(x, _GetAccessor(x, xParam)))
)
),
xParam
)
);
}
private Expression _GetAccessor(PropertyInfo property, ParameterExpression para)
{
var lambda = _resolver.FindGeneratedExpression(property);
_expressionArgumentReplacer.ParameterArgumentMapping.Add(lambda.Parameters[0], para);
var updatedBody = _expressionArgumentReplacer.Visit(lambda.Body);
_expressionArgumentReplacer.ParameterArgumentMapping.Clear();
return base.Visit(updatedBody);
}
}
}

View File

@@ -1,13 +1,8 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
using EntityFrameworkCore.Projectables.Extensions;
using Microsoft.EntityFrameworkCore.Storage.ValueConversion.Internal;
namespace EntityFrameworkCore.Projectables.Services
{

View File

@@ -8,7 +8,7 @@
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" Version="$(TestEFCoreVersion)" />
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" Version="7.0.12" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.0.0" />
<PackageReference Include="ScenarioTests.XUnit" Version="1.0.0" />
<PackageReference Include="Verify.Xunit" Version="11.5.0" />

View File

@@ -0,0 +1,2 @@
SELECT [e].[Id], [e].[Id] * 5
FROM [Entity] AS [e]

View File

@@ -0,0 +1,39 @@
using System.ComponentModel.DataAnnotations.Schema;
using System.Threading.Tasks;
using EntityFrameworkCore.Projectables.FunctionalTests.Helpers;
using Microsoft.EntityFrameworkCore;
using VerifyXunit;
using Xunit;
namespace EntityFrameworkCore.Projectables.FunctionalTests
{
[UsesVerify]
public class QueryRootTests
{
public record Entity
{
public int Id { get; set; }
[Projectable(UseMemberBody = nameof(Computed2))]
public int Computed1 => Id;
private int Computed2 => Id * 2;
[Projectable(UseMemberBody = nameof(_ComputedWithBaking))]
[NotMapped]
public int ComputedWithBacking { get; set; }
private int _ComputedWithBaking => Id * 5;
}
[Fact]
public Task UseMemberPropertyQueryRootExpression()
{
using var dbContext = new SampleDbContext<Entity>();
var query = dbContext.Set<Entity>();
return Verifier.Verify(query.ToQueryString());
}
}
}

View File

@@ -61,7 +61,7 @@ namespace EntityFrameworkCore.Projectables.Tests.Services
);
var subject = new ProjectableExpressionReplacer(resolver);
var actual = subject.Visit(input);
var actual = subject.Replace(input);
Assert.Equal(expected.ToString(), actual.ToString());
}
@@ -77,7 +77,7 @@ namespace EntityFrameworkCore.Projectables.Tests.Services
);
var subject = new ProjectableExpressionReplacer(resolver);
var actual = subject.Visit(input);
var actual = subject.Replace(input);
Assert.Equal(expected.ToString(), actual.ToString());
}
@@ -93,7 +93,7 @@ namespace EntityFrameworkCore.Projectables.Tests.Services
);
var subject = new ProjectableExpressionReplacer(resolver);
var actual = subject.Visit(input);
var actual = subject.Replace(input);
Assert.Equal(expected.ToString(), actual.ToString());
}
@@ -109,7 +109,7 @@ namespace EntityFrameworkCore.Projectables.Tests.Services
);
var subject = new ProjectableExpressionReplacer(resolver);
var actual = subject.Visit(input);
var actual = subject.Replace(input);
Assert.Equal(expected.ToString(), actual.ToString());
}
@@ -125,7 +125,7 @@ namespace EntityFrameworkCore.Projectables.Tests.Services
);
var subject = new ProjectableExpressionReplacer(resolver);
var actual = subject.Visit(input);
var actual = subject.Replace(input);
Assert.Equal(expected.ToString(), actual.ToString());
}
@@ -141,7 +141,7 @@ namespace EntityFrameworkCore.Projectables.Tests.Services
);
var subject = new ProjectableExpressionReplacer(resolver);
var actual = subject.Visit(input);
var actual = subject.Replace(input);
Assert.Equal(expected.ToString(), actual.ToString());
}
@@ -157,7 +157,7 @@ namespace EntityFrameworkCore.Projectables.Tests.Services
);
var subject = new ProjectableExpressionReplacer(resolver);
var actual = subject.Visit(input);
var actual = subject.Replace(input);
Assert.Equal(expected.ToString(), actual.ToString());
}