Skip to content

Commit

Permalink
Completing implementation of BoundTreeRewriter and migrate ConstantFo…
Browse files Browse the repository at this point in the history
…ldingBoundTreeRewriter (#113)

* Migrate ConstantFoldingBoundNodeVisitor to BoundTreeRewriter pattern

* Completing implementation of BoundTreeRewriter

* rename to ConstantFoldingBoundTreeRewriter
  • Loading branch information
ChrisKXu authored Feb 10, 2025
1 parent 18cb39b commit e1f8318
Show file tree
Hide file tree
Showing 10 changed files with 351 additions and 349 deletions.
40 changes: 31 additions & 9 deletions src/Todl.Compiler.Tests/CodeAnalysis/BinderTests/BoundNodeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using FluentAssertions;
using Todl.Compiler.CodeAnalysis.Binding;
using Todl.Compiler.CodeAnalysis.Binding.BoundTree;
Expand All @@ -19,7 +20,7 @@ public sealed class BoundNodeTests
{
[Theory]
[MemberData(nameof(GetAllSyntaxNodesForTest))]
void BoundNodeShouldHaveCorrectSyntaxNode(SyntaxNode syntaxNode, BoundNode boundNode)
internal void BoundNodeShouldHaveCorrectSyntaxNode(SyntaxNode syntaxNode, BoundNode boundNode)
{
boundNode.SyntaxNode.Should().NotBeNull();
boundNode.SyntaxNode.Should().Be(syntaxNode);
Expand Down Expand Up @@ -75,8 +76,19 @@ public void AllBoundNodeTypesAreEitherAbstractOrSealed()
.NotContain(t => !t.IsAbstract && !t.IsSealed);
}

private static readonly string[] testExpressions = new[]
[Theory]
[MemberData(nameof(GetAllSyntaxNodesForTest))]
internal void AllBoundNodeTypesHaveWalkerAndRewriterImplemented(SyntaxNode _, BoundNode boundNode)
{
var walker = new TestBoundTreeWalker();
var rewriter = new TestBoundTreeRewriter();

boundNode.Accept(walker).Should().Be(boundNode);
boundNode.Accept(rewriter).Should().Be(boundNode);
}

private static readonly string[] testExpressions =
[
"System.Uri", // BoundTypeExpression
"a = 5", // BoundAssignmentExpression
"-10", // BoundUnaryExpression
Expand All @@ -87,10 +99,10 @@ public void AllBoundNodeTypesAreEitherAbstractOrSealed()
"\"abc\".Length", // BoundClrPropertyAccessExpression
"int.MaxValue", // BoundClrFieldAccessExpression
"new System.Exception()" // BoundNewExpression
};
];

private static readonly string[] testStatements = new[]
{
private static readonly string[] testStatements =
[
"const a = 10;", // BoundVariableDeclarationStatement
"a = 10;", // BoundExpressionStatement
"{ const a = 5; a.ToString(); }", // BoundBlockStatement
Expand All @@ -99,7 +111,7 @@ public void AllBoundNodeTypesAreEitherAbstractOrSealed()
"break;", // BreakStatement
"continue;", // ContinueStatement
"while true { }" // WhileUntilStatement
};
];

private static readonly string[] testMembers = new[]
{
Expand All @@ -122,7 +134,7 @@ public static IEnumerable<object[]> GetAllSyntaxNodesForTest()
{
var sourceText = SourceText.FromString("{ const a = 5; a; }");
var blockStatement = SyntaxTree.ParseStatement(sourceText, TestDefaults.DefaultClrTypeCache, diagnosticBuilder);
var binder = Binder.CreateModuleBinder(TestDefaults.DefaultClrTypeCache, diagnosticBuilder);
var binder = Binder.CreateModuleBinder(TestDefaults.DefaultClrTypeCache, TestDefaults.ConstantValueFactory, diagnosticBuilder);
var boundBlockStatement =
binder.BindStatement(blockStatement).As<BoundBlockStatement>();

Expand All @@ -134,15 +146,15 @@ public static IEnumerable<object[]> GetAllSyntaxNodesForTest()
foreach (var inputText in testStatements)
{
var statement = SyntaxTree.ParseStatement(SourceText.FromString(inputText), TestDefaults.DefaultClrTypeCache, diagnosticBuilder);
var binder = Binder.CreateModuleBinder(TestDefaults.DefaultClrTypeCache, diagnosticBuilder);
var binder = Binder.CreateModuleBinder(TestDefaults.DefaultClrTypeCache, TestDefaults.ConstantValueFactory, diagnosticBuilder);
yield return new object[] { statement, binder.BindStatement(statement) };
}

foreach (var inputText in testMembers)
{
var syntaxTree = SyntaxTree.Parse(SourceText.FromString(inputText), TestDefaults.DefaultClrTypeCache, diagnosticBuilder);
var member = syntaxTree.Members[0];
var binder = Binder.CreateModuleBinder(TestDefaults.DefaultClrTypeCache, diagnosticBuilder);
var binder = Binder.CreateModuleBinder(TestDefaults.DefaultClrTypeCache, TestDefaults.ConstantValueFactory, diagnosticBuilder);
if (member is FunctionDeclarationMember functionDeclarationMember)
{
binder.Scope.DeclareFunction(FunctionSymbol.FromFunctionDeclarationMember(functionDeclarationMember));
Expand All @@ -151,4 +163,14 @@ public static IEnumerable<object[]> GetAllSyntaxNodesForTest()
yield return new object[] { member, binder.BindMember(member) };
}
}

private sealed class TestBoundTreeWalker : BoundTreeWalker
{
public override BoundNode DefaultVisit(BoundNode node) => default;
}

private sealed class TestBoundTreeRewriter : BoundTreeRewriter
{
public override BoundNode DefaultVisit(BoundNode node) => default;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
using FluentAssertions;
using Todl.Compiler.CodeAnalysis.Binding;
using Todl.Compiler.CodeAnalysis.Binding.BoundTree;
using Todl.Compiler.CodeAnalysis.Syntax;
using Todl.Compiler.CodeAnalysis.Text;
using Todl.Compiler.Diagnostics;
using Xunit;

namespace Todl.Compiler.Tests.CodeAnalysis;
Expand All @@ -31,15 +28,14 @@ public sealed class ConstantFoldingTests
[InlineData("const a = ~10UL;", ~10UL)]
public void ConstantFoldingUnaryOperatorTest(string inputText, object expectedValue)
{
var diagnosticBuilder = new DiagnosticBag.Builder();
var syntaxTree = SyntaxTree.Parse(SourceText.FromString(inputText), TestDefaults.DefaultClrTypeCache, diagnosticBuilder);
var module = BoundModule.Create(TestDefaults.DefaultClrTypeCache, [syntaxTree], diagnosticBuilder);
diagnosticBuilder.Build().Should().BeEmpty();
var constantFoldingBoundNodeVisitor = new ConstantFoldingBoundTreeRewriter(TestDefaults.ConstantValueFactory);
var boundVariableDeclarationStatement = TestUtils
.BindStatement<BoundVariableDeclarationStatement>(inputText)
.Accept(constantFoldingBoundNodeVisitor)
.As<BoundVariableDeclarationStatement>();

var variableMember = module.EntryPointType.Variables.ToList()[^1].As<BoundVariableMember>();
variableMember.BoundVariableDeclarationStatement.Variable.Constant.Should().Be(true);
var value = variableMember
.BoundVariableDeclarationStatement
boundVariableDeclarationStatement.Variable.Constant.Should().Be(true);
var value = boundVariableDeclarationStatement
.InitializerExpression
.As<BoundConstant>()
.Value;
Expand All @@ -56,20 +52,22 @@ public void ConstantFoldingUnaryOperatorTest(string inputText, object expectedVa
[InlineData("const a = -20;", -20)]
public void BasicConstantFoldingTests(string inputText, object expectedValue)
{
var diagnosticBuilder = new DiagnosticBag.Builder();
var syntaxTree = SyntaxTree.Parse(SourceText.FromString(inputText), TestDefaults.DefaultClrTypeCache, diagnosticBuilder);
var module = BoundModule.Create(TestDefaults.DefaultClrTypeCache, [syntaxTree], diagnosticBuilder);
diagnosticBuilder.Build().Should().BeEmpty();
var constantFoldingBoundNodeVisitor = new ConstantFoldingBoundTreeRewriter(TestDefaults.ConstantValueFactory);
var blockStatement = TestUtils
.BindStatement<BoundBlockStatement>("{ " + inputText + " }")
.Accept(constantFoldingBoundNodeVisitor)
.As<BoundBlockStatement>();

var variableMember = module.EntryPointType.Variables.ToList()[^1].As<BoundVariableMember>();
variableMember.BoundVariableDeclarationStatement.Variable.Constant.Should().Be(true);
var value = variableMember
.BoundVariableDeclarationStatement
var variableDeclarationStatements = blockStatement.Statements.Select(statement => statement.As<BoundVariableDeclarationStatement>());
variableDeclarationStatements.Count().Should().Be(blockStatement.Statements.Count());
variableDeclarationStatements.All(s => s.Variable.Constant).Should().BeTrue();
variableDeclarationStatements
.Last()
.InitializerExpression
.As<BoundConstant>()
.Value;

value.Should().Be(expectedValue);
.Value
.Should()
.Be(expectedValue);
}

[Theory]
Expand All @@ -79,27 +77,28 @@ public void BasicConstantFoldingTests(string inputText, object expectedValue)
[InlineData("const a = 10; let b = a + 10; const c = a + b;")]
public void BasicConstantFoldingNegativeTests(string inputText)
{
var diagnosticBuilder = new DiagnosticBag.Builder();
var syntaxTree = SyntaxTree.Parse(SourceText.FromString(inputText), TestDefaults.DefaultClrTypeCache, diagnosticBuilder);
var module = BoundModule.Create(TestDefaults.DefaultClrTypeCache, [syntaxTree], diagnosticBuilder);
diagnosticBuilder.Build().Should().BeEmpty();
var constantFoldingBoundNodeVisitor = new ConstantFoldingBoundTreeRewriter(TestDefaults.ConstantValueFactory);
var blockStatement = TestUtils
.BindStatement<BoundBlockStatement>("{ " + inputText + " }")
.Accept(constantFoldingBoundNodeVisitor)
.As<BoundBlockStatement>();

var variableMember = module.EntryPointType.Variables.ToList()[^1].As<BoundVariableMember>();
var boundVariableDeclarationStatement = variableMember.BoundVariableDeclarationStatement;
var boundVariableDeclarationStatement = blockStatement.Statements.Last().As<BoundVariableDeclarationStatement>();
boundVariableDeclarationStatement.Variable.Constant.Should().Be(false);
}

[Fact]
public void PartiallyFoldedConstantTests()
{
var diagnosticBuilder = new DiagnosticBag.Builder();
var syntaxTree = SyntaxTree.Parse(SourceText.FromString("let a = 10 + 10;"), TestDefaults.DefaultClrTypeCache, diagnosticBuilder);
var module = BoundModule.Create(TestDefaults.DefaultClrTypeCache, [syntaxTree], diagnosticBuilder);
diagnosticBuilder.Build().Should().BeEmpty();
var inputText = "let a = 10 + 10;";
var constantFoldingBoundNodeVisitor = new ConstantFoldingBoundTreeRewriter(TestDefaults.ConstantValueFactory);
var boundVariableDeclarationStatement = TestUtils
.BindStatement<BoundVariableDeclarationStatement>(inputText)
.Accept(constantFoldingBoundNodeVisitor)
.As<BoundVariableDeclarationStatement>();

var statement = module.EntryPointType.Variables.ToList()[^1].As<BoundVariableMember>().BoundVariableDeclarationStatement;
statement.Variable.Constant.Should().Be(false);
statement.InitializerExpression.Constant.Should().Be(true);
statement.InitializerExpression.As<BoundConstant>().Value.Should().Be(20);
boundVariableDeclarationStatement.Variable.Constant.Should().Be(false);
boundVariableDeclarationStatement.InitializerExpression.Constant.Should().Be(true);
boundVariableDeclarationStatement.InitializerExpression.As<BoundConstant>().Value.Should().Be(20);
}
}
3 changes: 3 additions & 0 deletions src/Todl.Compiler.Tests/TestDefaults.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
using System.Linq;
using System.Reflection;
using Todl.Compiler.CodeAnalysis;
using Todl.Compiler.CodeAnalysis.Binding;

namespace Todl.Compiler.Tests;

static class TestDefaults
{
public static readonly ClrTypeCache DefaultClrTypeCache;
public static readonly ConstantValueFactory ConstantValueFactory;
public static readonly MetadataLoadContext MetadataLoadContext;

static TestDefaults()
Expand All @@ -25,5 +27,6 @@ static TestDefaults()
}

DefaultClrTypeCache = ClrTypeCache.FromAssemblies(assemblies: MetadataLoadContext.GetAssemblies(), MetadataLoadContext.CoreAssembly);
ConstantValueFactory = new ConstantValueFactory(DefaultClrTypeCache.BuiltInTypes);
}
}
6 changes: 3 additions & 3 deletions src/Todl.Compiler.Tests/TestUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ internal static TBoundExpression BindExpression<TBoundExpression>(
where TBoundExpression : BoundExpression
{
var expression = SyntaxTree.ParseExpression(SourceText.FromString(inputText), TestDefaults.DefaultClrTypeCache, diagnosticBuilder);
var binder = Binder.CreateModuleBinder(TestDefaults.DefaultClrTypeCache, diagnosticBuilder);
var binder = Binder.CreateModuleBinder(TestDefaults.DefaultClrTypeCache, TestDefaults.ConstantValueFactory, diagnosticBuilder);
return binder.BindExpression(expression).As<TBoundExpression>();
}

Expand All @@ -41,7 +41,7 @@ internal static TBoundStatement BindStatement<TBoundStatement>(
where TBoundStatement : BoundStatement
{
var statement = SyntaxTree.ParseStatement(SourceText.FromString(inputText), TestDefaults.DefaultClrTypeCache, diagnosticBuilder);
var binder = Binder.CreateModuleBinder(TestDefaults.DefaultClrTypeCache, diagnosticBuilder);
var binder = Binder.CreateModuleBinder(TestDefaults.DefaultClrTypeCache, TestDefaults.ConstantValueFactory, diagnosticBuilder);
return binder.BindStatement(statement).As<TBoundStatement>();
}

Expand All @@ -59,7 +59,7 @@ internal static TBoundMember BindMember<TBoundMember>(
where TBoundMember : BoundMember
{
var syntaxTree = ParseSyntaxTree(inputText);
var binder = Binder.CreateModuleBinder(TestDefaults.DefaultClrTypeCache, diagnosticBuilder);
var binder = Binder.CreateModuleBinder(TestDefaults.DefaultClrTypeCache, TestDefaults.ConstantValueFactory, diagnosticBuilder);
var member = syntaxTree.Members[0];

if (member is FunctionDeclarationMember functionDeclarationMember)
Expand Down
15 changes: 7 additions & 8 deletions src/Todl.Compiler/CodeAnalysis/Binding/BoundModule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,19 @@ public static BoundModule Create(
DiagnosticBag.Builder diagnosticBuilder)
{
syntaxTrees ??= Array.Empty<SyntaxTree>();
var binder = Binder.CreateModuleBinder(clrTypeCache, diagnosticBuilder);
var constantValueFactory = new ConstantValueFactory(clrTypeCache.BuiltInTypes);
var binder = Binder.CreateModuleBinder(clrTypeCache, constantValueFactory, diagnosticBuilder);
var entryPointType = binder.BindEntryPointTypeDefinition(syntaxTrees);

var controlFlowAnalyzer = new ControlFlowAnalyzer(diagnosticBuilder);
entryPointType.Accept(controlFlowAnalyzer);

var boundNodeVisitors = new BoundNodeVisitor[]
var boundTreeVisitors = new BoundTreeVisitor[]
{
new ConstantFoldingBoundNodeVisitor(binder.ConstantValueFactory)
new ControlFlowAnalyzer(diagnosticBuilder),
new ConstantFoldingBoundTreeRewriter(binder.ConstantValueFactory)
};

foreach (var v in boundNodeVisitors)
foreach (var boundTreeVisitor in boundTreeVisitors)
{
entryPointType = (BoundEntryPointTypeDefinition)v.VisitBoundTypeDefinition(entryPointType);
entryPointType = (BoundEntryPointTypeDefinition)entryPointType.Accept(boundTreeVisitor);
}

return new()
Expand Down
Loading

0 comments on commit e1f8318

Please sign in to comment.