Fixed nullability rewrites with parameters and return types

This commit is contained in:
Koen Bekkenutte
2021-10-18 23:40:43 +08:00
parent 8ccb73d424
commit 8c611b9fd4
11 changed files with 184 additions and 25 deletions
@@ -126,13 +126,7 @@ namespace EntityFrameworkCore.Projectables.Generator
{ {
if (symbolInfo.Symbol is IMethodSymbol methodSymbol && methodSymbol.IsExtensionMethod) if (symbolInfo.Symbol is IMethodSymbol methodSymbol && methodSymbol.IsExtensionMethod)
{ {
if (SymbolEqualityComparer.Default.Equals(symbolInfo.Symbol.ContainingType, _targetTypeSymbol))
{
//return SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
// methodSymbol.ReducedFrom.Sym
//);
//throw new Exception("foo");
}
} }
else if (symbolInfo.Symbol.Kind is SymbolKind.Property or SymbolKind.Method or SymbolKind.Field && SymbolEqualityComparer.Default.Equals(symbolInfo.Symbol.ContainingType, _targetTypeSymbol)) else if (symbolInfo.Symbol.Kind is SymbolKind.Property or SymbolKind.Method or SymbolKind.Field && SymbolEqualityComparer.Default.Equals(symbolInfo.Symbol.ContainingType, _targetTypeSymbol))
{ {
@@ -178,5 +172,21 @@ namespace EntityFrameworkCore.Projectables.Generator
return base.VisitQualifiedName(node); return base.VisitQualifiedName(node);
} }
public override SyntaxNode? VisitNullableType(NullableTypeSyntax node)
{
var typeInfo = _semanticModel.GetTypeInfo(node);
if (typeInfo.Type is not null)
{
if (typeInfo.Type.TypeKind is not TypeKind.Struct)
{
return Visit(node.ElementType)
.WithLeadingTrivia(node.GetLeadingTrivia())
.WithTrailingTrivia(node.GetTrailingTrivia());
}
}
return base.VisitNullableType(node);
}
} }
} }
@@ -17,25 +17,48 @@ namespace EntityFrameworkCore.Projectables.Generator
public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node) public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node)
{ {
var symbol = _semanticModel.GetDeclaredSymbol(node); var visitedNode = base.VisitIdentifierName(node);
var symbol = _semanticModel.GetDeclaredSymbol(visitedNode);
if (symbol is not null) if (symbol is not null)
{ {
node = SyntaxFactory.IdentifierName(symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); return SyntaxFactory.IdentifierName(symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat));
} }
return base.VisitIdentifierName(node); return visitedNode;
} }
public override SyntaxNode? VisitParameter(ParameterSyntax node) public override SyntaxNode? VisitParameter(ParameterSyntax node)
{ {
var thisKeywordIndex = node.Modifiers.IndexOf(SyntaxKind.ThisKeyword); var visitedNode = base.VisitParameter(node);
if (thisKeywordIndex != -1)
if (visitedNode is ParameterSyntax visitedParameterSyntax)
{ {
node = node.WithModifiers(node.Modifiers.RemoveAt(thisKeywordIndex)); var thisKeywordIndex = visitedParameterSyntax.Modifiers.IndexOf(SyntaxKind.ThisKeyword);
if (thisKeywordIndex != -1)
{
return visitedParameterSyntax.WithModifiers(node.Modifiers.RemoveAt(thisKeywordIndex));
}
} }
return base.VisitParameter(node); return visitedNode;
}
public override SyntaxNode? VisitNullableType(NullableTypeSyntax node)
{
var typeInfo = _semanticModel.GetTypeInfo(node);
if (typeInfo.Type is not null)
{
if (typeInfo.Type.TypeKind is not TypeKind.Struct)
{
return Visit(node.ElementType)
.WithLeadingTrivia(node.GetLeadingTrivia())
.WithTrailingTrivia(node.GetTrailingTrivia());
}
}
return base.VisitNullableType(node);
} }
} }
} }
@@ -100,13 +100,10 @@ namespace EntityFrameworkCore.Projectables.Generator
return null; return null;
} }
var returnTypeSymbol = semanticModel.GetSymbolInfo(returnTypeSyntaxRewriter.Visit(methodDeclarationSyntax.ReturnType)).Symbol; var returnType = returnTypeSyntaxRewriter.Visit(methodDeclarationSyntax.ReturnType);
if (returnTypeSymbol is null)
{
return null;
}
descriptor.ReturnTypeName = returnTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
descriptor.ReturnTypeName = returnType.ToString();
descriptor.Body = expressionSyntaxRewriter.Visit(methodDeclarationSyntax.ExpressionBody.Expression); descriptor.Body = expressionSyntaxRewriter.Visit(methodDeclarationSyntax.ExpressionBody.Expression);
foreach (var additionalParameter in ((ParameterListSyntax)parameterSyntaxRewriter.Visit(methodDeclarationSyntax.ParameterList)).Parameters) foreach (var additionalParameter in ((ParameterListSyntax)parameterSyntaxRewriter.Visit(methodDeclarationSyntax.ParameterList)).Parameters)
{ {
@@ -25,7 +25,9 @@ namespace EntityFrameworkCore.Projectables.Generator
{ {
if (typeInfo.Type.TypeKind is not TypeKind.Struct) if (typeInfo.Type.TypeKind is not TypeKind.Struct)
{ {
return Visit(node.ElementType); return Visit(node.ElementType)
.WithLeadingTrivia(node.GetLeadingTrivia())
.WithTrailingTrivia(node.GetTrailingTrivia());
} }
} }
@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Linq;
using EntityFrameworkCore.Projectables;
using Foo;
namespace EntityFrameworkCore.Projectables.Generated
#nullable disable
{
public static class Foo_C_NextFoo
{
public static System.Linq.Expressions.Expression<System.Func<List<object> ,List<int?>, List<object>>> Expression =>
(List<object> input,List<int?> nullablePrimitiveArgument) => input;
}
}
@@ -9,7 +9,7 @@ namespace EntityFrameworkCore.Projectables.Generated
{ {
public static class Foo_EntityExtensions_GetFirstRelatedIgnoreNulls public static class Foo_EntityExtensions_GetFirstRelatedIgnoreNulls
{ {
public static System.Linq.Expressions.Expression<System.Func<Entity, global::Foo.EntityExtensions.Entity>> Expression => public static System.Linq.Expressions.Expression<System.Func<Entity, Entity>> Expression =>
(Entity entity) => entity.RelatedEntities[0]; (Entity entity) => entity.RelatedEntities[0];
} }
} }
@@ -9,7 +9,7 @@ namespace EntityFrameworkCore.Projectables.Generated
{ {
public static class Foo_EntityExtensions_GetFirstRelatedIgnoreNulls public static class Foo_EntityExtensions_GetFirstRelatedIgnoreNulls
{ {
public static System.Linq.Expressions.Expression<System.Func<Entity, global::Foo.EntityExtensions.Entity>> Expression => public static System.Linq.Expressions.Expression<System.Func<Entity, Entity>> Expression =>
(Entity entity) => entity != null ? (entity.RelatedEntities != null ? (entity.RelatedEntities[0]) : null) : null; (Entity entity) => entity != null ? (entity.RelatedEntities != null ? (entity.RelatedEntities[0]) : null) : null;
} }
} }
@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Linq;
using EntityFrameworkCore.Projectables;
using Foo;
namespace EntityFrameworkCore.Projectables.Generated
#nullable disable
{
public static class Foo_C_NullableReferenceType
{
public static System.Linq.Expressions.Expression<System.Func<object, string>> Expression =>
(object input) => (string)input;
}
}
@@ -8,7 +8,7 @@ namespace EntityFrameworkCore.Projectables.Generated
{ {
public static class Foo_C_NextFoo public static class Foo_C_NextFoo
{ {
public static System.Linq.Expressions.Expression<System.Func<object? ,int?, object>> Expression => public static System.Linq.Expressions.Expression<System.Func<object ,int?, object>> Expression =>
(object? unusedArgument,int? nullablePrimitiveArgument) => null; (object unusedArgument,int? nullablePrimitiveArgument) => null;
} }
} }
@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Linq;
using EntityFrameworkCore.Projectables;
using Foo;
namespace EntityFrameworkCore.Projectables.Generated
#nullable disable
{
public static class Foo_C_NullableValueType
{
public static System.Linq.Expressions.Expression<System.Func<object, int?>> Expression =>
(object input) => (int?)input;
}
}
@@ -487,6 +487,88 @@ namespace Foo {
return Verifier.Verify(result.GeneratedTrees[0].ToString()); return Verifier.Verify(result.GeneratedTrees[0].ToString());
} }
[Fact]
public Task GenericNullableReferenceTypesAreBeingEliminated()
{
var compilation = CreateCompilation(@"
using System;
using System.Collections.Generic;
using System.Linq;
using EntityFrameworkCore.Projectables;
#nullable enable
namespace Foo {
static class C {
[Projectable]
public static List<object?> NextFoo(this List<object?> input, List<int?> nullablePrimitiveArgument) => input;
}
}
");
var result = RunGenerator(compilation);
Assert.Empty(result.Diagnostics);
Assert.Single(result.GeneratedTrees);
return Verifier.Verify(result.GeneratedTrees[0].ToString());
}
[Fact]
public Task NullableReferenceTypeCastOperatorGetsEliminated()
{
var compilation = CreateCompilation(@"
using System;
using System.Collections.Generic;
using System.Linq;
using EntityFrameworkCore.Projectables;
#nullable enable
namespace Foo {
static class C {
[Projectable]
public static string? NullableReferenceType(object? input) => (string?)input;
}
}
");
var result = RunGenerator(compilation);
Assert.Empty(result.Diagnostics);
Assert.Single(result.GeneratedTrees);
return Verifier.Verify(result.GeneratedTrees[0].ToString());
}
[Fact]
public Task NullableValueCastOperatorsPersist()
{
var compilation = CreateCompilation(@"
using System;
using System.Collections.Generic;
using System.Linq;
using EntityFrameworkCore.Projectables;
#nullable enable
namespace Foo {
static class C {
[Projectable]
public static int? NullableValueType(object? input) => (int?)input;
}
}
");
var result = RunGenerator(compilation);
Assert.Empty(result.Diagnostics);
Assert.Single(result.GeneratedTrees);
return Verifier.Verify(result.GeneratedTrees[0].ToString());
}
[Fact] [Fact]
public void NullableMemberBinding_WithoutSupport_IsBeingReported() public void NullableMemberBinding_WithoutSupport_IsBeingReported()
{ {