diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/EntityFrameworkCore.Projectables.Benchmarks.csproj b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/EntityFrameworkCore.Projectables.Benchmarks.csproj index 24a30b2..f13aec1 100644 --- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/EntityFrameworkCore.Projectables.Benchmarks.csproj +++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/EntityFrameworkCore.Projectables.Benchmarks.csproj @@ -7,7 +7,7 @@ - + diff --git a/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/CustomQueryCompiler.cs b/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/CustomQueryCompiler.cs new file mode 100644 index 0000000..5a300a2 --- /dev/null +++ b/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/CustomQueryCompiler.cs @@ -0,0 +1,39 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using System.Transactions; +using EntityFrameworkCore.Projectables.Services; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Query.Internal; + +namespace EntityFrameworkCore.Projectables.Infrastructure.Internal +{ + [System.Diagnostics.CodeAnalysis.SuppressMessage("Usage", "EF1001:Internal EF Core API usage.", Justification = "Needed")] + public sealed class CustomQueryCompiler : IQueryCompiler + { + readonly IQueryCompiler _decoratedQueryCompiler; + readonly ProjectableExpressionReplacer _projectableExpressionReplacer; + + public CustomQueryCompiler(IQueryCompiler decoratedQueryCompiler) + { + _decoratedQueryCompiler = decoratedQueryCompiler; + _projectableExpressionReplacer = new ProjectableExpressionReplacer(new ProjectionExpressionResolver()); + } + + public Func CreateCompiledAsyncQuery(Expression query) + => _decoratedQueryCompiler.CreateCompiledAsyncQuery(Expand(query)); + public Func CreateCompiledQuery(Expression query) + => _decoratedQueryCompiler.CreateCompiledQuery(Expand(query)); + public TResult Execute(Expression query) + => _decoratedQueryCompiler.Execute(Expand(query)); + public TResult ExecuteAsync(Expression query, CancellationToken cancellationToken) + => _decoratedQueryCompiler.ExecuteAsync(Expand(query), cancellationToken); + + Expression Expand(Expression expression) + => _projectableExpressionReplacer.Visit(expression); + } +} diff --git a/src/EntityFrameworkCore.Projectables/Services/ExpressionArgumentReplacer.cs b/src/EntityFrameworkCore.Projectables/Services/ExpressionArgumentReplacer.cs index 203e03a..1862274 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ExpressionArgumentReplacer.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ExpressionArgumentReplacer.cs @@ -9,21 +9,11 @@ namespace EntityFrameworkCore.Projectables.Services { public sealed class ExpressionArgumentReplacer : ExpressionVisitor { - readonly IEnumerable<(ParameterExpression parameter, Expression argument)>? _parameterArgumentMapping; - - public ExpressionArgumentReplacer(IEnumerable<(ParameterExpression, Expression)>? parameterArgumentMapping = null) - { - _parameterArgumentMapping = parameterArgumentMapping; - } + public Dictionary ParameterArgumentMapping { get; } = new(); protected override Expression VisitParameter(ParameterExpression node) { - var mappedArgument = _parameterArgumentMapping? - .Where(x => x.parameter == node) - .Select(x => x.argument) - .FirstOrDefault(); - - if (mappedArgument is not null) + if (ParameterArgumentMapping.TryGetValue(node, out var mappedArgument)) { return mappedArgument; } diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs index 5013dd1..3188e69 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs @@ -1,43 +1,63 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Linq.Expressions; +using System.Reflection; using System.Text; using System.Threading.Tasks; +using System.Xml.Linq; namespace EntityFrameworkCore.Projectables.Services { public sealed class ProjectableExpressionReplacer : ExpressionVisitor { readonly IProjectionExpressionResolver _resolver; + readonly ExpressionArgumentReplacer _expressionArgumentReplacer = new(); + readonly Dictionary _projectableMemberCache = new(); public ProjectableExpressionReplacer(IProjectionExpressionResolver projectionExpressionResolver) { _resolver = projectionExpressionResolver; } + bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out LambdaExpression? reflectedExpression) + { + if (!_projectableMemberCache.TryGetValue(memberInfo, out reflectedExpression)) + { + var projectableAttribute = memberInfo.GetCustomAttribute(false); + + reflectedExpression = projectableAttribute is not null + ? _resolver.FindGeneratedExpression(memberInfo) + : (LambdaExpression?)null; + + _projectableMemberCache.Add(memberInfo, reflectedExpression); + } + + return reflectedExpression is not null; + } + protected override Expression VisitMethodCall(MethodCallExpression node) { - if (node.Method.GetCustomAttributes(false).OfType().Any()) + if (TryGetReflectedExpression(node.Method, out var reflectedExpression)) { - var reflectedExpression = _resolver.FindGeneratedExpression(node.Method); - - var parameterArgumentMapping = node.Object is not null - ? Enumerable.Repeat((reflectedExpression.Parameters[0], node.Object), 1) - : Enumerable.Empty<(ParameterExpression, Expression)>(); - - if (reflectedExpression.Parameters.Count > 0) + for (var parameterIndex = 0; parameterIndex < reflectedExpression.Parameters.Count; parameterIndex++) { - parameterArgumentMapping = parameterArgumentMapping.Concat( - node.Object is not null - ? reflectedExpression.Parameters.Skip(1).Zip(node.Arguments, (parameter, argument) => (parameter, argument)) - : reflectedExpression.Parameters.Zip(node.Arguments, (parameter, argument) => (parameter, argument)) - ); - } + var parameterExpession = reflectedExpression.Parameters[parameterIndex]; + var mappedArgumentExpression = (parameterIndex, node.Object) switch { + (0, not null) => node.Object, + (_, not null) => node.Arguments[parameterIndex - 1], + (_, null) => node.Arguments[parameterIndex] + }; + + _expressionArgumentReplacer.ParameterArgumentMapping.Add(parameterExpession, mappedArgumentExpression); + } + + var updatedBody = _expressionArgumentReplacer.Visit(reflectedExpression.Body); + _expressionArgumentReplacer.ParameterArgumentMapping.Clear(); - var expressionArgumentReplacer = new ExpressionArgumentReplacer(parameterArgumentMapping); return Visit( - expressionArgumentReplacer.Visit(reflectedExpression.Body) + updatedBody ); } @@ -46,17 +66,16 @@ namespace EntityFrameworkCore.Projectables.Services protected override Expression VisitMember(MemberExpression node) { - if (node.Member.GetCustomAttributes(false).OfType().Any()) + if (TryGetReflectedExpression(node.Member, out var reflectedExpression)) { - var reflectedExpression = _resolver.FindGeneratedExpression(node.Member); - if (node.Expression is not null) { - var expressionArgumentReplacer = new ExpressionArgumentReplacer( - Enumerable.Repeat((reflectedExpression.Parameters[0], node.Expression), 1) - ); + _expressionArgumentReplacer.ParameterArgumentMapping.Add(reflectedExpression.Parameters[0], node.Expression); + var updatedBody = _expressionArgumentReplacer.Visit(reflectedExpression.Body); + _expressionArgumentReplacer.ParameterArgumentMapping.Clear(); + return Visit( - expressionArgumentReplacer.Visit(reflectedExpression.Body) + updatedBody ); } else diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/ExtensionsMethods/ExtensionMethodTests.cs b/tests/EntityFrameworkCore.Projectables.FunctionalTests/ExtensionsMethods/ExtensionMethodTests.cs index 723ceee..2c1c5d5 100644 --- a/tests/EntityFrameworkCore.Projectables.FunctionalTests/ExtensionsMethods/ExtensionMethodTests.cs +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/ExtensionsMethods/ExtensionMethodTests.cs @@ -52,7 +52,7 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests.ExtensionMethods [Fact] public Task ExtensionMethodAcceptingDbContext() { - using var dbContext = new SampleDbContext(Infrastructure.CompatibilityMode.Full); + using var dbContext = new SampleDbContext(); var sampleQuery = dbContext.Set() .Select(x => dbContext.Set().Where(y => y.Id > x.Id).FirstOrDefault()); diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/Generics/MultipleGenericsTests.cs b/tests/EntityFrameworkCore.Projectables.FunctionalTests/Generics/MultipleGenericsTests.cs index d6029b1..56cba75 100644 --- a/tests/EntityFrameworkCore.Projectables.FunctionalTests/Generics/MultipleGenericsTests.cs +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/Generics/MultipleGenericsTests.cs @@ -28,7 +28,6 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests.Generics return Verifier.Verify(query.ToQueryString()); } - [Fact] public void MultipleInvocations() { diff --git a/tests/EntityFrameworkCore.Projectables.Tests/Services/ExpressionArgumentReplacerTests.cs b/tests/EntityFrameworkCore.Projectables.Tests/Services/ExpressionArgumentReplacerTests.cs index 30de92e..a8ba3a6 100644 --- a/tests/EntityFrameworkCore.Projectables.Tests/Services/ExpressionArgumentReplacerTests.cs +++ b/tests/EntityFrameworkCore.Projectables.Tests/Services/ExpressionArgumentReplacerTests.cs @@ -16,7 +16,11 @@ namespace EntityFrameworkCore.Projectables.Tests.Services { var parameter = Expression.Parameter(typeof(int)); var argument = Expression.Constant(1); - var subject = new ExpressionArgumentReplacer(new[] { (parameter, (Expression)argument) }); + var subject = new ExpressionArgumentReplacer() { + ParameterArgumentMapping = { + { parameter, argument } + } + }; var result = subject.Visit(parameter);