From 233cdcd07c6cd4ae526aefab2a40fe2ed9bf6efe Mon Sep 17 00:00:00 2001 From: Zhekai Jiang Date: Sat, 11 Jan 2025 01:53:02 +0800 Subject: [PATCH] Tail recursion elimination for GenC (#1275) (#1626) This is to address #1275 . We added a new phase to perform tail recursion elimination. We added test cases in the [GenCSuite](https://github.com/epfl-lara/stainless/pull/1626/files#diff-2091c70888d42120d35c352d0b060558b0c5ad02baa02de0ecd77b0fd0ded464). As discussed during the presentation, we may want to take a closer look at ghost elimination and see whether it is doing the job correctly. --------- Co-authored-by: Kacper Korban --- core/src/main/scala/stainless/genc/CAST.scala | 4 + .../main/scala/stainless/genc/CPrinter.scala | 9 + .../main/scala/stainless/genc/GenerateC.scala | 1 + .../src/main/scala/stainless/genc/ir/IR.scala | 6 +- .../scala/stainless/genc/ir/IRPrinter.scala | 5 + .../genc/ir/TailRecSimpTransformer.scala | 89 +++++++++ .../genc/ir/TailRecTransformer.scala | 181 ++++++++++++++++++ .../scala/stainless/genc/ir/Transformer.scala | 5 +- .../scala/stainless/genc/ir/Visitor.scala | 3 + .../stainless/genc/phases/IR2CPhase.scala | 12 +- .../genc/phases/TailRecElimPhase.scala | 19 ++ .../genc/valid/TailRecAliasedArgs.check | 1 + .../genc/valid/TailRecAliasedArgs.scala | 17 ++ .../genc/valid/TailRecComplexArgs.check | 1 + .../genc/valid/TailRecComplexArgs.scala | 17 ++ .../genc/valid/TailRecCountDown.check | 1 + .../genc/valid/TailRecCountDown.scala | 17 ++ .../genc/valid/TailRecEarlyReturn.check | 1 + .../genc/valid/TailRecEarlyReturn.scala | 18 ++ .../benchmarks/genc/valid/TailRecFib.check | 1 + .../benchmarks/genc/valid/TailRecFib.scala | 19 ++ .../genc/valid/TailRecFibAliased.check | 1 + .../genc/valid/TailRecFibAliased.scala | 21 ++ .../genc/valid/TailRecMultipleCalls.check | 1 + .../genc/valid/TailRecMultipleCalls.scala | 19 ++ .../benchmarks/genc/valid/TailRecNested.check | 1 + .../benchmarks/genc/valid/TailRecNested.scala | 20 ++ .../genc/valid/TailRecPatternMatching.check | 1 + .../genc/valid/TailRecPatternMatching.scala | 26 +++ .../genc/valid/TailRecStackOverflow.check | 1 + .../genc/valid/TailRecStackOverflow.scala | 18 ++ .../benchmarks/genc/valid/TailRecUnit.check | 0 .../benchmarks/genc/valid/TailRecUnit.scala | 17 ++ .../benchmarks/genc/valid/TailRecUnitIf.check | 0 .../benchmarks/genc/valid/TailRecUnitIf.scala | 17 ++ .../genc/valid/TailRecUnitNoExplicitEnd.check | 0 .../genc/valid/TailRecUnitNoExplicitEnd.scala | 16 ++ .../src/it/scala/stainless/GenCSuite.scala | 21 ++ .../it/scala/stainless/TailRecGenCSuite.scala | 72 +++++++ 39 files changed, 673 insertions(+), 6 deletions(-) create mode 100644 core/src/main/scala/stainless/genc/ir/TailRecSimpTransformer.scala create mode 100644 core/src/main/scala/stainless/genc/ir/TailRecTransformer.scala create mode 100644 core/src/main/scala/stainless/genc/phases/TailRecElimPhase.scala create mode 100644 frontends/benchmarks/genc/valid/TailRecAliasedArgs.check create mode 100644 frontends/benchmarks/genc/valid/TailRecAliasedArgs.scala create mode 100644 frontends/benchmarks/genc/valid/TailRecComplexArgs.check create mode 100644 frontends/benchmarks/genc/valid/TailRecComplexArgs.scala create mode 100644 frontends/benchmarks/genc/valid/TailRecCountDown.check create mode 100644 frontends/benchmarks/genc/valid/TailRecCountDown.scala create mode 100644 frontends/benchmarks/genc/valid/TailRecEarlyReturn.check create mode 100644 frontends/benchmarks/genc/valid/TailRecEarlyReturn.scala create mode 100644 frontends/benchmarks/genc/valid/TailRecFib.check create mode 100644 frontends/benchmarks/genc/valid/TailRecFib.scala create mode 100644 frontends/benchmarks/genc/valid/TailRecFibAliased.check create mode 100644 frontends/benchmarks/genc/valid/TailRecFibAliased.scala create mode 100644 frontends/benchmarks/genc/valid/TailRecMultipleCalls.check create mode 100644 frontends/benchmarks/genc/valid/TailRecMultipleCalls.scala create mode 100644 frontends/benchmarks/genc/valid/TailRecNested.check create mode 100644 frontends/benchmarks/genc/valid/TailRecNested.scala create mode 100644 frontends/benchmarks/genc/valid/TailRecPatternMatching.check create mode 100644 frontends/benchmarks/genc/valid/TailRecPatternMatching.scala create mode 100644 frontends/benchmarks/genc/valid/TailRecStackOverflow.check create mode 100644 frontends/benchmarks/genc/valid/TailRecStackOverflow.scala create mode 100644 frontends/benchmarks/genc/valid/TailRecUnit.check create mode 100644 frontends/benchmarks/genc/valid/TailRecUnit.scala create mode 100644 frontends/benchmarks/genc/valid/TailRecUnitIf.check create mode 100644 frontends/benchmarks/genc/valid/TailRecUnitIf.scala create mode 100644 frontends/benchmarks/genc/valid/TailRecUnitNoExplicitEnd.check create mode 100644 frontends/benchmarks/genc/valid/TailRecUnitNoExplicitEnd.scala create mode 100644 frontends/common/src/it/scala/stainless/TailRecGenCSuite.scala diff --git a/core/src/main/scala/stainless/genc/CAST.scala b/core/src/main/scala/stainless/genc/CAST.scala index 71ee31c5f2..129624efe6 100644 --- a/core/src/main/scala/stainless/genc/CAST.scala +++ b/core/src/main/scala/stainless/genc/CAST.scala @@ -114,6 +114,8 @@ object CAST { // C Abstract Syntax Tree case class Block(exprs: Seq[Expr]) extends Expr // Can be empty + case class Labeled(label: String, block: Expr) extends Expr + case class Lit(lit: Literal) extends Expr case class EnumLiteral(id: Id) extends Expr @@ -212,6 +214,8 @@ object CAST { // C Abstract Syntax Tree require(cond.isValue, s"Condition ($cond) of while loop must be a value") } + case class Goto(name: String) extends Expr + case object Break extends Expr case class Return(value: Expr) extends Expr { diff --git a/core/src/main/scala/stainless/genc/CPrinter.scala b/core/src/main/scala/stainless/genc/CPrinter.scala index 076d8ce918..7a726bbb5c 100644 --- a/core/src/main/scala/stainless/genc/CPrinter.scala +++ b/core/src/main/scala/stainless/genc/CPrinter.scala @@ -243,6 +243,12 @@ class CPrinter( |""" } + case Labeled(name, block) => + // In C, a label cannot be followed by a variable declaration + // So we add a semicolon to add an empty statement to work around this + c"""|$name: ; + | $block""" + case Lit(lit) => c"$lit" case EnumLiteral(lit) => c"$lit" @@ -319,6 +325,9 @@ class CPrinter( c"""|while ($cond) { | $body |}""" + + case Goto(label) => + c"goto $label" case Break => c"break" diff --git a/core/src/main/scala/stainless/genc/GenerateC.scala b/core/src/main/scala/stainless/genc/GenerateC.scala index 7f5e5c687b..6451b876e2 100644 --- a/core/src/main/scala/stainless/genc/GenerateC.scala +++ b/core/src/main/scala/stainless/genc/GenerateC.scala @@ -21,6 +21,7 @@ object GenerateC { NamedLeonPhase("Lifting", new LiftingPhase) `andThen` NamedLeonPhase("Referencing", new ReferencingPhase) `andThen` NamedLeonPhase("StructInlining", new StructInliningPhase) `andThen` + NamedLeonPhase("TailRecElim", new TailRecElimPhase) `andThen` NamedLeonPhase("IR2C", new IR2CPhase) } diff --git a/core/src/main/scala/stainless/genc/ir/IR.scala b/core/src/main/scala/stainless/genc/ir/IR.scala index 68e8912a54..3958a691f0 100644 --- a/core/src/main/scala/stainless/genc/ir/IR.scala +++ b/core/src/main/scala/stainless/genc/ir/IR.scala @@ -192,6 +192,7 @@ private[genc] sealed trait IR { ir => case Binding(vd) => vd.getType case c: Callable => c.typ case Block(exprs) => exprs.last.getType + case Labeled(_, block) => block.getType case MemSet(_, _, _) => NoType case SizeOf(_) => PrimitiveType(UInt32Type) case Decl(_, _) => NoType @@ -221,6 +222,7 @@ private[genc] sealed trait IR { ir => case If(_, _) => NoType case IfElse(_, thenn, _) => thenn.getType // same as elze case While(_, _) => NoType + case Goto(_) => NoType case IsA(_, _) => PrimitiveType(BoolType) case AsA(_, ct) => ct case IntegralCast(_, newIntegralType) => PrimitiveType(newIntegralType) @@ -260,6 +262,7 @@ private[genc] sealed trait IR { ir => case class Block(exprs: Seq[Expr]) extends Expr { require(exprs.nonEmpty, "GenC IR blocks must be non-empty") } + case class Labeled(name: String, expr: Expr) extends Expr case class MemSet(pointer: Expr, value: Expr, size: Expr) extends Expr case class SizeOf(tpe: Type) extends Expr @@ -296,6 +299,7 @@ private[genc] sealed trait IR { ir => case class If(cond: Expr, thenn: Expr) extends Expr case class IfElse(cond: Expr, thenn: Expr, elze: Expr) extends Expr case class While(cond: Expr, body: Expr) extends Expr + case class Goto(label: String) extends Expr // Type probindg + casting case class IsA(expr: Expr, ct: ClassType) extends Expr @@ -323,7 +327,6 @@ private[genc] sealed trait IR { ir => case object Break extends Expr - /**************************************************************************************************** * Expression Helpers * ****************************************************************************************************/ @@ -484,6 +487,7 @@ private[genc] sealed trait IR { ir => } object IRs { + object TIR extends IR object SIR extends IR object CIR extends IR object RIR extends IR diff --git a/core/src/main/scala/stainless/genc/ir/IRPrinter.scala b/core/src/main/scala/stainless/genc/ir/IRPrinter.scala index 492daec114..7cd86b0d35 100644 --- a/core/src/main/scala/stainless/genc/ir/IRPrinter.scala +++ b/core/src/main/scala/stainless/genc/ir/IRPrinter.scala @@ -93,6 +93,9 @@ final class IRPrinter[S <: IR](val ir: S) { case MemSet(pointer, value, size) => s"memset(${rec(pointer)}, ${rec(value)}, ${rec(size)})" case SizeOf(tpe) => s"sizeof(${rec(tpe)})" case Block(exprs) => "{{ " + (exprs map rec mkString ptx.newLine) + " }}" + case Labeled(label, expr) => + s"""|{{ $label: + | ${rec(expr)} }}""".stripMargin case Decl(vd, None) => (if (vd.isVar) "var" else "val") + " " + rec(vd) case Decl(vd, Some(value)) => (if (vd.isVar) "var" else "val") + " " + rec(vd) + " = " + rec(value) case App(callable, extra, args) => @@ -112,6 +115,8 @@ final class IRPrinter[S <: IR](val ir: S) { "else {" + ptx.newLine + " " + rec(elze)(using ptx + 1) + ptx.newLine + "}" case While(cond, body) => "while (" + rec(cond) + ") {" + ptx.newLine + " " + rec(body)(using ptx + 1) + ptx.newLine + "}" + case Goto(label) => + s"goto $label" case IsA(expr, ct) => "¿" + ct.clazz.id + "?" + rec(expr) case AsA(expr, ct) => "(" + ct.clazz.id + ")" + rec(expr) case IntegralCast(expr, newType) => "(" + newType + ")" + rec(expr) diff --git a/core/src/main/scala/stainless/genc/ir/TailRecSimpTransformer.scala b/core/src/main/scala/stainless/genc/ir/TailRecSimpTransformer.scala new file mode 100644 index 0000000000..80ce500c80 --- /dev/null +++ b/core/src/main/scala/stainless/genc/ir/TailRecSimpTransformer.scala @@ -0,0 +1,89 @@ +package stainless +package genc +package ir + +import PrimitiveTypes.{ PrimitiveType => PT, _ } // For desambiguation +import Literals._ +import Operators._ +import IRs._ +import scala.collection.mutable + +final class TailRecSimpTransformer extends Transformer(SIR, SIR) with NoEnv { + import from._ + + private given givenDebugSection: DebugSectionGenC.type = DebugSectionGenC + + /** + * Replace a variable assignment that is immediately + * returned + * + * val i = f(...); + * return i; + * + * ==> + * + * return f(...); + * + */ + private def replaceImmediateReturn(fd: Expr): Expr = { + val transformer = new ir.Transformer(from, to) with NoEnv { + override protected def recImpl(expr: Expr)(using Env): (Expr, Env) = expr match { + case Block(stmts) => + Block(stmts.zipWithIndex.flatMap { + case (expr @ Decl(id, Some(rhs)), idx) => + stmts.lift(idx + 1) match { + case Some(Return(Binding(retId))) if retId == id => + List(Return(rhs)) + case _ => List(recImpl(expr)._1) + } + case (expr @ Return(Binding(retId)), idx) => + stmts.lift(idx - 1) match { + case Some(Decl(id, rhs)) if id == retId => + Nil + case _ => List(recImpl(expr)._1) + } + case (expr, idx) => List(recImpl(expr)._1) + }) -> () + case expr => super.recImpl(expr) + } + } + transformer(fd) + } + + /** + * Remove all statements after a return statement + * + * return f(...); + * someStmt; + * + * ==> + * + * return f(...); + * + */ + private def removeAfterReturn(fd: Expr): Expr = { + val transformer = new ir.Transformer(from, to) with NoEnv { + override protected def recImpl(expr: Expr)(using Env): (Expr, Env) = expr match { + case Block(stmts) => + val transformedStmts = stmts.map(recImpl(_)._1) + val firstReturn = transformedStmts.find { + case Return(_) => true + case _ => false + }.toList + val newStmts = transformedStmts.takeWhile { + case Return(_) => false + case _ => true + } + Block(newStmts ++ firstReturn) -> () + case expr => super.recImpl(expr) + } + } + transformer(fd) + } + + override protected def recImpl(fd: Expr)(using Env): (to.Expr, Env) = { + val afterReturn = removeAfterReturn(fd) + val immediateReturn = replaceImmediateReturn(afterReturn) + immediateReturn -> () + } +} diff --git a/core/src/main/scala/stainless/genc/ir/TailRecTransformer.scala b/core/src/main/scala/stainless/genc/ir/TailRecTransformer.scala new file mode 100644 index 0000000000..e0b54acdab --- /dev/null +++ b/core/src/main/scala/stainless/genc/ir/TailRecTransformer.scala @@ -0,0 +1,181 @@ +package stainless +package genc +package ir + +import PrimitiveTypes.{ PrimitiveType => PT, _ } // For desambiguation +import Literals._ +import Operators._ +import IRs._ +import scala.collection.mutable + +final class TailRecTransformer(val ctx: inox.Context) extends Transformer(SIR, TIR) with NoEnv { + import from._ + + private given givenDebugSection: DebugSectionGenC.type = DebugSectionGenC + + private given printer.Context = printer.Context(0) + + /** + * If the function returns Unit type and the last one statement is a recursive call, + * put the recursive call in a return statement. + * + * Example: + * def countDown(n: Int): Unit = + * if (n == 0) return + * countDown(n - 1) + * + * ==> + * + * def countDown(n: Int): Unit = + * if (n == 0) return + * return countDown(n - 1) + */ + private def putTailRecursiveUnitCallInReturn(fd: FunDef): FunDef = { + def go(expr: Expr): Expr = expr match { + case Block(stmts) if stmts.nonEmpty => + Block(stmts.init :+ go(stmts.last)) + case IfElse(cond, thenn, elze) => + IfElse(cond, go(thenn), go(elze)) + case app @ App(FunVal(calledFd), _, _) if calledFd.id == fd.id => + Return(app) + case _ => expr + } + fd.body match { + case FunBodyAST(expr) if fd.returnType.isUnitType => + fd.copy(body = FunBodyAST(go(expr))) + case _ => fd + } + } + + private def isTailRecursive(fd: FunDef): Boolean = { + var functionRefs = mutable.ListBuffer.empty[FunDef] + val functionRefVisitor = new ir.Visitor(from) { + override protected def visit(expr: Expr): Unit = expr match { + case FunVal(fd) => functionRefs += fd + case _ => + } + } + var tailFunctionRefs = mutable.ListBuffer.empty[FunDef] + val tailRecCallVisitor = new ir.Visitor(from) { + override protected def visit(expr: Expr): Unit = expr match { + case Return(App(FunVal(fdcall), _, _)) => tailFunctionRefs += fdcall + + case _ => + } + } + functionRefVisitor(fd) + tailRecCallVisitor(fd) + functionRefs.contains(fd) && functionRefs.filter(_ == fd).size == tailFunctionRefs.filter(_ == fd).size + } + + /* Rewrite a tail recursive function to a while loop + * Example: + * def fib(n: Int, i: Int = 0, j: Int = 1): Int = + * if (n == 0) + * return i + * else + * return fib(n-1, j, i+j) + * + * ==> + * + * def fib(n: Int, i: Int = 0, j: Int = 1): Int = { + * + * var n$ = n + * var i$ = i + * var j$ = j + * while (true) { + * someLabel: + * if (n$ == 0) { + * return i$ + * } else { + * val n$1 = n$ - 1 + * val i$1 = j$ + * val j$1 = i$ + j$ + * n$ = n$1 + * i$ = i$1 + * j$ = j$1 + * goto someLabel + * } + * } + * } + * Steps: + * - Create a new variable for each parameter of the function + * - Replace existing parameter references with the new variables + * - Create a while loop with a condition true + * - Replace the recursive return with a variable assignments (updating the state) and a continue statement + */ + private def rewriteToAWhileLoop(fd: FunDef): FunDef = fd.body match { + case FunBodyAST(body) => + val newParams = fd.params.map(p => ValDef(freshId(p.id), p.typ, isVar = true)) + val newParamMap = fd.params.zip(newParams).toMap + val labelName = freshId("label") + val bodyWithNewParams = replaceBindings(newParamMap, body) + val bodyWithUnitReturn = bodyWithNewParams match { + case Block(stmts) => + if fd.returnType.isUnitType then + Block(stmts :+ Return(Lit(UnitLit))) + else + bodyWithNewParams + case _ => bodyWithNewParams + } + val declarations = newParamMap.toList.map { case (old, nw) => Decl(nw, Some(Binding(old))) } + val newBody = replaceRecursiveCalls(fd, bodyWithUnitReturn, newParams.toList, labelName) + val newBodyWithALabel = Labeled(labelName, newBody) + val newBodyWithAWhileLoop = While(True, newBodyWithALabel) + FunDef(fd.id, fd.returnType, fd.ctx, fd.params, FunBodyAST(Block(declarations :+ newBodyWithAWhileLoop)), fd.isExported, fd.isPure) + case _ => fd + } + + private def replaceRecursiveCalls(fd: FunDef, body: Expr, valdefs: List[ValDef], labelName: String): Expr = { + val replacer = new Transformer(from, from) with NoEnv { + override def recImpl(e: Expr)(using Env): (Expr, Env) = e match { + case Return(App(FunVal(fdcall), _, args)) if fdcall == fd => + val tmpValDefs = valdefs.map(vd => ValDef(freshId(vd.id), vd.typ, isVar = false)) + val tmpDecls = tmpValDefs.zip(args).map { case (vd, arg) => Decl(vd, Some(arg)) } + val valdefAssign = valdefs.zip(tmpValDefs).map { case (vd, tmp) => Assign(Binding(vd), Binding(tmp)) } + Block(tmpDecls ++ valdefAssign :+ Goto(labelName)) -> () + case _ => + super.recImpl(e) + } + } + replacer(body) + } + + /* Replace the bindings in the function body with the mapped variables */ + private def replaceBindings(mapping: Map[ValDef, ValDef], funBody: Expr): Expr = { + val replacer = new Transformer(from, from) with NoEnv { + override protected def rec(vd: ValDef)(using Env): to.ValDef = + mapping.getOrElse(vd, vd) + } + replacer(funBody) + } + + private def replaceWithNewFuns(prog: Prog, newFdsMap: Map[FunDef, FunDef]): Prog = { + val replacer = new Transformer(from, from) with NoEnv { + override protected def recImpl(fd: FunDef)(using Env): FunDef = + super.recImpl(newFdsMap.getOrElse(fd, fd)) + } + replacer(prog) + } + + override protected def rec(prog: from.Prog)(using Unit): to.Prog = { + super.rec { + val newFdsMap = prog.functions.map { fd => + val fdWithTailRecUnitInReturn = putTailRecursiveUnitCallInReturn(fd) + if isTailRecursive(fdWithTailRecUnitInReturn) then + val fdRewrittenToLoop = rewriteToAWhileLoop(fdWithTailRecUnitInReturn) + // val irPrinter = IRPrinter(SIR) + // print(irPrinter.apply(newFd)(using irPrinter.Context(0))) + fd -> fdRewrittenToLoop + else + fd -> fdWithTailRecUnitInReturn + }.toMap + val newProg = Prog(prog.decls, newFdsMap.values.toSeq, prog.classes) + replaceWithNewFuns(newProg, newFdsMap) + } + } + + private def freshId(id: String): to.Id = id + "_" + freshCounter.next(id) + + private val freshCounter = new utils.UniqueCounter[String]() +} diff --git a/core/src/main/scala/stainless/genc/ir/Transformer.scala b/core/src/main/scala/stainless/genc/ir/Transformer.scala index 19a0cbda51..c17b477510 100644 --- a/core/src/main/scala/stainless/genc/ir/Transformer.scala +++ b/core/src/main/scala/stainless/genc/ir/Transformer.scala @@ -121,7 +121,7 @@ abstract class Transformer[From <: IR, To <: IR](final val from: From, final val to.ArrayAllocVLA(to.ArrayType(rec(base), optLength), rec(length), rec(valueInit)) } - protected final def rec(e: Expr)(using Env): to.Expr = recImpl(e)._1 + protected def rec(e: Expr)(using Env): to.Expr = recImpl(e)._1 protected final def recCallable(fun: Callable)(using Env): to.Callable = rec(fun).asInstanceOf[to.Callable] @@ -141,6 +141,8 @@ abstract class Transformer[From <: IR, To <: IR](final val from: From, final val e } to.buildBlock(exprs) -> newEnv + case Labeled(name, expr) => + to.Labeled(name, rec(expr)) -> env case MemSet(pointer, value, size) => to.MemSet(rec(pointer), rec(value), rec(size)) -> env case SizeOf(tpe) => to.SizeOf(rec(tpe)) -> env @@ -159,6 +161,7 @@ abstract class Transformer[From <: IR, To <: IR](final val from: From, final val case If(cond, thenn) => to.If(rec(cond), rec(thenn)) -> env case IfElse(cond, thenn, elze) => to.IfElse(rec(cond), rec(thenn), rec(elze)) -> env case While(cond, body) => to.While(rec(cond), rec(body)) -> env + case Goto(label) => to.Goto(label) -> env case IsA(expr, ct) => to.IsA(rec(expr), to.ClassType(rec(ct.clazz))) -> env case AsA(expr, ct) => to.AsA(rec(expr), to.ClassType(rec(ct.clazz))) -> env case IntegralCast(expr, t) => to.IntegralCast(rec(expr), t) -> env diff --git a/core/src/main/scala/stainless/genc/ir/Visitor.scala b/core/src/main/scala/stainless/genc/ir/Visitor.scala index 9c3d648508..75bdd14b75 100644 --- a/core/src/main/scala/stainless/genc/ir/Visitor.scala +++ b/core/src/main/scala/stainless/genc/ir/Visitor.scala @@ -21,6 +21,7 @@ abstract class Visitor[S <: IR](val ir: S) { // Entry point for the visit final def apply(prog: Prog): Unit = rec(prog) final def apply(e: Expr): Unit = rec(e) + final def apply(fd: FunDef): Unit = rec(fd) protected def visit(prog: Prog): Unit = () protected def visit(fd: FunDef): Unit = () @@ -95,6 +96,7 @@ abstract class Visitor[S <: IR](val ir: S) { case FunRef(e) => rec(e) case Assert(e) => rec(e) case Block(exprs) => exprs foreach rec + case Labeled(name, block) => rec(block) case Decl(vd, None) => rec(vd) case Decl(vd, Some(value)) => rec(vd); rec(value) case App(fun, extra, args) => rec(fun); extra foreach rec; args foreach rec @@ -109,6 +111,7 @@ abstract class Visitor[S <: IR](val ir: S) { case If(cond, thenn) => rec(cond); rec(thenn) case IfElse(cond, thenn, elze) => rec(cond); rec(thenn); rec(elze) case While(cond, body) => rec(cond); rec(body) + case Goto(_) => () case IsA(expr, ct) => rec(expr); rec(ct.clazz) case AsA(expr, ct) => rec(expr); rec(ct.clazz) case IntegralCast(expr, _) => rec(expr) diff --git a/core/src/main/scala/stainless/genc/phases/IR2CPhase.scala b/core/src/main/scala/stainless/genc/phases/IR2CPhase.scala index 64ad0c6480..fb06a2f7f5 100644 --- a/core/src/main/scala/stainless/genc/phases/IR2CPhase.scala +++ b/core/src/main/scala/stainless/genc/phases/IR2CPhase.scala @@ -4,22 +4,22 @@ package stainless package genc package phases -import ir.IRs.{ SIR } +import ir.IRs.{ TIR } import ir.PrimitiveTypes._ import ir.Literals._ import ir.Operators._ import genc.{ CAST => C } -import SIR._ +import TIR._ import collection.mutable.{ Map => MutableMap, Set => MutableSet } -class IR2CPhase(using override val context: inox.Context) extends LeonPipeline[SIR.Prog, CAST.Prog](context) { +class IR2CPhase(using override val context: inox.Context) extends LeonPipeline[TIR.Prog, CAST.Prog](context) { val name = "CASTer" val description = "Translate the IR tree into the final C AST" - def run(ir: SIR.Prog): CAST.Prog = new IR2CImpl()(using context)(ir) + def run(ir: TIR.Prog): CAST.Prog = new IR2CImpl()(using context)(ir) } // This implementation is basically a Transformer that produce something which isn't an IR tree. @@ -167,6 +167,9 @@ private class IR2CImpl()(using ctx: inox.Context) { .map(rec(_)) } + case Labeled(name, block) => + C.Labeled(name, rec(block)) + case Decl(vd, None) => C.Decl(rec(vd.id), rec(vd.getType), None) case Decl(vd, Some(ArrayInit(ArrayAllocStatic(arrayType, length, values0)))) if allowFixedArray && vd.typ.isFixedArray => @@ -315,6 +318,7 @@ private class IR2CImpl()(using ctx: inox.Context) { case IfElse(cond, thenn, Lit(UnitLit)) => C.If(rec(cond), C.buildBlock(rec(thenn))) case IfElse(cond, thenn, elze) => C.IfElse(rec(cond), C.buildBlock(rec(thenn)), C.buildBlock(rec(elze))) case While(cond, body) => C.While(rec(cond), C.buildBlock(rec(body))) + case Goto(label) => C.Goto(label) // Find out if we can actually handle IsInstanceOf. case IsA(_, ClassType(cd)) if cd.parent.isEmpty => C.True // Since it has typechecked, it can only be true. diff --git a/core/src/main/scala/stainless/genc/phases/TailRecElimPhase.scala b/core/src/main/scala/stainless/genc/phases/TailRecElimPhase.scala new file mode 100644 index 0000000000..24717a62d6 --- /dev/null +++ b/core/src/main/scala/stainless/genc/phases/TailRecElimPhase.scala @@ -0,0 +1,19 @@ +package stainless +package genc +package phases + +import ir.IRs.{ SIR, TIR } +import ir.TailRecSimpTransformer +import ir.TailRecTransformer + +class TailRecElimPhase(using override val context: inox.Context) extends LeonPipeline[SIR.Prog, TIR.Prog](context) { + val name = "TailRecElim" + + private given givenDebugSection: DebugSectionGenC.type = DebugSectionGenC + + def run(sprog: SIR.Prog): TIR.Prog = + val simplTransformer = new TailRecSimpTransformer + val sprog1 = simplTransformer(sprog) + new TailRecTransformer(context)(sprog1) +} + diff --git a/frontends/benchmarks/genc/valid/TailRecAliasedArgs.check b/frontends/benchmarks/genc/valid/TailRecAliasedArgs.check new file mode 100644 index 0000000000..7813681f5b --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecAliasedArgs.check @@ -0,0 +1 @@ +5 \ No newline at end of file diff --git a/frontends/benchmarks/genc/valid/TailRecAliasedArgs.scala b/frontends/benchmarks/genc/valid/TailRecAliasedArgs.scala new file mode 100644 index 0000000000..3eb743d71c --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecAliasedArgs.scala @@ -0,0 +1,17 @@ +import stainless.annotation._ +import stainless.lang._ +import stainless.io._ + +object TailRecAliasedArgs { + def aliased(n: Int, a: Int, b: Int): Int = + require(n >= 0) + decreases(n) + if n == 0 then a + else aliased(n - 1, b, a + b) + + @cCode.`export` + def main(): Unit = { + implicit val state = stainless.io.newState + StdOut.println(aliased(5, 0, 1)) // Expected: 5 + } +} diff --git a/frontends/benchmarks/genc/valid/TailRecComplexArgs.check b/frontends/benchmarks/genc/valid/TailRecComplexArgs.check new file mode 100644 index 0000000000..56a6051ca2 --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecComplexArgs.check @@ -0,0 +1 @@ +1 \ No newline at end of file diff --git a/frontends/benchmarks/genc/valid/TailRecComplexArgs.scala b/frontends/benchmarks/genc/valid/TailRecComplexArgs.scala new file mode 100644 index 0000000000..da9e08d013 --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecComplexArgs.scala @@ -0,0 +1,17 @@ +import stainless.annotation._ +import stainless.lang._ +import stainless.io._ + +object TailRecComplexArgs { + def complexArgs(n: Int): Int = + require(n >= 0) + decreases(n) + if n <= 0 then 1 + else complexArgs(n - 1 * 2 + 1) // Complex argument + + @cCode.`export` + def main(): Unit = { + implicit val state = stainless.io.newState + StdOut.println(complexArgs(5)) // Expected: 1 + } +} diff --git a/frontends/benchmarks/genc/valid/TailRecCountDown.check b/frontends/benchmarks/genc/valid/TailRecCountDown.check new file mode 100644 index 0000000000..c227083464 --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecCountDown.check @@ -0,0 +1 @@ +0 \ No newline at end of file diff --git a/frontends/benchmarks/genc/valid/TailRecCountDown.scala b/frontends/benchmarks/genc/valid/TailRecCountDown.scala new file mode 100644 index 0000000000..9b414e6e16 --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecCountDown.scala @@ -0,0 +1,17 @@ +import stainless.annotation._ +import stainless.lang._ +import stainless.io._ + +object TailRecNoArguments { + def countDown(n: Int): Int = + require(n >= 0) + decreases(n) + if n == 0 then 0 + else countDown(n - 1) + + @cCode.`export` + def main(): Unit = { + implicit val state = stainless.io.newState + StdOut.println(countDown(1000000)) // Expected: 0 + } +} diff --git a/frontends/benchmarks/genc/valid/TailRecEarlyReturn.check b/frontends/benchmarks/genc/valid/TailRecEarlyReturn.check new file mode 100644 index 0000000000..bf0d87ab1b --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecEarlyReturn.check @@ -0,0 +1 @@ +4 \ No newline at end of file diff --git a/frontends/benchmarks/genc/valid/TailRecEarlyReturn.scala b/frontends/benchmarks/genc/valid/TailRecEarlyReturn.scala new file mode 100644 index 0000000000..444c565de8 --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecEarlyReturn.scala @@ -0,0 +1,18 @@ +import stainless.annotation._ +import stainless.lang._ +import stainless.io._ + +object TailRecEarlyReturn { + def earlyReturn(n: Int, acc: Int): Int = + require(n >= 0) + decreases(n) + if n == 0 then acc + else if n == 3 then return acc * 2 // Early return + else earlyReturn(n - 1, acc + 1) + + @cCode.`export` + def main(): Unit = { + implicit val state = stainless.io.newState + StdOut.println(earlyReturn(5, 0)) // Expected: 4 + } +} diff --git a/frontends/benchmarks/genc/valid/TailRecFib.check b/frontends/benchmarks/genc/valid/TailRecFib.check new file mode 100644 index 0000000000..7c6ba0fe18 --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecFib.check @@ -0,0 +1 @@ +55 \ No newline at end of file diff --git a/frontends/benchmarks/genc/valid/TailRecFib.scala b/frontends/benchmarks/genc/valid/TailRecFib.scala new file mode 100644 index 0000000000..0c69bb7467 --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecFib.scala @@ -0,0 +1,19 @@ +import stainless.annotation._ +import stainless.lang._ +import stainless.io._ + +object TailRecFib { + + def fib(n: Int, i: Int = 0, j: Int = 1): Int = + require(n >= 0) + decreases(n) + if n == 0 then i + else fib(n-1, j, i+j) + + @cCode.`export` + def main(): Unit = { + implicit val state = stainless.io.newState + StdOut.println(fib(10)) + } + +} diff --git a/frontends/benchmarks/genc/valid/TailRecFibAliased.check b/frontends/benchmarks/genc/valid/TailRecFibAliased.check new file mode 100644 index 0000000000..7c6ba0fe18 --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecFibAliased.check @@ -0,0 +1 @@ +55 \ No newline at end of file diff --git a/frontends/benchmarks/genc/valid/TailRecFibAliased.scala b/frontends/benchmarks/genc/valid/TailRecFibAliased.scala new file mode 100644 index 0000000000..6a2fdcf0cb --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecFibAliased.scala @@ -0,0 +1,21 @@ +import stainless.annotation._ +import stainless.lang._ +import stainless.io._ + +object TailRecFibAliased { + + def fib(n: Int, i: Int = 0, j: Int = 1): Int = + require(n >= 0) + decreases(n) + if n == 0 then i + else + val res = fib(n-1, j, i+j) + res + + @cCode.`export` + def main(): Unit = { + implicit val state = stainless.io.newState + StdOut.println(fib(10)) + } + +} diff --git a/frontends/benchmarks/genc/valid/TailRecMultipleCalls.check b/frontends/benchmarks/genc/valid/TailRecMultipleCalls.check new file mode 100644 index 0000000000..9d607966b7 --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecMultipleCalls.check @@ -0,0 +1 @@ +11 \ No newline at end of file diff --git a/frontends/benchmarks/genc/valid/TailRecMultipleCalls.scala b/frontends/benchmarks/genc/valid/TailRecMultipleCalls.scala new file mode 100644 index 0000000000..0b6b6d6ab2 --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecMultipleCalls.scala @@ -0,0 +1,19 @@ +import stainless.annotation._ +import stainless.lang._ +import stainless.io._ + +object TailRecMultipleCalls { + def multipleCalls(n: Int, acc: Int): Int = + require(n >= 0) + decreases(n) + if n == 0 then acc + else if n == 1 then acc + 2 + else if n % 2 == 0 then multipleCalls(n - 1, acc + 1) + else multipleCalls(n - 2, acc + 2) + + @cCode.`export` + def main(): Unit = { + implicit val state = stainless.io.newState + StdOut.println(multipleCalls(10, 0)) // Expected: 11 + } +} diff --git a/frontends/benchmarks/genc/valid/TailRecNested.check b/frontends/benchmarks/genc/valid/TailRecNested.check new file mode 100644 index 0000000000..c227083464 --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecNested.check @@ -0,0 +1 @@ +0 \ No newline at end of file diff --git a/frontends/benchmarks/genc/valid/TailRecNested.scala b/frontends/benchmarks/genc/valid/TailRecNested.scala new file mode 100644 index 0000000000..b4063fec9c --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecNested.scala @@ -0,0 +1,20 @@ +import stainless.annotation._ +import stainless.lang._ +import stainless.io._ + +object TailRecNested { + def outer(n: Int): Int = + require(n >= 0) + def inner(x: Int): Int = + require(x >= 0) + decreases(x) + if x == 0 then 0 + else inner(x - 1) + inner(n) + + @cCode.`export` + def main(): Unit = { + implicit val state = stainless.io.newState + StdOut.println(outer(5)) // Expected: 0 + } +} diff --git a/frontends/benchmarks/genc/valid/TailRecPatternMatching.check b/frontends/benchmarks/genc/valid/TailRecPatternMatching.check new file mode 100644 index 0000000000..7813681f5b --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecPatternMatching.check @@ -0,0 +1 @@ +5 \ No newline at end of file diff --git a/frontends/benchmarks/genc/valid/TailRecPatternMatching.scala b/frontends/benchmarks/genc/valid/TailRecPatternMatching.scala new file mode 100644 index 0000000000..09c540424d --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecPatternMatching.scala @@ -0,0 +1,26 @@ +import stainless.annotation._ +import stainless.lang._ +import stainless.io._ + +object TailRecPatternMatching { + def patternMatch(x: Option[Int], acc: Int): Int = + require(x match { + case None() => true + case Some(n) => n >= 1 + }) + val measure = x match { + case None() => 0 + case Some(n) => n + } + decreases(measure) + x match + case None() => acc + case Some(n) if n == 1 => patternMatch(None(), acc + 1) + case Some(n) => patternMatch(Some(n - 1), acc + 1) + + @cCode.`export` + def main(): Unit = { + implicit val state = stainless.io.newState + StdOut.println(patternMatch(Some(5), 0)) // Expected: 5 + } +} diff --git a/frontends/benchmarks/genc/valid/TailRecStackOverflow.check b/frontends/benchmarks/genc/valid/TailRecStackOverflow.check new file mode 100644 index 0000000000..56a6051ca2 --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecStackOverflow.check @@ -0,0 +1 @@ +1 \ No newline at end of file diff --git a/frontends/benchmarks/genc/valid/TailRecStackOverflow.scala b/frontends/benchmarks/genc/valid/TailRecStackOverflow.scala new file mode 100644 index 0000000000..4f085bfbf1 --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecStackOverflow.scala @@ -0,0 +1,18 @@ +import stainless.annotation._ +import stainless.lang._ +import stainless.io._ + +object TailRecStackOverflow { + def even(n: Int): Int = + require(n >= 0) + decreases(n) + if n == 0 then 1 + else if n == 1 then 0 + else even(n - 2) + + @cCode.`export` + def main(): Unit = { + implicit val state = stainless.io.newState + StdOut.println(even(1000000)) // Expected: 1 + } +} diff --git a/frontends/benchmarks/genc/valid/TailRecUnit.check b/frontends/benchmarks/genc/valid/TailRecUnit.check new file mode 100644 index 0000000000..e69de29bb2 diff --git a/frontends/benchmarks/genc/valid/TailRecUnit.scala b/frontends/benchmarks/genc/valid/TailRecUnit.scala new file mode 100644 index 0000000000..82632eed3a --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecUnit.scala @@ -0,0 +1,17 @@ +import stainless.annotation._ +import stainless.lang._ +import stainless.io._ + +object TailRecUnit { + def countDown(n: Int): Unit = + require(n >= 0) + decreases(n) + if (n == 0) return + countDown(n - 1) + + @cCode.`export` + def main(): Unit = { + implicit val state = stainless.io.newState + countDown(1000000) + } +} diff --git a/frontends/benchmarks/genc/valid/TailRecUnitIf.check b/frontends/benchmarks/genc/valid/TailRecUnitIf.check new file mode 100644 index 0000000000..e69de29bb2 diff --git a/frontends/benchmarks/genc/valid/TailRecUnitIf.scala b/frontends/benchmarks/genc/valid/TailRecUnitIf.scala new file mode 100644 index 0000000000..2aeba56e0c --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecUnitIf.scala @@ -0,0 +1,17 @@ +import stainless.annotation._ +import stainless.lang._ +import stainless.io._ + +object TailRecUnitIf { + def countDown(n: Int): Unit = + require(n >= 0) + decreases(n) + if (n == 0) return + else countDown(n - 1) + + @cCode.`export` + def main(): Unit = { + implicit val state = stainless.io.newState + countDown(1000000) + } +} diff --git a/frontends/benchmarks/genc/valid/TailRecUnitNoExplicitEnd.check b/frontends/benchmarks/genc/valid/TailRecUnitNoExplicitEnd.check new file mode 100644 index 0000000000..e69de29bb2 diff --git a/frontends/benchmarks/genc/valid/TailRecUnitNoExplicitEnd.scala b/frontends/benchmarks/genc/valid/TailRecUnitNoExplicitEnd.scala new file mode 100644 index 0000000000..1445e94e8d --- /dev/null +++ b/frontends/benchmarks/genc/valid/TailRecUnitNoExplicitEnd.scala @@ -0,0 +1,16 @@ +import stainless.annotation._ +import stainless.lang._ +import stainless.io._ + +object TailRecUnitWithNoExplicitEnd { + def countDown(n: Int): Unit = + require(n >= 0) + decreases(n) + if (n > 0) countDown(n - 1) + + @cCode.`export` + def main(): Unit = { + implicit val state = stainless.io.newState + countDown(1000000) + } +} diff --git a/frontends/common/src/it/scala/stainless/GenCSuite.scala b/frontends/common/src/it/scala/stainless/GenCSuite.scala index 4317db1700..d4bc3d15b7 100644 --- a/frontends/common/src/it/scala/stainless/GenCSuite.scala +++ b/frontends/common/src/it/scala/stainless/GenCSuite.scala @@ -16,6 +16,11 @@ import Utils._ class GenCSuite extends AnyFunSuite with inox.ResourceUtils with InputUtils with Matchers { val validFiles = resourceFiles("genc/valid", _.endsWith(".scala"), false).map(_.getPath) val invalidFiles = resourceFiles("genc/invalid", _.endsWith(".scala"), false).map(_.getPath) + val tailrecFiles = validFiles.filter(_.toLowerCase.contains("tailrec".toLowerCase)).map { path => + val checkFile = path.replace(".scala", ".check") + path -> checkFile + } + val tailrecScalaFiles = tailrecFiles.map(_._1) val ctx = TestContext.empty for (file <- invalidFiles) { @@ -69,6 +74,22 @@ class GenCSuite extends AnyFunSuite with inox.ResourceUtils with InputUtils with assert(output == "124443", s"Output '$output' should be '124443'") } + for (case (file, _) <- tailrecFiles) { + test(s"Checking that ${file.split("/").last} has tail recursive function rewritten as loop") { + val cFile = file.replace(".scala", ".c") + val cCode = Files.readAllLines(Paths.get(cFile)).toArray.mkString + assert(cCode.contains("goto"), "Should contain a goto statement") + } + } + + for (case (file, checkFile) <- tailrecFiles) { + test(s"Checking that ${file.split("/").last} outputs ${Files.readAllLines(Paths.get(checkFile)).toArray.mkString}") { + val output = runCHelper(file) + val checkValue = Files.readAllLines(Paths.get(checkFile)).toArray.mkString + assert(output == checkValue, s"Output '$output' should be $checkValue") + } + } + def runCHelper(filename: String): String = { val file = validFiles.find(_.contains(filename)).get val outFile = file.replace(".scala", ".out") diff --git a/frontends/common/src/it/scala/stainless/TailRecGenCSuite.scala b/frontends/common/src/it/scala/stainless/TailRecGenCSuite.scala new file mode 100644 index 0000000000..e11b009d4c --- /dev/null +++ b/frontends/common/src/it/scala/stainless/TailRecGenCSuite.scala @@ -0,0 +1,72 @@ +/* Copyright 2009-2021 EPFL, Lausanne */ + +package stainless + +import utils._ + +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers + +import java.nio.file.{Paths, Files} +import java.io.File +import java.io.PrintWriter + +import Utils._ + +class TailRecGenCSuite extends AnyFunSuite with inox.ResourceUtils with InputUtils with Matchers { + val validFiles = resourceFiles("genc/valid", _.endsWith(".scala"), false).map(_.getPath) + val tailrecFiles = validFiles.filter(_.toLowerCase.contains("tailrec".toLowerCase)).map { path => + val checkFile = path.replace(".scala", ".check") + path -> checkFile + } + val tailrecScalaFiles = tailrecFiles.map(_._1) + val ctx = TestContext.empty + + // for (file <- tailrecFiles) { + // val extraOpts = Seq("--batched", "--solvers=smt-z3", "--strict-arithmetic=false", "--timeout=10") + // test(s"stainless ${extraOpts.mkString(" ")} $file") { + // val (localCtx, optReport) = runMainWithArgs(Array(file) ++ extraOpts) + // assert(localCtx.reporter.errorCount == 0, "No errors") + // assert(optReport.nonEmpty, "Valid report returned by Stainless") + // assert(optReport.get.isSuccess, "Only valid VCs") + // } + // } + + for (file <- tailrecScalaFiles) { + val cFile = file.replace(".scala", ".c") + val outFile = file.replace(".scala", ".out") + test(s"stainless --genc --genc-output=$cFile $file") { + runMainWithArgs(Array(file) :+ "--genc" :+ s"--genc-output=$cFile") + assert(Files.exists(Paths.get(cFile))) + val gccCompile = s"gcc $cFile -o $outFile" + ctx.reporter.info(s"Running: $gccCompile") + val (std, exitCode) = runCommand(gccCompile) + assert(exitCode == 0, "gcc failed with output:\n" + std.mkString("\n")) + } + } + + for (case (file, _) <- tailrecFiles) { + test(s"Checking that ${file.split("/").last} has tail recursive function rewritten as loop") { + val cFile = file.replace(".scala", ".c") + val cCode = Files.readAllLines(Paths.get(cFile)).toArray.mkString + assert(cCode.contains("goto"), "Should contain a goto statement") + } + } + + for (case (file, checkFile) <- tailrecFiles) { + test(s"Checking that ${file.split("/").last} outputs ${Files.readAllLines(Paths.get(checkFile)).toArray.mkString}") { + val output = runCHelper(file) + val checkValue = Files.readAllLines(Paths.get(checkFile)).toArray.mkString + assert(output == checkValue, s"Output '$output' should be $checkValue") + } + } + + def runCHelper(filename: String): String = { + val file = validFiles.find(_.contains(filename)).get + val outFile = file.replace(".scala", ".out") + ctx.reporter.info(s"Running: $outFile") + val (std, _) = runCommand(outFile) + // Note: lines are concatenated without adding newlines between them + std.mkString + } +}