diff --git a/src/EntityFrameworkCore.Projectables.Abstractions/ProjectableAttribute.cs b/src/EntityFrameworkCore.Projectables.Abstractions/ProjectableAttribute.cs
index 248076a..4c3082b 100644
--- a/src/EntityFrameworkCore.Projectables.Abstractions/ProjectableAttribute.cs
+++ b/src/EntityFrameworkCore.Projectables.Abstractions/ProjectableAttribute.cs
@@ -17,5 +17,11 @@ namespace EntityFrameworkCore.Projectables
/// Get or set how null-conditional operators are handeled
///
public NullConditionalRewriteSupport NullConditionalRewriteSupport { get; set; }
+
+ ///
+ /// Get or set from which member to get the expression,
+ /// or null to get it from the current member.
+ ///
+ public string? UseMemberBody { get; set; }
}
}
diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectableInterpreter.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectableInterpreter.cs
index 3755611..ac1b9c2 100644
--- a/src/EntityFrameworkCore.Projectables.Generator/ProjectableInterpreter.cs
+++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectableInterpreter.cs
@@ -51,6 +51,70 @@ namespace EntityFrameworkCore.Projectables.Generator
.Cast()
.FirstOrDefault();
+ var useMemberBody = projectableAttributeClass.NamedArguments
+ .Where(x => x.Key == "UseMemberBody")
+ .Select(x => x.Value.Value)
+ .OfType()
+ .FirstOrDefault();
+
+ var memberBody = member;
+
+ if (useMemberBody is not null)
+ {
+ var comparer = SymbolEqualityComparer.Default;
+
+ memberBody = memberSymbol.ContainingType.GetMembers(useMemberBody)
+ .Where(x =>
+ {
+ if (memberSymbol is IMethodSymbol symbolMethod &&
+ x is IMethodSymbol xMethod &&
+ comparer.Equals(symbolMethod.ReturnType, xMethod.ReturnType) &&
+ symbolMethod.TypeArguments.Length == xMethod.TypeArguments.Length &&
+ !symbolMethod.TypeArguments.Zip(xMethod.TypeArguments, (a, b) => !comparer.Equals(a, b)).Any())
+ {
+ return true;
+ }
+ else if (memberSymbol is IPropertySymbol symbolProperty &&
+ x is IPropertySymbol xProperty &&
+ comparer.Equals(symbolProperty.Type, xProperty.Type))
+ {
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+ })
+ .SelectMany(x => x.DeclaringSyntaxReferences)
+ .Select(x => x.GetSyntax())
+ .OfType()
+ .FirstOrDefault(x =>
+ {
+ if (x == null ||
+ x.SyntaxTree != member.SyntaxTree ||
+ x.Modifiers.Any(SyntaxKind.StaticKeyword) != member.Modifiers.Any(SyntaxKind.StaticKeyword))
+ {
+ return false;
+ }
+ else if (x is MethodDeclarationSyntax xMethod &&
+ xMethod.ExpressionBody is not null)
+ {
+ return true;
+ }
+ else if (x is PropertyDeclarationSyntax xProperty &&
+ xProperty.ExpressionBody is not null)
+ {
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+ });
+
+ if (memberBody is null) return null;
+ }
+
var expressionSyntaxRewriter = new ExpressionSyntaxRewriter(memberSymbol.ContainingType, nullConditionalRewriteSupport, compilation, semanticModel, context);
var declarationSyntaxRewriter = new DeclarationSyntaxRewriter(semanticModel);
@@ -93,7 +157,7 @@ namespace EntityFrameworkCore.Projectables.Generator
descriptor.TargetNestedInClassNames = descriptor.NestedInClassNames;
}
- if (member is MethodDeclarationSyntax methodDeclarationSyntax)
+ if (memberBody is MethodDeclarationSyntax methodDeclarationSyntax)
{
if (methodDeclarationSyntax.ExpressionBody is null)
{
@@ -125,7 +189,7 @@ namespace EntityFrameworkCore.Projectables.Generator
.Select(x => (TypeParameterConstraintClauseSyntax)declarationSyntaxRewriter.Visit(x));
}
}
- else if (member is PropertyDeclarationSyntax propertyDeclarationSyntax)
+ else if (memberBody is PropertyDeclarationSyntax propertyDeclarationSyntax)
{
if (propertyDeclarationSyntax.ExpressionBody is null)
{
diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs
index 699ea3f..ef82ed7 100644
--- a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs
+++ b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs
@@ -22,27 +22,55 @@ namespace EntityFrameworkCore.Projectables.Services
_ => null
};
- var expressionFactoryMethod = reflectedType.Assembly
- .GetTypes()
- .Where(x => x.FullName == generatedContainingTypeName)
- .SelectMany(x => x.GetMethods())
- .FirstOrDefault();
+ var expressionFactoryMethod = reflectedType.Assembly.GetType(generatedContainingTypeName)
+ ?.GetMethods()
+ ?.FirstOrDefault();
- if (expressionFactoryMethod is null)
+ if (expressionFactoryMethod is not null)
{
- throw new InvalidOperationException("Unable to resolve generated expression") {
- Data = {
- ["GeneratedContainingTypeName"] = generatedContainingTypeName
+ if (genericArguments is { Length: > 0 })
+ {
+ expressionFactoryMethod = expressionFactoryMethod.MakeGenericMethod(genericArguments);
+ }
+
+ return expressionFactoryMethod.Invoke(null, null) as LambdaExpression ?? throw new InvalidOperationException("Expected lambda");
+ }
+
+ var useMemberBody = projectableMemberInfo.GetCustomAttribute()?.UseMemberBody;
+
+ if (useMemberBody is not null)
+ {
+ var exprProperty = reflectedType.GetProperty(useMemberBody, BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic);
+ var lambda = exprProperty?.GetValue(null) as LambdaExpression;
+
+ if (lambda is not null)
+ {
+ if (projectableMemberInfo is PropertyInfo property &&
+ lambda.Parameters.Count == 1 &&
+ lambda.Parameters[0].Type == reflectedType && lambda.ReturnType == property.PropertyType)
+ {
+ return lambda;
}
- };
+ else if (projectableMemberInfo is MethodInfo method &&
+ lambda.Parameters.Count == method.GetParameters().Length + 1 &&
+ lambda.Parameters.Last().Type == reflectedType &&
+ !lambda.Parameters.Zip(method.GetParameters(), (a, b) => a.Type != b.ParameterType).Any())
+ {
+ return lambda;
+ }
+ }
}
- if (genericArguments is { Length: > 0 } )
- {
- expressionFactoryMethod = expressionFactoryMethod.MakeGenericMethod(genericArguments);
- }
+ var fullName = string.Join(".", Enumerable.Empty()
+ .Concat(new[] { reflectedType.Namespace })
+ .Concat(reflectedType.GetNestedTypePath().Select(x => x.Name))
+ .Concat(new[] { projectableMemberInfo.Name }));
- return expressionFactoryMethod.Invoke(null, null) as LambdaExpression ?? throw new InvalidOperationException("Expected lambda");
+ throw new InvalidOperationException($"Unable to resolve generated expression for {fullName}.") {
+ Data = {
+ ["GeneratedContainingTypeName"] = generatedContainingTypeName
+ }
+ };
}
}
}
diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/UseMemberBodyPropertyTests.UseMemberPropertyExpression.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/UseMemberBodyPropertyTests.UseMemberPropertyExpression.verified.txt
new file mode 100644
index 0000000..91b8106
--- /dev/null
+++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/UseMemberBodyPropertyTests.UseMemberPropertyExpression.verified.txt
@@ -0,0 +1,2 @@
+SELECT [e].[Id] * 3
+FROM [Entity] AS [e]
\ No newline at end of file
diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/UseMemberBodyPropertyTests.UseMemberPropertyGenerated.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/UseMemberBodyPropertyTests.UseMemberPropertyGenerated.verified.txt
new file mode 100644
index 0000000..028694e
--- /dev/null
+++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/UseMemberBodyPropertyTests.UseMemberPropertyGenerated.verified.txt
@@ -0,0 +1,2 @@
+SELECT [e].[Id] * 2
+FROM [Entity] AS [e]
\ No newline at end of file
diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/UseMemberBodyPropertyTests.cs b/tests/EntityFrameworkCore.Projectables.FunctionalTests/UseMemberBodyPropertyTests.cs
new file mode 100644
index 0000000..3b62b72
--- /dev/null
+++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/UseMemberBodyPropertyTests.cs
@@ -0,0 +1,57 @@
+using System;
+using System.Collections.Generic;
+using System.Data.SqlTypes;
+using System.IO;
+using System.Linq;
+using System.Linq.Expressions;
+using System.Text;
+using System.Threading.Tasks;
+using EntityFrameworkCore.Projectables.FunctionalTests.Helpers;
+using Microsoft.EntityFrameworkCore;
+using ScenarioTests;
+using VerifyXunit;
+using Xunit;
+
+namespace EntityFrameworkCore.Projectables.FunctionalTests
+{
+ [UsesVerify]
+ public class UseMemberBodyPropertyTests
+ {
+ public record Entity
+ {
+ public int Id { get; set; }
+
+ [Projectable(UseMemberBody = nameof(Computed2))]
+ public int Computed1 => Id;
+
+ private int Computed2 => Id * 2;
+
+ [Projectable(UseMemberBody = nameof(Computed4))]
+ public int Computed3 => Id;
+
+ private static Expression> Computed4 => x => x.Id * 3;
+ }
+
+ [Fact]
+ public Task UseMemberPropertyGenerated()
+ {
+ using var dbContext = new SampleDbContext();
+
+ var query = dbContext.Set()
+ .Select(x => x.Computed1);
+
+ return Verifier.Verify(query.ToQueryString());
+ }
+
+ [Fact]
+ public Task UseMemberPropertyExpression()
+ {
+ using var dbContext = new SampleDbContext();
+
+ var query = dbContext.Set()
+ .Select(x => x.Computed3);
+
+ return Verifier.Verify(query.ToQueryString());
+ }
+ }
+}