mirror of
https://github.com/zoriya/EntityFrameworkCore.Projectables.git
synced 2025-12-06 05:56:10 +00:00
Merge pull request #51 from koenbeuk/issue-48
Support conditional access rewriting for value types
This commit is contained in:
@@ -74,28 +74,40 @@ namespace EntityFrameworkCore.Projectables.Generator
|
||||
_context.ReportDiagnostic(diagnostic);
|
||||
}
|
||||
|
||||
return _nullConditionalRewriteSupport switch {
|
||||
NullConditionalRewriteSupport.Ignore => Visit(node.WhenNotNull),
|
||||
NullConditionalRewriteSupport.Rewrite =>
|
||||
SyntaxFactory.ConditionalExpression(
|
||||
else if (_nullConditionalRewriteSupport is NullConditionalRewriteSupport.Ignore)
|
||||
{
|
||||
// Ignore the conditional accesss and simply visit the WhenNotNull expression
|
||||
return Visit(node.WhenNotNull);
|
||||
}
|
||||
|
||||
else if (_nullConditionalRewriteSupport is NullConditionalRewriteSupport.Rewrite)
|
||||
{
|
||||
var whenNotNullSymbol = _semanticModel.GetSymbolInfo(node.WhenNotNull).Symbol as IPropertySymbol;
|
||||
var typeInfo = _semanticModel.GetTypeInfo(node);
|
||||
|
||||
// Do not translate until we can resolve the target type
|
||||
if (typeInfo.ConvertedType is not null)
|
||||
{
|
||||
// Translate null-conditional into a conditional expression
|
||||
return SyntaxFactory.ConditionalExpression(
|
||||
SyntaxFactory.BinaryExpression(
|
||||
SyntaxKind.NotEqualsExpression,
|
||||
targetExpression
|
||||
.WithTrailingTrivia(SyntaxFactory.Whitespace(" ")),
|
||||
targetExpression.WithTrailingTrivia(SyntaxFactory.Whitespace(" ")),
|
||||
SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression).WithLeadingTrivia(SyntaxFactory.Whitespace(" "))
|
||||
).WithTrailingTrivia(SyntaxFactory.Whitespace(" ")),
|
||||
SyntaxFactory.ParenthesizedExpression(
|
||||
(ExpressionSyntax)Visit(node.WhenNotNull)
|
||||
).WithLeadingTrivia(SyntaxFactory.Whitespace(" ")).WithTrailingTrivia(SyntaxFactory.Whitespace(" ")),
|
||||
SyntaxFactory.CastExpression(
|
||||
SyntaxFactory.ParseName(typeInfo.ConvertedType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)),
|
||||
SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression)
|
||||
.WithLeadingTrivia(SyntaxFactory.Whitespace(" "))
|
||||
)
|
||||
.WithTrailingTrivia(SyntaxFactory.Whitespace(" ")),
|
||||
SyntaxFactory.ParenthesizedExpression(
|
||||
(ExpressionSyntax)Visit(node.WhenNotNull)
|
||||
)
|
||||
.WithLeadingTrivia(SyntaxFactory.Whitespace(" "))
|
||||
.WithTrailingTrivia(SyntaxFactory.Whitespace(" ")),
|
||||
SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression)
|
||||
.WithLeadingTrivia(SyntaxFactory.Whitespace(" "))
|
||||
),
|
||||
_ => base.VisitConditionalAccessExpression(node)
|
||||
};
|
||||
).WithLeadingTrivia(SyntaxFactory.Whitespace(" "))
|
||||
).WithLeadingTrivia(node.GetLeadingTrivia()).WithTrailingTrivia(node.GetTrailingTrivia());
|
||||
}
|
||||
}
|
||||
|
||||
return base.VisitConditionalAccessExpression(node);
|
||||
|
||||
}
|
||||
|
||||
public override SyntaxNode? VisitMemberBindingExpression(MemberBindingExpressionSyntax node)
|
||||
@@ -147,15 +159,13 @@ namespace EntityFrameworkCore.Projectables.Generator
|
||||
var symbol = _semanticModel.GetSymbolInfo(node).Symbol;
|
||||
if (symbol is not null)
|
||||
{
|
||||
var operation = node switch {
|
||||
{ Parent: { } parent } when parent.IsKind(SyntaxKind.InvocationExpression) => _semanticModel.GetOperation(node.Parent),
|
||||
var operation = node switch { { Parent: { } parent } when parent.IsKind(SyntaxKind.InvocationExpression) => _semanticModel.GetOperation(node.Parent),
|
||||
_ => _semanticModel.GetOperation(node!)
|
||||
};
|
||||
|
||||
if (operation is IMemberReferenceOperation memberReferenceOperation)
|
||||
{
|
||||
var memberAccessCanBeQualified = node switch {
|
||||
{ Parent: { Parent: { } parent } } when parent.IsKind(SyntaxKind.ObjectInitializerExpression) => false,
|
||||
var memberAccessCanBeQualified = node switch { { Parent: { Parent: { } parent } } when parent.IsKind(SyntaxKind.ObjectInitializerExpression) => false,
|
||||
_ => true
|
||||
};
|
||||
|
||||
@@ -169,7 +179,7 @@ namespace EntityFrameworkCore.Projectables.Generator
|
||||
node.WithoutLeadingTrivia()
|
||||
).WithLeadingTrivia(node.GetLeadingTrivia());
|
||||
}
|
||||
|
||||
|
||||
// if this operation is targeting a static member on our targetType implicitly
|
||||
if (memberReferenceOperation.Instance is null && SymbolEqualityComparer.Default.Equals(memberReferenceOperation.Member.ContainingType, _targetTypeSymbol))
|
||||
{
|
||||
|
||||
@@ -15,7 +15,7 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests.Helpers
|
||||
{
|
||||
readonly CompatibilityMode _compatibilityMode;
|
||||
|
||||
public SampleDbContext(CompatibilityMode compatibilityMode = CompatibilityMode.Limited)
|
||||
public SampleDbContext(CompatibilityMode compatibilityMode = CompatibilityMode.Full)
|
||||
{
|
||||
_compatibilityMode = compatibilityMode;
|
||||
}
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
SELECT 1 + 1
|
||||
SELECT 2
|
||||
FROM [Concrete] AS [c]
|
||||
@@ -1,2 +1,2 @@
|
||||
SELECT 1 + 1
|
||||
SELECT 2
|
||||
FROM [Concrete] AS [c]
|
||||
@@ -1,2 +1,2 @@
|
||||
SELECT 1 + 1
|
||||
SELECT 2
|
||||
FROM [Concrete] AS [c]
|
||||
@@ -1,2 +1,2 @@
|
||||
SELECT 1 + 1
|
||||
SELECT 2
|
||||
FROM [Concrete] AS [c]
|
||||
@@ -1,2 +1,2 @@
|
||||
SELECT CAST(LEN([e].[Name]) AS int)
|
||||
SELECT [e].[Name]
|
||||
FROM [Entity] AS [e]
|
||||
@@ -18,7 +18,7 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests.NullConditionals
|
||||
using var dbContext = new SampleDbContext<Entity>();
|
||||
|
||||
var query = dbContext.Set<Entity>()
|
||||
.Select(x => x.GetNameLengthRewriteNulls());
|
||||
.Select(x => x.GetNameRewriteNulls());
|
||||
|
||||
return Verifier.Verify(query.ToQueryString());
|
||||
}
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
// <auto-generated/>
|
||||
using EntityFrameworkCore.Projectables;
|
||||
|
||||
namespace EntityFrameworkCore.Projectables.Generated
|
||||
#nullable disable
|
||||
{
|
||||
[System.ComponentModel.EditorBrowsable(System.ComponentModel.EditorBrowsableState.Never)]
|
||||
public static class _Foo_SomeNumber
|
||||
{
|
||||
public static System.Linq.Expressions.Expression<System.Func<global::Foo, int>> Expression()
|
||||
{
|
||||
return (global::Foo fancyClass) =>
|
||||
fancyClass != null ? (fancyClass.FancyNumber ) : (int?)null ?? 3;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,7 @@ namespace EntityFrameworkCore.Projectables.Generated
|
||||
public static System.Linq.Expressions.Expression<System.Func<global::Foo.EntityExtensions.Entity, global::Foo.EntityExtensions.Entity>> Expression()
|
||||
{
|
||||
return (global::Foo.EntityExtensions.Entity entity) =>
|
||||
entity != null ? (entity.RelatedEntities != null ? (entity.RelatedEntities[0]) : null) : null;
|
||||
entity != null ? (entity.RelatedEntities != null ? (entity.RelatedEntities[0]) : (global::Foo.EntityExtensions.Entity)null) : (global::Foo.EntityExtensions.Entity)null;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -13,7 +13,7 @@ namespace EntityFrameworkCore.Projectables.Generated
|
||||
public static System.Linq.Expressions.Expression<System.Func<string, string>> Expression()
|
||||
{
|
||||
return (string input) =>
|
||||
input != null ? (input[0].ToString()) : null;
|
||||
input != null ? (input[0].ToString()) : (string)null;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -13,7 +13,7 @@ namespace EntityFrameworkCore.Projectables.Generated
|
||||
public static System.Linq.Expressions.Expression<System.Func<string, int?>> Expression()
|
||||
{
|
||||
return (string input) =>
|
||||
input != null ? (input.Length) : null;
|
||||
input != null ? (input.Length) : (int?)null;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,7 @@ namespace EntityFrameworkCore.Projectables.Generated
|
||||
public static System.Linq.Expressions.Expression<System.Func<global::Foo.EntityExtensions.Entity, string>> Expression()
|
||||
{
|
||||
return (global::Foo.EntityExtensions.Entity entity) =>
|
||||
entity.FullName != null ? (entity.FullName.Substring(entity.FullName != null ? (entity.FullName.IndexOf(' ') ) : null?? 0)) : null;
|
||||
entity.FullName != null ? (entity.FullName.Substring(entity.FullName != null ? (entity.FullName.IndexOf(' ') ) : (int?)null ?? 0)) : (string)null;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -13,7 +13,7 @@ namespace EntityFrameworkCore.Projectables.Generated
|
||||
public static System.Linq.Expressions.Expression<System.Func<string, char?>> Expression()
|
||||
{
|
||||
return (string input) =>
|
||||
input != null ? (input[0]) : null;
|
||||
input != null ? (input[0]) : (char?)null;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1542,6 +1542,30 @@ namespace One.Two {
|
||||
return Verifier.Verify(result.GeneratedTrees[0].ToString());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public Task NullConditionalNullCoalesceTypeConversion()
|
||||
{
|
||||
// issue: https://github.com/koenbeuk/EntityFrameworkCore.Projectables/issues/48
|
||||
|
||||
var compilation = CreateCompilation(@"
|
||||
using EntityFrameworkCore.Projectables;
|
||||
|
||||
class Foo {
|
||||
public int? FancyNumber { get; set; }
|
||||
|
||||
[Projectable(NullConditionalRewriteSupport = NullConditionalRewriteSupport.Rewrite)]
|
||||
public static int SomeNumber(Foo fancyClass) => fancyClass?.FancyNumber ?? 3;
|
||||
}
|
||||
");
|
||||
|
||||
var result = RunGenerator(compilation);
|
||||
|
||||
Assert.Empty(result.Diagnostics);
|
||||
Assert.Single(result.GeneratedTrees);
|
||||
|
||||
return Verifier.Verify(result.GeneratedTrees[0].ToString());
|
||||
}
|
||||
|
||||
#region Helpers
|
||||
|
||||
Compilation CreateCompilation(string source, bool expectedToCompile = true)
|
||||
|
||||
Reference in New Issue
Block a user