diff --git a/core/src/main/scala/stainless/extraction/imperative/AntiAliasing.scala b/core/src/main/scala/stainless/extraction/imperative/AntiAliasing.scala index 175f598e10..283eaf4afa 100644 --- a/core/src/main/scala/stainless/extraction/imperative/AntiAliasing.scala +++ b/core/src/main/scala/stainless/extraction/imperative/AntiAliasing.scala @@ -173,12 +173,12 @@ trait AntiAliasing def mapApplication(formalArgs: Seq[ValDef], args: Seq[Expr], nfi: Expr, nfiType: Type, fiEffects: Set[Effect], env: Env): Expr = { if (fiEffects.exists(e => formalArgs contains e.receiver.toVal)) { - val localEffects = (formalArgs zip args) + val localEffects: Seq[Set[(Effect, Set[Effect])]] = (formalArgs zip args) .map { case (vd, arg) => (fiEffects.filter(_.receiver == vd.toVariable), arg) } .filter { case (effects, _) => effects.nonEmpty } .map { case (effects, arg) => val rArg = exprOps.replaceFromSymbols(env.rewritings, arg) - effects map (e => (e, e on rArg)) + effects map (e => (e, e.effectsOn(rArg))) } for ((_, effects) <- localEffects.flatMap(_.flatMap(_._2)).groupBy(_.receiver)) { @@ -267,7 +267,7 @@ trait AntiAliasing override def transform(e: Expr, env: Env): Expr = (e match { case l @ Let(vd, e, b) if isMutableType(vd.tpe) => val newExpr = transform(e, env) - if (getKnownEffects(newExpr).nonEmpty) { + if (computeKnownTargets(newExpr).nonEmpty) { val newBody = transform(b, env withRewritings Map(vd -> newExpr)) Let(vd, newExpr, newBody).copiedFrom(l) } else { @@ -281,13 +281,13 @@ trait AntiAliasing LetVar(vd, newExpr, newBody).copiedFrom(l) case m @ MatchExpr(scrut, cses) if isMutableType(scrut.getType) => - if (effects(scrut).nonEmpty) { + if (exprEffects(scrut).nonEmpty) { def liftEffects(e: Expr): (Seq[(ValDef, Expr)], Expr) = e match { - case ArraySelect(e, i) if effects(i).nonEmpty => + case ArraySelect(e, i) if exprEffects(i).nonEmpty => val (eBindings, eLift) = liftEffects(e) val vd = ValDef(FreshIdentifier("index", true), Int32Type().copiedFrom(i)).copiedFrom(i) (eBindings :+ (vd -> i), ArraySelect(eLift, vd.toVariable).copiedFrom(e)) - case _ if effects(e).nonEmpty => + case _ if exprEffects(e).nonEmpty => throw MalformedStainlessCode(m, "Unexpected effects in match scrutinee") case _ => (Seq.empty, e) } @@ -310,7 +310,7 @@ trait AntiAliasing case up @ ArrayUpdate(a, i, v) => val ra = exprOps.replaceFromSymbols(env.rewritings, a) - val effects = getExactEffects(ra) + val effects = computeExactTargets(ra) if (effects.exists(eff => !env.bindings.contains(eff.receiver.toVal))) throw MalformedStainlessCode(up, "Unsupported form of array update") @@ -322,7 +322,7 @@ trait AntiAliasing case up @ MutableMapUpdate(map, k, v) => val rmap = exprOps.replaceFromSymbols(env.rewritings, map) - val effects = getExactEffects(rmap) + val effects = computeExactTargets(rmap) if (effects.exists(eff => !env.bindings.contains(eff.receiver.toVal))) throw MalformedStainlessCode(up, "Unsupported form of map update") @@ -334,7 +334,7 @@ trait AntiAliasing case as @ FieldAssignment(o, id, v) => val so = exprOps.replaceFromSymbols(env.rewritings, o) - val effects = getExactEffects(so) + val effects = computeExactTargets(so) if (effects.exists(eff => !env.bindings.contains(eff.receiver.toVal))) throw MalformedStainlessCode(as, "Unsupported form of field assignment") @@ -390,7 +390,7 @@ trait AntiAliasing id, tps, args.map(arg => transform(exprOps.replaceFromSymbols(env.rewritings, arg), env)) ).copiedFrom(fi) - mapApplication(fd.params, args, nfi, fi.tfd.instantiate(analysis.getReturnType(fd)), effects(fd), env) + mapApplication(fd.params, args, nfi, fi.tfd.instantiate(analysis.getReturnType(fd)), funEffects(fd), env) case alr @ ApplyLetRec(id, tparams, tpe, tps, args) => val fd = Inner(env.locals(id)) @@ -406,7 +406,7 @@ trait AntiAliasing ).copiedFrom(alr) val resultType = typeOps.instantiateType(analysis.getReturnType(fd), (tparams zip tps).toMap) - mapApplication(fd.params, args, nfi, resultType, effects(fd), env) + mapApplication(fd.params, args, nfi, resultType, funEffects(fd), env) case app @ Application(callee, args) => val ft @ FunctionType(from, to) = callee.getType diff --git a/core/src/main/scala/stainless/extraction/imperative/EffectsAnalyzer.scala b/core/src/main/scala/stainless/extraction/imperative/EffectsAnalyzer.scala index 577d4d7961..09fa6e31f6 100644 --- a/core/src/main/scala/stainless/extraction/imperative/EffectsAnalyzer.scala +++ b/core/src/main/scala/stainless/extraction/imperative/EffectsAnalyzer.scala @@ -85,10 +85,10 @@ trait EffectsAnalyzer extends oo.CachingPhase { trait EffectsAnalysis { self: TransformerContext => implicit val symbols: s.Symbols - private[this] def functionEffects(fd: FunAbstraction, current: Result): Set[Effect] = + private[this] def computeFunEffects(fd: FunAbstraction, current: Result): Set[Effect] = exprOps.withoutSpecs(fd.fullBody) match { case Some(body) => - expressionEffects(body, current) + computeExprEffects(body, current) case None if !fd.flags.contains(IsPure) => fd.params .filter(vd => symbols.isMutableType(vd.getType) && !vd.flags.contains(IsPure)) @@ -121,7 +121,7 @@ trait EffectsAnalyzer extends oo.CachingPhase { inners.flatMap { case (_, inners) => inners.map(fun => fun.id -> fun) }) val result = inox.utils.fixpoint[Result] { case res @ Result(effects, locals) => - Result(effects.map { case (fd, _) => fd -> functionEffects(fd, res) }, locals) + Result(effects.map { case (fd, _) => fd -> computeFunEffects(fd, res) }, locals) } (prevResult merge baseResult) for ((fd, inners) <- inners) { @@ -137,14 +137,14 @@ trait EffectsAnalyzer extends oo.CachingPhase { results merge newResult } - def effects(fd: FunDef): Set[Effect] = result.effects(Outer(fd)) - def effects(fun: FunAbstraction): Set[Effect] = result.effects(fun) - def effects(expr: Expr): Set[Effect] = expressionEffects(expr, result) + def funEffects(fd: FunDef): Set[Effect] = result.effects(Outer(fd)) + def funEffects(fun: FunAbstraction): Set[Effect] = result.effects(fun) + def exprEffects(expr: Expr): Set[Effect] = computeExprEffects(expr, result) private[imperative] def local(id: Identifier): FunAbstraction = result.locals(id) private[imperative] def getAliasedParams(fd: FunAbstraction): Seq[ValDef] = { - val receivers = effects(fd).map(_.receiver) + val receivers = funEffects(fd).map(_.receiver) fd.params.filter(vd => receivers(vd.toVariable)) } @@ -170,7 +170,8 @@ trait EffectsAnalyzer extends oo.CachingPhase { def +:(elem: Accessor): Path = Path(elem +: path) def ++(that: Path): Path = Path(this.path ++ that.path) - def on(that: Expr)(implicit symbols: Symbols): Set[Target] = { + // Compute all the effect targets of a given expression extended by selections of this path. + def targetsOn(that: Expr)(implicit symbols: Symbols): Set[Target] = { def rec(expr: Expr, path: Seq[Accessor]): Option[Expr] = path match { case ADTFieldAccessor(id) +: xs => rec(ADTSelector(expr, id), xs) @@ -198,7 +199,21 @@ trait EffectsAnalyzer extends oo.CachingPhase { Some(expr) } - rec(that, path).toSet.flatMap(getEffects) + // Check that this path is valid on the given expression, otherwise return the empty set. + // Note that if this path is valid on `expr`, we end up simply computing the targets of + // `expr` with `this` appended, since computeTargets will initially strip away the selectors. + rec(that, path) match { + case None => + // NOTE(gsps): I believe this check is redundant except for the cases where we try to + // select class fields on expressions whose type is an abstract type def. + // We have a test case (extraction/valid/TypeMembers2.scala) in which we expand an + // abstract base class' method to a dispatcher, and such type defs appear. Simply omitting + // the effect by returning the empty set in such cases actually seems an odd choice to me, + // though I can't see how to exploit it for unsoundness. + Set.empty + case Some(pathOnExpr) => + computeTargets(pathOnExpr) + } } def prefixOf(that: Path): Boolean = { @@ -283,9 +298,9 @@ trait EffectsAnalyzer extends oo.CachingPhase { case class Effect(receiver: Variable, path: Path) { def +(elem: Accessor) = Effect(receiver, path :+ elem) - def on(that: Expr)(implicit symbols: Symbols): Set[Effect] = for { - Target(receiver, _, path) <- this.path on that - } yield Effect(receiver, path) + // Compute all the effects one gets from replacing this effect's receiver by a given expression. + def effectsOn(that: Expr)(implicit symbols: Symbols): Set[Effect] = + path.targetsOn(that).map(_.toEffect) def prefixOf(that: Effect): Boolean = receiver == that.receiver && (path prefixOf that.path) @@ -301,7 +316,23 @@ trait EffectsAnalyzer extends oo.CachingPhase { override def toString: String = asString } - def getEffects(expr: Expr)(implicit symbols: Symbols): Set[Target] = { + // Computes the effect targets of a given expression. + // The effect targets of an expression represent all the possible variables in scope that might + // be affected, if one were to mutate the given expression. + // + // Individual cases are represented by one Target each. Targets consist of + // - the variable being modified (the receiver), + // - the path condition under which this happens, and + // - the selection path at which this variable is being modified (if any). + // + // There is no support for the case where we don't have a name for the target's receiver. + // (For instance, imagine computing the targets of some temporary ADT value that isn't bound to a + // variable.) + // + // This function provides an under-approximation in some cases; in particular, it might return + // an empty set for certain "non-local" constructs like function invocations. These seem to be + // handled specially elsewhere. + def computeTargets(expr: Expr)(implicit symbols: Symbols): Set[Target] = { def rec(expr: Expr, path: Seq[Accessor]): Set[Target] = expr match { case v: Variable => Set(Target(v, None, Path(path))) case _ if variablesOf(expr).forall(v => !symbols.isMutableType(v.tpe)) => Set.empty @@ -336,6 +367,9 @@ trait EffectsAnalyzer extends oo.CachingPhase { And(Not(cnd).setPos(cnd), e.setPos(cnd)).setPos(cnd) } getOrElse(Not(cnd).setPos(cnd)) + // FIXME: This seems wrong: why are we ignoring t.condition? + // FIXME: This also seems inefficient: is it really a good idea to "ground" all the paths + // rather than representing them as some sort of tree? for { t <- rec(thn, path) e <- rec(els, path) @@ -362,8 +396,9 @@ trait EffectsAnalyzer extends oo.CachingPhase { rec(b, path) case Let(vd, e, b) => + // FIXME(gsps): This seems exceedingly cryptic. val bEffects = rec(b, path) - val res = for (ee <- getEffects(e); be <- bEffects) yield { + val res = for (ee <- computeTargets(e); be <- bEffects) yield { if (be.receiver == vd.toVariable) ee.append(be) else be } @@ -379,13 +414,15 @@ trait EffectsAnalyzer extends oo.CachingPhase { rec(expr, Seq.empty) } - def getExactEffects(expr: Expr)(implicit symbols: Symbols): Set[Target] = getEffects(expr) match { - case effects if effects.nonEmpty => effects + // Like computeTargets, but never under-approximates. + def computeExactTargets(expr: Expr)(implicit symbols: Symbols): Set[Target] = computeTargets(expr) match { + case targets if targets.nonEmpty => targets case _ => throw MalformedStainlessCode(expr, s"Couldn't compute exact effect targets in: $expr") } - def getKnownEffects(expr: Expr)(implicit symbols: Symbols): Set[Target] = try { - getEffects(expr) + // Like computeTargets, but replaces some unsupported cases by the (empty) under-approximation. + def computeKnownTargets(expr: Expr)(implicit symbols: Symbols): Set[Target] = try { + computeTargets(expr) } catch { case _: MalformedStainlessCode => Set.empty } @@ -410,7 +447,7 @@ trait EffectsAnalyzer extends oo.CachingPhase { * * We are assuming no aliasing. */ - private def expressionEffects(expr: Expr, result: Result)(implicit symbols: Symbols): Set[Effect] = { + private def computeExprEffects(expr: Expr, result: Result)(implicit symbols: Symbols): Set[Effect] = { import symbols._ val freeVars = variablesOf(expr) @@ -418,7 +455,7 @@ trait EffectsAnalyzer extends oo.CachingPhase { env.get(effect.receiver).map(e => e.copy(path = e.path ++ effect.path)) def effect(expr: Expr, env: Map[Variable, Effect]): Set[Effect] = - getEffects(expr) flatMap { (target: Target) => + computeTargets(expr) flatMap { (target: Target) => inEnv(target.toEffect, env).toSet } @@ -474,7 +511,7 @@ trait EffectsAnalyzer extends oo.CachingPhase { val currentEffects: Set[Effect] = result.effects(fun) val paramSubst = (fun.params.map(_.toVariable) zip args).toMap val invocEffects = currentEffects.flatMap(e => paramSubst.get(e.receiver) match { - case Some(arg) => (e on arg).flatMap(inEnv(_, env)) + case Some(arg) => e.effectsOn(arg).flatMap(inEnv(_, env)) case None => Seq(e) // This effect occurs on some variable captured from scope }) diff --git a/core/src/main/scala/stainless/extraction/imperative/EffectsChecker.scala b/core/src/main/scala/stainless/extraction/imperative/EffectsChecker.scala index 0dba173a98..c5ca822970 100644 --- a/core/src/main/scala/stainless/extraction/imperative/EffectsChecker.scala +++ b/core/src/main/scala/stainless/extraction/imperative/EffectsChecker.scala @@ -48,7 +48,7 @@ trait EffectsChecker { self: EffectsAnalyzer => case l @ Let(vd, e, b) => if (!isExpressionFresh(e) && isMutableType(vd.tpe)) try { // Check if a precise effect can be computed - getEffects(e) + computeTargets(e) } catch { case _: MalformedStainlessCode => throw ImperativeEliminationException(e, "Illegal aliasing: " + e.asString) @@ -87,7 +87,7 @@ trait EffectsChecker { self: EffectsAnalyzer => case l @ Lambda(args, body) => if (isMutableType(body.getType) && !isExpressionFresh(body)) throw ImperativeEliminationException(l, "Illegal aliasing in lambda body") - if (effects(body).exists(e => !args.contains(e.receiver.toVal))) + if (exprEffects(body).exists(e => !args.contains(e.receiver.toVal))) throw ImperativeEliminationException(l, "Illegal effects in lambda body") super.traverse(l) @@ -142,22 +142,22 @@ trait EffectsChecker { self: EffectsAnalyzer => if (isMutableType(fd.returnType)) throw ImperativeEliminationException(fd, "A field cannot refer to a mutable object") - if (effects(fd.fullBody).nonEmpty) - throw ImperativeEliminationException(fd, s"A field must be pure, but ${fd.id.asString} has effects: ${effects(fd.fullBody).map(_.asString).mkString(", ")}") + if (exprEffects(fd.fullBody).nonEmpty) + throw ImperativeEliminationException(fd, s"A field must be pure, but ${fd.id.asString} has effects: ${exprEffects(fd.fullBody).map(_.asString).mkString(", ")}") } def checkEffectsLocations(fd: FunAbstraction): Unit = exprOps.preTraversal { case Require(pre, _) => - val preEffects = effects(pre) + val preEffects = exprEffects(pre) if (preEffects.nonEmpty) throw ImperativeEliminationException(pre, "Precondition has effects on: " + preEffects.head.receiver.asString) case Ensuring(_, post @ Lambda(_, body)) => - val bodyEffects = effects(body) + val bodyEffects = exprEffects(body) if (bodyEffects.nonEmpty) throw ImperativeEliminationException(post, "Postcondition has effects on: " + bodyEffects.head.receiver.asString) - val oldEffects = effects(exprOps.postMap { + val oldEffects = exprEffects(exprOps.postMap { case Old(e) => Some(e) case _ => None } (body)) @@ -165,36 +165,36 @@ trait EffectsChecker { self: EffectsAnalyzer => throw ImperativeEliminationException(post, s"Postcondition tries to mutate ${Old(oldEffects.head.receiver).asString}") case Decreases(meas, _) => - val measEffects = effects(meas) + val measEffects = exprEffects(meas) if (measEffects.nonEmpty) throw ImperativeEliminationException(meas, "Decreases has effects on: " + measEffects.head.receiver.asString) case Assert(pred, _, _) => - val predEffects = effects(pred) + val predEffects = exprEffects(pred) if (predEffects.nonEmpty) throw ImperativeEliminationException(pred, "Assertion has effects on: " + predEffects.head.receiver.asString) case Forall(_, pred) => - val predEffects = effects(pred) + val predEffects = exprEffects(pred) if (predEffects.nonEmpty) throw ImperativeEliminationException(pred, "Quantifier has effects on: " + predEffects.head.receiver.asString) case wh @ While(_, _, Some(invariant)) => - val invEffects = effects(invariant) + val invEffects = exprEffects(invariant) if (invEffects.nonEmpty) throw ImperativeEliminationException(invariant, "Loop invariant has effects on: " + invEffects.head.receiver.asString) case m @ MatchExpr(_, cses) => cses.foreach { cse => cse.optGuard.foreach { guard => - val guardEffects = effects(guard) + val guardEffects = exprEffects(guard) if (guardEffects.nonEmpty) throw ImperativeEliminationException(guard, "Pattern guard has effects on: " + guardEffects.head.receiver.asString) } patternOps.preTraversal { case up: UnapplyPattern => - val upEffects = effects(Outer(up.getFunction.fd)) + val upEffects = funEffects(up.getFunction.fd) if (upEffects.nonEmpty) throw ImperativeEliminationException(up, "Pattern unapply has effects on: " + upEffects.head.receiver.asString) @@ -203,7 +203,7 @@ trait EffectsChecker { self: EffectsAnalyzer => } case Let(vd, v, rest) if vd.flags.contains(Lazy) => - val eff = effects(v) + val eff = exprEffects(v) if (eff.nonEmpty) throw ImperativeEliminationException(v, "Stainless does not support effects in lazy val's on: " + eff.head.receiver.asString) @@ -211,7 +211,7 @@ trait EffectsChecker { self: EffectsAnalyzer => }(fd.fullBody) def checkPurity(fd: FunAbstraction): Unit = { - val effs = effects(fd.fullBody) + val effs = exprEffects(fd.fullBody) if ((fd.flags contains IsPure) && !effs.isEmpty) throw ImperativeEliminationException(fd, s"Functions marked @pure cannot have side-effects") @@ -276,7 +276,7 @@ trait EffectsChecker { self: EffectsAnalyzer => def checkSort(sort: ADTSort)(analysis: EffectsAnalysis): Unit = { for (fd <- sort.invariant(analysis.symbols)) { - val invEffects = analysis.effects(fd) + val invEffects = analysis.funEffects(fd) if (invEffects.nonEmpty) throw ImperativeEliminationException(fd, "Invariant has effects on: " + invEffects.head.asString) } diff --git a/core/src/main/scala/stainless/extraction/imperative/GhostChecker.scala b/core/src/main/scala/stainless/extraction/imperative/GhostChecker.scala index 76ff5b878d..c49e7e5d0a 100644 --- a/core/src/main/scala/stainless/extraction/imperative/GhostChecker.scala +++ b/core/src/main/scala/stainless/extraction/imperative/GhostChecker.scala @@ -84,7 +84,7 @@ trait GhostChecker { self: EffectsAnalyzer => if (fun.flags contains Synthetic) { () // Synthetic functions should always be fine with respect to ghost flow } else if (fun.flags contains Ghost) { - effects(fun).find(!isGhostEffect(_)) match { + funEffects(fun).find(!isGhostEffect(_)) match { case Some(eff) => throw ImperativeEliminationException(fun, s"Ghost function cannot have effect on non-ghost state: ${eff.targetString}") case None => () } @@ -109,7 +109,7 @@ trait GhostChecker { self: EffectsAnalyzer => override def traverse(expr: Expr): Unit = expr match { case Let(vd, e, b) if vd.flags contains Ghost => - effects(e).find(!isGhostEffect(_)) match { + exprEffects(e).find(!isGhostEffect(_)) match { case Some(eff) => throw ImperativeEliminationException(expr, s"Right-hand side of ghost variable must only have effects on ghost fields (${eff.targetString} is not ghost)") @@ -128,7 +128,7 @@ trait GhostChecker { self: EffectsAnalyzer => "Right-hand side of non-ghost variable cannot be ghost") case LetVar(vd, e, b) if vd.flags contains Ghost => - effects(e).find(!isGhostEffect(_)) match { + exprEffects(e).find(!isGhostEffect(_)) match { case Some(eff) => throw ImperativeEliminationException(expr, s"Right-hand side of ghost variable must only have effects on ghost fields (${eff.targetString} is not ghost)") @@ -147,7 +147,7 @@ trait GhostChecker { self: EffectsAnalyzer => "Right-hand side of non-ghost variable cannot be ghost") case Assignment(v, e) if v.flags contains Ghost => - effects(e).find(!isGhostEffect(_)) match { + exprEffects(e).find(!isGhostEffect(_)) match { case Some(eff) => throw ImperativeEliminationException(expr, s"Right-hand side of ghost variable assignment must only have effects on ghost fields (${eff.targetString} is not ghost)") @@ -165,7 +165,7 @@ trait GhostChecker { self: EffectsAnalyzer => "Snapshots can only be used in ghost contexts") case FieldAssignment(obj, id, e) if isADT(obj) && isGhostExpression(ADTSelector(obj, id)) => - effects(e).find(!isGhostEffect(_)) match { + exprEffects(e).find(!isGhostEffect(_)) match { case Some(eff) => throw ImperativeEliminationException(expr, s"Right-hand side of ghost field assignment must only have effects on ghost fields (${eff.targetString} is not ghost)") @@ -176,7 +176,7 @@ trait GhostChecker { self: EffectsAnalyzer => } case FieldAssignment(obj, id, e) if isObject(obj) && isGhostExpression(ClassSelector(obj, id)) => - effects(e).find(!isGhostEffect(_)) match { + exprEffects(e).find(!isGhostEffect(_)) match { case Some(eff) => throw ImperativeEliminationException(expr, s"Right-hand side of ghost field assignment must only have effects on ghost fields (${eff.targetString} is not ghost)") @@ -211,7 +211,7 @@ trait GhostChecker { self: EffectsAnalyzer => (lookupFunction(id).map(Outer(_)).getOrElse(analysis.local(id)).params zip args) .foreach { case (vd, arg) => if (vd.flags contains Ghost) { - effects(arg).find(!isGhostEffect(_)) match { + exprEffects(arg).find(!isGhostEffect(_)) match { case Some(eff) => throw ImperativeEliminationException(arg, s"Argument to ghost parameter `${vd.id}` of `${id}` must only have effects on ghost fields (${eff.targetString} is not ghost)") @@ -230,7 +230,7 @@ trait GhostChecker { self: EffectsAnalyzer => (adt.getConstructor.fields zip args) .foreach { case (vd, arg) => if (vd.flags contains Ghost) { - if (!effects(arg).forall(isGhostEffect)) + if (!exprEffects(arg).forall(isGhostEffect)) throw ImperativeEliminationException(arg, s"Argument to ghost field `${vd.id.asString}` of class `${id.asString}` must only have effects on ghost fields") new Checker(true).traverse(arg) @@ -245,7 +245,7 @@ trait GhostChecker { self: EffectsAnalyzer => (ct.tcd.fields zip args) .foreach { case (vd, arg) => if (vd.flags contains Ghost) { - if (!effects(arg).forall(isGhostEffect)) + if (!exprEffects(arg).forall(isGhostEffect)) throw ImperativeEliminationException(arg, s"Argument to ghost field `${vd.id.asString}` of class `${ct.id.asString}` must only have effects on ghost fields") new Checker(true).traverse(arg)