progs/countdown.scala
changeset 505 c06b45a52d50
equal deleted inserted replaced
504:7653dd662db3 505:c06b45a52d50
       
     1 // Countdown Game using the Powerset Function
       
     2 //============================================
       
     3 
       
     4 
       
     5 def powerset(xs: Set[Int]) : Set[Set[Int]] = {
       
     6   if (xs == Set()) Set(Set())
       
     7   else {
       
     8     val ps = powerset(xs.tail)  
       
     9     ps ++ ps.map(_ + xs.head)
       
    10   }
       
    11 }  
       
    12 
       
    13 powerset(Set(1,2,3)).mkString("\n")
       
    14 
       
    15 // proper subsets
       
    16 def psubsets(xs: Set[Int]) = 
       
    17   powerset(xs) -- Set(Set(), xs) 
       
    18 
       
    19 psubsets(Set(1,2,3)).mkString("\n")
       
    20 
       
    21 // pairs of subsets and their "complement"
       
    22 def splits(xs: Set[Int]) : Set[(Set[Int], Set[Int])] =
       
    23   psubsets(xs).map(s => (s, xs -- s))
       
    24 
       
    25 splits(Set(1,2,3,4)).mkString("\n")
       
    26 
       
    27 
       
    28 // ususal trees representing infix notation for expressions
       
    29 enum Tree {
       
    30   case Num(i: Int)
       
    31   case Add(l: Tree, r: Tree)
       
    32   case Mul(l: Tree, r: Tree)
       
    33   case Sub(l: Tree, r: Tree)
       
    34   case Div(l: Tree, r: Tree)
       
    35 }
       
    36 import Tree._
       
    37 
       
    38 //pretty printing for trees
       
    39 def pp(tr: Tree) : String = tr match {
       
    40   case Num(n) => s"$n"
       
    41   case Add(l, r) => s"(${pp(l)} + ${pp(r)})"
       
    42   case Mul(l, r) => s"(${pp(l)} * ${pp(r)})"
       
    43   case Sub(l, r) => s"(${pp(l)} - ${pp(r)})"
       
    44   case Div(l, r) => s"(${pp(l)} / ${pp(r)})"
       
    45 }
       
    46 
       
    47 // evaluating a tree - might fail when dividing by 0 
       
    48 // (the for-notation makes it easy to deal with Options)
       
    49 def eval(tr: Tree) : Option[Int] = tr match {
       
    50   case Num(n) => Some(n)
       
    51   case Add(l, r) => 
       
    52     for (ln <- eval(l); rn <- eval(r)) yield ln + rn
       
    53   case Mul(l, r) => 
       
    54     for (ln <- eval(l); rn <- eval(r)) yield ln * rn 
       
    55   case Sub(l, r) => 
       
    56     for (ln <- eval(l); rn <- eval(r); if rn <= ln) yield ln - rn
       
    57   case Div(l, r) => 
       
    58     for (ln <- eval(l); rn <- eval(r); if rn != 0) 
       
    59       yield ln / rn 
       
    60 }
       
    61 
       
    62 
       
    63 // simple-minded generation of nearly all possible expressions
       
    64 // (the nums argument determines the set of numbers over which
       
    65 //  the expressions are generated)
       
    66 def gen1(nums: Set[Int]) : Set[Tree] = nums.size match {
       
    67   case 0 => Set()
       
    68   case 1 => Set(Num(nums.head))
       
    69   case 2 => {
       
    70     val ln = Num(nums.head)
       
    71     val rn = Num(nums.tail.head)
       
    72     Set(Add(ln, rn), Mul(ln, rn),
       
    73         Sub(ln, rn), Sub(rn, ln),
       
    74         Div(ln, rn), Div(rn, ln))
       
    75   }
       
    76   case n => {
       
    77     val res = 
       
    78       for ((ls, rs) <- splits(nums);
       
    79             lt <- gen1(ls);
       
    80             rt <- gen1(rs)) yield {
       
    81             Set(Add(lt, rt), Mul(lt, rt),
       
    82                 Sub(lt, rt), Sub(rt, lt),
       
    83                 Div(lt, rt), Div(rt, lt))
       
    84            }
       
    85     res.flatten
       
    86   }
       
    87 }
       
    88 
       
    89 
       
    90 // some testcases
       
    91 gen1(Set(1))
       
    92 gen1(Set(1, 2)).mkString("\n")
       
    93 gen1(Set(1, 2, 3)).map(pp).mkString("\n")
       
    94 gen1(Set(1, 2, 3)).map(tr => s"${pp(tr)} = ${eval(tr)}").mkString("\n")
       
    95 
       
    96 gen1(Set(1, 2, 3)).size             // => 168
       
    97 gen1(Set(1, 2, 3, 4, 5, 6)).size    // => 26951040
       
    98 
       
    99 /*
       
   100     It is clear that gen1 generates too many expressions
       
   101     to be fast overall.
       
   102 
       
   103     An easy fix is to not generate improper Subs and 
       
   104     Divs when they are at the leaves.
       
   105 */
       
   106 
       
   107 
       
   108 def gen2(nums: Set[Int]) : Set[Tree] =  nums.size match {
       
   109   case 0 => Set()
       
   110   case 1 => Set(Num(nums.head))
       
   111   case 2 => {
       
   112     val l = nums.head
       
   113     val r = nums.tail.head
       
   114     Set(Add(Num(l), Num(r)), 
       
   115         Mul(Num(l), Num(r)))
       
   116         ++ Option.when(l <= r)(Sub(Num(r), Num(l)))
       
   117         ++ Option.when(l > r)(Sub(Num(l), Num(r)))
       
   118         ++ Option.when(r > 0 && l % r == 0)(Div(Num(l), Num(r)))
       
   119         ++ Option.when(l > 0 && r % l == 0)(Div(Num(r), Num(l)))
       
   120   }
       
   121   case xs => {
       
   122     val res = 
       
   123       for ((lspls, rspls) <- splits(nums);
       
   124            lt <- gen2(lspls); 
       
   125            rt <- gen2(rspls)) yield {
       
   126         Set(Add(lt, rt), Sub(lt, rt),
       
   127             Mul(lt, rt), Div(lt, rt))
       
   128     } 
       
   129     res.flatten
       
   130   }
       
   131 }
       
   132 
       
   133 gen2(Set(1, 2, 3)).size             // => 68
       
   134 gen2(Set(1, 2, 3, 4, 5, 6)).size    // => 6251936
       
   135 
       
   136 // OK, the numbers decreased in gen2 as it does 
       
   137 // not generate leaves like (2 - 3) and (4 / 3).
       
   138 // It might still generate such expressions "higher
       
   139 // up" though, for example (1 + 2) - (4 + 3)
       
   140 
       
   141 println(gen2(Set(1,2,3,4)).map(pp).mkString("\n"))
       
   142 
       
   143 // before taking any time measure, it might be good
       
   144 // to check that no "essential" expression has been
       
   145 // lost by the optimisation in gen2...some eyeballing
       
   146 
       
   147 gen1(Set(1,2,3,4)).map(eval) == gen2(Set(1,2,3,4)).map(eval)
       
   148 gen1(Set(1,2,3,4,5,6)).map(eval) == gen2(Set(1,2,3,4,5,6)).map(eval)
       
   149 
       
   150 
       
   151 // lets start some timing
       
   152 
       
   153 def time_needed[T](n: Int, code: => T) = {
       
   154   val start = System.nanoTime()
       
   155   for (i <- (0 to n)) code
       
   156   val end = System.nanoTime()
       
   157   (end - start) / 1.0e9
       
   158 }
       
   159 
       
   160 // check to reach a target
       
   161 def check(xs: Set[Int], target: Int) =
       
   162   gen2(xs).find(eval(_) == Some(target))
       
   163 
       
   164 // example 1
       
   165 check(Set(50, 5, 4, 9, 10, 8), 560).foreach { sol =>
       
   166   println(s"${pp(sol)} => ${eval(sol)}")
       
   167 }
       
   168 // => ((50 + ((4 / (10 - 9)) * 5)) * 8) => Some(560)
       
   169 
       
   170 time_needed(1, check(Set(50, 5, 4, 9, 10, 8), 560))
       
   171 // => ~14 secs
       
   172 
       
   173 
       
   174 // example 2
       
   175 check(Set(25, 5, 2, 10, 7, 1), 986).foreach { sol =>
       
   176   println(s"${pp(sol)} => ${eval(sol)}")
       
   177 }
       
   178 
       
   179 time_needed(1, check(Set(25, 5, 2, 10, 7, 1), 986))
       
   180 // => ~15 secs
       
   181 
       
   182 // example 3 (unsolvable)
       
   183 check(Set(25, 5, 2, 10, 7, 1), -1)
       
   184 time_needed(1, check(Set(25, 5, 2, 10, 7, 1), -1))
       
   185 // => ~22 secs
       
   186 
       
   187 // example 4
       
   188 check(Set(100, 25, 75, 50, 7, 10), 360).foreach { sol =>
       
   189   println(s"${pp(sol)} => ${eval(sol)}")
       
   190 }
       
   191 time_needed(1, check(Set(100, 25, 75, 50, 7, 10), 360))
       
   192 // => ~14 secs
       
   193 
       
   194 /*
       
   195   Twenty-two seconds in the worst case...this does not yet 
       
   196   look competitive enough.
       
   197 
       
   198   Lets generate the expression together with the result 
       
   199   and restrict the number of expressions in this way.
       
   200 */
       
   201 
       
   202 
       
   203 def gen3(nums: Set[Int]) : Set[(Tree, Int)] =  nums.size match {
       
   204   case 0 => Set()
       
   205   case 1 => Set((Num(nums.head), nums.head))
       
   206   case xs => {
       
   207     val res =
       
   208       for ((lspls, rspls) <- splits(nums);
       
   209            (lt, ln) <- gen3(lspls); 
       
   210            (rt, rn) <- gen3(rspls)) yield {
       
   211         Set((Add(lt, rt), ln + rn),
       
   212             (Mul(lt, rt), ln * rn))
       
   213         ++ Option.when(ln <= rn)((Sub(rt, lt), rn - ln)) 
       
   214         ++ Option.when(ln > rn)((Sub(lt, rt), ln - rn))
       
   215         ++ Option.when(rn > 0 && ln % rn == 0)((Div(lt, rt), ln / rn))
       
   216         ++ Option.when(ln > 0 && rn % ln == 0)((Div(rt, lt), rn / ln))
       
   217     } 
       
   218     res.flatten
       
   219   }
       
   220 }
       
   221 
       
   222 // eyeballing that we did not lose any solutions...
       
   223 gen2(Set(1,2,3,4)).map(eval).flatten == gen3(Set(1,2,3,4)).map(_._2)
       
   224 gen2(Set(1,2,3,4,5,6)).map(eval).flatten == gen3(Set(1,2,3,4,5,6)).map(_._2)
       
   225 
       
   226 // the number of generated expressions
       
   227 gen3(Set(1, 2, 3)).size             // => 104
       
   228 gen3(Set(1, 2, 3))
       
   229 gen3(Set(1, 2, 3, 4, 5, 6)).size    // => 5092300
       
   230 
       
   231 // ...while this version does not "optimise" the leaves
       
   232 // as in gen2, the number of generated expressions grows
       
   233 // slower
       
   234 
       
   235 def check2(xs: Set[Int], target: Int) =
       
   236   gen3(xs).find(_._2 == target)
       
   237 
       
   238 
       
   239 // example 1
       
   240 time_needed(1, check2(Set(50, 5, 4, 9, 10, 8), 560))
       
   241 // => ~14 secs
       
   242 
       
   243 // example 2
       
   244 time_needed(1, check2(Set(25, 5, 2, 10, 7, 1), 986))
       
   245 // => ~18 secs
       
   246 
       
   247 // example 3 (unsolvable)
       
   248 time_needed(1, check2(Set(25, 5, 2, 10, 7, 1), -1))
       
   249 // => ~19 secs
       
   250 
       
   251 // example 4
       
   252 time_needed(1, check2(Set(100, 25, 75, 50, 7, 10), 360))
       
   253 // => ~16 secs
       
   254 
       
   255 // ...we are getting better, but not yet competetive enough.
       
   256 // 
       
   257 // The problem is that splits generates for sets say {1,2,3,4}
       
   258 // the splits ({1,2}, {3,4}) and ({3,4}, {1,2}). This means that
       
   259 // we consider terms (1 + 2) * (3 + 4) and (3 + 4) * (1 + 2) as
       
   260 // separate candidates. We can avoid this duplication by returning
       
   261 // sets of sets of numbers, like 
       
   262 
       
   263 // pairs of subsets and their "complement"
       
   264 def splits2(xs: Set[Int]) : Set[Set[Set[Int]]] =
       
   265   psubsets(xs).map(s => Set(s, xs -- s))
       
   266 
       
   267 splits(Set(1,2,3,4)).mkString("\n")
       
   268 splits2(Set(1,2,3,4)).mkString("\n")
       
   269 
       
   270 def gen4(nums: Set[Int]) : Set[(Tree, Int)] =  nums.size match {
       
   271   case 0 => Set()
       
   272   case 1 => Set((Num(nums.head), nums.head))
       
   273   case xs => {
       
   274     val res =
       
   275       for (spls <- splits2(nums);
       
   276            (lt, ln) <- gen4(spls.head); 
       
   277            (rt, rn) <- gen4(spls.tail.head)) yield {
       
   278         Set((Add(lt, rt), ln + rn),
       
   279             (Mul(lt, rt), ln * rn))
       
   280         ++ Option.when(ln <= rn)((Sub(rt, lt), rn - ln)) 
       
   281         ++ Option.when(ln > rn)((Sub(lt, rt), ln - rn)) 
       
   282         ++ Option.when(rn > 0 && ln % rn == 0)((Div(lt, rt), ln / rn))
       
   283         ++ Option.when(ln > 0 && rn % ln == 0)((Div(rt, lt), rn / ln)) 
       
   284     } 
       
   285     res.flatten
       
   286   }
       
   287 }
       
   288 
       
   289 // eyeballing that we did not lose any solutions...
       
   290 gen4(Set(1,2,3,4)).map(_._2) == gen3(Set(1,2,3,4)).map(_._2)
       
   291 gen4(Set(1,2,3,4,5,6)).map(_._2) == gen3(Set(1,2,3,4,5,6)).map(_._2)
       
   292 
       
   293 
       
   294 gen4(Set(1, 2, 3)).size             // => 43
       
   295 gen4(Set(1, 2, 3, 4))
       
   296 gen4(Set(1, 2, 3, 4, 5, 6)).size    // => 550218
       
   297 
       
   298 def check3(xs: Set[Int], target: Int) =
       
   299   gen4(xs).find(_._2 == target)
       
   300 
       
   301 // example 1
       
   302 check3(Set(50, 5, 4, 9, 10, 8), 560).foreach { sol =>
       
   303   println(s"${pp(sol._1)} => ${eval(sol._1)}")
       
   304 }
       
   305 // (((10 - 5) * (9 * 8)) + (4 * 50)) => Some(560)
       
   306 
       
   307 time_needed(1, check3(Set(50, 5, 4, 9, 10, 8), 560))
       
   308 // => ~1 sec
       
   309 
       
   310 
       
   311 // example 2
       
   312 check3(Set(25, 5, 2, 10, 7, 1), 986).foreach { sol =>
       
   313   println(s"${pp(sol._1)} => ${eval(sol._1)}")
       
   314 }
       
   315 
       
   316 time_needed(1, check3(Set(25, 5, 2, 10, 7, 1), 986))
       
   317 // => ~1 sec
       
   318 
       
   319 // example 3 (unsolvable)
       
   320 check3(Set(25, 5, 2, 10, 7, 1), -1)
       
   321 time_needed(1, check3(Set(25, 5, 2, 10, 7, 1), -1))
       
   322 // => ~2 secs
       
   323 
       
   324 // example 4
       
   325 check3(Set(100, 25, 75, 50, 7, 10), 360).foreach { sol =>
       
   326   println(s"${pp(sol._1)} => ${eval(sol._1)}")
       
   327 }
       
   328 time_needed(1, check3(Set(100, 25, 75, 50, 7, 10), 360))
       
   329 // => ~1 secs
       
   330 
       
   331 
       
   332 
       
   333 time_needed(1, check3(Set(50, 5, 4, 9, 10, 8), 560))
       
   334 time_needed(1, check3(Set(25, 5, 2, 10, 7, 1), 986))
       
   335 time_needed(1, check3(Set(25, 5, 2, 10, 7, 1), -1))
       
   336 time_needed(1, check3(Set(100, 25, 75, 50, 7, 10), 360))
       
   337