progs/while-arrays/compile_arrays2.sc
changeset 817 89f9c68fc417
parent 791 d47041b23498
child 829 dc3c35673e94
equal deleted inserted replaced
816:2b6e23985982 817:89f9c68fc417
       
     1 // A Small Compiler for the WHILE Language
       
     2 //
       
     3 // - this compiler contains support for "static" integer 
       
     4 //   arrays (they are mutable but cannot be re-sized)
       
     5 //
       
     6 // Call with 
       
     7 //
       
     8 // amm compile_arrays.sc
       
     9   
       
    10 
       
    11 // the abstract syntax trees for WHILE
       
    12 
       
    13 abstract class Stmt
       
    14 abstract class AExp
       
    15 abstract class BExp 
       
    16 type Block = List[Stmt]
       
    17 
       
    18 // statements
       
    19 case object Skip extends Stmt
       
    20 case class ArrayDef(s: String, n: Int) extends Stmt            // array definition
       
    21 case class If(a: BExp, bl1: Block, bl2: Block) extends Stmt
       
    22 case class While(b: BExp, bl: Block) extends Stmt
       
    23 case class Assign(s: String, a: AExp) extends Stmt             // var := exp
       
    24 case class AssignA(s: String, a1: AExp, a2: AExp) extends Stmt // arr[exp1] := exp2
       
    25 case class Write(s: String) extends Stmt
       
    26 
       
    27 
       
    28 // arithmetic expressions
       
    29 case class Var(s: String) extends AExp
       
    30 case class Num(i: Int) extends AExp
       
    31 case class Aop(o: String, a1: AExp, a2: AExp) extends AExp
       
    32 case class Ref(s: String, a: AExp) extends AExp
       
    33 
       
    34 // boolean expressions
       
    35 case object True extends BExp
       
    36 case object False extends BExp
       
    37 case class Bop(o: String, a1: AExp, a2: AExp) extends BExp
       
    38 
       
    39 
       
    40 // compiler headers needed for the JVM
       
    41 //
       
    42 // - contains a main method and a method for writing out an integer
       
    43 //
       
    44 // - the stack and locals are hard-coded
       
    45 //
       
    46 
       
    47 val beginning = """
       
    48 .class public XXX.XXX
       
    49 .super java/lang/Object
       
    50 
       
    51 .method public static write(I)V 
       
    52     .limit locals 1 
       
    53     .limit stack 2 
       
    54     getstatic java/lang/System/out Ljava/io/PrintStream; 
       
    55     iload 0
       
    56     invokevirtual java/io/PrintStream/print(I)V
       
    57     return 
       
    58 .end method
       
    59 
       
    60 .method public static main([Ljava/lang/String;)V
       
    61    .limit locals 200
       
    62    .limit stack 200
       
    63 
       
    64 ; COMPILED CODE STARTS   
       
    65 
       
    66 """
       
    67 
       
    68 val ending = """
       
    69 ; COMPILED CODE ENDS
       
    70    return
       
    71 
       
    72 .end method
       
    73 """
       
    74 
       
    75 
       
    76 
       
    77 // for generating new labels
       
    78 var counter = -1
       
    79 
       
    80 def Fresh(x: String) = {
       
    81   counter += 1
       
    82   x ++ "_" ++ counter.toString()
       
    83 }
       
    84 
       
    85 // environments for variables and indices
       
    86 type Env = Map[String, Int]
       
    87 
       
    88 // convenient string interpolations 
       
    89 // for generating instructions and labels
       
    90 
       
    91 implicit def sring_inters(sc: StringContext) = new {
       
    92     def i(args: Any*): String = "   " ++ sc.s(args:_*) ++ "\n"
       
    93     def l(args: Any*): String = sc.s(args:_*) ++ ":\n"
       
    94 }
       
    95 
       
    96 def compile_num(i: Int) = 
       
    97   if (0 <= i && i <= 5) i"iconst_$i" else 
       
    98   if (-128 <= i && i <= 127) i"bipush $i" else i"ldc $i"
       
    99 
       
   100 def compile_aload(i: Int) = 
       
   101   if (0 <= i && i <= 3) i"aload_$i" else i"aload $i"
       
   102 
       
   103 def compile_astore(i: Int) = 
       
   104   if (0 <= i && i <= 3) i"astore_$i" else i"astore $i"
       
   105 
       
   106 def compile_iload(i: Int) = 
       
   107   if (0 <= i && i <= 3) i"iload_$i" else i"iload $i"
       
   108 
       
   109 def compile_istore(i: Int) = 
       
   110   if (0 <= i && i <= 3) i"istore_$i" else i"istore $i"
       
   111 
       
   112 
       
   113 def compile_op(op: String) = op match {
       
   114   case "+" => i"iadd"
       
   115   case "-" => i"isub"
       
   116   case "*" => i"imul"
       
   117 }
       
   118 
       
   119 // arithmetic expression compilation
       
   120 def compile_aexp(a: AExp, env : Env) : String = a match {
       
   121   case Num(i) => compile_num(i)
       
   122   case Var(s) => compile_iload(env(s))
       
   123   case Aop(op, a1, a2) => 
       
   124     compile_aexp(a1, env) ++ compile_aexp(a2, env) ++ compile_op(op)
       
   125   case Ref(s, a) =>
       
   126     compile_aload(env(s)) ++ compile_aexp(a, env) ++  i"iaload"
       
   127 }
       
   128 
       
   129 // boolean expression compilation
       
   130 def compile_bexp(b: BExp, env : Env, jmp: String) : String = b match {
       
   131   case True => ""
       
   132   case False => i"goto $jmp"
       
   133   case Bop("==", a1, a2) => 
       
   134     compile_aexp(a1, env) ++ compile_aexp(a2, env) ++ i"if_icmpne $jmp"
       
   135   case Bop("!=", a1, a2) => 
       
   136     compile_aexp(a1, env) ++ compile_aexp(a2, env) ++ i"if_icmpeq $jmp"
       
   137   case Bop("<", a1, a2) => 
       
   138     compile_aexp(a1, env) ++ compile_aexp(a2, env) ++ i"if_icmpge $jmp"
       
   139 }
       
   140 
       
   141 // statement compilation
       
   142 def compile_stmt(s: Stmt, env: Env) : (String, Env) = s match {
       
   143   case Skip => ("", env)
       
   144   case Assign(x, a) => {
       
   145      val index = env.getOrElse(x, env.keys.size)
       
   146     (compile_aexp(a, env) ++ compile_istore(index), env + (x -> index)) 
       
   147   } 
       
   148   case If(b, bl1, bl2) => {
       
   149     val if_else = Fresh("If_else")
       
   150     val if_end = Fresh("If_end")
       
   151     val (instrs1, env1) = compile_block(bl1, env)
       
   152     val (instrs2, env2) = compile_block(bl2, env1)
       
   153     (compile_bexp(b, env, if_else) ++
       
   154      instrs1 ++
       
   155      i"goto $if_end" ++
       
   156      l"$if_else" ++
       
   157      instrs2 ++
       
   158      l"$if_end", env2)
       
   159   }
       
   160   case While(b, bl) => {
       
   161     val loop_begin = Fresh("Loop_begin")
       
   162     val loop_end = Fresh("Loop_end")
       
   163     val (instrs1, env1) = compile_block(bl, env)
       
   164     (l"$loop_begin" ++
       
   165      compile_bexp(b, env, loop_end) ++
       
   166      instrs1 ++
       
   167      i"goto $loop_begin" ++
       
   168      l"$loop_end", env1)
       
   169   }
       
   170   case Write(x) => 
       
   171     (compile_iload(env(x)) ++ 
       
   172      i"invokestatic XXX/XXX/write(I)V", env)
       
   173   case ArrayDef(s: String, n: Int) => {
       
   174     val index = if (env.isDefinedAt(s)) throw new Exception("array def error") else 
       
   175                     env.keys.size
       
   176     (compile_num(n) ++
       
   177      i"newarray int" ++
       
   178      compile_astore(index), env + (s -> index))
       
   179   }
       
   180   case AssignA(s, a1, a2) => {
       
   181     val index = if (env.isDefinedAt(s)) env(s) else 
       
   182                     throw new Exception("array not defined")
       
   183     (compile_aload(env(s)) ++
       
   184      compile_aexp(a1, env) ++
       
   185      compile_aexp(a2, env) ++
       
   186      i"iastore", env)
       
   187   } 
       
   188 }
       
   189 
       
   190 // compile a block (i.e. list of statements)
       
   191 def compile_block(bl: Block, env: Env) : (String, Env) = bl match {
       
   192   case Nil => ("", env)
       
   193   case s::bl => {
       
   194     val (instrs1, env1) = compile_stmt(s, env)
       
   195     val (instrs2, env2) = compile_block(bl, env1)
       
   196     (instrs1 ++ instrs2, env2)
       
   197   }
       
   198 }
       
   199 
       
   200 
       
   201 // main compile function for blocks (adds headers and proper JVM names)
       
   202 def compile(bl: Block, class_name: String) : String = {
       
   203   val instructions = compile_block(bl, Map())._1
       
   204   (beginning ++ instructions ++ ending).replace("XXX", class_name)
       
   205 }
       
   206 
       
   207 
       
   208 
       
   209 // contrived example involving arrays
       
   210 val array_test = 
       
   211   List(ArrayDef("a", 10),               // array a[10]
       
   212        ArrayDef("b", 2),                // array b[2]
       
   213        AssignA("a", Num(0), Num(10)),   // a[0] := 10
       
   214        Assign("x", Ref("a", Num(0))),   // x := a[0]
       
   215        Write("x"),            
       
   216        AssignA("b", Num(1), Num(5)),    // b[1] := 5
       
   217        Assign("x", Ref("b", Num(1))),   // x := b[1] 
       
   218        Write("x"))                     
       
   219 
       
   220 
       
   221 // prints out the JVM-assembly instructions for fib above
       
   222 //
       
   223 //    println(compile(array_test, "arr"))
       
   224 //
       
   225 // can be assembled by hand with 
       
   226 //
       
   227 //    java -jar jasmin.jar arr.j
       
   228 //
       
   229 // and run with
       
   230 //
       
   231 //    java arr/arr
       
   232 
       
   233 // automating the above
       
   234 import ammonite.ops._
       
   235 
       
   236 def compile_to_file(bl: Block, class_name: String) : Unit = 
       
   237   write.over(pwd / s"$class_name.j", compile(bl, class_name))  
       
   238 
       
   239 def compile_and_run(bl: Block, class_name: String) : Unit = {
       
   240   println(s"Start of compilation")
       
   241   compile_to_file(bl, class_name)
       
   242   println(s"generated $class_name.j file")
       
   243   os.proc("java", "-jar", "jasmin.jar", s"$class_name.j").call()
       
   244   println(s"generated $class_name.class file ")
       
   245   //println(os.proc("java", s"${class_name}/${class_name}").call().out.text())
       
   246   os.proc("java", s"${class_name}/${class_name}").call(stdout = os.Inherit)
       
   247   println(s"done.")
       
   248 }
       
   249 
       
   250 
       
   251    
       
   252 @main def main() = {
       
   253   compile_and_run(array_test, "arr")
       
   254 }
       
   255 
       
   256