Skip to content

Commit

Permalink
Fixing issues with ControlFlowAnalyzer on loops (#105)
Browse files Browse the repository at this point in the history
* Preparing change for loops in CFG

* fix build after merging master

* rewrite cfg builder to use BoundTreeWalker
  • Loading branch information
ChrisKXu authored Jul 20, 2024
1 parent 884aeac commit b03a75f
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 117 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
using System.Linq;
using FluentAssertions;
using Todl.Compiler.CodeAnalysis.Binding;
using Todl.Compiler.CodeAnalysis.Syntax;
using Todl.Compiler.CodeAnalysis.Text;
using Todl.Compiler.CodeAnalysis.Binding.BoundTree;
using Todl.Compiler.CodeAnalysis.Binding.ControlFlowAnalysis;
using Todl.Compiler.Diagnostics;
using Xunit;

Expand All @@ -28,18 +27,17 @@ public sealed class ControlFlowAnalysisTests
[InlineData("int func() { let i = 0; while i < 10 { ++i; } return i; }")]
public void TestControlFlowAnalysisBasic(string inputText)
{
var syntaxTree = SyntaxTree.Parse(SourceText.FromString(inputText), TestDefaults.DefaultClrTypeCache);
BoundModule.Create(TestDefaults.DefaultClrTypeCache, new[] { syntaxTree }).GetDiagnostics().Should().BeEmpty();
var function = BindMemberAndAnalyze<BoundFunctionMember>(inputText);
function.GetDiagnostics().Should().BeEmpty();
}

[Theory]
[InlineData("int func() { }")]
[InlineData("int func() { int.MaxValue.ToString(); }")]
public void TestControlFlowAnalysisWithNoReturnStatement(string inputText)
{
var syntaxTree = SyntaxTree.Parse(SourceText.FromString(inputText), TestDefaults.DefaultClrTypeCache);
var module = BoundModule.Create(TestDefaults.DefaultClrTypeCache, new[] { syntaxTree });
var diagnostics = module.GetDiagnostics().ToList();
var function = BindMemberAndAnalyze<BoundFunctionMember>(inputText);
var diagnostics = function.GetDiagnostics().ToList();

diagnostics[0].ErrorCode.Should().Be(ErrorCode.NotAllPathsReturn);
diagnostics[0].Level.Should().Be(DiagnosticLevel.Error);
Expand All @@ -52,9 +50,8 @@ public void TestControlFlowAnalysisWithNoReturnStatement(string inputText)
[InlineData("System.Uri func(string a) { const r = new System.Uri(a); return r; r.ToString(); }")]
public void TestControlFlowAnalysisWithUnreachableCode(string inputText)
{
var syntaxTree = SyntaxTree.Parse(SourceText.FromString(inputText), TestDefaults.DefaultClrTypeCache);
var module = BoundModule.Create(TestDefaults.DefaultClrTypeCache, new[] { syntaxTree });
var diagnostics = module.GetDiagnostics().ToList();
var function = BindMemberAndAnalyze<BoundFunctionMember>(inputText);
var diagnostics = function.GetDiagnostics().ToList();

diagnostics[0].ErrorCode.Should().Be(ErrorCode.UnreachableCode);
diagnostics[0].Level.Should().Be(DiagnosticLevel.Warning);
Expand All @@ -66,9 +63,8 @@ public void TestControlFlowAnalysisWithUnreachableCode(string inputText)
[InlineData("int func() { const a = 3; if a == 0 { return int.MaxValue; } else { if a == 1 { return 1; } } }")]
public void TestControlFlowAnalysisWithConditionalStatements(string inputText)
{
var syntaxTree = SyntaxTree.Parse(SourceText.FromString(inputText), TestDefaults.DefaultClrTypeCache);
var module = BoundModule.Create(TestDefaults.DefaultClrTypeCache, new[] { syntaxTree });
var diagnostics = module.GetDiagnostics().ToList();
var function = BindMemberAndAnalyze<BoundFunctionMember>(inputText);
var diagnostics = function.GetDiagnostics().ToList();

diagnostics[0].ErrorCode.Should().Be(ErrorCode.NotAllPathsReturn);
diagnostics[0].Level.Should().Be(DiagnosticLevel.Error);
Expand All @@ -79,12 +75,19 @@ public void TestControlFlowAnalysisWithConditionalStatements(string inputText)
[InlineData("int func() { while true { continue; return 1; } }")]
public void TestControlFlowAnalysisWithLoopStatements(string inputText)
{
var syntaxTree = SyntaxTree.Parse(SourceText.FromString(inputText), TestDefaults.DefaultClrTypeCache);
var module = BoundModule.Create(TestDefaults.DefaultClrTypeCache, new[] { syntaxTree });
var diagnostics = module.GetDiagnostics().ToList();
var function = BindMemberAndAnalyze<BoundFunctionMember>(inputText);
var diagnostics = function.GetDiagnostics().ToList();
diagnostics.Should().NotBeEmpty();
diagnostics.Count.Should().Be(1);

diagnostics[0].ErrorCode.Should().Be(ErrorCode.UnreachableCode);
diagnostics[0].Level.Should().Be(DiagnosticLevel.Warning);
}

private static TBoundMember BindMemberAndAnalyze<TBoundMember>(string inputText) where TBoundMember : BoundMember
{
var boundMember = TestUtils.BindMember<TBoundMember>(inputText);
new ControlFlowAnalyzer().Visit(boundMember);
return boundMember;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ private BoundLoopStatement BindWhileUntilStatement(WhileUntilStatement whileUnti
}

return BoundNodeFactory.CreateBoundLoopStatement(
whileUntilStatement,
condition,
negated,
body,
loopBinder.BoundLoopContext,
diagnosticBuilder);
syntaxNode: whileUntilStatement,
condition: condition,
conditionNegated: negated,
body: body,
boundLoopContext: loopBinder.BoundLoopContext,
diagnosticBuilder: diagnosticBuilder);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ public static BoundUnaryOperator Create(

public partial class Binder
{
private BoundExpression BindUnaryExpression(UnaryExpression unaryExpression)
private BoundUnaryExpression BindUnaryExpression(UnaryExpression unaryExpression)
{
var diagnosticBuilder = new DiagnosticBag.Builder();
var boundOperand = BindExpression(unaryExpression.Operand);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Todl.Compiler.CodeAnalysis.Binding.BoundTree;

Expand All @@ -15,15 +16,12 @@ internal sealed class ControlFlowGraph
internal static ControlFlowGraph Create(BoundFunctionMember boundFunctionMember)
{
var builder = new Builder();
foreach (var statement in boundFunctionMember.Body.Statements)
{
builder.AddStatement(statement);
}
boundFunctionMember.Accept(builder);

return builder.Build();
}

private sealed class Builder
private sealed class Builder : BoundTreeWalker
{
private readonly List<BasicBlock> blocks = new();
private readonly List<BasicBlockBranch> branches = new();
Expand All @@ -35,94 +33,109 @@ private sealed class Builder

private BasicBlock current = new();

public void AddStatement(BoundStatement boundStatement)
public override BoundNode DefaultVisit(BoundNode node)
{
if (node is BoundStatement boundStatement)
{
current.Statements.Add(boundStatement);
}

return base.DefaultVisit(node);
}

public override BoundNode VisitBoundReturnStatement(BoundReturnStatement boundReturnStatement)
{
current.Statements.Add(boundReturnStatement);
StartNewBlock(endBlock);

return boundReturnStatement;
}

public override BoundNode VisitBoundBlockStatement(BoundBlockStatement boundBlockStatement)
{
if (!boundBlockStatement.Statements.Any())
{
current.Statements.Add(new BoundNoOpStatement());
return boundBlockStatement;
}

return base.VisitBoundBlockStatement(boundBlockStatement);
}

public override BoundNode VisitBoundConditionalStatement(BoundConditionalStatement boundConditionalStatement)
{
switch (boundStatement)
var begin = current;

StartNewBlock(endBlock);
Connect(begin, current);
Visit(boundConditionalStatement.Consequence);
var consequence = current;

StartNewBlock(endBlock);
Connect(begin, current);
Visit(boundConditionalStatement.Alternative);
var alternative = current;

StartNewBlock(endBlock);
Connect(consequence, current);
Connect(alternative, current);

if (boundConditionalStatement.Consequence is BoundNoOpStatement || boundConditionalStatement.Alternative is BoundNoOpStatement)
{
case BoundReturnStatement:
current.Statements.Add(boundStatement);
StartNewBlock(endBlock);
break;
case BoundBlockStatement boundBlockStatement:
if (!boundBlockStatement.Statements.Any())
{
AddStatement(new BoundNoOpStatement());
}
else
{
foreach (var innerStatement in boundBlockStatement.Statements)
{
AddStatement(innerStatement);
}
}
break;
case BoundConditionalStatement boundConditionalStatement:
{
var begin = current;

StartNewBlock(endBlock);
Connect(begin, current);
AddStatement(boundConditionalStatement.Consequence);
var consequence = current;

StartNewBlock(endBlock);
Connect(begin, current);
AddStatement(boundConditionalStatement.Alternative);
var alternative = current;

StartNewBlock(endBlock);
Connect(consequence, current);
Connect(alternative, current);

if (boundConditionalStatement.Consequence is BoundNoOpStatement || boundConditionalStatement.Alternative is BoundNoOpStatement)
{
AddStatement(new BoundNoOpStatement());
}
break;
}
case BoundLoopStatement boundLoopStatement:
{
var begin = current;

StartNewBlock(endBlock);
Connect(begin, current);

var end = new BasicBlock();
Connect(begin, end);
loopBlocks[boundLoopStatement.BoundLoopContext] = (begin, end);

AddStatement(boundLoopStatement.Body);
var body = current;

StartNewBlock(end);
Connect(body, current);
Connect(begin, current);

current = end;

break;
}
case BoundBreakStatement boundBreakStatement:
{
current.Statements.Add(boundStatement);
var (_, end) = loopBlocks[boundBreakStatement.BoundLoopContext];

StartNewBlock(end);
break;
}
case BoundContinueStatement boundContinueStatement:
{
current.Statements.Add(boundStatement);
var (begin, end) = loopBlocks[boundContinueStatement.BoundLoopContext];

Connect(current, begin);
StartNewBlock(end);
break;
}
default:
current.Statements.Add(boundStatement);
break;
current.Statements.Add(new BoundNoOpStatement());
}

return boundConditionalStatement;
}

public override BoundNode VisitBoundLoopStatement(BoundLoopStatement boundLoopStatement)
{
var begin = current;

StartNewBlock(endBlock);
Connect(begin, current);

var end = new BasicBlock();
Connect(begin, end);
loopBlocks[boundLoopStatement.BoundLoopContext] = (begin, end);

Visit(boundLoopStatement.Body);
var body = current;

StartNewBlock(end);
Connect(body, current);
Connect(begin, current);

current = end;

return boundLoopStatement;
}

public override BoundNode VisitBoundBreakStatement(BoundBreakStatement boundBreakStatement)
{
current.Statements.Add(boundBreakStatement);
var (_, end) = loopBlocks[boundBreakStatement.BoundLoopContext];

StartNewBlock(end);

return boundBreakStatement;
}

public override BoundNode VisitBoundContinueStatement(BoundContinueStatement boundContinueStatement)
{
current.Statements.Add(boundContinueStatement);
var (begin, end) = loopBlocks[boundContinueStatement.BoundLoopContext];

Connect(current, begin);
StartNewBlock(end);

return boundContinueStatement;
}

public override BoundNode VisitBoundExpressionStatement(BoundExpressionStatement boundExpressionStatement)
{
current.Statements.Add(boundExpressionStatement);
return boundExpressionStatement;
}

private void Connect(BasicBlock from, BasicBlock to)
Expand Down Expand Up @@ -185,6 +198,7 @@ public ControlFlowGraph Build()
}
}

[DebuggerDisplay("{GetDebuggerDisplay()}")]
internal sealed class BasicBlock
{
public List<BoundStatement> Statements { get; } = new();
Expand Down Expand Up @@ -218,7 +232,18 @@ public bool IsReturn
}

public bool Reachable => Incoming.Any();

public string GetDebuggerDisplay()
{
if (!Statements.Any())
{
return "[Empty]";
}

return Statements[0].SyntaxNode.Text.ToString();
}
}

[DebuggerDisplay("{From} ==> {To}")]
internal sealed record BasicBlockBranch(BasicBlock From, BasicBlock To);
}
7 changes: 6 additions & 1 deletion src/Todl.Compiler/Diagnostics/Diagnostic.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
using System.Diagnostics;
using Todl.Compiler.CodeAnalysis.Text;

namespace Todl.Compiler.Diagnostics
{
public sealed class Diagnostic
[DebuggerDisplay("{GetDebuggerDisplay()}")]
public readonly struct Diagnostic
{
public string Message { get; init; }
public DiagnosticLevel Level { get; init; }
public TextLocation TextLocation { get; init; }
public ErrorCode ErrorCode { get; init; }

public string GetDebuggerDisplay()
=> $"ErrorCode = {ErrorCode}, Text = \"{TextLocation.TextSpan}\"";
}
}
3 changes: 1 addition & 2 deletions src/Todl.Compiler/Diagnostics/DiagnosticBag.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.Immutable;
Expand Down Expand Up @@ -54,7 +54,6 @@ public void AddRange<TKey, TValue>(IReadOnlyDictionary<TKey, TValue> diagnosable
private DiagnosticBag(IEnumerable<Diagnostic> unsortedDiagnostics)
{
diagnostics = unsortedDiagnostics
.Where(d => d != null)
.OrderBy(d => d.Level)
.ThenBy(d => d.TextLocation.TextSpan.Start)
.ToImmutableList();
Expand Down

0 comments on commit b03a75f

Please sign in to comment.