progs/compile_arr2.scala
changeset 708 4980f421b3b0
parent 707 2fcd7c2da729
child 710 183663740fb7
equal deleted inserted replaced
707:2fcd7c2da729 708:4980f421b3b0
       
     1 // A Small Compiler for the WHILE Language
       
     2 //
       
     3 //  - includes arrays and a small parser for
       
     4 //    WHILE programs
       
     5 //
       
     6 //  - transpiles BF programs into WHILE programs
       
     7 //    and compiles and runs them
       
     8 //
       
     9 // Call with
       
    10 //
       
    11 // scala compiler_arr.scala
       
    12 
       
    13 
       
    14 
       
    15 // the abstract syntax trees
       
    16 abstract class Stmt
       
    17 abstract class AExp
       
    18 abstract class BExp 
       
    19 type Block = List[Stmt]
       
    20 
       
    21 // statements
       
    22 case object Skip extends Stmt
       
    23 case class ArrayDef(s: String, n: Int) extends Stmt
       
    24 case class If(a: BExp, bl1: Block, bl2: Block) extends Stmt
       
    25 case class While(b: BExp, bl: Block) extends Stmt
       
    26 case class Assign(s: String, a: AExp) extends Stmt             // var := exp
       
    27 case class AssignA(s: String, a1: AExp, a2: AExp) extends Stmt // arr[exp1] := exp2
       
    28 case class Write(s: String) extends Stmt
       
    29 case class Read(s: String) extends Stmt
       
    30 
       
    31 // arithmetic expressions
       
    32 case class Var(s: String) extends AExp
       
    33 case class Num(i: Int) extends AExp
       
    34 case class Aop(o: String, a1: AExp, a2: AExp) extends AExp
       
    35 case class Ref(s: String, a: AExp) extends AExp
       
    36 
       
    37 // boolean expressions
       
    38 case object True extends BExp
       
    39 case object False extends BExp
       
    40 case class Bop(o: String, a1: AExp, a2: AExp) extends BExp
       
    41 
       
    42 
       
    43 // compiler headers needed for the JVM
       
    44 // (contains an init method, as well as methods for read and write)
       
    45 val beginning = """
       
    46 .class public XXX.XXX
       
    47 .super java/lang/Object
       
    48 
       
    49 .method public static write(I)V 
       
    50     .limit locals 1 
       
    51     .limit stack 2 
       
    52     getstatic java/lang/System/out Ljava/io/PrintStream; 
       
    53     iload 0
       
    54     i2c       ; Int => Char
       
    55     invokevirtual java/io/PrintStream/print(C)V   ; println(I)V => print(C)V    
       
    56     return 
       
    57 .end method
       
    58 
       
    59 .method public static main([Ljava/lang/String;)V
       
    60    .limit locals 200
       
    61    .limit stack 200
       
    62 
       
    63 ; COMPILED CODE STARTS   
       
    64 
       
    65 """
       
    66 
       
    67 val ending = """
       
    68 ; COMPILED CODE ENDS
       
    69    return
       
    70 
       
    71 .end method
       
    72 """
       
    73 
       
    74 
       
    75 
       
    76 
       
    77 // for generating new labels
       
    78 var counter = -1
       
    79 
       
    80 def Fresh(x: String) = {
       
    81   counter += 1
       
    82   x ++ "_" ++ counter.toString()
       
    83 }
       
    84 
       
    85 // environments and instructions
       
    86 type Env = Map[String, Int]
       
    87 
       
    88 // convenient string interpolations 
       
    89 // for instructions and labels
       
    90 import scala.language.implicitConversions
       
    91 import scala.language.reflectiveCalls
       
    92 
       
    93 implicit def sring_inters(sc: StringContext) = new {
       
    94     def i(args: Any*): String = "   " ++ sc.s(args:_*) ++ "\n"
       
    95     def l(args: Any*): String = sc.s(args:_*) ++ ":\n"
       
    96 }
       
    97 
       
    98 def compile_op(op: String) = op match {
       
    99   case "+" => i"iadd"
       
   100   case "-" => i"isub"
       
   101   case "*" => i"imul"
       
   102 }
       
   103 
       
   104 // arithmetic expression compilation
       
   105 def compile_aexp(a: AExp, env : Env) : String = a match {
       
   106   case Num(i) => i"ldc $i"
       
   107   case Var(s) => i"iload ${env(s)}"
       
   108   case Aop(op, a1, a2) => 
       
   109     compile_aexp(a1, env) ++ compile_aexp(a2, env) ++ compile_op(op)
       
   110   case Ref(s, a) =>
       
   111     i"aload ${env(s)}" ++ compile_aexp(a, env) ++  i"iaload"
       
   112 }
       
   113 
       
   114 def compile_bop(op: String, jmp: String) = op match {
       
   115   case "==" => i"if_icmpne $jmp"
       
   116   case "!=" => i"if_icmpeq $jmp"
       
   117   case "<"  => i"if_icmpge $jmp"
       
   118 }
       
   119 
       
   120 // boolean expression compilation
       
   121 def compile_bexp(b: BExp, env : Env, jmp: String) : String = b match {
       
   122   case True => ""
       
   123   case False => i"goto $jmp"
       
   124   case Bop(op, a1, a2) => 
       
   125     compile_aexp(a1, env) ++ compile_aexp(a2, env) ++ compile_bop(op, jmp)
       
   126 }
       
   127 
       
   128 // statement compilation
       
   129 def compile_stmt(s: Stmt, env: Env) : (String, Env) = s match {
       
   130   case Skip => ("", env)
       
   131   case Assign(x, a) => {
       
   132      val index = env.getOrElse(x, env.keys.size)
       
   133     (compile_aexp(a, env) ++ i"istore $index \t\t; $x", env + (x -> index)) 
       
   134   } 
       
   135   case If(b, bl1, bl2) => {
       
   136     val if_else = Fresh("If_else")
       
   137     val if_end = Fresh("If_end")
       
   138     val (instrs1, env1) = compile_block(bl1, env)
       
   139     val (instrs2, env2) = compile_block(bl2, env1)
       
   140     (compile_bexp(b, env, if_else) ++
       
   141      instrs1 ++
       
   142      i"goto $if_end" ++
       
   143      l"$if_else" ++
       
   144      instrs2 ++
       
   145      l"$if_end", env2)
       
   146   }
       
   147   case While(b, bl) => {
       
   148     val loop_begin = Fresh("Loop_begin")
       
   149     val loop_end = Fresh("Loop_end")
       
   150     val (instrs1, env1) = compile_block(bl, env)
       
   151     (l"$loop_begin" ++
       
   152      compile_bexp(b, env, loop_end) ++
       
   153      instrs1 ++
       
   154      i"goto $loop_begin" ++
       
   155      l"$loop_end", env1)
       
   156   }
       
   157   case Write(x) => 
       
   158     (i"iload ${env(x)} \t\t; $x" ++ 
       
   159      i"invokestatic XXX/XXX/write(I)V", env)
       
   160   case Read(x) => {
       
   161     val index = env.getOrElse(x, env.keys.size) 
       
   162     (i"invokestatic XXX/XXX/read()I" ++ 
       
   163      i"istore $index \t\t; $x", env + (x -> index))
       
   164   }
       
   165   case ArrayDef(s: String, n: Int) => {
       
   166     val index = if (env.isDefinedAt(s)) throw new Exception("array def error") else 
       
   167                     env.keys.size
       
   168     (i"ldc $n" ++
       
   169      i"newarray int" ++
       
   170      i"astore $index", env + (s -> index))
       
   171   }
       
   172   case AssignA(s, a1, a2) => {
       
   173     val index = if (env.isDefinedAt(s)) env(s) else 
       
   174                     throw new Exception("array not defined")
       
   175     (i"aload ${env(s)}" ++
       
   176      compile_aexp(a1, env) ++
       
   177      compile_aexp(a2, env) ++
       
   178      i"iastore", env)
       
   179   } 
       
   180 }
       
   181 
       
   182 // compilation of a block (i.e. list of statements)
       
   183 def compile_block(bl: Block, env: Env) : (String, Env) = bl match {
       
   184   case Nil => ("", env)
       
   185   case s::bl => {
       
   186     val (instrs1, env1) = compile_stmt(s, env)
       
   187     val (instrs2, env2) = compile_block(bl, env1)
       
   188     (instrs1 ++ instrs2, env2)
       
   189   }
       
   190 }
       
   191 
       
   192 
       
   193 // main compilation function for blocks
       
   194 def compile(bl: Block, class_name: String) : String = {
       
   195   val instructions = compile_block(bl, Map())._1
       
   196   (beginning ++ instructions ++ ending).replaceAllLiterally("XXX", class_name)
       
   197 }
       
   198 
       
   199 
       
   200 import scala.util._
       
   201 import scala.sys.process._
       
   202 
       
   203 def time_needed[T](i: Int, code: => T) = {
       
   204   val start = System.nanoTime()
       
   205   for (j <- 2 to i) code
       
   206   val result = code
       
   207   val end = System.nanoTime()
       
   208   ((end - start) / (i * 1.0e9), result)
       
   209 }
       
   210 
       
   211 def compile_to_file(bl: Block, class_name: String) : Unit = 
       
   212   Using(new java.io.FileWriter(class_name + ".j")) {
       
   213     _.write(compile(bl, class_name))
       
   214   }
       
   215 
       
   216 def compile_and_run(bl: Block, class_name: String) : Unit = {
       
   217   println(s"Start compilation of $class_name")
       
   218   compile_to_file(bl, class_name)
       
   219   println("generated .j file")
       
   220   (s"java -jar jvm/jasmin-2.4/jasmin.jar ${class_name}.j").!!
       
   221   println("generated .class file ")
       
   222   println("Time: " + time_needed(1, (s"java ${class_name}/${class_name}").!)._1)
       
   223 }
       
   224 
       
   225 
       
   226 val arr_test = 
       
   227   List(ArrayDef("a", 10),               // new(a[10])
       
   228        ArrayDef("b", 2),                // new(b[2])
       
   229        AssignA("a", Num(0), Num(10)),   // a[0] := 10
       
   230        Assign("x", Ref("a", Num(0))),   // x := a[0]
       
   231        Write("x"),                      // write x
       
   232        AssignA("b", Num(2), Num(5)),    // b[2] := 5
       
   233        Assign("x", Ref("b", Num(1))),   // x := b[1]
       
   234        Write("x"))                      // write x
       
   235 
       
   236 
       
   237 compile_and_run(arr_test, "a")
       
   238