solutions5/bfc.scala
changeset 232 0855c4478f27
child 233 38ea26f227af
equal deleted inserted replaced
231:eecbc9ae73c2 232:0855c4478f27
       
     1 // Part 2 about a "Compiler" for the Brainf*** language
       
     2 //======================================================
       
     3 
       
     4 //object CW10b {
       
     5 
       
     6 // !!! Copy any function you need from file bf.scala !!!
       
     7 //
       
     8 // If you need any auxiliary function, feel free to 
       
     9 // implement it, but do not make any changes to the
       
    10 // templates below.
       
    11 
       
    12 
       
    13 def time_needed[T](n: Int, code: => T) = {
       
    14   val start = System.nanoTime()
       
    15   for (i <- 0 until n) code
       
    16   val end = System.nanoTime()
       
    17   (end - start)/(n * 1.0e9)
       
    18 }
       
    19 
       
    20 type Mem = Map[Int, Int]
       
    21 
       
    22 
       
    23 import io.Source
       
    24 import scala.util._
       
    25 
       
    26 def load_bff(name: String) : String = 
       
    27   Try(Source.fromFile(name)("ISO-8859-1").mkString).getOrElse("")
       
    28 
       
    29 def sread(mem: Mem, mp: Int) : Int = 
       
    30   mem.getOrElse(mp, 0)
       
    31 
       
    32 def write(mem: Mem, mp: Int, v: Int) : Mem =
       
    33   mem.updated(mp, v)
       
    34 
       
    35 def jumpRight(prog: String, pc: Int, level: Int) : Int = {
       
    36   if (prog.length <= pc) pc 
       
    37   else (prog(pc), level) match {
       
    38     case (']', 0) => pc + 1
       
    39     case (']', l) => jumpRight(prog, pc + 1, l - 1)
       
    40     case ('[', l) => jumpRight(prog, pc + 1, l + 1)
       
    41     case (_, l) => jumpRight(prog, pc + 1, l)
       
    42   }
       
    43 }
       
    44 
       
    45 def jumpLeft(prog: String, pc: Int, level: Int) : Int = {
       
    46   if (pc < 0) pc 
       
    47   else (prog(pc), level) match {
       
    48     case ('[', 0) => pc + 1
       
    49     case ('[', l) => jumpLeft(prog, pc - 1, l - 1)
       
    50     case (']', l) => jumpLeft(prog, pc - 1, l + 1)
       
    51     case (_, l) => jumpLeft(prog, pc - 1, l)
       
    52   }
       
    53 }
       
    54 
       
    55 def compute(prog: String, pc: Int, mp: Int, mem: Mem) : Mem = {
       
    56   if (0 <= pc && pc < prog.length) { 
       
    57     val (new_pc, new_mp, new_mem) = prog(pc) match {
       
    58       case '>' => (pc + 1, mp + 1, mem)
       
    59       case '<' => (pc + 1, mp - 1, mem)
       
    60       case '+' => (pc + 1, mp, write(mem, mp, sread(mem, mp) + 1))
       
    61       case '-' => (pc + 1, mp, write(mem, mp, sread(mem, mp) - 1))
       
    62       case '.' => { print(sread(mem, mp).toChar); (pc + 1, mp, mem) }
       
    63       case ',' => (pc + 1, mp, write(mem, mp, Console.in.read().toByte))
       
    64       case '['  => 
       
    65 	if (sread(mem, mp) == 0) (jumpRight(prog, pc + 1, 0), mp, mem) else (pc + 1, mp, mem) 
       
    66       case ']'  => 
       
    67 	if (sread(mem, mp) != 0) (jumpLeft(prog, pc - 1, 0), mp, mem) else (pc + 1, mp, mem) 
       
    68       case _ => (pc + 1, mp, mem)
       
    69     }		     
       
    70     compute(prog, new_pc, new_mp, new_mem)	
       
    71   }
       
    72   else mem
       
    73 }
       
    74 
       
    75 def run(prog: String, m: Mem = Map()) = compute(prog, 0, 0, m)
       
    76 
       
    77 
       
    78 // The baseline to what we compare our "compiler";
       
    79 // it requires something like 60 seconds on my laptop
       
    80 //
       
    81 //time_needed(1, run(load_bff("benchmark.bf")))
       
    82 
       
    83 
       
    84 
       
    85 
       
    86 // (5) Write a function jtable that precomputes the "jump
       
    87 //     table" for a bf-program. This function takes a bf-program 
       
    88 //     as an argument and Returns a Map[Int, Int]. The 
       
    89 //     purpose of this map is to record the information
       
    90 //     that given on the pc-position, say n, is a '[' or a ']',
       
    91 //     then to which pc-position do we need to jump?
       
    92 // 
       
    93 //     For example for the program
       
    94 //    
       
    95 //       "+++++[->++++++++++<]>--<+++[->>++++++++++<<]>>++<<----------[+>.>.<+<]"
       
    96 //
       
    97 //     we obtain the map
       
    98 //
       
    99 //       Map(69 -> 61, 5 -> 20, 60 -> 70, 27 -> 44, 43 -> 28, 19 -> 6)
       
   100 //  
       
   101 //     This states that for the '[' on position 5, we need to
       
   102 //     jump to position 20, which is just after the corresponding ']'.
       
   103 //     Similarly, for the ']' on position 19, we need to jump to
       
   104 //     position 6, which is just after the '[' on position 5, and so
       
   105 //     on. The idea to not calculate this information each time
       
   106 //     we hit a bracket, but just loop uu this information in the 
       
   107 //     jtable.
       
   108 //
       
   109 //     Adapt the compute and run functions from Part 1 in order to
       
   110 //     take advantage of the information in the jtable. 
       
   111  
       
   112 
       
   113 def jtable(pg: String) : Map[Int, Int] = 
       
   114     (0 until pg.length).collect { pc => pg(pc) match {
       
   115       case '[' => (pc -> jumpRight(pg, pc + 1, 0))
       
   116       case ']' => (pc -> jumpLeft(pg, pc - 1, 0))
       
   117     }}.toMap
       
   118 
       
   119 
       
   120 // testcase
       
   121 // jtable("""+++++[->++++++++++<]>--<+++[->>++++++++++<<]>>++<<----------[+>.>.<+<]""")
       
   122 // =>  Map(69 -> 61, 5 -> 20, 60 -> 70, 27 -> 44, 43 -> 28, 19 -> 6)
       
   123 
       
   124 
       
   125 def compute2(pg: String, tb: Map[Int, Int], pc: Int, mp: Int, mem: Mem) : Mem = {
       
   126   if (0 <= pc && pc < pg.length) { 
       
   127     val (new_pc, new_mp, new_mem) = pg(pc) match {
       
   128       case '>' => (pc + 1, mp + 1, mem)
       
   129       case '<' => (pc + 1, mp - 1, mem)
       
   130       case '+' => (pc + 1, mp, write(mem, mp, sread(mem, mp) + 1))
       
   131       case '-' => (pc + 1, mp, write(mem, mp, sread(mem, mp) - 1))
       
   132       case '.' => { print(sread(mem, mp).toChar); (pc + 1, mp, mem) }
       
   133       case ',' => (pc + 1, mp, write(mem, mp, Console.in.read().toByte))
       
   134       case '['  => 
       
   135 	if (sread(mem, mp) == 0) (tb(pc), mp, mem) else (pc + 1, mp, mem) 
       
   136       case ']'  => 
       
   137 	if (sread(mem, mp) != 0) (tb(pc), mp, mem) else (pc + 1, mp, mem) 
       
   138       case _ => (pc + 1, mp, mem)
       
   139     }		     
       
   140     compute2(pg, tb, new_pc, new_mp, new_mem)	
       
   141   }
       
   142   else mem
       
   143 }
       
   144 
       
   145 
       
   146 def run2(pg: String, m: Mem = Map()) = 
       
   147   compute2(pg, jtable(pg), 0, 0, m)
       
   148 
       
   149 //time_needed(1, run2(load_bff("benchmark.bf")))
       
   150 
       
   151 
       
   152 // (6) Write a function optimise which deletes "dead code" (everything
       
   153 // that is not a bf-command) and also replaces substrings of the form
       
   154 // [-] by a new command 0. The idea is that the the loop [-] resets the
       
   155 // memory at the current location to 0. In the compute3 and run3 functions
       
   156 // below we implement this command by writing 0 to mem(mp), then is
       
   157 // write(mem, mp, 0). 
       
   158 //
       
   159 // The easiest way to modify a string in this way is to use the regular
       
   160 // expression """[^<>+-.,\[\]]""", whcih recognises everything that is 
       
   161 // not a bf-command and replace it by the empty string. Similarly the
       
   162 // regular expression """\[-\]""" finds all occurences of [-] and 
       
   163 // by using the Scala method .replaceAll you can repplace it with the 
       
   164 // string "0" standing for the new bf-program.
       
   165 
       
   166 def optimise(s: String) : String = 
       
   167   s.replaceAll("""[^<>+-.,\[\]]""","").replaceAll("""\[-\]""", "0")
       
   168 
       
   169 def compute3(pg: String, tb: Map[Int, Int], pc: Int, mp: Int, mem: Mem) : Mem = {
       
   170   if (0 <= pc && pc < pg.length) { 
       
   171     val (new_pc, new_mp, new_mem) = pg(pc) match {
       
   172       case '0' => (pc + 1, mp, write(mem, mp, 0))
       
   173       case '>' => (pc + 1, mp + 1, mem)
       
   174       case '<' => (pc + 1, mp - 1, mem)
       
   175       case '+' => (pc + 1, mp, write(mem, mp, sread(mem, mp) + 1))
       
   176       case '-' => (pc + 1, mp, write(mem, mp, sread(mem, mp) - 1))
       
   177       case '.' => { print(sread(mem, mp).toChar); (pc + 1, mp, mem) }
       
   178       case ',' => (pc + 1, mp, write(mem, mp, Console.in.read().toByte))
       
   179       case '['  => 
       
   180 	if (sread(mem, mp) == 0) (tb(pc), mp, mem) else (pc + 1, mp, mem) 
       
   181       case ']'  => 
       
   182 	if (sread(mem, mp) != 0) (tb(pc), mp, mem) else (pc + 1, mp, mem) 
       
   183       case _ => (pc + 1, mp, mem)
       
   184     }		     
       
   185     compute3(pg, tb, new_pc, new_mp, new_mem)	
       
   186   }
       
   187   else mem
       
   188 }
       
   189 
       
   190 def run3(pg: String, m: Mem = Map()) = { 
       
   191   val pg_opt = optimise(pg)
       
   192   compute3(pg_opt, jtable(pg_opt), 0, 0, m)
       
   193 }
       
   194 
       
   195 
       
   196 time_needed(1, run3(load_bff("benchmark.bf")))
       
   197 
       
   198 
       
   199 // (7) 
       
   200 
       
   201 def splice(cs: List[Char], acc: List[(Char, Int)]) : List[(Char, Int)] = (cs, acc) match {
       
   202   case (Nil, acc) => acc  
       
   203   case ('[' :: cs, acc) => splice(cs, ('[', 1) :: acc)
       
   204   case (']' :: cs, acc) => splice(cs, (']', 1) :: acc)
       
   205   case ('.' :: cs, acc) => splice(cs, ('.', 1) :: acc)
       
   206   case (',' :: cs, acc) => splice(cs, (',', 1) :: acc)
       
   207   case ('0' :: cs, acc) => splice(cs, ('0', 1) :: acc)
       
   208   case (c :: cs, Nil) => splice(cs, List((c, 1)))
       
   209   case (c :: cs, (d, n) :: acc) => 
       
   210     if (c == d && n < 26) splice(cs, (c, n + 1) :: acc)
       
   211     else splice(cs, (c, 1) :: (d, n) :: acc)
       
   212 }
       
   213 
       
   214 def spl(s: String) = splice(s.toList, Nil).reverse
       
   215 
       
   216 spl(load_bff("benchmark.bf"))
       
   217 
       
   218 def combine(cs: List[(Char, Int)]) : String = {
       
   219   (for ((c, n) <- cs) yield c match {
       
   220     case '>' => List('>', (n + '@').toChar)
       
   221     case '<' => List('<', (n + '@').toChar)
       
   222     case '+' => List('+', (n + '@').toChar)
       
   223     case '-' => List('-', (n + '@').toChar)
       
   224     case _ => List(c)
       
   225   }).flatten.mkString
       
   226 }
       
   227 
       
   228 
       
   229 combine(spl(load_bff("benchmark.bf")))
       
   230 
       
   231 
       
   232 def compute4(pg: String, tb: Map[Int, Int], pc: Int, mp: Int, mem: Mem) : Mem = {
       
   233   if (0 <= pc && pc < pg.length) { 
       
   234     val (new_pc, new_mp, new_mem) = pg(pc) match {
       
   235       case '0' => (pc + 1, mp, write(mem, mp, 0))
       
   236       case '>' => (pc + 2, mp + (pg(pc + 1) - '@'), mem)
       
   237       case '<' => (pc + 2, mp - (pg(pc + 1) - '@'), mem)
       
   238       case '+' => (pc + 2, mp, write(mem, mp, sread(mem, mp) + (pg(pc + 1) - '@')))
       
   239       case '-' => (pc + 2, mp, write(mem, mp, sread(mem, mp) - (pg(pc + 1) - '@')))
       
   240       case '.' => { print(sread(mem, mp).toChar); (pc + 1, mp, mem) }
       
   241       case ',' => (pc + 1, mp, write(mem, mp, Console.in.read().toByte))
       
   242       case '['  => 
       
   243 	if (sread(mem, mp) == 0) (tb(pc), mp, mem) else (pc + 1, mp, mem) 
       
   244       case ']'  => 
       
   245 	if (sread(mem, mp) != 0) (tb(pc), mp, mem) else (pc + 1, mp, mem) 
       
   246       case _ => (pc + 1, mp, mem)
       
   247     }		     
       
   248     compute4(pg, tb, new_pc, new_mp, new_mem)	
       
   249   }
       
   250   else mem
       
   251 }
       
   252 
       
   253 def run4(pg: String, m: Mem = Map()) = { 
       
   254   val pg_opt = combine(spl(optimise(pg)))
       
   255   compute4(pg_opt, jtable(pg_opt), 0, 0, m)
       
   256 }
       
   257 
       
   258 
       
   259 //time_needed(1, run4(load_bff("benchmark.bf")))
       
   260 //time_needed(1, run4(load_bff("mandelbrot.bf")))
       
   261 
       
   262 
       
   263 }