diff --git a/src/EntityFrameworkCore.Projectables.Abstractions/ProjectableAttribute.cs b/src/EntityFrameworkCore.Projectables.Abstractions/ProjectableAttribute.cs index 248076a..4c3082b 100644 --- a/src/EntityFrameworkCore.Projectables.Abstractions/ProjectableAttribute.cs +++ b/src/EntityFrameworkCore.Projectables.Abstractions/ProjectableAttribute.cs @@ -17,5 +17,11 @@ namespace EntityFrameworkCore.Projectables /// Get or set how null-conditional operators are handeled /// public NullConditionalRewriteSupport NullConditionalRewriteSupport { get; set; } + + /// + /// Get or set from which member to get the expression, + /// or null to get it from the current member. + /// + public string? UseMemberBody { get; set; } } } diff --git a/src/EntityFrameworkCore.Projectables.Generator/ProjectableInterpreter.cs b/src/EntityFrameworkCore.Projectables.Generator/ProjectableInterpreter.cs index 3755611..ac1b9c2 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/ProjectableInterpreter.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/ProjectableInterpreter.cs @@ -51,6 +51,70 @@ namespace EntityFrameworkCore.Projectables.Generator .Cast() .FirstOrDefault(); + var useMemberBody = projectableAttributeClass.NamedArguments + .Where(x => x.Key == "UseMemberBody") + .Select(x => x.Value.Value) + .OfType() + .FirstOrDefault(); + + var memberBody = member; + + if (useMemberBody is not null) + { + var comparer = SymbolEqualityComparer.Default; + + memberBody = memberSymbol.ContainingType.GetMembers(useMemberBody) + .Where(x => + { + if (memberSymbol is IMethodSymbol symbolMethod && + x is IMethodSymbol xMethod && + comparer.Equals(symbolMethod.ReturnType, xMethod.ReturnType) && + symbolMethod.TypeArguments.Length == xMethod.TypeArguments.Length && + !symbolMethod.TypeArguments.Zip(xMethod.TypeArguments, (a, b) => !comparer.Equals(a, b)).Any()) + { + return true; + } + else if (memberSymbol is IPropertySymbol symbolProperty && + x is IPropertySymbol xProperty && + comparer.Equals(symbolProperty.Type, xProperty.Type)) + { + return true; + } + else + { + return false; + } + }) + .SelectMany(x => x.DeclaringSyntaxReferences) + .Select(x => x.GetSyntax()) + .OfType() + .FirstOrDefault(x => + { + if (x == null || + x.SyntaxTree != member.SyntaxTree || + x.Modifiers.Any(SyntaxKind.StaticKeyword) != member.Modifiers.Any(SyntaxKind.StaticKeyword)) + { + return false; + } + else if (x is MethodDeclarationSyntax xMethod && + xMethod.ExpressionBody is not null) + { + return true; + } + else if (x is PropertyDeclarationSyntax xProperty && + xProperty.ExpressionBody is not null) + { + return true; + } + else + { + return false; + } + }); + + if (memberBody is null) return null; + } + var expressionSyntaxRewriter = new ExpressionSyntaxRewriter(memberSymbol.ContainingType, nullConditionalRewriteSupport, compilation, semanticModel, context); var declarationSyntaxRewriter = new DeclarationSyntaxRewriter(semanticModel); @@ -93,7 +157,7 @@ namespace EntityFrameworkCore.Projectables.Generator descriptor.TargetNestedInClassNames = descriptor.NestedInClassNames; } - if (member is MethodDeclarationSyntax methodDeclarationSyntax) + if (memberBody is MethodDeclarationSyntax methodDeclarationSyntax) { if (methodDeclarationSyntax.ExpressionBody is null) { @@ -125,7 +189,7 @@ namespace EntityFrameworkCore.Projectables.Generator .Select(x => (TypeParameterConstraintClauseSyntax)declarationSyntaxRewriter.Visit(x)); } } - else if (member is PropertyDeclarationSyntax propertyDeclarationSyntax) + else if (memberBody is PropertyDeclarationSyntax propertyDeclarationSyntax) { if (propertyDeclarationSyntax.ExpressionBody is null) { diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs index 699ea3f..ef82ed7 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs @@ -22,27 +22,55 @@ namespace EntityFrameworkCore.Projectables.Services _ => null }; - var expressionFactoryMethod = reflectedType.Assembly - .GetTypes() - .Where(x => x.FullName == generatedContainingTypeName) - .SelectMany(x => x.GetMethods()) - .FirstOrDefault(); + var expressionFactoryMethod = reflectedType.Assembly.GetType(generatedContainingTypeName) + ?.GetMethods() + ?.FirstOrDefault(); - if (expressionFactoryMethod is null) + if (expressionFactoryMethod is not null) { - throw new InvalidOperationException("Unable to resolve generated expression") { - Data = { - ["GeneratedContainingTypeName"] = generatedContainingTypeName + if (genericArguments is { Length: > 0 }) + { + expressionFactoryMethod = expressionFactoryMethod.MakeGenericMethod(genericArguments); + } + + return expressionFactoryMethod.Invoke(null, null) as LambdaExpression ?? throw new InvalidOperationException("Expected lambda"); + } + + var useMemberBody = projectableMemberInfo.GetCustomAttribute()?.UseMemberBody; + + if (useMemberBody is not null) + { + var exprProperty = reflectedType.GetProperty(useMemberBody, BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic); + var lambda = exprProperty?.GetValue(null) as LambdaExpression; + + if (lambda is not null) + { + if (projectableMemberInfo is PropertyInfo property && + lambda.Parameters.Count == 1 && + lambda.Parameters[0].Type == reflectedType && lambda.ReturnType == property.PropertyType) + { + return lambda; } - }; + else if (projectableMemberInfo is MethodInfo method && + lambda.Parameters.Count == method.GetParameters().Length + 1 && + lambda.Parameters.Last().Type == reflectedType && + !lambda.Parameters.Zip(method.GetParameters(), (a, b) => a.Type != b.ParameterType).Any()) + { + return lambda; + } + } } - if (genericArguments is { Length: > 0 } ) - { - expressionFactoryMethod = expressionFactoryMethod.MakeGenericMethod(genericArguments); - } + var fullName = string.Join(".", Enumerable.Empty() + .Concat(new[] { reflectedType.Namespace }) + .Concat(reflectedType.GetNestedTypePath().Select(x => x.Name)) + .Concat(new[] { projectableMemberInfo.Name })); - return expressionFactoryMethod.Invoke(null, null) as LambdaExpression ?? throw new InvalidOperationException("Expected lambda"); + throw new InvalidOperationException($"Unable to resolve generated expression for {fullName}.") { + Data = { + ["GeneratedContainingTypeName"] = generatedContainingTypeName + } + }; } } } diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/UseMemberBodyPropertyTests.UseMemberPropertyExpression.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/UseMemberBodyPropertyTests.UseMemberPropertyExpression.verified.txt new file mode 100644 index 0000000..91b8106 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/UseMemberBodyPropertyTests.UseMemberPropertyExpression.verified.txt @@ -0,0 +1,2 @@ +SELECT [e].[Id] * 3 +FROM [Entity] AS [e] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/UseMemberBodyPropertyTests.UseMemberPropertyGenerated.verified.txt b/tests/EntityFrameworkCore.Projectables.FunctionalTests/UseMemberBodyPropertyTests.UseMemberPropertyGenerated.verified.txt new file mode 100644 index 0000000..028694e --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/UseMemberBodyPropertyTests.UseMemberPropertyGenerated.verified.txt @@ -0,0 +1,2 @@ +SELECT [e].[Id] * 2 +FROM [Entity] AS [e] \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/UseMemberBodyPropertyTests.cs b/tests/EntityFrameworkCore.Projectables.FunctionalTests/UseMemberBodyPropertyTests.cs new file mode 100644 index 0000000..3b62b72 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/UseMemberBodyPropertyTests.cs @@ -0,0 +1,57 @@ +using System; +using System.Collections.Generic; +using System.Data.SqlTypes; +using System.IO; +using System.Linq; +using System.Linq.Expressions; +using System.Text; +using System.Threading.Tasks; +using EntityFrameworkCore.Projectables.FunctionalTests.Helpers; +using Microsoft.EntityFrameworkCore; +using ScenarioTests; +using VerifyXunit; +using Xunit; + +namespace EntityFrameworkCore.Projectables.FunctionalTests +{ + [UsesVerify] + public class UseMemberBodyPropertyTests + { + public record Entity + { + public int Id { get; set; } + + [Projectable(UseMemberBody = nameof(Computed2))] + public int Computed1 => Id; + + private int Computed2 => Id * 2; + + [Projectable(UseMemberBody = nameof(Computed4))] + public int Computed3 => Id; + + private static Expression> Computed4 => x => x.Id * 3; + } + + [Fact] + public Task UseMemberPropertyGenerated() + { + using var dbContext = new SampleDbContext(); + + var query = dbContext.Set() + .Select(x => x.Computed1); + + return Verifier.Verify(query.ToQueryString()); + } + + [Fact] + public Task UseMemberPropertyExpression() + { + using var dbContext = new SampleDbContext(); + + var query = dbContext.Set() + .Select(x => x.Computed3); + + return Verifier.Verify(query.ToQueryString()); + } + } +}