|      1 // A Small Compiler for a Simple Functional Language |         | 
|      2 // (includes an external lexer and parser) |         | 
|      3 // |         | 
|      4 // call with  |         | 
|      5 // |         | 
|      6 //     scala fun.scala fact |         | 
|      7 // |         | 
|      8 //     scala fun.scala defs |         | 
|      9 // |         | 
|     10 // this will generate a .j file and run the jasmin |         | 
|     11 // assembler (installed at jvm/jasmin-2.4/jasmin.jar) |         | 
|     12 // it runs the resulting JVM file twice for timing  |         | 
|     13 // purposes. |         | 
|     14  |         | 
|     15  |         | 
|     16  |         | 
|     17  |         | 
|     18 object Compiler { |         | 
|     19  |         | 
|     20 import java.io._   |         | 
|     21 import scala.util._ |         | 
|     22 import scala.sys.process._ |         | 
|     23  |         | 
|     24 // Abstract syntax trees for the Fun language |         | 
|     25 abstract class Exp extends Serializable  |         | 
|     26 abstract class BExp extends Serializable  |         | 
|     27 abstract class Decl extends Serializable  |         | 
|     28  |         | 
|     29 case class Def(name: String, args: List[String], body: Exp) extends Decl |         | 
|     30 case class Main(e: Exp) extends Decl |         | 
|     31  |         | 
|     32 case class Call(name: String, args: List[Exp]) extends Exp |         | 
|     33 case class If(a: BExp, e1: Exp, e2: Exp) extends Exp |         | 
|     34 case class Write(e: Exp) extends Exp |         | 
|     35 case class Var(s: String) extends Exp |         | 
|     36 case class Num(i: Int) extends Exp |         | 
|     37 case class Aop(o: String, a1: Exp, a2: Exp) extends Exp |         | 
|     38 case class Sequence(e1: Exp, e2: Exp) extends Exp |         | 
|     39 case class Bop(o: String, a1: Exp, a2: Exp) extends BExp |         | 
|     40  |         | 
|     41  |         | 
|     42 // compiler - built-in functions  |         | 
|     43 // copied from http://www.ceng.metu.edu.tr/courses/ceng444/link/jvm-cpm.html |         | 
|     44 // |         | 
|     45  |         | 
|     46 val library = """ |         | 
|     47 .class public XXX.XXX |         | 
|     48 .super java/lang/Object |         | 
|     49  |         | 
|     50 .method public static write(I)V  |         | 
|     51         .limit locals 1  |         | 
|     52         .limit stack 2  |         | 
|     53         getstatic java/lang/System/out Ljava/io/PrintStream;  |         | 
|     54         iload 0 |         | 
|     55         invokevirtual java/io/PrintStream/println(I)V  |         | 
|     56         return  |         | 
|     57 .end method |         | 
|     58  |         | 
|     59 """ |         | 
|     60  |         | 
|     61 // calculating the maximal needed stack size |         | 
|     62 def max_stack_exp(e: Exp): Int = e match { |         | 
|     63   case Call(_, args) => args.map(max_stack_exp).sum |         | 
|     64   case If(a, e1, e2) => max_stack_bexp(a) + (List(max_stack_exp(e1), max_stack_exp(e2)).max) |         | 
|     65   case Write(e) => max_stack_exp(e) + 1 |         | 
|     66   case Var(_) => 1 |         | 
|     67   case Num(_) => 1 |         | 
|     68   case Aop(_, a1, a2) => max_stack_exp(a1) + max_stack_exp(a2) |         | 
|     69   case Sequence(e1, e2) => List(max_stack_exp(e1), max_stack_exp(e2)).max |         | 
|     70 } |         | 
|     71  |         | 
|     72 def max_stack_bexp(e: BExp): Int = e match { |         | 
|     73   case Bop(_, a1, a2) => max_stack_exp(a1) + max_stack_exp(a2) |         | 
|     74 } |         | 
|     75  |         | 
|     76  |         | 
|     77 // for generating new labels |         | 
|     78 var counter = -1 |         | 
|     79  |         | 
|     80 def Fresh(x: String) = { |         | 
|     81   counter += 1 |         | 
|     82   x ++ "_" ++ counter.toString() |         | 
|     83 } |         | 
|     84  |         | 
|     85 // convenient string interpolations  |         | 
|     86 // for instructions, labels and methods |         | 
|     87 import scala.language.implicitConversions |         | 
|     88 import scala.language.reflectiveCalls |         | 
|     89  |         | 
|     90 implicit def sring_inters(sc: StringContext) = new { |         | 
|     91     def i(args: Any*): String = "   " ++ sc.s(args:_*) ++ "\n" |         | 
|     92     def l(args: Any*): String = sc.s(args:_*) ++ ":\n" |         | 
|     93     def m(args: Any*): String = sc.s(args:_*) ++ "\n" |         | 
|     94 } |         | 
|     95  |         | 
|     96  |         | 
|     97 type Env = Map[String, Int] |         | 
|     98  |         | 
|     99 // compile expressions |         | 
|    100 def compile_exp(a: Exp, env : Env) : String = a match { |         | 
|    101   case Num(i) => i"ldc $i" |         | 
|    102   case Var(s) => i"iload ${env(s)}" |         | 
|    103   case Aop("+", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"iadd" |         | 
|    104   case Aop("-", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"isub" |         | 
|    105   case Aop("*", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"imul" |         | 
|    106   case Aop("/", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"idiv" |         | 
|    107   case Aop("%", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"irem" |         | 
|    108   case If(b, a1, a2) => { |         | 
|    109     val if_else = Fresh("If_else") |         | 
|    110     val if_end = Fresh("If_end") |         | 
|    111     compile_bexp(b, env, if_else) ++ |         | 
|    112     compile_exp(a1, env) ++ |         | 
|    113     i"goto $if_end" ++ |         | 
|    114     l"$if_else" ++ |         | 
|    115     compile_exp(a2, env) ++ |         | 
|    116     l"$if_end" |         | 
|    117   } |         | 
|    118   case Call(name, args) => { |         | 
|    119     val is = "I" * args.length |         | 
|    120     args.map(a => compile_exp(a, env)).mkString ++ |         | 
|    121     i"invokestatic XXX/XXX/$name($is)I" |         | 
|    122   } |         | 
|    123   case Sequence(a1, a2) => { |         | 
|    124     compile_exp(a1, env) ++ i"pop" ++ compile_exp(a2, env) |         | 
|    125   } |         | 
|    126   case Write(a1) => { |         | 
|    127     compile_exp(a1, env) ++ |         | 
|    128     i"dup" ++ |         | 
|    129     i"invokestatic XXX/XXX/write(I)V" |         | 
|    130   } |         | 
|    131 } |         | 
|    132  |         | 
|    133 // compile boolean expressions |         | 
|    134 def compile_bexp(b: BExp, env : Env, jmp: String) : String = b match { |         | 
|    135   case Bop("==", a1, a2) =>  |         | 
|    136     compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"if_icmpne $jmp" |         | 
|    137   case Bop("!=", a1, a2) =>  |         | 
|    138     compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"if_icmpeq $jmp" |         | 
|    139   case Bop("<", a1, a2) =>  |         | 
|    140     compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"if_icmpge $jmp" |         | 
|    141   case Bop("<=", a1, a2) =>  |         | 
|    142     compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"if_icmpgt $jmp" |         | 
|    143 } |         | 
|    144  |         | 
|    145 // compile function for declarations and main |         | 
|    146 def compile_decl(d: Decl) : String = d match { |         | 
|    147   case Def(name, args, a) => {  |         | 
|    148     val env = args.zipWithIndex.toMap |         | 
|    149     val is = "I" * args.length |         | 
|    150     m".method public static $name($is)I" ++ |         | 
|    151     m".limit locals ${args.length}" ++ |         | 
|    152     m".limit stack ${1 + max_stack_exp(a)}" ++ |         | 
|    153     l"${name}_Start" ++    |         | 
|    154     compile_exp(a, env) ++ |         | 
|    155     i"ireturn" ++ |         | 
|    156     m".end method\n" |         | 
|    157   } |         | 
|    158   case Main(a) => { |         | 
|    159     m".method public static main([Ljava/lang/String;)V" ++ |         | 
|    160     m".limit locals 200" ++ |         | 
|    161     m".limit stack 200" ++ |         | 
|    162     compile_exp(a, Map()) ++ |         | 
|    163     i"invokestatic XXX/XXX/write(I)V" ++ |         | 
|    164     i"return" ++ |         | 
|    165     m".end method\n" |         | 
|    166   } |         | 
|    167 } |         | 
|    168  |         | 
|    169 // main compiler functions |         | 
|    170  |         | 
|    171 def time_needed[T](i: Int, code: => T) = { |         | 
|    172   val start = System.nanoTime() |         | 
|    173   for (j <- 1 to i) code |         | 
|    174   val end = System.nanoTime() |         | 
|    175   (end - start)/(i * 1.0e9) |         | 
|    176 } |         | 
|    177  |         | 
|    178 def deserialise[T](fname: String) : Try[T] = { |         | 
|    179   import scala.util.Using |         | 
|    180   Using(new ObjectInputStream(new FileInputStream(fname))) { |         | 
|    181     in => in.readObject.asInstanceOf[T] |         | 
|    182   } |         | 
|    183 } |         | 
|    184  |         | 
|    185 def compile(class_name: String) : String = { |         | 
|    186   val ast = deserialise[List[Decl]](class_name ++ ".prs").getOrElse(Nil)  |         | 
|    187   val instructions = ast.map(compile_decl).mkString |         | 
|    188   (library + instructions).replaceAllLiterally("XXX", class_name) |         | 
|    189 } |         | 
|    190  |         | 
|    191 def compile_to_file(class_name: String) = { |         | 
|    192   val output = compile(class_name) |         | 
|    193   scala.tools.nsc.io.File(s"${class_name}.j").writeAll(output) |         | 
|    194 } |         | 
|    195  |         | 
|    196 def compile_and_run(class_name: String) : Unit = { |         | 
|    197   compile_to_file(class_name) |         | 
|    198   (s"java -jar jvm/jasmin-2.4/jasmin.jar ${class_name}.j").!! |         | 
|    199   println("Time: " + time_needed(1, (s"java ${class_name}/${class_name}").!)) |         | 
|    200 } |         | 
|    201  |         | 
|    202  |         | 
|    203 // some examples of .fun files |         | 
|    204 //compile_to_file("fact") |         | 
|    205 //compile_and_run("fact") |         | 
|    206 //compile_and_run("defs") |         | 
|    207  |         | 
|    208  |         | 
|    209 def main(args: Array[String]) : Unit =  |         | 
|    210    compile_and_run(args(0)) |         | 
|    211  |         | 
|    212  |         | 
|    213 } |         |