Fixed nullability rewrites with parameters and return types

This commit is contained in:
Koen Bekkenutte
2021-10-18 23:40:43 +08:00
parent 8ccb73d424
commit 8c611b9fd4
11 changed files with 184 additions and 25 deletions
@@ -126,13 +126,7 @@ namespace EntityFrameworkCore.Projectables.Generator
{
if (symbolInfo.Symbol is IMethodSymbol methodSymbol && methodSymbol.IsExtensionMethod)
{
if (SymbolEqualityComparer.Default.Equals(symbolInfo.Symbol.ContainingType, _targetTypeSymbol))
{
//return SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
// methodSymbol.ReducedFrom.Sym
//);
//throw new Exception("foo");
}
}
else if (symbolInfo.Symbol.Kind is SymbolKind.Property or SymbolKind.Method or SymbolKind.Field && SymbolEqualityComparer.Default.Equals(symbolInfo.Symbol.ContainingType, _targetTypeSymbol))
{
@@ -178,5 +172,21 @@ namespace EntityFrameworkCore.Projectables.Generator
return base.VisitQualifiedName(node);
}
public override SyntaxNode? VisitNullableType(NullableTypeSyntax node)
{
var typeInfo = _semanticModel.GetTypeInfo(node);
if (typeInfo.Type is not null)
{
if (typeInfo.Type.TypeKind is not TypeKind.Struct)
{
return Visit(node.ElementType)
.WithLeadingTrivia(node.GetLeadingTrivia())
.WithTrailingTrivia(node.GetTrailingTrivia());
}
}
return base.VisitNullableType(node);
}
}
}
@@ -17,25 +17,48 @@ namespace EntityFrameworkCore.Projectables.Generator
public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node)
{
var symbol = _semanticModel.GetDeclaredSymbol(node);
var visitedNode = base.VisitIdentifierName(node);
var symbol = _semanticModel.GetDeclaredSymbol(visitedNode);
if (symbol is not null)
{
node = SyntaxFactory.IdentifierName(symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat));
return SyntaxFactory.IdentifierName(symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat));
}
return base.VisitIdentifierName(node);
return visitedNode;
}
public override SyntaxNode? VisitParameter(ParameterSyntax node)
{
var thisKeywordIndex = node.Modifiers.IndexOf(SyntaxKind.ThisKeyword);
if (thisKeywordIndex != -1)
var visitedNode = base.VisitParameter(node);
if (visitedNode is ParameterSyntax visitedParameterSyntax)
{
node = node.WithModifiers(node.Modifiers.RemoveAt(thisKeywordIndex));
var thisKeywordIndex = visitedParameterSyntax.Modifiers.IndexOf(SyntaxKind.ThisKeyword);
if (thisKeywordIndex != -1)
{
return visitedParameterSyntax.WithModifiers(node.Modifiers.RemoveAt(thisKeywordIndex));
}
}
return base.VisitParameter(node);
return visitedNode;
}
public override SyntaxNode? VisitNullableType(NullableTypeSyntax node)
{
var typeInfo = _semanticModel.GetTypeInfo(node);
if (typeInfo.Type is not null)
{
if (typeInfo.Type.TypeKind is not TypeKind.Struct)
{
return Visit(node.ElementType)
.WithLeadingTrivia(node.GetLeadingTrivia())
.WithTrailingTrivia(node.GetTrailingTrivia());
}
}
return base.VisitNullableType(node);
}
}
}
@@ -100,13 +100,10 @@ namespace EntityFrameworkCore.Projectables.Generator
return null;
}
var returnTypeSymbol = semanticModel.GetSymbolInfo(returnTypeSyntaxRewriter.Visit(methodDeclarationSyntax.ReturnType)).Symbol;
if (returnTypeSymbol is null)
{
return null;
}
var returnType = returnTypeSyntaxRewriter.Visit(methodDeclarationSyntax.ReturnType);
descriptor.ReturnTypeName = returnTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
descriptor.ReturnTypeName = returnType.ToString();
descriptor.Body = expressionSyntaxRewriter.Visit(methodDeclarationSyntax.ExpressionBody.Expression);
foreach (var additionalParameter in ((ParameterListSyntax)parameterSyntaxRewriter.Visit(methodDeclarationSyntax.ParameterList)).Parameters)
{
@@ -25,7 +25,9 @@ namespace EntityFrameworkCore.Projectables.Generator
{
if (typeInfo.Type.TypeKind is not TypeKind.Struct)
{
return Visit(node.ElementType);
return Visit(node.ElementType)
.WithLeadingTrivia(node.GetLeadingTrivia())
.WithTrailingTrivia(node.GetTrailingTrivia());
}
}
@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Linq;
using EntityFrameworkCore.Projectables;
using Foo;
namespace EntityFrameworkCore.Projectables.Generated
#nullable disable
{
public static class Foo_C_NextFoo
{
public static System.Linq.Expressions.Expression<System.Func<List<object> ,List<int?>, List<object>>> Expression =>
(List<object> input,List<int?> nullablePrimitiveArgument) => input;
}
}
@@ -9,7 +9,7 @@ namespace EntityFrameworkCore.Projectables.Generated
{
public static class Foo_EntityExtensions_GetFirstRelatedIgnoreNulls
{
public static System.Linq.Expressions.Expression<System.Func<Entity, global::Foo.EntityExtensions.Entity>> Expression =>
public static System.Linq.Expressions.Expression<System.Func<Entity, Entity>> Expression =>
(Entity entity) => entity.RelatedEntities[0];
}
}
@@ -9,7 +9,7 @@ namespace EntityFrameworkCore.Projectables.Generated
{
public static class Foo_EntityExtensions_GetFirstRelatedIgnoreNulls
{
public static System.Linq.Expressions.Expression<System.Func<Entity, global::Foo.EntityExtensions.Entity>> Expression =>
public static System.Linq.Expressions.Expression<System.Func<Entity, Entity>> Expression =>
(Entity entity) => entity != null ? (entity.RelatedEntities != null ? (entity.RelatedEntities[0]) : null) : null;
}
}
@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Linq;
using EntityFrameworkCore.Projectables;
using Foo;
namespace EntityFrameworkCore.Projectables.Generated
#nullable disable
{
public static class Foo_C_NullableReferenceType
{
public static System.Linq.Expressions.Expression<System.Func<object, string>> Expression =>
(object input) => (string)input;
}
}
@@ -8,7 +8,7 @@ namespace EntityFrameworkCore.Projectables.Generated
{
public static class Foo_C_NextFoo
{
public static System.Linq.Expressions.Expression<System.Func<object? ,int?, object>> Expression =>
(object? unusedArgument,int? nullablePrimitiveArgument) => null;
public static System.Linq.Expressions.Expression<System.Func<object ,int?, object>> Expression =>
(object unusedArgument,int? nullablePrimitiveArgument) => null;
}
}
@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Linq;
using EntityFrameworkCore.Projectables;
using Foo;
namespace EntityFrameworkCore.Projectables.Generated
#nullable disable
{
public static class Foo_C_NullableValueType
{
public static System.Linq.Expressions.Expression<System.Func<object, int?>> Expression =>
(object input) => (int?)input;
}
}
@@ -487,6 +487,88 @@ namespace Foo {
return Verifier.Verify(result.GeneratedTrees[0].ToString());
}
[Fact]
public Task GenericNullableReferenceTypesAreBeingEliminated()
{
var compilation = CreateCompilation(@"
using System;
using System.Collections.Generic;
using System.Linq;
using EntityFrameworkCore.Projectables;
#nullable enable
namespace Foo {
static class C {
[Projectable]
public static List<object?> NextFoo(this List<object?> input, List<int?> nullablePrimitiveArgument) => input;
}
}
");
var result = RunGenerator(compilation);
Assert.Empty(result.Diagnostics);
Assert.Single(result.GeneratedTrees);
return Verifier.Verify(result.GeneratedTrees[0].ToString());
}
[Fact]
public Task NullableReferenceTypeCastOperatorGetsEliminated()
{
var compilation = CreateCompilation(@"
using System;
using System.Collections.Generic;
using System.Linq;
using EntityFrameworkCore.Projectables;
#nullable enable
namespace Foo {
static class C {
[Projectable]
public static string? NullableReferenceType(object? input) => (string?)input;
}
}
");
var result = RunGenerator(compilation);
Assert.Empty(result.Diagnostics);
Assert.Single(result.GeneratedTrees);
return Verifier.Verify(result.GeneratedTrees[0].ToString());
}
[Fact]
public Task NullableValueCastOperatorsPersist()
{
var compilation = CreateCompilation(@"
using System;
using System.Collections.Generic;
using System.Linq;
using EntityFrameworkCore.Projectables;
#nullable enable
namespace Foo {
static class C {
[Projectable]
public static int? NullableValueType(object? input) => (int?)input;
}
}
");
var result = RunGenerator(compilation);
Assert.Empty(result.Diagnostics);
Assert.Single(result.GeneratedTrees);
return Verifier.Verify(result.GeneratedTrees[0].ToString());
}
[Fact]
public void NullableMemberBinding_WithoutSupport_IsBeingReported()
{