Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for pattern alternatives #1627

Merged
merged 4 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/src/main/scala/stainless/ast/Deconstructors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ trait TreeDeconstructor extends inox.ast.TreeDeconstructor {
(Seq(id), binder.map(_.toVariable).toSeq, recs, tps, subs, (ids, vs, es, tps, pats) => {
t.UnapplyPattern(vs.headOption.map(_.toVal), es, ids.head, tps, pats)
})
case s.AlternativePattern(binder, subs) =>
(Seq(), binder.map(_.toVariable).toSeq, Seq(), Seq(), subs, (_, vs, _, _, pats) => {
t.AlternativePattern(vs.headOption.map(_.toVal), pats)
})
}

/** Rebuild a match case from the given set of identifiers, variables, expressions and types */
Expand Down
8 changes: 8 additions & 0 deletions core/src/main/scala/stainless/ast/ExprOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,14 @@ class ExprOps(override val trees: Trees) extends inox.ast.ExprOps(trees) { self
(_ ++ _)
)

case AlternativePattern(vdOpt, subPatterns) =>
val freshVdOpt = vdOpt.map(vd => transform(vd.freshen, env))
// We don't need to freshen the subPatterns here, as they are not bound
(
AlternativePattern(freshVdOpt, subPatterns),
env ++ freshVdOpt.map(freshVd => vdOpt.get.id -> freshVd.id)
)

case LiteralPattern(vdOpt, lit) =>
val freshVdOpt = vdOpt.map(vd => transform(vd.freshen, env))
val newEnv = env ++ freshVdOpt.map(freshVd => vdOpt.get.id -> freshVd.id)
Expand Down
7 changes: 7 additions & 0 deletions core/src/main/scala/stainless/ast/Expressions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ trait Expressions extends inox.ast.Expressions with Types { self: Trees =>
val subPatterns = Seq()
}

/**
* Pattern encoding like `case binder @ (subPattern1 | subPattern2 | ...) => ...`
*
* If [[binder]] is empty, consider a wildcard `_` in its place.
*/
sealed case class AlternativePattern(binder: Option[ValDef], subPatterns: Seq[Pattern]) extends Pattern

