| 
     1 // A Small Compiler for the WHILE Language  | 
         | 
     2 // (it does not use a parser and lexer)  | 
         | 
     3   | 
         | 
     4 // the abstract syntax trees  | 
         | 
     5 abstract class Stmt  | 
         | 
     6 abstract class AExp  | 
         | 
     7 abstract class BExp   | 
         | 
     8 type Block = List[Stmt]  | 
         | 
     9   | 
         | 
    10 // statements  | 
         | 
    11 case object Skip extends Stmt  | 
         | 
    12 case class If(a: BExp, bl1: Block, bl2: Block) extends Stmt  | 
         | 
    13 case class While(b: BExp, bl: Block) extends Stmt  | 
         | 
    14 case class Assign(s: String, a: AExp) extends Stmt  | 
         | 
    15 case class Write(s: String) extends Stmt  | 
         | 
    16 case class Read(s: String) extends Stmt  | 
         | 
    17   | 
         | 
    18 // arithmetic expressions  | 
         | 
    19 case class Var(s: String) extends AExp  | 
         | 
    20 case class Num(i: Int) extends AExp  | 
         | 
    21 case class Aop(o: String, a1: AExp, a2: AExp) extends AExp  | 
         | 
    22   | 
         | 
    23 // boolean expressions  | 
         | 
    24 case object True extends BExp  | 
         | 
    25 case object False extends BExp  | 
         | 
    26 case class Bop(o: String, a1: AExp, a2: AExp) extends BExp  | 
         | 
    27   | 
         | 
    28   | 
         | 
    29 // compiler headers needed for the JVM  | 
         | 
    30 // (contains an init method, as well as methods for read and write)  | 
         | 
    31 val beginning = """  | 
         | 
    32 .class public XXX.XXX  | 
         | 
    33 .super java/lang/Object  | 
         | 
    34   | 
         | 
    35 .method public <init>()V  | 
         | 
    36    aload_0  | 
         | 
    37    invokenonvirtual java/lang/Object/<init>()V  | 
         | 
    38    return  | 
         | 
    39 .end method  | 
         | 
    40   | 
         | 
    41 .method public static write(I)V   | 
         | 
    42     .limit locals 1   | 
         | 
    43     .limit stack 2   | 
         | 
    44     getstatic java/lang/System/out Ljava/io/PrintStream;   | 
         | 
    45     iload 0  | 
         | 
    46     invokevirtual java/io/PrintStream/println(I)V   | 
         | 
    47     return   | 
         | 
    48 .end method  | 
         | 
    49   | 
         | 
    50 .method public static read()I   | 
         | 
    51     .limit locals 10   | 
         | 
    52     .limit stack 10  | 
         | 
    53   | 
         | 
    54     ldc 0   | 
         | 
    55     istore 1  ; this will hold our final integer   | 
         | 
    56 Label1:   | 
         | 
    57     getstatic java/lang/System/in Ljava/io/InputStream;   | 
         | 
    58     invokevirtual java/io/InputStream/read()I   | 
         | 
    59     istore 2   | 
         | 
    60     iload 2   | 
         | 
    61     ldc 10   ; the newline delimiter   | 
         | 
    62     isub   | 
         | 
    63     ifeq Label2   | 
         | 
    64     iload 2   | 
         | 
    65     ldc 32   ; the space delimiter   | 
         | 
    66     isub   | 
         | 
    67     ifeq Label2  | 
         | 
    68   | 
         | 
    69     iload 2   | 
         | 
    70     ldc 48   ; we have our digit in ASCII, have to subtract it from 48   | 
         | 
    71     isub   | 
         | 
    72     ldc 10   | 
         | 
    73     iload 1   | 
         | 
    74     imul   | 
         | 
    75     iadd   | 
         | 
    76     istore 1   | 
         | 
    77     goto Label1   | 
         | 
    78 Label2:   | 
         | 
    79     ;when we come here we have our integer computed in local variable 1   | 
         | 
    80     iload 1   | 
         | 
    81     ireturn   | 
         | 
    82 .end method  | 
         | 
    83   | 
         | 
    84 .method public static main([Ljava/lang/String;)V  | 
         | 
    85    .limit locals 200  | 
         | 
    86    .limit stack 200  | 
         | 
    87   | 
         | 
    88 """  | 
         | 
    89   | 
         | 
    90 val ending = """  | 
         | 
    91   | 
         | 
    92    return  | 
         | 
    93   | 
         | 
    94 .end method  | 
         | 
    95 """  | 
         | 
    96   | 
         | 
    97 println("Start compilation") | 
         | 
    98   | 
         | 
    99   | 
         | 
   100 // for generating new labels  | 
         | 
   101 var counter = -1  | 
         | 
   102   | 
         | 
   103 def Fresh(x: String) = { | 
         | 
   104   counter += 1  | 
         | 
   105   x ++ "_" ++ counter.toString()  | 
         | 
   106 }  | 
         | 
   107   | 
         | 
   108 // environments and instructions  | 
         | 
   109 type Env = Map[String, String]  | 
         | 
   110 type Instrs = List[String]  | 
         | 
   111   | 
         | 
   112 // arithmetic expression compilation  | 
         | 
   113 def compile_aexp(a: AExp, env : Env) : Instrs = a match { | 
         | 
   114   case Num(i) => List("ldc " + i.toString + "\n") | 
         | 
   115   case Var(s) => List("iload " + env(s) + "\n") | 
         | 
   116   case Aop("+", a1, a2) =>  | 
         | 
   117     compile_aexp(a1, env) ++   | 
         | 
   118     compile_aexp(a2, env) ++ List("iadd\n") | 
         | 
   119   case Aop("-", a1, a2) =>  | 
         | 
   120     compile_aexp(a1, env) ++ compile_aexp(a2, env) ++ List("isub\n") | 
         | 
   121   case Aop("*", a1, a2) =>  | 
         | 
   122     compile_aexp(a1, env) ++ compile_aexp(a2, env) ++ List("imul\n") | 
         | 
   123 }  | 
         | 
   124   | 
         | 
   125 // boolean expression compilation  | 
         | 
   126 def compile_bexp(b: BExp, env : Env, jmp: String) : Instrs = b match { | 
         | 
   127   case True => Nil  | 
         | 
   128   case False => List("goto " + jmp + "\n") | 
         | 
   129   case Bop("=", a1, a2) =>  | 
         | 
   130     compile_aexp(a1, env) ++ compile_aexp(a2, env) ++   | 
         | 
   131     List("if_icmpne " + jmp + "\n") | 
         | 
   132   case Bop("!=", a1, a2) =>  | 
         | 
   133     compile_aexp(a1, env) ++ compile_aexp(a2, env) ++   | 
         | 
   134     List("if_icmpeq " + jmp + "\n") | 
         | 
   135   case Bop("<", a1, a2) =>  | 
         | 
   136     compile_aexp(a1, env) ++ compile_aexp(a2, env) ++   | 
         | 
   137     List("if_icmpge " + jmp + "\n") | 
         | 
   138 }  | 
         | 
   139   | 
         | 
   140 // statement compilation  | 
         | 
   141 def compile_stmt(s: Stmt, env: Env) : (Instrs, Env) = s match { | 
         | 
   142   case Skip => (Nil, env)  | 
         | 
   143   case Assign(x, a) => { | 
         | 
   144     val index = if (env.isDefinedAt(x)) env(x) else   | 
         | 
   145                     env.keys.size.toString  | 
         | 
   146     (compile_aexp(a, env) ++   | 
         | 
   147      List("istore " + index + "\n"), env + (x -> index)) | 
         | 
   148   }   | 
         | 
   149   case If(b, bl1, bl2) => { | 
         | 
   150     val if_else = Fresh("If_else") | 
         | 
   151     val if_end = Fresh("If_end") | 
         | 
   152     val (instrs1, env1) = compile_block(bl1, env)  | 
         | 
   153     val (instrs2, env2) = compile_block(bl2, env1)  | 
         | 
   154     (compile_bexp(b, env, if_else) ++  | 
         | 
   155      instrs1 ++  | 
         | 
   156      List("goto " + if_end + "\n") ++ | 
         | 
   157      List("\n" + if_else + ":\n\n") ++ | 
         | 
   158      instrs2 ++  | 
         | 
   159      List("\n" + if_end + ":\n\n"), env2) | 
         | 
   160   }  | 
         | 
   161   case While(b, bl) => { | 
         | 
   162     val loop_begin = Fresh("Loop_begin") | 
         | 
   163     val loop_end = Fresh("Loop_end") | 
         | 
   164     val (instrs1, env1) = compile_block(bl, env)  | 
         | 
   165     (List("\n" + loop_begin + ":\n\n") ++ | 
         | 
   166      compile_bexp(b, env, loop_end) ++  | 
         | 
   167      instrs1 ++  | 
         | 
   168      List("goto " + loop_begin + "\n") ++ | 
         | 
   169      List("\n" + loop_end + ":\n\n"), env1) | 
         | 
   170   }  | 
         | 
   171   case Write(x) =>   | 
         | 
   172     (List("iload " + env(x) + "\n" +  | 
         | 
   173            "invokestatic XXX/XXX/write(I)V\n"), env)  | 
         | 
   174   case Read(x) => { | 
         | 
   175     val index = if (env.isDefinedAt(x)) env(x) else   | 
         | 
   176                     env.keys.size.toString  | 
         | 
   177     (List("invokestatic XXX/XXX/read()I\n" +  | 
         | 
   178           "istore " + index + "\n"), env + (x -> index))  | 
         | 
   179   }  | 
         | 
   180 }  | 
         | 
   181   | 
         | 
   182 // compilation of a block (i.e. list of instructions)  | 
         | 
   183 def compile_block(bl: Block, env: Env) : (Instrs, Env) = bl match { | 
         | 
   184   case Nil => (Nil, env)  | 
         | 
   185   case s::bl => { | 
         | 
   186     val (instrs1, env1) = compile_stmt(s, env)  | 
         | 
   187     val (instrs2, env2) = compile_block(bl, env1)  | 
         | 
   188     (instrs1 ++ instrs2, env2)  | 
         | 
   189   }  | 
         | 
   190 }  | 
         | 
   191   | 
         | 
   192 // main compilation function for blocks  | 
         | 
   193 def compile(bl: Block, class_name: String) : String = { | 
         | 
   194   val instructions = compile_block(bl, Map.empty)._1  | 
         | 
   195   (beginning ++ instructions.mkString ++ ending).replaceAllLiterally("XXX", class_name) | 
         | 
   196 }  | 
         | 
   197   | 
         | 
   198   | 
         | 
   199 // compiling and running files  | 
         | 
   200 //  | 
         | 
   201 // JVM files can be assembled with   | 
         | 
   202 //  | 
         | 
   203 //    java -jar jvm/jasmin-2.4/jasmin.jar fib.j  | 
         | 
   204 //  | 
         | 
   205 // and started with  | 
         | 
   206 //  | 
         | 
   207 //    java fib/fib  | 
         | 
   208   | 
         | 
   209   | 
         | 
   210   | 
         | 
   211 import scala.util._  | 
         | 
   212 import scala.sys.process._  | 
         | 
   213 import scala.io  | 
         | 
   214   | 
         | 
   215 def compile_tofile(bl: Block, class_name: String) = { | 
         | 
   216   val output = compile(bl, class_name)  | 
         | 
   217   val fw = new java.io.FileWriter(class_name + ".j")   | 
         | 
   218   fw.write(output)   | 
         | 
   219   fw.close()  | 
         | 
   220 }  | 
         | 
   221   | 
         | 
   222 def compile_all(bl: Block, class_name: String) : Unit = { | 
         | 
   223   compile_tofile(bl, class_name)  | 
         | 
   224   println("compiled ") | 
         | 
   225   val test = ("java -jar jvm/jasmin-2.4/jasmin.jar " + class_name + ".j").!! | 
         | 
   226   println("assembled ") | 
         | 
   227 }  | 
         | 
   228   | 
         | 
   229 def time_needed[T](i: Int, code: => T) = { | 
         | 
   230   val start = System.nanoTime()  | 
         | 
   231   for (j <- 1 to i) code  | 
         | 
   232   val end = System.nanoTime()  | 
         | 
   233   (end - start)/(i * 1.0e9)  | 
         | 
   234 }  | 
         | 
   235   | 
         | 
   236   | 
         | 
   237 def compile_run(bl: Block, class_name: String) : Unit = { | 
         | 
   238   println("Start compilation") | 
         | 
   239   compile_all(bl, class_name)  | 
         | 
   240   println("Time: " + time_needed(1, ("java " + class_name + "/" + class_name).!)) | 
         | 
   241 }  | 
         | 
   242   | 
         | 
   243   | 
         | 
   244 // Fibonacci numbers as a test-case  | 
         | 
   245 val fib_test =   | 
         | 
   246   List(Assign("n", Num(10)),            //  n := 10;                      | 
         | 
   247        Assign("minus1",Num(0)),         //  minus1 := 0; | 
         | 
   248        Assign("minus2",Num(1)),         //  minus2 := 1; | 
         | 
   249        Assign("temp",Num(0)),           //  temp := 0; | 
         | 
   250        While(Bop("<",Num(0),Var("n")),  //  while n > 0 do  { | 
         | 
   251           List(Assign("temp",Var("minus2")),    //  temp := minus2; | 
         | 
   252                Assign("minus2",Aop("+",Var("minus1"),Var("minus2"))),  | 
         | 
   253                                         //  minus2 := minus1 + minus2;  | 
         | 
   254                Assign("minus1",Var("temp")), //  minus1 := temp; | 
         | 
   255                Assign("n",Aop("-",Var("n"),Num(1))))), //  n := n - 1 }; | 
         | 
   256        Write("minus1"))                 //  write minus1 | 
         | 
   257   | 
         | 
   258   | 
         | 
   259 compile_run(fib_test, "fib")  | 
         | 
   260   | 
         | 
   261   |