Support conditional access rewriting for value types

This commit is contained in:
Koen
2022-11-06 15:59:01 +00:00
parent 552b593083
commit 3e2bf5c6ae
7 changed files with 78 additions and 28 deletions
@@ -74,28 +74,40 @@ namespace EntityFrameworkCore.Projectables.Generator
_context.ReportDiagnostic(diagnostic); _context.ReportDiagnostic(diagnostic);
} }
return _nullConditionalRewriteSupport switch { else if (_nullConditionalRewriteSupport is NullConditionalRewriteSupport.Ignore)
NullConditionalRewriteSupport.Ignore => Visit(node.WhenNotNull), {
NullConditionalRewriteSupport.Rewrite => // Ignore the conditional accesss and simply visit the WhenNotNull expression
SyntaxFactory.ConditionalExpression( return Visit(node.WhenNotNull);
}
else if (_nullConditionalRewriteSupport is NullConditionalRewriteSupport.Rewrite)
{
var whenNotNullSymbol = _semanticModel.GetSymbolInfo(node.WhenNotNull).Symbol as IPropertySymbol;
var typeInfo = _semanticModel.GetTypeInfo(node);
// Do not translate until we can resolve the target type
if (typeInfo.ConvertedType is not null)
{
// Translate null-conditional into a conditional expression
return SyntaxFactory.ConditionalExpression(
SyntaxFactory.BinaryExpression( SyntaxFactory.BinaryExpression(
SyntaxKind.NotEqualsExpression, SyntaxKind.NotEqualsExpression,
targetExpression targetExpression.WithTrailingTrivia(SyntaxFactory.Whitespace(" ")),
.WithTrailingTrivia(SyntaxFactory.Whitespace(" ")), SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression).WithLeadingTrivia(SyntaxFactory.Whitespace(" "))
).WithTrailingTrivia(SyntaxFactory.Whitespace(" ")),
SyntaxFactory.ParenthesizedExpression(
(ExpressionSyntax)Visit(node.WhenNotNull)
).WithLeadingTrivia(SyntaxFactory.Whitespace(" ")).WithTrailingTrivia(SyntaxFactory.Whitespace(" ")),
SyntaxFactory.CastExpression(
SyntaxFactory.ParseName(typeInfo.ConvertedType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)),
SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression) SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression)
.WithLeadingTrivia(SyntaxFactory.Whitespace(" ")) ).WithLeadingTrivia(SyntaxFactory.Whitespace(" "))
) ).WithLeadingTrivia(node.GetLeadingTrivia()).WithTrailingTrivia(node.GetTrailingTrivia());
.WithTrailingTrivia(SyntaxFactory.Whitespace(" ")), }
SyntaxFactory.ParenthesizedExpression( }
(ExpressionSyntax)Visit(node.WhenNotNull)
) return base.VisitConditionalAccessExpression(node);
.WithLeadingTrivia(SyntaxFactory.Whitespace(" "))
.WithTrailingTrivia(SyntaxFactory.Whitespace(" ")),
SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression)
.WithLeadingTrivia(SyntaxFactory.Whitespace(" "))
),
_ => base.VisitConditionalAccessExpression(node)
};
} }
public override SyntaxNode? VisitMemberBindingExpression(MemberBindingExpressionSyntax node) public override SyntaxNode? VisitMemberBindingExpression(MemberBindingExpressionSyntax node)
@@ -147,15 +159,13 @@ namespace EntityFrameworkCore.Projectables.Generator
var symbol = _semanticModel.GetSymbolInfo(node).Symbol; var symbol = _semanticModel.GetSymbolInfo(node).Symbol;
if (symbol is not null) if (symbol is not null)
{ {
var operation = node switch { var operation = node switch { { Parent: { } parent } when parent.IsKind(SyntaxKind.InvocationExpression) => _semanticModel.GetOperation(node.Parent),
{ Parent: { } parent } when parent.IsKind(SyntaxKind.InvocationExpression) => _semanticModel.GetOperation(node.Parent),
_ => _semanticModel.GetOperation(node!) _ => _semanticModel.GetOperation(node!)
}; };
if (operation is IMemberReferenceOperation memberReferenceOperation) if (operation is IMemberReferenceOperation memberReferenceOperation)
{ {
var memberAccessCanBeQualified = node switch { var memberAccessCanBeQualified = node switch { { Parent: { Parent: { } parent } } when parent.IsKind(SyntaxKind.ObjectInitializerExpression) => false,
{ Parent: { Parent: { } parent } } when parent.IsKind(SyntaxKind.ObjectInitializerExpression) => false,
_ => true _ => true
}; };
@@ -169,7 +179,7 @@ namespace EntityFrameworkCore.Projectables.Generator
node.WithoutLeadingTrivia() node.WithoutLeadingTrivia()
).WithLeadingTrivia(node.GetLeadingTrivia()); ).WithLeadingTrivia(node.GetLeadingTrivia());
} }
// if this operation is targeting a static member on our targetType implicitly // if this operation is targeting a static member on our targetType implicitly
if (memberReferenceOperation.Instance is null && SymbolEqualityComparer.Default.Equals(memberReferenceOperation.Member.ContainingType, _targetTypeSymbol)) if (memberReferenceOperation.Instance is null && SymbolEqualityComparer.Default.Equals(memberReferenceOperation.Member.ContainingType, _targetTypeSymbol))
{ {
@@ -0,0 +1,16 @@
// <auto-generated/>
using EntityFrameworkCore.Projectables;
namespace EntityFrameworkCore.Projectables.Generated
#nullable disable
{
[System.ComponentModel.EditorBrowsable(System.ComponentModel.EditorBrowsableState.Never)]
public static class _Foo_SomeNumber
{
public static System.Linq.Expressions.Expression<System.Func<global::Foo, int>> Expression()
{
return (global::Foo fancyClass) =>
fancyClass != null ? (fancyClass.FancyNumber ) : (int?)null ?? 3;
}
}
}
@@ -14,7 +14,7 @@ namespace EntityFrameworkCore.Projectables.Generated
public static System.Linq.Expressions.Expression<System.Func<global::Foo.EntityExtensions.Entity, global::Foo.EntityExtensions.Entity>> Expression() public static System.Linq.Expressions.Expression<System.Func<global::Foo.EntityExtensions.Entity, global::Foo.EntityExtensions.Entity>> Expression()
{ {
return (global::Foo.EntityExtensions.Entity entity) => return (global::Foo.EntityExtensions.Entity entity) =>
entity != null ? (entity.RelatedEntities != null ? (entity.RelatedEntities[0]) : null) : null; entity != null ? (entity.RelatedEntities != null ? (entity.RelatedEntities[0]) : (global::Foo.EntityExtensions.Entity)null) : (global::Foo.EntityExtensions.Entity)null;
} }
} }
} }
@@ -13,7 +13,7 @@ namespace EntityFrameworkCore.Projectables.Generated
public static System.Linq.Expressions.Expression<System.Func<string, int?>> Expression() public static System.Linq.Expressions.Expression<System.Func<string, int?>> Expression()
{ {
return (string input) => return (string input) =>
input != null ? (input.Length) : null; input != null ? (input.Length) : (int?)null;
} }
} }
} }
@@ -14,7 +14,7 @@ namespace EntityFrameworkCore.Projectables.Generated
public static System.Linq.Expressions.Expression<System.Func<global::Foo.EntityExtensions.Entity, string>> Expression() public static System.Linq.Expressions.Expression<System.Func<global::Foo.EntityExtensions.Entity, string>> Expression()
{ {
return (global::Foo.EntityExtensions.Entity entity) => return (global::Foo.EntityExtensions.Entity entity) =>
entity.FullName != null ? (entity.FullName.Substring(entity.FullName != null ? (entity.FullName.IndexOf(' ') ) : null?? 0)) : null; entity.FullName != null ? (entity.FullName.Substring(entity.FullName != null ? (entity.FullName.IndexOf(' ') ) : (int?)null ?? 0)) : (string)null;
} }
} }
} }
@@ -13,7 +13,7 @@ namespace EntityFrameworkCore.Projectables.Generated
public static System.Linq.Expressions.Expression<System.Func<string, char?>> Expression() public static System.Linq.Expressions.Expression<System.Func<string, char?>> Expression()
{ {
return (string input) => return (string input) =>
input != null ? (input[0]) : null; input != null ? (input[0]) : (char?)null;
} }
} }
} }
@@ -1542,6 +1542,30 @@ namespace One.Two {
return Verifier.Verify(result.GeneratedTrees[0].ToString()); return Verifier.Verify(result.GeneratedTrees[0].ToString());
} }
[Fact]
public Task NullConditionalNullCoalesceTypeConversion()
{
// issue: https://github.com/koenbeuk/EntityFrameworkCore.Projectables/issues/48
var compilation = CreateCompilation(@"
using EntityFrameworkCore.Projectables;
class Foo {
public int? FancyNumber { get; set; }
[Projectable(NullConditionalRewriteSupport = NullConditionalRewriteSupport.Rewrite)]
public static int SomeNumber(Foo fancyClass) => fancyClass?.FancyNumber ?? 3;
}
");
var result = RunGenerator(compilation);
Assert.Empty(result.Diagnostics);
Assert.Single(result.GeneratedTrees);
return Verifier.Verify(result.GeneratedTrees[0].ToString());
}
#region Helpers #region Helpers
Compilation CreateCompilation(string source, bool expectedToCompile = true) Compilation CreateCompilation(string source, bool expectedToCompile = true)