syntactic convenience for recursive functions
authorChristian Urban <christian dot urban at kcl dot ac dot uk>
Tue, 26 Feb 2013 23:44:57 +0000
changeset 201 09befdf4fc99
parent 200 8dde2e46c69d
child 202 7cfc83879fc9
syntactic convenience for recursive functions
scala/ex.scala
scala/recs.scala
--- a/scala/ex.scala	Tue Feb 26 17:39:47 2013 +0000
+++ b/scala/ex.scala	Tue Feb 26 23:44:57 2013 +0000
@@ -48,7 +48,7 @@
 
 // Recursive function examples 
 println("Add 3 4:   " + Add.eval(3, 4))
-println("Mult 3 4:   " + recs.Mult.eval(3, 4))
+println("Mult 3 4:  " + recs.Mult.eval(3, 4))
 println("Twice 4:   " + Twice.eval(4))
 println("FourTm 4:  " + Fourtimes.eval(4))
 println("Pred 9:    " + Pred.eval(9))
--- a/scala/recs.scala	Tue Feb 26 17:39:47 2013 +0000
+++ b/scala/recs.scala	Tue Feb 26 23:44:57 2013 +0000
@@ -5,6 +5,13 @@
 abstract class Rec {
   def eval(ns: List[Int]) : Int
   def eval(ns: Int*) : Int = eval(ns.toList)
+
+  //syntactic convenience for composition
+  def o(r: Rec) = Cn(r.arity, this, List(r))
+  def o(r: Rec, f: Rec) = Cn(r.arity, this, List(r, f))
+  def o(r: Rec, f: Rec, g: Rec) = Cn(r.arity, this, List(r, f, g))
+
+  def arity : Int
 }
 
 case object Z extends Rec {
@@ -12,6 +19,7 @@
     case n::Nil => 0
     case _ => throw new IllegalArgumentException("Z args: " + ns)
   }
+  override def arity = 1
 } 
 
 case object S extends Rec {
@@ -19,18 +27,28 @@
     case n::Nil => n + 1
     case _ => throw new IllegalArgumentException("S args: " + ns)
   }
+  override def arity = 1 
 } 
 
 case class Id(n: Int, m: Int) extends Rec {
   override def eval(ns: List[Int]) = 
     if (ns.length == n && m < n) ns(m)
     else throw new IllegalArgumentException("Id args: " + ns + "," + n + "," + m)
+
+  override def arity = n
 }
 
 case class Cn(n: Int, f: Rec, gs: List[Rec]) extends Rec {
   override def eval(ns: List[Int]) = 
-    if (ns.length == n) f.eval(gs.map(_.eval(ns)))
-    else throw new IllegalArgumentException("Cn: args")
+    if (ns.length == n && gs.forall(_.arity == n) && f.arity == gs.length) f.eval(gs.map(_.eval(ns)))
+    else throw new IllegalArgumentException("Cn args: " + ns + "," + n)
+
+  override def arity = n
+}
+
+// syntactic convenience
+object Cn {
+  def apply(n: Int, f: Rec, g: Rec) : Rec = new Cn(n, f, List(g))
 }
 
 case class Pr(n: Int, f: Rec, g: Rec) extends Rec {
@@ -43,6 +61,13 @@
       }
     }
     else throw new IllegalArgumentException("Pr: args")
+
+  override def arity = n + 1
+}
+
+// syntactic convenience
+object Pr {
+  def apply(r: Rec, f: Rec) : Rec = Pr(r.arity, r, f) 
 }
 
 case class Mn(n: Int, f: Rec) extends Rec {
@@ -52,41 +77,36 @@
   override def eval(ns: List[Int]) = 
     if (ns.length == n) evaln(ns, 0) 
     else throw new IllegalArgumentException("Mn: args")
+
+  override def arity = n
 }
 
 
 
 // Recursive Function examples