protected def unapplyScrut(scrut: Expr, up: UnapplyPattern)(using s: Symbols): Expr = {
FunctionInvocation(up.id, up.tps, up.recs :+ scrut)
}
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/scala/stainless/ast/Printers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ trait Printer extends inox.ast.Printer {
printNameWithPath(id)
p"(${nary(subs)})"

case AlternativePattern(ovd, subs) =>
ovd foreach (vd => p"${vd.toVariable} : ")
p"(${nary(subs, " | ")})"

case Passes(in, out, cases) =>
optP {
p"""|($in, $out) passes {
Expand Down
6 changes: 6 additions & 0 deletions core/src/main/scala/stainless/ast/SymbolOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ trait SymbolOps extends inox.ast.SymbolOps with TypeOps { self =>
val subTests = subps.zipWithIndex.map { case (p, i) => apply(tupleSelect(in, i+1, subps.size), p) }
bind(ob, in) `merge` subTests

case AlternativePattern(ob, subps) =>
// one of the alternatives must hold (disjunction)
// we use A \/ B = ~ (~A /\ ~B)
val disjunction = subps.map(p => apply(in, p).negate).reduce(_ `merge` _).negate
bind(ob, in) `merge` disjunction

case up @ UnapplyPattern(ob, _, _, _, subps) =>
val subs = unwrapTuple(up.get(in), subps.size).zip(subps) map (apply).tupled
bind(ob, in) `withCond` Not(up.isEmpty(in)) `merge` subs
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/scala/stainless/ast/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ trait TypeOps extends inox.ast.TypeOps {
case _ => false
}

case AlternativePattern(ob, subs) =>
ob.forall(vd => isSubtypeOf(vd.getType, in)) &&
(subs exists (patternIsTyped(in, _)))

case up @ UnapplyPattern(ob, recs, id, tps, subs) =>
ob.forall(vd => isSubtypeOf(vd.getType, in)) &&
lookupFunction(id).exists(_.tparams.size == tps.size) && {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,12 @@ abstract class RecursiveEvaluator(override val program: Program,
None
}

case (AlternativePattern(ob, subs), scrut) =>
subs.map(matchesPattern(_, scrut)).find(_.isDefined) match {
case Some(_) => Some(obind(ob, expr)) // There should be no mapping nested in the alternative
case _ => None
}

case (up @ UnapplyPattern(ob, rec, id, tps, subs), scrut) =>
val eRec = rec map e
val unapp = e(FunctionInvocation(id, tps, eRec :+ scrut))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ trait TransformerWithType extends TreeTransformer {
val rsubs = (subs zip tps).map(p => transform(p._1, p._2))
t.TuplePattern(ob map transform, rsubs).copiedFrom(pat)

case s.AlternativePattern(ob, subs) =>
val rsubs = subs map (transform(_, tpe))
t.AlternativePattern(ob map transform, rsubs).copiedFrom(pat)

case up @ s.UnapplyPattern(ob, recs, id, tps, subs) =>
val rsubs = (subs zip up.subTypes(tpe)).map(p => transform(p._1, p._2))
val rrecs = (recs zip getFunction(id, tps).params.init).map(p => transform(p._1, p._2.getType))
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/scala/stainless/extraction/oo/TypeEncoding.scala
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,12 @@ class TypeEncoding(override val s: Trees, override val t: Trees)
case s.TuplePattern(Some(vd), _) =>
instanceOfPattern(super.transform(pat, vd.tpe), tpe, vd.tpe)

case s.AlternativePattern(None, subs) =>
t.AlternativePattern(None, subs.map(transform)).copiedFrom(pat)

case s.AlternativePattern(Some(vd), subs) =>
instanceOfPattern(t.AlternativePattern(Some(transform(vd)), subs.map(transform)).copiedFrom(pat), tpe, vd.tpe)

case up @ s.UnapplyPattern(ob, recs, id, tps, subs) =>
val funScope = this `in` id
val FunInfo(fun, tparams) = functions(id)
Expand Down Expand Up @@ -1028,6 +1034,10 @@ class TypeEncoding(override val s: Trees, override val t: Trees)
super.transform(pat, in)
}

case s.AlternativePattern(ob, subs) =>
simple --= s.typeOps.typeParamsOf(in.getType)
super.transform(pat, in)

case up @ s.UnapplyPattern(ob, recs, id, tps, subs) =>
val tparams = infos(id).tparams
simple --= tps.zipWithIndex.flatMap { case (tp, i) =>
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/stainless/extraction/oo/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ trait TypeOps extends innerfuns.TypeOps { self =>
.map(cons => ADTType(cons.sort, tps))
.getOrElse(Untyped)
case TuplePattern(_, subs) => TupleType(subs map patternInType)
case AlternativePattern(_, subs) => leastUpperBound(subs map patternInType)
case ClassPattern(_, ct, subs) => ct
case UnapplyPattern(_, recs, id, tps, _) =>
lookupFunction(id)
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/scala/stainless/genc/phases/Scala2IRPhase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,9 @@ private class S2IRImpl(override val s: tt.type,
update(b, scrutinee)
buildBinOp(scrutinee, O.Equals, lit)(pat.getPos)

case AlternativePattern(_, _) =>
reporter.fatalError(pat.getPos, s"Alternative Pattern, a.k.a pattern disjunction, is not yet supported by GenC")

case UnapplyPattern(_, _, _, _, _) =>
reporter.fatalError(pat.getPos, s"Unapply Pattern, a.k.a. Extractor Objects, is not supported by GenC")
}
Expand Down
19 changes: 19 additions & 0 deletions core/src/main/scala/stainless/transformers/lattices/Core.scala
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,13 @@ trait Core extends Definitions { ocbsl =>
val subScruts = tupleSubscrutinees(scrut, tt)
val (rsubs, subst2) = recHelper(subScruts, subps)
(LabelledPattern.TuplePattern(rsubs), subst2)
case AlternativePattern(_, subps) =>
val (rsubs, subst2) = subps.foldLeft((Seq.empty[LabelledPattern], subst1)) {
case ((acc, subst), subp) =>
val (rsub, subst2) = convertPattern(scrut, subp, subst)
(acc :+ rsub, subst2)
}
(LabelledPattern.Alternative(rsubs), subst2)
case UnapplyPattern(_, recs, id, tps, subps) =>
if (recs.nonEmpty) throw UnsupportedOperationException("recs is not empty")
val unapp = unapplySubScrutinees(scrut, id, tps)
Expand Down Expand Up @@ -2644,6 +2651,7 @@ trait Core extends Definitions { ocbsl =>
assert(bases.size == subps.size)
val rsubs = recHelper(tupleSubscrutinees(scrut, tt), subps)
TuplePattern(bdg, rsubs)
case LabelledPattern.Alternative(subs) => AlternativePattern(bdg, subs.map(sub => convertPattern(scrut, sub, vds)))
case LabelledPattern.Lit(lit) => LiteralPattern(bdg, lit)
case LabelledPattern.Unapply(recs, id, tps, subps) =>
assert(recs.isEmpty)
Expand Down Expand Up @@ -3286,6 +3294,17 @@ trait Core extends Definitions { ocbsl =>
assert(ctxs.isPrefixOf(newCtxs))
PatBdgsAndConds(newCtxs, subscruts ++ recBdgs, recPatConds)

case LabelledPattern.Alternative(sub) =>
val PatBdgsAndConds(newCtxs, _, subPatConds) =
sub.foldLeft(PatBdgsAndConds(ctxs, Seq.empty, Seq.empty)) {
case (PatBdgsAndConds(ctxs, _, condsAcc), subpat) =>
val PatBdgsAndConds(ctxs2, _, conds2) = addPatternBindingsAndConds(ctxs, scrut, subpat)
PatBdgsAndConds(ctxs2, Seq.empty, condsAcc ++ conds2)
}
val cond = codeOfSig(mkOr(subPatConds), BoolTy)
assert(ctxs.isPrefixOf(newCtxs))
PatBdgsAndConds(newCtxs.withCond(cond), Seq.empty, Seq(cond))

case LabelledPattern.Unapply(recs, id, tps, subps) =>
assert(recs.isEmpty)
val unapp = unapplySubScrutinees(scrut, id, tps)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ trait Definitions {
case Wildcard extends LabelledPattern
case ADT(id: Identifier, tps: Seq[Type], sub: Seq[LabelledPattern]) extends LabelledPattern
case TuplePattern(sub: Seq[LabelledPattern]) extends LabelledPattern
case Alternative(sub: Seq[LabelledPattern]) extends LabelledPattern
case Lit[T](lit: Literal[T]) extends LabelledPattern
case Unapply(recs: Seq[Code], id: Identifier, tps: Seq[Type], sub: Seq[LabelledPattern]) extends LabelledPattern

Expand All @@ -231,6 +232,7 @@ trait Definitions {
case Wildcard => Seq.empty
case ADT(_, _, sub) => sub.flatMap(_.allPatterns)
case TuplePattern(sub) => sub.flatMap(_.allPatterns)
case Alternative(sub) => sub.flatMap(_.allPatterns)
case Lit(_) => Seq.empty
case Unapply(_, _, _, sub) => sub.flatMap(_.allPatterns)
})
Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/stainless/utils/Serialization.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ class StainlessSerializer(override val trees: ast.Trees, serializeProducts: Bool
/** An extension to the set of registered classes in the `InoxSerializer`.
* occur within Stainless programs.
*
* The new identifiers in the mapping range from 120 to 172.
* The new identifiers in the mapping range from 120 to 173.
*
* NEXT ID: 173
* NEXT ID: 174
*/
override protected def classSerializers: Map[Class[?], Serializer[?]] =
super.classSerializers ++ Map(
Expand All @@ -40,6 +40,7 @@ class StainlessSerializer(override val trees: ast.Trees, serializeProducts: Bool
stainlessClassSerializer[TuplePattern] (130),
stainlessClassSerializer[LiteralPattern[Any]](131),
stainlessClassSerializer[UnapplyPattern] (132),
stainlessClassSerializer[AlternativePattern] (173),
stainlessClassSerializer[FiniteArray] (133),
stainlessClassSerializer[LargeArray] (134),
stainlessClassSerializer[ArraySelect] (135),
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/stainless/verification/CoqEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ trait CoqEncoder {
ctx.reporter.warning(s"Ignoring type $tpe in the wildcard pattern $p.")
//TODO not tested
CoqTuplePatternVd(ps.map(transformPattern), VariablePattern(Some(makeFresh(id))))
case AlternativePattern(_, _) =>
ctx.reporter.fatalError(s"The translation to Coq does not support disjunctive patterns such as `$p` (${p.getClass}) yet.")
case _ => ctx.reporter.fatalError(s"Coq does not support patterns such as `$p` (${p.getClass}) yet.")
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
object PatternAlternative1 {
sealed trait SignSet
case object None extends SignSet
case object Any extends SignSet
case object Neg extends SignSet
case object Zer extends SignSet
case object Pos extends SignSet
case object NegZer extends SignSet
case object NotZer extends SignSet
case object PosZer extends SignSet

def subsetOf(a: SignSet, b: SignSet): Boolean = (a, b) match {
case (None, _) => true
case (_, Any) => true
case (Neg, NegZer | NotZer) => true
case (Zer, NegZer | PosZer) => true
case (Pos, NotZer | PosZer) => true
case _ => false
}
}
33 changes: 33 additions & 0 deletions frontends/benchmarks/extraction/valid/PatternAlternative2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
object PatternAlternative2 {
sealed trait Tree
case class Node(left: Tree, right: Tree) extends Tree
case class IntLeaf(value: Int) extends Tree
case class StringLeaf(value: String) extends Tree
case class NoneLeaf() extends Tree

def containsNoneLeaf(tree: Tree): Boolean = {
tree match {
case Node(left, right) => containsNoneLeaf(left) || containsNoneLeaf(right)
case NoneLeaf() => true
case _ => false
}
}

def containsOnlyBinaryLeaves(tree: Tree): Boolean = {
tree match {
case Node(left, right) => containsOnlyBinaryLeaves(left) && containsOnlyBinaryLeaves(right)
case IntLeaf(v) => v == 0 || v == 1
case StringLeaf(v) => v == "0" || v == "1"
case _ => true
}
}

def hasBinaryLeaves(tree: Tree): Boolean = {
require(!containsNoneLeaf(tree) && containsOnlyBinaryLeaves(tree))
tree match {
case a @ Node(left: (IntLeaf | StringLeaf), right: (IntLeaf | StringLeaf)) => hasBinaryLeaves(left) && hasBinaryLeaves(right)
case b @ (IntLeaf(0 | 1) | StringLeaf("0" | "1")) => true
case _ => false
}
} ensuring { res => res }
}
33 changes: 33 additions & 0 deletions frontends/benchmarks/verification/valid/PatternAlternative.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
object PatternAlternative {
sealed trait Tree
case class Node(left: Tree, right: Tree) extends Tree
case class IntLeaf(value: Int) extends Tree
case class StringLeaf(value: String) extends Tree
case class NoneLeaf() extends Tree

def containsNoneLeaf(tree: Tree): Boolean = {
tree match {
case Node(left, right) => containsNoneLeaf(left) || containsNoneLeaf(right)
case NoneLeaf() => true
case _ => false
}
}

def containsOnlyBinaryLeaves(tree: Tree): Boolean = {
tree match {
case Node(left, right) => containsOnlyBinaryLeaves(left) && containsOnlyBinaryLeaves(right)
case IntLeaf(v) => v == 0 || v == 1
case StringLeaf(v) => v == "0" || v == "1"
case _ => true
}
}

def hasBinaryLeaves(tree: Tree): Boolean = {
require(!containsNoneLeaf(tree) && containsOnlyBinaryLeaves(tree))
tree match {
case a @ Node(left: (IntLeaf | StringLeaf), right: (IntLeaf | StringLeaf)) => hasBinaryLeaves(left) && hasBinaryLeaves(right)
case b @ (IntLeaf(0 | 1) | StringLeaf("0" | "1")) => true
case _ => false
}
} ensuring { res => res }
}
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,11 @@ class CodeExtraction(inoxCtx: inox.Context,
// Note that this pattern will be correctly rejected as "Unsupported pattern" (in fact, it cannot be even tested at runtime):
// val (aa: B, bb: B) = (a, b)
private def extractPattern(p: tpd.Tree, expectedTpe: Option[xt.Type], binder: Option[xt.ValDef] = None)(using dctx: DefContext): (xt.Pattern, DefContext) = p match {

case a @ Alternative(subpatterns) =>
val (patterns, nctx) = subpatterns.map(extractPattern(_, expectedTpe)).unzip
(xt.AlternativePattern(binder, patterns), dctx)

case b @ Bind(name, t @ Typed(pat, tpt)) =>
val vd = xt.ValDef(FreshIdentifier(name.toString), extractType(tpt), annotationsOf(b.symbol, ignoreOwner = true)).setPos(b.sourcePos)
val pctx = dctx.withNewVar(b.symbol -> (() => vd.toVariable))
Expand Down