diff --git a/src/EntityFrameworkCore.Projectables.Generator/ExpressionSyntaxRewriter.cs b/src/EntityFrameworkCore.Projectables.Generator/ExpressionSyntaxRewriter.cs index f5215c6..535b7ce 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/ExpressionSyntaxRewriter.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/ExpressionSyntaxRewriter.cs @@ -72,6 +72,9 @@ namespace EntityFrameworkCore.Projectables.Generator { var diagnostic = Diagnostic.Create(Diagnostics.NullConditionalRewriteUnsupported, node.GetLocation(), node); _context.ReportDiagnostic(diagnostic); + + // Return the original node, do not attempt further rewrites + return node; } else if (_nullConditionalRewriteSupport is NullConditionalRewriteSupport.Ignore) @@ -112,34 +115,34 @@ namespace EntityFrameworkCore.Projectables.Generator public override SyntaxNode? VisitMemberBindingExpression(MemberBindingExpressionSyntax node) { - if (_conditionalAccessExpressionsStack.Count == 0) + if (_conditionalAccessExpressionsStack.Count > 0) { - throw new InvalidOperationException("Expected at least one conditional expression on the stack"); + var targetExpression = _conditionalAccessExpressionsStack.Pop(); + + return _nullConditionalRewriteSupport switch { + NullConditionalRewriteSupport.Ignore => SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, targetExpression, node.Name), + NullConditionalRewriteSupport.Rewrite => SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, targetExpression, node.Name), + _ => node + }; } - var targetExpression = _conditionalAccessExpressionsStack.Pop(); - - return _nullConditionalRewriteSupport switch { - NullConditionalRewriteSupport.Ignore => SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, targetExpression, node.Name), - NullConditionalRewriteSupport.Rewrite => SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, targetExpression, node.Name), - _ => node - }; + return base.VisitMemberBindingExpression(node); } public override SyntaxNode? VisitElementBindingExpression(ElementBindingExpressionSyntax node) { - if (_conditionalAccessExpressionsStack.Count == 0) + if (_conditionalAccessExpressionsStack.Count > 0) { - throw new InvalidOperationException("Expected at least one conditional expression on the stack"); + var targetExpression = _conditionalAccessExpressionsStack.Pop(); + + return _nullConditionalRewriteSupport switch { + NullConditionalRewriteSupport.Ignore => SyntaxFactory.ElementAccessExpression(targetExpression, node.ArgumentList), + NullConditionalRewriteSupport.Rewrite => SyntaxFactory.ElementAccessExpression(targetExpression, node.ArgumentList), + _ => Visit(node) + }; } - var targetExpression = _conditionalAccessExpressionsStack.Pop(); - - return _nullConditionalRewriteSupport switch { - NullConditionalRewriteSupport.Ignore => SyntaxFactory.ElementAccessExpression(targetExpression, node.ArgumentList), - NullConditionalRewriteSupport.Rewrite => SyntaxFactory.ElementAccessExpression(targetExpression, node.ArgumentList), - _ => Visit(node) - }; + return base.VisitElementBindingExpression(node); } public override SyntaxNode? VisitThisExpression(ThisExpressionSyntax node) diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.cs b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.cs index 11cfe89..b79842d 100644 --- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.cs +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.cs @@ -644,6 +644,69 @@ namespace Foo { Assert.Equal("EFP0002", diagnostic.Id); } + [Fact] + public void NullableMemberBinding_UndefinedSupport_IsBeingReported() + { + var compilation = CreateCompilation(@" +using System; +using System.Linq; +using EntityFrameworkCore.Projectables; + +namespace Foo { + static class C { + [Projectable] + public static int? GetLength(this string input) => input?.Length; + } +} +"); + var result = RunGenerator(compilation); + + var diagnostic = Assert.Single(result.Diagnostics); + Assert.Equal("EFP0002", diagnostic.Id); + } + + + [Fact] + public void MultiLevelNullableMemberBinding_UndefinedSupport_IsBeingReported() + { + var compilation = CreateCompilation(@" +using System; +using System.Linq; +using EntityFrameworkCore.Projectables; + +namespace Foo { + public record Address + { + public int Id { get; set; } + public string? Country { get; set; } + } + + public record Party + { + public int Id { get; set; } + + public Address? Address { get; set; } + } + + public record Entity + { + public int Id { get; set; } + + public Party? Left { get; set; } + public Party? Right { get; set; } + + [Projectable] + public bool IsSameCountry => Left?.Address?.Country == Right?.Address?.Country; + } +} +"); + var result = RunGenerator(compilation); + + Assert.All(result.Diagnostics, diagnostic => { + Assert.Equal("EFP0002", diagnostic.Id); + }); + } + [Fact] public Task NullableMemberBinding_WithIgnoreSupport_IsBeingRewritten() {