-def arity(f: Rec) = f match {
-  case Z => 1
-  case S => 1
-  case Id(n, _) => n
-  case Cn(n, _, _) => n
-  case Pr(n, _, _) => n + 1
-  case Mn(n, _) => n 
+def Const(n: Int) : Rec = n match {
+  case 0 => Z
+  case n => S o Const(n - 1)
 }
 
-val Add = Pr(1, Id(1, 0), Cn(3, S, List(Id(3, 2))))
-val Mult = Pr(1, Z, Cn(3, Add, List(Id(3, 0), Id(3, 2))))
-val Twice = Cn(1, Mult, List(Id(1, 0), Const(2)))
-val Fourtimes = Cn(1, Mult, List(Id(1, 0), Const(4)))
-val Pred = Cn(1, Pr(1, Z, Id(3, 1)), List(Id(1, 0), Id(1, 0)))
-val Minus = Pr(1, Id(1, 0), Cn(3, Pred, List(Id(3, 2))))
-def Const(n: Int) : Rec = n match {
-  case 0 => Z
-  case n => Cn(1, S, List(Const(n - 1)))
-}
+val Add = Pr(Id(1, 0), S o Id(3, 2))
+val Mult = Pr(Z, Add o (Id(3, 0), Id(3, 2)))
+val Twice = Mult o (Id(1, 0), Const(2))
+val Fourtimes = Mult o (Id(1, 0), Const(4))
+val Pred = Pr(Z, Id(3, 1)) o (Id(1, 0), Id(1, 0))
+val Minus = Pr(Id(1, 0), Pred o Id(3, 2))
+val Power = Pr(Const(1), Mult o (Id(3, 0), Id(3, 2)))
+val Fact = Pr(Const(1), Mult o (Id(3, 2), S o Id(3, 1))) o (Id(1, 0), Id(1, 0))
 
-val Power = Pr(1, Const(1), Cn(3, Mult, List(Id(3, 0), Id(3, 2))))
-val Sign = Cn(1, Minus, List(Const(1), Cn(1, Minus, List(Const(1), Id(1, 0)))))
-val Less = Cn(2, Sign, List(Cn(2, Minus, List(Id(2, 1), Id(2, 0)))))
-val Not = Cn(1, Minus, List(Const(1), Id(1, 0)))
-val Eq = Cn(2, Minus, List(Cn(2, Const(1), List(Id(2, 0))), 
-           Cn(2, Add, List(Cn(2, Minus, List(Id(2, 0), Id(2, 1))), 
-             Cn(2, Minus, List(Id(2, 1), Id(2, 0)))))))
-val Noteq = Cn(2, Not, List(Cn(2, Eq, List(Id(2, 0), Id(2, 1))))) 
-val Conj = Cn(2, Sign, List(Cn(2, Mult, List(Id(2, 0), Id(2, 1)))))
-val Disj = Cn(2, Sign, List(Cn(2, Add, List(Id(2, 0), Id(2, 1)))))
+val Sign = Minus o (Const(1), Minus o (Const(1), Id(1, 0)))
+val Less = Sign o (Minus o (Id(2, 1), Id(2, 0)))
+val Not = Minus o (Const(1), Id(1, 0))
+val Eq = Minus o (Const(1) o Id(2, 0), 
+                  Add o (Minus o (Id(2, 0), Id(2, 1)), 
+                         Minus o (Id(2, 1), Id(2, 0))))
+val Noteq = Not o (Eq o (Id(2, 0), Id(2, 1)))
+val Conj = Sign o (Mult o (Id(2, 0), Id(2, 1)))
+val Disj = Sign o (Add o (Id(2, 0), Id(2, 1)))
 
 def Nargs(n: Int, m: Int) : List[Rec] = m match {
   case 0 => Nil
@@ -94,112 +114,95 @@
 }
 
 def Sigma(f: Rec) = {
-  val ar = arity(f)  
-  Pr(ar - 1, Cn(ar - 1, f, Nargs(ar - 1, ar - 1) :::
-                    List(Cn(ar - 1, Const(0), List(Id(ar - 1, 0))))), 
-             Cn(ar + 1, Add, List(Id(ar + 1, ar), 
-                    Cn(ar + 1, f, Nargs(ar + 1, ar - 1) :::
-                        List(Cn(ar + 1, S, List(Id(ar + 1, ar - 1))))))))
+  val ar = f.arity
+  Pr(Cn(ar - 1, f, Nargs(ar - 1, ar - 1) ::: List(Const(0) o Id(ar - 1, 0))), 
+     Add o (Id(ar + 1, ar), 
+            Cn(ar + 1, f, Nargs(ar + 1, ar - 1) ::: List(S o (Id(ar + 1, ar - 1))))))
 }
 
 def Accum(f: Rec) = {
-  val ar = arity(f)  
-  Pr(ar - 1, Cn(ar - 1, f, Nargs(ar - 1, ar - 1) :::
-                    List(Cn(ar - 1, Const(0), List(Id(ar - 1, 0))))), 
-             Cn(ar + 1, Mult, List(Id(ar + 1, ar), 
-                    Cn(ar + 1, f, Nargs(ar + 1, ar - 1) :::
-                        List(Cn(ar + 1, S, List(Id(ar + 1, ar - 1))))))))
+  val ar = f.arity
+  Pr(Cn(ar - 1, f, Nargs(ar - 1, ar - 1) ::: List(Const(0) o Id(ar - 1, 0))), 
+     Mult o (Id(ar + 1, ar), 
+             Cn(ar + 1, f, Nargs(ar + 1, ar - 1) ::: List(S o Id(ar + 1, ar - 1)))))
 }
 
 def All(t: Rec, f: Rec) = {
-  val ar = arity(f)
-  Cn(ar - 1, Sign, List(Cn(ar - 1, Accum(f), Nargs(ar - 1, ar - 1) ::: List(t))))
+  val ar = f.arity
+  Sign o (Cn(ar - 1, Accum(f), Nargs(ar - 1, ar - 1) ::: List(t)))
 }
 
 def Ex(t: Rec, f: Rec) = {
-  val ar = arity(f)
-  Cn(ar - 1, Sign, List(Cn(ar - 1, Sigma(f), Nargs(ar - 1, ar - 1) ::: List(t))))
+  val ar = f.arity
+  Sign o (Cn(ar - 1, Sigma(f), Nargs(ar - 1, ar - 1) ::: List(t)))
 }
 
 //Definition on page 77 of Boolos's book.
 def Minr(f: Rec) = {
-  val ar = arity(f)
-  val rq = All(Id(ar, ar - 1), 
-    Cn(ar + 1, Not, List(Cn(ar + 1, f, Nargs(ar + 1, ar - 1) ::: List(Id(ar + 1, ar)))))) 
-  Sigma(rq)
+  val ar = f.arity
+  Sigma(All(Id(ar, ar - 1), Not o (Cn(ar + 1, f, Nargs(ar + 1, ar - 1) ::: List(Id(ar + 1, ar))))))
 }
 
 //Definition on page 77 of Boolos's book.
 def Maxr(f: Rec) = {
-  val ar  = arity(f) 
+  val ar  = f.arity
   val rt  = Id(ar + 1, ar - 1) 
-  val rf1 = Cn(ar + 2, Less, List(Id(ar + 2, ar + 1), Id(ar + 2, ar))) 
-  val rf2 = Cn(ar + 2, Not, List(Cn (ar + 2, f, Nargs(ar + 2, ar - 1) ::: List(Id(ar + 2, ar + 1))))) 
-  val rf  = Cn(ar + 2, Disj, List(rf1, rf2)) 
-  val rq  = All(rt, rf) 
-  val Qf  = Cn(ar + 1, Not, List(rq))
+  val rf1 = Less o (Id(ar + 2, ar + 1), Id(ar + 2, ar)) 
+  val rf2 = Not o (Cn (ar + 2, f, Nargs(ar + 2, ar - 1) ::: List(Id(ar + 2, ar + 1)))) 
+  val Qf  = Not o All(rt, Disj o (rf1, rf2)) 
   Cn(ar, Sigma(Qf), Nargs(ar, ar) ::: List(Id(ar, ar - 1)))
 }
 
-//Mutli-way branching statement on page 79 of Boolos's book
-def Branch(rs: List[(Rec, Rec)]) = {
-  val ar = arity(rs.head._1)
-  
-  def Branch_aux(rs: List[(Rec, Rec)], l: Int) : Rec = rs match {
-    case Nil => Cn(l, Z, List(Id(l, l - 1)))
-    case (rg, rc)::recs => Cn(l, Add, List(Cn(l, Mult, List(rg, rc)), Branch_aux(recs, l)))
-  }
-
-  Branch_aux(rs, ar)
-}
-
-//Factorial
-val Fact = {
-  val Fact_aux = Pr(1, Const(1), Cn(3, Mult, List(Id(3, 2), Cn(3, S, List(Id(3, 1))))))
-  Cn(1, Fact_aux, List(Id(1, 0), Id(1, 0)))
-}
 
 //Prime test
-val Prime = Cn(1, Conj, List(Cn(1, Less, List(Const(1), Id(1, 0))),
-                             All(Cn(1, Minus, List(Id(1, 0), Const(1))), 
-                                 All(Cn(2, Minus, List(Id(2, 0), Cn(2, Const(1), List(Id(2, 0))))), 
-                                     Cn(3, Noteq, List(Cn(3, Mult, List(Id(3, 1), Id(3, 2))), Id(3, 0)))))))
+val Prime = Conj o (Less o (Const(1), Id(1, 0)),
+                    All(Minus o (Id(1, 0), Const(1)), 
+                        All(Minus o (Id(2, 0), Const(1) o Id(2, 0)), 
+                            Noteq o (Mult o (Id(3, 1), Id(3, 2)), Id(3, 0)))))
 
-//Returns the first prime number after n
+//Returns the first prime number after n (very slow for n > 4)
 val NextPrime = {
-  val R = Cn(2, Conj, List(Cn(2, Less, List(Id(2, 0), Id(2, 1))), 
-                           Cn(2, Prime, List(Id(2, 1)))))
-  Cn(1, Minr(R), List(Id(1, 0), Cn(1, S, List(Fact))))
+  val R = Conj o (Less o (Id(2, 0), Id(2, 1)), Prime o Id(2, 1))
+  Minr(R) o (Id(1, 0), S o Fact)
 }
 
-val NthPrime = {
-  val NthPrime_aux = Pr(1, Const(2), Cn(3, NextPrime, List(Id(3, 2))))
-  Cn(1, NthPrime_aux, List(Id(1, 0), Id(1, 0)))
-}
+val NthPrime = Pr(Const(2), NextPrime o Id(3, 2)) o (Id(1, 0), Id(1, 0))
 
 def Listsum(k: Int, m: Int) : Rec = m match {
-  case 0 => Cn(k, Z, List(Id(k, 0)))
-  case n => Cn(k, Add, List(Listsum(k, n - 1), Id(k, n - 1)))
+  case 0 => Z o Id(k, 0)
+  case n => Add o (Listsum(k, n - 1), Id(k, n - 1))
 }
 
 //strt-function on page 90 of Boolos, but our definition generalises 
 //the original one in order to deal with multiple input-arguments
 
 def Strt(n: Int) = {
+  
   def Strt_aux(l: Int, k: Int) : Rec = k match {
-    case 0 => Cn(l, Z, List(Id(l, 0)))
+    case 0 => Z o Id(l, 0)
     case n => {
-      val rec_dbound = Cn(l, Add, List(Listsum(l, n - 1), Cn(l, Const(n - 1), List(Id(l, 0)))))
-      Cn(l, Add, List(Strt_aux(l, n - 1), 
-                      Cn(l, Minus, List(Cn(l, Power, List(Cn(l, Const(2), List(Id(l, 0))), 
-                                                          Cn(l, Add, List(Id(l, n - 1), rec_dbound)))), 
-                                        Cn(l, Power, List(Cn(l, Const(2), List(Id(l, 0))), rec_dbound))))))
+      val rec_dbound = Add o (Listsum(l, n - 1), Const(n - 1) o Id(l, 0))
+      Add o (Strt_aux(l, n - 1), 
+             Minus o (Power o (Const(2) o Id(l, 0), Add o (Id(l, n - 1), rec_dbound)), 
+                      Power o (Const(2) o Id(l, 0), rec_dbound)))
     }
   }
 
-  def Rmap(f: Rec, k: Int) = (0 until k).map{i => Cn(k, f, List(Id(k, i)))}.toList
+  def Rmap(f: Rec, k: Int) = (0 until k).map{i => f o Id(k, i)}.toList
  
   Cn(n, Strt_aux(n, n), Rmap(S, n))
 }
 
+
+//Mutli-way branching statement on page 79 of Boolos's book
+def Branch(rs: List[(Rec, Rec)]) = {
+
+  def Branch_aux(rs: List[(Rec, Rec)], l: Int) : Rec = rs match {
+    case Nil => Z o Id(l, l - 1)
+    case (rg, rc)::recs => Add o (Mult o (rg, rc), Branch_aux(recs, l))
+  }
+
+  Branch_aux(rs, rs.head._1.arity)
 }
+
+}