Skip to content

Commit

Permalink
Tail recursion elimination for GenC (#1275) (#1626)
Browse files Browse the repository at this point in the history
This is to address #1275 . We added a new phase to perform tail
recursion elimination.

We added test cases in the
[GenCSuite](https://github.com/epfl-lara/stainless/pull/1626/files#diff-2091c70888d42120d35c352d0b060558b0c5ad02baa02de0ecd77b0fd0ded464).

As discussed during the presentation, we may want to take a closer look
at ghost elimination and see whether it is doing the job correctly.

---------

Co-authored-by: Kacper Korban <[email protected]>
  • Loading branch information
zhekai-jiang and KacperFKorban authored Jan 10, 2025
1 parent c92fee2 commit 233cdcd
Show file tree
Hide file tree
Showing 39 changed files with 673 additions and 6 deletions.
4 changes: 4 additions & 0 deletions core/src/main/scala/stainless/genc/CAST.scala
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ object CAST { // C Abstract Syntax Tree

case class Block(exprs: Seq[Expr]) extends Expr // Can be empty

case class Labeled(label: String, block: Expr) extends Expr

case class Lit(lit: Literal) extends Expr

case class EnumLiteral(id: Id) extends Expr
Expand Down Expand Up @@ -212,6 +214,8 @@ object CAST { // C Abstract Syntax Tree
require(cond.isValue, s"Condition ($cond) of while loop must be a value")
}

case class Goto(name: String) extends Expr

case object Break extends Expr

case class Return(value: Expr) extends Expr {
Expand Down
9 changes: 9 additions & 0 deletions core/src/main/scala/stainless/genc/CPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,12 @@ class CPrinter(
|"""
}

case Labeled(name, block) =>
// In C, a label cannot be followed by a variable declaration
// So we add a semicolon to add an empty statement to work around this
c"""|$name: ;
| $block"""

case Lit(lit) => c"$lit"

case EnumLiteral(lit) => c"$lit"
Expand Down Expand Up @@ -319,6 +325,9 @@ class CPrinter(
c"""|while ($cond) {
| $body
|}"""

case Goto(label) =>
c"goto $label"

case Break => c"break"

Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/stainless/genc/GenerateC.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ object GenerateC {
NamedLeonPhase("Lifting", new LiftingPhase) `andThen`
NamedLeonPhase("Referencing", new ReferencingPhase) `andThen`
NamedLeonPhase("StructInlining", new StructInliningPhase) `andThen`
NamedLeonPhase("TailRecElim", new TailRecElimPhase) `andThen`
NamedLeonPhase("IR2C", new IR2CPhase)
}

Expand Down
6 changes: 5 additions & 1 deletion core/src/main/scala/stainless/genc/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ private[genc] sealed trait IR { ir =>
case Binding(vd) => vd.getType
case c: Callable => c.typ
case Block(exprs) => exprs.last.getType
case Labeled(_, block) => block.getType
case MemSet(_, _, _) => NoType
case SizeOf(_) => PrimitiveType(UInt32Type)
case Decl(_, _) => NoType
Expand Down Expand Up @@ -221,6 +222,7 @@ private[genc] sealed trait IR { ir =>
case If(_, _) => NoType
case IfElse(_, thenn, _) => thenn.getType // same as elze
case While(_, _) => NoType
case Goto(_) => NoType
case IsA(_, _) => PrimitiveType(BoolType)
case AsA(_, ct) => ct
case IntegralCast(_, newIntegralType) => PrimitiveType(newIntegralType)
Expand Down Expand Up @@ -260,6 +262,7 @@ private[genc] sealed trait IR { ir =>
case class Block(exprs: Seq[Expr]) extends Expr {
require(exprs.nonEmpty, "GenC IR blocks must be non-empty")
}
case class Labeled(name: String, expr: Expr) extends Expr

case class MemSet(pointer: Expr, value: Expr, size: Expr) extends Expr
case class SizeOf(tpe: Type) extends Expr
Expand Down Expand Up @@ -296,6 +299,7 @@ private[genc] sealed trait IR { ir =>
case class If(cond: Expr, thenn: Expr) extends Expr
case class IfElse(cond: Expr, thenn: Expr, elze: Expr) extends Expr
case class While(cond: Expr, body: Expr) extends Expr
case class Goto(label: String) extends Expr

// Type probindg + casting
case class IsA(expr: Expr, ct: ClassType) extends Expr
Expand Down Expand Up @@ -323,7 +327,6 @@ private[genc] sealed trait IR { ir =>

case object Break extends Expr


/****************************************************************************************************
* Expression Helpers *
****************************************************************************************************/
Expand Down Expand Up @@ -484,6 +487,7 @@ private[genc] sealed trait IR { ir =>
}

object IRs {
object TIR extends IR
object SIR extends IR
object CIR extends IR
object RIR extends IR
Expand Down
5 changes: 5 additions & 0 deletions core/src/main/scala/stainless/genc/ir/IRPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ final class IRPrinter[S <: IR](val ir: S) {
case MemSet(pointer, value, size) => s"memset(${rec(pointer)}, ${rec(value)}, ${rec(size)})"
case SizeOf(tpe) => s"sizeof(${rec(tpe)})"
case Block(exprs) => "{{ " + (exprs map rec mkString ptx.newLine) + " }}"
case Labeled(label, expr) =>
s"""|{{ $label:
| ${rec(expr)} }}""".stripMargin
case Decl(vd, None) => (if (vd.isVar) "var" else "val") + " " + rec(vd)
case Decl(vd, Some(value)) => (if (vd.isVar) "var" else "val") + " " + rec(vd) + " = " + rec(value)
case App(callable, extra, args) =>
Expand All @@ -112,6 +115,8 @@ final class IRPrinter[S <: IR](val ir: S) {
"else {" + ptx.newLine + " " + rec(elze)(using ptx + 1) + ptx.newLine + "}"
case While(cond, body) =>
"while (" + rec(cond) + ") {" + ptx.newLine + " " + rec(body)(using ptx + 1) + ptx.newLine + "}"
case Goto(label) =>
s"goto $label"
case IsA(expr, ct) => "¿" + ct.clazz.id + "?" + rec(expr)
case AsA(expr, ct) => "(" + ct.clazz.id + ")" + rec(expr)
case IntegralCast(expr, newType) => "(" + newType + ")" + rec(expr)
Expand Down
89 changes: 89 additions & 0 deletions core/src/main/scala/stainless/genc/ir/TailRecSimpTransformer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package stainless
package genc
package ir

import PrimitiveTypes.{ PrimitiveType => PT, _ } // For desambiguation
import Literals._
import Operators._
import IRs._
import scala.collection.mutable

final class TailRecSimpTransformer extends Transformer(SIR, SIR) with NoEnv {
import from._

private given givenDebugSection: DebugSectionGenC.type = DebugSectionGenC

/**
* Replace a variable assignment that is immediately
* returned
*
* val i = f(...);
* return i;
*
* ==>
*
* return f(...);
*
*/
private def replaceImmediateReturn(fd: Expr): Expr = {
val transformer = new ir.Transformer(from, to) with NoEnv {
override protected def recImpl(expr: Expr)(using Env): (Expr, Env) = expr match {
case Block(stmts) =>
Block(stmts.zipWithIndex.flatMap {
case (expr @ Decl(id, Some(rhs)), idx) =>
stmts.lift(idx + 1) match {
case Some(Return(Binding(retId))) if retId == id =>
List(Return(rhs))
case _ => List(recImpl(expr)._1)
}
case (expr @ Return(Binding(retId)), idx) =>
stmts.lift(idx - 1) match {
case Some(Decl(id, rhs)) if id == retId =>
Nil
case _ => List(recImpl(expr)._1)
}
case (expr, idx) => List(recImpl(expr)._1)
}) -> ()
case expr => super.recImpl(expr)
}
}
transformer(fd)
}

/**
* Remove all statements after a return statement
*
* return f(...);
* someStmt;
*
* ==>
*
* return f(...);
*
*/
private def removeAfterReturn(fd: Expr): Expr = {
val transformer = new ir.Transformer(from, to) with NoEnv {
override protected def recImpl(expr: Expr)(using Env): (Expr, Env) = expr match {
case Block(stmts) =>
val transformedStmts = stmts.map(recImpl(_)._1)
val firstReturn = transformedStmts.find {
case Return(_) => true
case _ => false
}.toList
val newStmts = transformedStmts.takeWhile {
case Return(_) => false
case _ => true
}
Block(newStmts ++ firstReturn) -> ()
case expr => super.recImpl(expr)
}
}
transformer(fd)
}

override protected def recImpl(fd: Expr)(using Env): (to.Expr, Env) = {
val afterReturn = removeAfterReturn(fd)
val immediateReturn = replaceImmediateReturn(afterReturn)
immediateReturn -> ()
}
}
181 changes: 181 additions & 0 deletions core/src/main/scala/stainless/genc/ir/TailRecTransformer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
package stainless
package genc
package ir

import PrimitiveTypes.{ PrimitiveType => PT, _ } // For desambiguation
import Literals._
import Operators._
import IRs._
import scala.collection.mutable

final class TailRecTransformer(val ctx: inox.Context) extends Transformer(SIR, TIR) with NoEnv {
import from._

private given givenDebugSection: DebugSectionGenC.type = DebugSectionGenC

private given printer.Context = printer.Context(0)

/**
* If the function returns Unit type and the last one statement is a recursive call,
* put the recursive call in a return statement.
*
* Example:
* def countDown(n: Int): Unit =
* if (n == 0) return
* countDown(n - 1)
*
* ==>
*
* def countDown(n: Int): Unit =
* if (n == 0) return
* return countDown(n - 1)
*/
private def putTailRecursiveUnitCallInReturn(fd: FunDef): FunDef = {
def go(expr: Expr): Expr = expr match {
case Block(stmts) if stmts.nonEmpty =>
Block(stmts.init :+ go(stmts.last))
case IfElse(cond, thenn, elze) =>
IfElse(cond, go(thenn), go(elze))
case app @ App(FunVal(calledFd), _, _) if calledFd.id == fd.id =>
Return(app)
case _ => expr
}
fd.body match {
case FunBodyAST(expr) if fd.returnType.isUnitType =>
fd.copy(body = FunBodyAST(go(expr)))
case _ => fd
}
}

private def isTailRecursive(fd: FunDef): Boolean = {
var functionRefs = mutable.ListBuffer.empty[FunDef]
val functionRefVisitor = new ir.Visitor(from) {
override protected def visit(expr: Expr): Unit = expr match {
case FunVal(fd) => functionRefs += fd
case _ =>
}
}
var tailFunctionRefs = mutable.ListBuffer.empty[FunDef]
val tailRecCallVisitor = new ir.Visitor(from) {
override protected def visit(expr: Expr): Unit = expr match {
case Return(App(FunVal(fdcall), _, _)) => tailFunctionRefs += fdcall

case _ =>
}
}
functionRefVisitor(fd)
tailRecCallVisitor(fd)
functionRefs.contains(fd) && functionRefs.filter(_ == fd).size == tailFunctionRefs.filter(_ == fd).size
}

/* Rewrite a tail recursive function to a while loop
* Example:
* def fib(n: Int, i: Int = 0, j: Int = 1): Int =
* if (n == 0)
* return i
* else
* return fib(n-1, j, i+j)
*
* ==>
*
* def fib(n: Int, i: Int = 0, j: Int = 1): Int = {
*
* var n$ = n
* var i$ = i
* var j$ = j
* while (true) {
* someLabel:
* if (n$ == 0) {
* return i$
* } else {
* val n$1 = n$ - 1
* val i$1 = j$
* val j$1 = i$ + j$
* n$ = n$1
* i$ = i$1
* j$ = j$1
* goto someLabel
* }
* }
* }
* Steps:
* - Create a new variable for each parameter of the function
* - Replace existing parameter references with the new variables
* - Create a while loop with a condition true
* - Replace the recursive return with a variable assignments (updating the state) and a continue statement
*/
private def rewriteToAWhileLoop(fd: FunDef): FunDef = fd.body match {
case FunBodyAST(body) =>
val newParams = fd.params.map(p => ValDef(freshId(p.id), p.typ, isVar = true))
val newParamMap = fd.params.zip(newParams).toMap
val labelName = freshId("label")
val bodyWithNewParams = replaceBindings(newParamMap, body)
val bodyWithUnitReturn = bodyWithNewParams match {
case Block(stmts) =>
if fd.returnType.isUnitType then
Block(stmts :+ Return(Lit(UnitLit)))
else
bodyWithNewParams
case _ => bodyWithNewParams
}
val declarations = newParamMap.toList.map { case (old, nw) => Decl(nw, Some(Binding(old))) }
val newBody = replaceRecursiveCalls(fd, bodyWithUnitReturn, newParams.toList, labelName)
val newBodyWithALabel = Labeled(labelName, newBody)
val newBodyWithAWhileLoop = While(True, newBodyWithALabel)
FunDef(fd.id, fd.returnType, fd.ctx, fd.params, FunBodyAST(Block(declarations :+ newBodyWithAWhileLoop)), fd.isExported, fd.isPure)
case _ => fd
}

private def replaceRecursiveCalls(fd: FunDef, body: Expr, valdefs: List[ValDef], labelName: String): Expr = {
val replacer = new Transformer(from, from) with NoEnv {
override def recImpl(e: Expr)(using Env): (Expr, Env) = e match {
case Return(App(FunVal(fdcall), _, args)) if fdcall == fd =>
val tmpValDefs = valdefs.map(vd => ValDef(freshId(vd.id), vd.typ, isVar = false))
val tmpDecls = tmpValDefs.zip(args).map { case (vd, arg) => Decl(vd, Some(arg)) }
val valdefAssign = valdefs.zip(tmpValDefs).map { case (vd, tmp) => Assign(Binding(vd), Binding(tmp)) }
Block(tmpDecls ++ valdefAssign :+ Goto(labelName)) -> ()
case _ =>
super.recImpl(e)
}
}
replacer(body)
}

/* Replace the bindings in the function body with the mapped variables */
private def replaceBindings(mapping: Map[ValDef, ValDef], funBody: Expr): Expr = {
val replacer = new Transformer(from, from) with NoEnv {
override protected def rec(vd: ValDef)(using Env): to.ValDef =
mapping.getOrElse(vd, vd)
}
replacer(funBody)
}

private def replaceWithNewFuns(prog: Prog, newFdsMap: Map[FunDef, FunDef]): Prog = {
val replacer = new Transformer(from, from) with NoEnv {
override protected def recImpl(fd: FunDef)(using Env): FunDef =
super.recImpl(newFdsMap.getOrElse(fd, fd))
}
replacer(prog)
}

override protected def rec(prog: from.Prog)(using Unit): to.Prog = {
super.rec {
val newFdsMap = prog.functions.map { fd =>
val fdWithTailRecUnitInReturn = putTailRecursiveUnitCallInReturn(fd)
if isTailRecursive(fdWithTailRecUnitInReturn) then
val fdRewrittenToLoop = rewriteToAWhileLoop(fdWithTailRecUnitInReturn)
// val irPrinter = IRPrinter(SIR)
// print(irPrinter.apply(newFd)(using irPrinter.Context(0)))
fd -> fdRewrittenToLoop
else
fd -> fdWithTailRecUnitInReturn
}.toMap
val newProg = Prog(prog.decls, newFdsMap.values.toSeq, prog.classes)
replaceWithNewFuns(newProg, newFdsMap)
}
}

private def freshId(id: String): to.Id = id + "_" + freshCounter.next(id)

private val freshCounter = new utils.UniqueCounter[String]()
}
Loading

0 comments on commit 233cdcd

Please sign in to comment.