Fully qualify extension method calls

This commit is contained in:
Koen
2022-11-05 18:11:28 +00:00
parent 0272bb2f8f
commit 636d0e5c6c
7 changed files with 78 additions and 8 deletions

View File

@@ -16,16 +16,14 @@ namespace EntityFrameworkCore.Projectables.Generator
readonly INamedTypeSymbol _targetTypeSymbol;
readonly SemanticModel _semanticModel;
readonly NullConditionalRewriteSupport _nullConditionalRewriteSupport;
readonly Compilation _compilation;
readonly SourceProductionContext _context;
readonly Stack<ExpressionSyntax> _conditionalAccessExpressionsStack = new();
public ExpressionSyntaxRewriter(INamedTypeSymbol targetTypeSymbol, NullConditionalRewriteSupport nullConditionalRewriteSupport, Compilation compilation, SemanticModel semanticModel, SourceProductionContext context)
public ExpressionSyntaxRewriter(INamedTypeSymbol targetTypeSymbol, NullConditionalRewriteSupport nullConditionalRewriteSupport, SemanticModel semanticModel, SourceProductionContext context)
{
_targetTypeSymbol = targetTypeSymbol;
_nullConditionalRewriteSupport = nullConditionalRewriteSupport;
_semanticModel = semanticModel;
_compilation = compilation;
_context = context;
}
@@ -37,6 +35,33 @@ namespace EntityFrameworkCore.Projectables.Generator
.WithTrailingTrivia(node.GetTrailingTrivia());
}
public override SyntaxNode? VisitInvocationExpression(InvocationExpressionSyntax node)
{
// Fully qualify extension method calls
if (node.Expression is MemberAccessExpressionSyntax memberAccessExpressionSyntax)
{
var symbol = _semanticModel.GetSymbolInfo(node).Symbol;
if (symbol is IMethodSymbol { IsExtensionMethod: true } methodSymbol)
{
return SyntaxFactory.InvocationExpression(
SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.ParseName(methodSymbol.ContainingType.ToDisplayString(NullableFlowState.None, SymbolDisplayFormat.FullyQualifiedFormat)),
memberAccessExpressionSyntax.Name
),
node.ArgumentList.WithArguments(
((ArgumentListSyntax)VisitArgumentList(node.ArgumentList)!).Arguments.Insert(0, SyntaxFactory.Argument(
(ExpressionSyntax)Visit(memberAccessExpressionSyntax.Expression)
)
)
)
);
}
}
return base.VisitInvocationExpression(node);
}
public override SyntaxNode? VisitConditionalAccessExpression(ConditionalAccessExpressionSyntax node)
{
var targetExpression = (ExpressionSyntax)Visit(node.Expression);

View File

@@ -115,7 +115,7 @@ namespace EntityFrameworkCore.Projectables.Generator
if (memberBody is null) return null;
}
var expressionSyntaxRewriter = new ExpressionSyntaxRewriter(memberSymbol.ContainingType, nullConditionalRewriteSupport, compilation, semanticModel, context);
var expressionSyntaxRewriter = new ExpressionSyntaxRewriter(memberSymbol.ContainingType, nullConditionalRewriteSupport, semanticModel, context);
var declarationSyntaxRewriter = new DeclarationSyntaxRewriter(semanticModel);
var descriptor = new ProjectableDescriptor {

View File

@@ -13,7 +13,7 @@ namespace EntityFrameworkCore.Projectables.Generated
public static System.Linq.Expressions.Expression<System.Func<object, object>> Expression()
{
return (object i) =>
i.Foo1();
global::Foo.C.Foo1(i);
}
}
}

View File

@@ -13,7 +13,7 @@ namespace EntityFrameworkCore.Projectables.Generated
public static System.Linq.Expressions.Expression<System.Func<global::Foo.C, global::Foo.D>> Expression()
{
return (global::Foo.C @this) =>
@this.Dees.First();
global::System.Linq.Enumerable.First(@this.Dees);
}
}
}

View File

@@ -0,0 +1,17 @@
// <auto-generated/>
using EntityFrameworkCore.Projectables;
using One.Two;
namespace EntityFrameworkCore.Projectables.Generated
#nullable disable
{
[System.ComponentModel.EditorBrowsable(System.ComponentModel.EditorBrowsableState.Never)]
public static class One_Two_Bar_Method
{
public static System.Linq.Expressions.Expression<System.Func<global::One.Two.Bar, int>> Expression()
{
return (global::One.Two.Bar @this) =>
global::One.IntExtensions.AddOne(1);
}
}
}

View File

@@ -13,7 +13,7 @@ namespace EntityFrameworkCore.Projectables.Generated
public static System.Linq.Expressions.Expression<System.Func<global::Foo.C, int>> Expression()
{
return (global::Foo.C @this) =>
@this.Dees.OfType<global::Foo.D>().Count();
global::System.Linq.Enumerable.Count(global::System.Linq.Enumerable.OfType<D>(@this.Dees));
}
}
}

View File

@@ -1514,6 +1514,34 @@ class Foo {
return Verifier.Verify(result.GeneratedTrees[0].ToString());
}
[Fact]
public Task RequiredNamespace()
{
var compilation = CreateCompilation(@"
using EntityFrameworkCore.Projectables;
namespace One {
static class IntExtensions {
public static int AddOne(this int i) => i + 1;
}
}
namespace One.Two {
class Bar {
[Projectable]
public int Method() => 1.AddOne();
}
}
");
var result = RunGenerator(compilation);
Assert.Empty(result.Diagnostics);
Assert.Single(result.GeneratedTrees);
return Verifier.Verify(result.GeneratedTrees[0].ToString());
}
#region Helpers
Compilation CreateCompilation(string source, bool expectedToCompile = true)