Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tail recursion elimination for GenC (#1275) #1626

Merged
merged 32 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
39d935a
Scaffold for a genc tailrec phase
KacperFKorban Nov 14, 2024
5e4f8c5
Detect tail recursive functions correctly
KacperFKorban Nov 21, 2024
df2ab0a
Draft implementation of local tailrec elimination
KacperFKorban Nov 21, 2024
a2e51b4
Fix tailrec phase for genc; fib works now
KacperFKorban Nov 29, 2024
d56b729
Remove debug from TailRecTransformer.scala
KacperFKorban Nov 29, 2024
f092b33
Implement continue
zhekai-jiang Nov 29, 2024
3c44385
Add bad example with tail recursive call in a loop
zhekai-jiang Nov 29, 2024
c2a8b30
Rename files and classes for test programs
zhekai-jiang Nov 29, 2024
1d01989
Fix new example
zhekai-jiang Nov 29, 2024
3fff420
Add labeled + goto; Implement jumping to label after tail call
KacperFKorban Nov 29, 2024
920b85b
Add value checks for tail recursive examples, add back the full GenCS…
zhekai-jiang Nov 29, 2024
935d770
Support aliased tail recursive calls
zhekai-jiang Dec 5, 2024
f0b7e1a
Change handling almost-tail-recursive functions to a simplification p…
KacperFKorban Dec 9, 2024
b49f92b
Add more tests for tailrec (with one for mutual recursion that is fai…
KacperFKorban Dec 10, 2024
d6cf090
detect mutually recursive functions; fix mutual recursion test
KacperFKorban Dec 12, 2024
5efdcec
Update the functions in the rest of the program when rewriting tailre…
KacperFKorban Dec 12, 2024
1883c7c
Add a test that runs into stack overflow (segmentation fault) without…
zhekai-jiang Dec 12, 2024
0032ae4
Add a simple test to check if tail recursive calls are indeed replace…
zhekai-jiang Dec 12, 2024
d802134
Add test case with unit-type tail-recursive function
zhekai-jiang Dec 26, 2024
395c56c
Remove checks for mutual recursion
zhekai-jiang Jan 2, 2025
03191b3
Add back replacement of function calls
zhekai-jiang Jan 2, 2025
1703fcf
For unit functions, put tail recursive calls in a return statement
zhekai-jiang Jan 2, 2025
46b61af
Rename phase
zhekai-jiang Jan 3, 2025
0ae41de
Make the unit return heuristic handle more cases; copy over some tail…
KacperFKorban Jan 4, 2025
3f77a04
Add one more Unit return test
KacperFKorban Jan 4, 2025
b58efe5
Make test class name consistent with the slides
zhekai-jiang Jan 5, 2025
6559119
Remove unused things
zhekai-jiang Jan 5, 2025
a77c13b
Remove unused things
zhekai-jiang Jan 5, 2025
4aa27ef
Remove unused things
zhekai-jiang Jan 5, 2025
021037e
Add return at the end of Unit functions
zhekai-jiang Jan 5, 2025
f2a60f9
Add measures in genc tailrec test cases, add empty statement after la…
zhekai-jiang Jan 8, 2025
ba2e58b
Merge branch 'main' into genc-tailrec
vkuncak Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/src/main/scala/stainless/genc/CAST.scala
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ object CAST { // C Abstract Syntax Tree

case class Block(exprs: Seq[Expr]) extends Expr // Can be empty

case class Labeled(label: String, block: Expr) extends Expr

case class Lit(lit: Literal) extends Expr

case class EnumLiteral(id: Id) extends Expr
Expand Down Expand Up @@ -212,6 +214,8 @@ object CAST { // C Abstract Syntax Tree
require(cond.isValue, s"Condition ($cond) of while loop must be a value")
}

case class Goto(name: String) extends Expr

case object Break extends Expr

case class Return(value: Expr) extends Expr {
Expand Down
9 changes: 9 additions & 0 deletions core/src/main/scala/stainless/genc/CPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,12 @@ class CPrinter(
|"""
}

case Labeled(name, block) =>
// In C, a label cannot be followed by a variable declaration
// So we add a semicolon to add an empty statement to work around this
c"""|$name: ;
| $block"""

case Lit(lit) => c"$lit"

case EnumLiteral(lit) => c"$lit"
Expand Down Expand Up @@ -319,6 +325,9 @@ class CPrinter(
c"""|while ($cond) {
| $body
|}"""

case Goto(label) =>
c"goto $label"

case Break => c"break"

Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/stainless/genc/GenerateC.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ object GenerateC {
NamedLeonPhase("Lifting", new LiftingPhase) `andThen`
NamedLeonPhase("Referencing", new ReferencingPhase) `andThen`
NamedLeonPhase("StructInlining", new StructInliningPhase) `andThen`
NamedLeonPhase("TailRecElim", new TailRecElimPhase) `andThen`
NamedLeonPhase("IR2C", new IR2CPhase)
}

Expand Down
6 changes: 5 additions & 1 deletion core/src/main/scala/stainless/genc/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ private[genc] sealed trait IR { ir =>
case Binding(vd) => vd.getType
case c: Callable => c.typ
case Block(exprs) => exprs.last.getType
case Labeled(_, block) => block.getType
case MemSet(_, _, _) => NoType
case SizeOf(_) => PrimitiveType(UInt32Type)
case Decl(_, _) => NoType
Expand Down Expand Up @@ -221,6 +222,7 @@ private[genc] sealed trait IR { ir =>
case If(_, _) => NoType
case IfElse(_, thenn, _) => thenn.getType // same as elze
case While(_, _) => NoType
case Goto(_) => NoType
case IsA(_, _) => PrimitiveType(BoolType)
case AsA(_, ct) => ct
case IntegralCast(_, newIntegralType) => PrimitiveType(newIntegralType)
Expand Down Expand Up @@ -260,6 +262,7 @@ private[genc] sealed trait IR { ir =>
case class Block(exprs: Seq[Expr]) extends Expr {
require(exprs.nonEmpty, "GenC IR blocks must be non-empty")
}
case class Labeled(name: String, expr: Expr) extends Expr

case class MemSet(pointer: Expr, value: Expr, size: Expr) extends Expr
case class SizeOf(tpe: Type) extends Expr
Expand Down Expand Up @@ -296,6 +299,7 @@ private[genc] sealed trait IR { ir =>
case class If(cond: Expr, thenn: Expr) extends Expr
case class IfElse(cond: Expr, thenn: Expr, elze: Expr) extends Expr
case class While(cond: Expr, body: Expr) extends Expr
case class Goto(label: String) extends Expr

// Type probindg + casting
case class IsA(expr: Expr, ct: ClassType) extends Expr
Expand Down Expand Up @@ -323,7 +327,6 @@ private[genc] sealed trait IR { ir =>

case object Break extends Expr


/****************************************************************************************************
* Expression Helpers *
****************************************************************************************************/
Expand Down Expand Up @@ -484,6 +487,7 @@ private[genc] sealed trait IR { ir =>
}

object IRs {
object TIR extends IR
object SIR extends IR
object CIR extends IR
object RIR extends IR
Expand Down
5 changes: 5 additions & 0 deletions core/src/main/scala/stainless/genc/ir/IRPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ final class IRPrinter[S <: IR](val ir: S) {
case MemSet(pointer, value, size) => s"memset(${rec(pointer)}, ${rec(value)}, ${rec(size)})"
case SizeOf(tpe) => s"sizeof(${rec(tpe)})"
case Block(exprs) => "{{ " + (exprs map rec mkString ptx.newLine) + " }}"
case Labeled(label, expr) =>
s"""|{{ $label:
| ${rec(expr)} }}""".stripMargin
case Decl(vd, None) => (if (vd.isVar) "var" else "val") + " " + rec(vd)
case Decl(vd, Some(value)) => (if (vd.isVar) "var" else "val") + " " + rec(vd) + " = " + rec(value)
case App(callable, extra, args) =>
Expand All @@ -112,6 +115,8 @@ final class IRPrinter[S <: IR](val ir: S) {
"else {" + ptx.newLine + " " + rec(elze)(using ptx + 1) + ptx.newLine + "}"
case While(cond, body) =>
"while (" + rec(cond) + ") {" + ptx.newLine + " " + rec(body)(using ptx + 1) + ptx.newLine + "}"
case Goto(label) =>
s"goto $label"
case IsA(expr, ct) => "¿" + ct.clazz.id + "?" + rec(expr)
case AsA(expr, ct) => "(" + ct.clazz.id + ")" + rec(expr)
case IntegralCast(expr, newType) => "(" + newType + ")" + rec(expr)
Expand Down
89 changes: 89 additions & 0 deletions core/src/main/scala/stainless/genc/ir/TailRecSimpTransformer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package stainless
package genc
package ir

import PrimitiveTypes.{ PrimitiveType => PT, _ } // For desambiguation
import Literals._
import Operators._
import IRs._
import scala.collection.mutable

final class TailRecSimpTransformer extends Transformer(SIR, SIR) with NoEnv {
import from._

private given givenDebugSection: DebugSectionGenC.type = DebugSectionGenC

/**
* Replace a variable assignment that is immediately
* returned
*
* val i = f(...);
* return i;
*
* ==>
*
* return f(...);
*
*/
private def replaceImmediateReturn(fd: Expr): Expr = {
val transformer = new ir.Transformer(from, to) with NoEnv {
override protected def recImpl(expr: Expr)(using Env): (Expr, Env) = expr match {
case Block(stmts) =>
Block(stmts.zipWithIndex.flatMap {
case (expr @ Decl(id, Some(rhs)), idx) =>
stmts.lift(idx + 1) match {
case Some(Return(Binding(retId))) if retId == id =>
List(Return(rhs))
case _ => List(recImpl(expr)._1)
}
case (expr @ Return(Binding(retId)), idx) =>
stmts.lift(idx - 1) match {
case Some(Decl(id, rhs)) if id == retId =>
Nil
case _ => List(recImpl(expr)._1)
}
case (expr, idx) => List(recImpl(expr)._1)
}) -> ()
case expr => super.recImpl(expr)
}
}
transformer(fd)
}

/**
* Remove all statements after a return statement
*
* return f(...);
* someStmt;
*
* ==>
*
* return f(...);
*
*/
private def removeAfterReturn(fd: Expr): Expr = {
val transformer = new ir.Transformer(from, to) with NoEnv {
override protected def recImpl(expr: Expr)(using Env): (Expr, Env) = expr match {
case Block(stmts) =>
val transformedStmts = stmts.map(recImpl(_)._1)
val firstReturn = transformedStmts.find {
case Return(_) => true
case _ => false
}.toList
val newStmts = transformedStmts.takeWhile {
case Return(_) => false
case _ => true
}
Block(newStmts ++ firstReturn) -> ()
case expr => super.recImpl(expr)
}
}
transformer(fd)
}

override protected def recImpl(fd: Expr)(using Env): (to.Expr, Env) = {
val afterReturn = removeAfterReturn(fd)
val immediateReturn = replaceImmediateReturn(afterReturn)
immediateReturn -> ()
}
}
181 changes: 181 additions & 0 deletions core/src/main/scala/stainless/genc/ir/TailRecTransformer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
package stainless
package genc
package ir

import PrimitiveTypes.{ PrimitiveType => PT, _ } // For desambiguation
import Literals._
import Operators._
import IRs._
import scala.collection.mutable

final class TailRecTransformer(val ctx: inox.Context) extends Transformer(SIR, TIR) with NoEnv {
import from._

private given givenDebugSection: DebugSectionGenC.type = DebugSectionGenC

private given printer.Context = printer.Context(0)

/**
* If the function returns Unit type and the last one statement is a recursive call,
* put the recursive call in a return statement.
*
* Example:
* def countDown(n: Int): Unit =
* if (n == 0) return
* countDown(n - 1)
*
* ==>
*
* def countDown(n: Int): Unit =
* if (n == 0) return
* return countDown(n - 1)
*/
private def putTailRecursiveUnitCallInReturn(fd: FunDef): FunDef = {
def go(expr: Expr): Expr = expr match {
case Block(stmts) if stmts.nonEmpty =>
Block(stmts.init :+ go(stmts.last))
case IfElse(cond, thenn, elze) =>
IfElse(cond, go(thenn), go(elze))
case app @ App(FunVal(calledFd), _, _) if calledFd.id == fd.id =>
Return(app)
case _ => expr
}
fd.body match {
case FunBodyAST(expr) if fd.returnType.isUnitType =>
fd.copy(body = FunBodyAST(go(expr)))
case _ => fd
}
}

private def isTailRecursive(fd: FunDef): Boolean = {
var functionRefs = mutable.ListBuffer.empty[FunDef]
val functionRefVisitor = new ir.Visitor(from) {
override protected def visit(expr: Expr): Unit = expr match {
case FunVal(fd) => functionRefs += fd
case _ =>
}
}
var tailFunctionRefs = mutable.ListBuffer.empty[FunDef]
val tailRecCallVisitor = new ir.Visitor(from) {
override protected def visit(expr: Expr): Unit = expr match {
case Return(App(FunVal(fdcall), _, _)) => tailFunctionRefs += fdcall

case _ =>
}
}
functionRefVisitor(fd)
tailRecCallVisitor(fd)
functionRefs.contains(fd) && functionRefs.filter(_ == fd).size == tailFunctionRefs.filter(_ == fd).size
}

/* Rewrite a tail recursive function to a while loop
* Example:
* def fib(n: Int, i: Int = 0, j: Int = 1): Int =
* if (n == 0)
* return i
* else
* return fib(n-1, j, i+j)
*
* ==>
*
* def fib(n: Int, i: Int = 0, j: Int = 1): Int = {
*
* var n$ = n
* var i$ = i
* var j$ = j
* while (true) {
* someLabel:
* if (n$ == 0) {
* return i$
* } else {
* val n$1 = n$ - 1
* val i$1 = j$
* val j$1 = i$ + j$
* n$ = n$1
* i$ = i$1
* j$ = j$1
* goto someLabel
* }
* }
* }
* Steps:
* - Create a new variable for each parameter of the function
* - Replace existing parameter references with the new variables
* - Create a while loop with a condition true
* - Replace the recursive return with a variable assignments (updating the state) and a continue statement
*/
private def rewriteToAWhileLoop(fd: FunDef): FunDef = fd.body match {
case FunBodyAST(body) =>
val newParams = fd.params.map(p => ValDef(freshId(p.id), p.typ, isVar = true))
val newParamMap = fd.params.zip(newParams).toMap
val labelName = freshId("label")
val bodyWithNewParams = replaceBindings(newParamMap, body)
val bodyWithUnitReturn = bodyWithNewParams match {
case Block(stmts) =>
if fd.returnType.isUnitType then
Block(stmts :+ Return(Lit(UnitLit)))
else
bodyWithNewParams
case _ => bodyWithNewParams
}
val declarations = newParamMap.toList.map { case (old, nw) => Decl(nw, Some(Binding(old))) }
val newBody = replaceRecursiveCalls(fd, bodyWithUnitReturn, newParams.toList, labelName)
val newBodyWithALabel = Labeled(labelName, newBody)
val newBodyWithAWhileLoop = While(True, newBodyWithALabel)
FunDef(fd.id, fd.returnType, fd.ctx, fd.params, FunBodyAST(Block(declarations :+ newBodyWithAWhileLoop)), fd.isExported, fd.isPure)
case _ => fd
}

private def replaceRecursiveCalls(fd: FunDef, body: Expr, valdefs: List[ValDef], labelName: String): Expr = {
val replacer = new Transformer(from, from) with NoEnv {
override def recImpl(e: Expr)(using Env): (Expr, Env) = e match {
case Return(App(FunVal(fdcall), _, args)) if fdcall == fd =>
val tmpValDefs = valdefs.map(vd => ValDef(freshId(vd.id), vd.typ, isVar = false))
val tmpDecls = tmpValDefs.zip(args).map { case (vd, arg) => Decl(vd, Some(arg)) }
val valdefAssign = valdefs.zip(tmpValDefs).map { case (vd, tmp) => Assign(Binding(vd), Binding(tmp)) }
Block(tmpDecls ++ valdefAssign :+ Goto(labelName)) -> ()
case _ =>
super.recImpl(e)
}
}
replacer(body)
}

/* Replace the bindings in the function body with the mapped variables */
private def replaceBindings(mapping: Map[ValDef, ValDef], funBody: Expr): Expr = {
val replacer = new Transformer(from, from) with NoEnv {
override protected def rec(vd: ValDef)(using Env): to.ValDef =
mapping.getOrElse(vd, vd)
}
replacer(funBody)
}

private def replaceWithNewFuns(prog: Prog, newFdsMap: Map[FunDef, FunDef]): Prog = {
val replacer = new Transformer(from, from) with NoEnv {
override protected def recImpl(fd: FunDef)(using Env): FunDef =
super.recImpl(newFdsMap.getOrElse(fd, fd))
}
replacer(prog)
}

override protected def rec(prog: from.Prog)(using Unit): to.Prog = {
super.rec {
val newFdsMap = prog.functions.map { fd =>
val fdWithTailRecUnitInReturn = putTailRecursiveUnitCallInReturn(fd)
if isTailRecursive(fdWithTailRecUnitInReturn) then
val fdRewrittenToLoop = rewriteToAWhileLoop(fdWithTailRecUnitInReturn)
// val irPrinter = IRPrinter(SIR)
// print(irPrinter.apply(newFd)(using irPrinter.Context(0)))
fd -> fdRewrittenToLoop
else
fd -> fdWithTailRecUnitInReturn
}.toMap
val newProg = Prog(prog.decls, newFdsMap.values.toSeq, prog.classes)
replaceWithNewFuns(newProg, newFdsMap)
}
}

private def freshId(id: String): to.Id = id + "_" + freshCounter.next(id)

private val freshCounter = new utils.UniqueCounter[String]()
}
Loading