diff --git a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/EntityFrameworkCore.Projectables.Benchmarks.csproj b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/EntityFrameworkCore.Projectables.Benchmarks.csproj
index 24a30b2..f13aec1 100644
--- a/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/EntityFrameworkCore.Projectables.Benchmarks.csproj
+++ b/benchmarks/EntityFrameworkCore.Projectables.Benchmarks/EntityFrameworkCore.Projectables.Benchmarks.csproj
@@ -7,7 +7,7 @@
-
+
diff --git a/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/CustomQueryCompiler.cs b/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/CustomQueryCompiler.cs
new file mode 100644
index 0000000..5a300a2
--- /dev/null
+++ b/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/CustomQueryCompiler.cs
@@ -0,0 +1,39 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Linq.Expressions;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using System.Transactions;
+using EntityFrameworkCore.Projectables.Services;
+using Microsoft.EntityFrameworkCore.Query;
+using Microsoft.EntityFrameworkCore.Query.Internal;
+
+namespace EntityFrameworkCore.Projectables.Infrastructure.Internal
+{
+ [System.Diagnostics.CodeAnalysis.SuppressMessage("Usage", "EF1001:Internal EF Core API usage.", Justification = "Needed")]
+ public sealed class CustomQueryCompiler : IQueryCompiler
+ {
+ readonly IQueryCompiler _decoratedQueryCompiler;
+ readonly ProjectableExpressionReplacer _projectableExpressionReplacer;
+
+ public CustomQueryCompiler(IQueryCompiler decoratedQueryCompiler)
+ {
+ _decoratedQueryCompiler = decoratedQueryCompiler;
+ _projectableExpressionReplacer = new ProjectableExpressionReplacer(new ProjectionExpressionResolver());
+ }
+
+ public Func CreateCompiledAsyncQuery(Expression query)
+ => _decoratedQueryCompiler.CreateCompiledAsyncQuery(Expand(query));
+ public Func CreateCompiledQuery(Expression query)
+ => _decoratedQueryCompiler.CreateCompiledQuery(Expand(query));
+ public TResult Execute(Expression query)
+ => _decoratedQueryCompiler.Execute(Expand(query));
+ public TResult ExecuteAsync(Expression query, CancellationToken cancellationToken)
+ => _decoratedQueryCompiler.ExecuteAsync(Expand(query), cancellationToken);
+
+ Expression Expand(Expression expression)
+ => _projectableExpressionReplacer.Visit(expression);
+ }
+}
diff --git a/src/EntityFrameworkCore.Projectables/Services/ExpressionArgumentReplacer.cs b/src/EntityFrameworkCore.Projectables/Services/ExpressionArgumentReplacer.cs
index 203e03a..1862274 100644
--- a/src/EntityFrameworkCore.Projectables/Services/ExpressionArgumentReplacer.cs
+++ b/src/EntityFrameworkCore.Projectables/Services/ExpressionArgumentReplacer.cs
@@ -9,21 +9,11 @@ namespace EntityFrameworkCore.Projectables.Services
{
public sealed class ExpressionArgumentReplacer : ExpressionVisitor
{
- readonly IEnumerable<(ParameterExpression parameter, Expression argument)>? _parameterArgumentMapping;
-
- public ExpressionArgumentReplacer(IEnumerable<(ParameterExpression, Expression)>? parameterArgumentMapping = null)
- {
- _parameterArgumentMapping = parameterArgumentMapping;
- }
+ public Dictionary ParameterArgumentMapping { get; } = new();
protected override Expression VisitParameter(ParameterExpression node)
{
- var mappedArgument = _parameterArgumentMapping?
- .Where(x => x.parameter == node)
- .Select(x => x.argument)
- .FirstOrDefault();
-
- if (mappedArgument is not null)
+ if (ParameterArgumentMapping.TryGetValue(node, out var mappedArgument))
{
return mappedArgument;
}
diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs
index 5013dd1..3188e69 100644
--- a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs
+++ b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs
@@ -1,43 +1,63 @@
using System;
using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
+using System.Reflection;
using System.Text;
using System.Threading.Tasks;
+using System.Xml.Linq;
namespace EntityFrameworkCore.Projectables.Services
{
public sealed class ProjectableExpressionReplacer : ExpressionVisitor
{
readonly IProjectionExpressionResolver _resolver;
+ readonly ExpressionArgumentReplacer _expressionArgumentReplacer = new();
+ readonly Dictionary _projectableMemberCache = new();
public ProjectableExpressionReplacer(IProjectionExpressionResolver projectionExpressionResolver)
{
_resolver = projectionExpressionResolver;
}
+ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out LambdaExpression? reflectedExpression)
+ {
+ if (!_projectableMemberCache.TryGetValue(memberInfo, out reflectedExpression))
+ {
+ var projectableAttribute = memberInfo.GetCustomAttribute(false);
+
+ reflectedExpression = projectableAttribute is not null
+ ? _resolver.FindGeneratedExpression(memberInfo)
+ : (LambdaExpression?)null;
+
+ _projectableMemberCache.Add(memberInfo, reflectedExpression);
+ }
+
+ return reflectedExpression is not null;
+ }
+
protected override Expression VisitMethodCall(MethodCallExpression node)
{
- if (node.Method.GetCustomAttributes(false).OfType().Any())
+ if (TryGetReflectedExpression(node.Method, out var reflectedExpression))
{
- var reflectedExpression = _resolver.FindGeneratedExpression(node.Method);
-
- var parameterArgumentMapping = node.Object is not null
- ? Enumerable.Repeat((reflectedExpression.Parameters[0], node.Object), 1)
- : Enumerable.Empty<(ParameterExpression, Expression)>();
-
- if (reflectedExpression.Parameters.Count > 0)
+ for (var parameterIndex = 0; parameterIndex < reflectedExpression.Parameters.Count; parameterIndex++)
{
- parameterArgumentMapping = parameterArgumentMapping.Concat(
- node.Object is not null
- ? reflectedExpression.Parameters.Skip(1).Zip(node.Arguments, (parameter, argument) => (parameter, argument))
- : reflectedExpression.Parameters.Zip(node.Arguments, (parameter, argument) => (parameter, argument))
- );
- }
+ var parameterExpession = reflectedExpression.Parameters[parameterIndex];
+ var mappedArgumentExpression = (parameterIndex, node.Object) switch {
+ (0, not null) => node.Object,
+ (_, not null) => node.Arguments[parameterIndex - 1],
+ (_, null) => node.Arguments[parameterIndex]
+ };
+
+ _expressionArgumentReplacer.ParameterArgumentMapping.Add(parameterExpession, mappedArgumentExpression);
+ }
+
+ var updatedBody = _expressionArgumentReplacer.Visit(reflectedExpression.Body);
+ _expressionArgumentReplacer.ParameterArgumentMapping.Clear();
- var expressionArgumentReplacer = new ExpressionArgumentReplacer(parameterArgumentMapping);
return Visit(
- expressionArgumentReplacer.Visit(reflectedExpression.Body)
+ updatedBody
);
}
@@ -46,17 +66,16 @@ namespace EntityFrameworkCore.Projectables.Services
protected override Expression VisitMember(MemberExpression node)
{
- if (node.Member.GetCustomAttributes(false).OfType().Any())
+ if (TryGetReflectedExpression(node.Member, out var reflectedExpression))
{
- var reflectedExpression = _resolver.FindGeneratedExpression(node.Member);
-
if (node.Expression is not null)
{
- var expressionArgumentReplacer = new ExpressionArgumentReplacer(
- Enumerable.Repeat((reflectedExpression.Parameters[0], node.Expression), 1)
- );
+ _expressionArgumentReplacer.ParameterArgumentMapping.Add(reflectedExpression.Parameters[0], node.Expression);
+ var updatedBody = _expressionArgumentReplacer.Visit(reflectedExpression.Body);
+ _expressionArgumentReplacer.ParameterArgumentMapping.Clear();
+
return Visit(
- expressionArgumentReplacer.Visit(reflectedExpression.Body)
+ updatedBody
);
}
else
diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/ExtensionsMethods/ExtensionMethodTests.cs b/tests/EntityFrameworkCore.Projectables.FunctionalTests/ExtensionsMethods/ExtensionMethodTests.cs
index 723ceee..2c1c5d5 100644
--- a/tests/EntityFrameworkCore.Projectables.FunctionalTests/ExtensionsMethods/ExtensionMethodTests.cs
+++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/ExtensionsMethods/ExtensionMethodTests.cs
@@ -52,7 +52,7 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests.ExtensionMethods
[Fact]
public Task ExtensionMethodAcceptingDbContext()
{
- using var dbContext = new SampleDbContext(Infrastructure.CompatibilityMode.Full);
+ using var dbContext = new SampleDbContext();
var sampleQuery = dbContext.Set()
.Select(x => dbContext.Set().Where(y => y.Id > x.Id).FirstOrDefault());
diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/Generics/MultipleGenericsTests.cs b/tests/EntityFrameworkCore.Projectables.FunctionalTests/Generics/MultipleGenericsTests.cs
index d6029b1..56cba75 100644
--- a/tests/EntityFrameworkCore.Projectables.FunctionalTests/Generics/MultipleGenericsTests.cs
+++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/Generics/MultipleGenericsTests.cs
@@ -28,7 +28,6 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests.Generics
return Verifier.Verify(query.ToQueryString());
}
-
[Fact]
public void MultipleInvocations()
{
diff --git a/tests/EntityFrameworkCore.Projectables.Tests/Services/ExpressionArgumentReplacerTests.cs b/tests/EntityFrameworkCore.Projectables.Tests/Services/ExpressionArgumentReplacerTests.cs
index 30de92e..a8ba3a6 100644
--- a/tests/EntityFrameworkCore.Projectables.Tests/Services/ExpressionArgumentReplacerTests.cs
+++ b/tests/EntityFrameworkCore.Projectables.Tests/Services/ExpressionArgumentReplacerTests.cs
@@ -16,7 +16,11 @@ namespace EntityFrameworkCore.Projectables.Tests.Services
{
var parameter = Expression.Parameter(typeof(int));
var argument = Expression.Constant(1);
- var subject = new ExpressionArgumentReplacer(new[] { (parameter, (Expression)argument) });
+ var subject = new ExpressionArgumentReplacer() {
+ ParameterArgumentMapping = {
+ { parameter, argument }
+ }
+ };
var result = subject.Visit(parameter);