|         |      1 // A Small Compiler for a Simple Functional Language | 
|         |      2 // (it does not include a parser and lexer) | 
|         |      3  | 
|         |      4 // abstract syntax trees | 
|         |      5 abstract class Exp | 
|         |      6 abstract class BExp  | 
|         |      7 abstract class Decl | 
|         |      8  | 
|         |      9 // functions and declarations | 
|         |     10 case class Def(name: String, args: List[String], body: Exp) extends Decl | 
|         |     11 case class Main(e: Exp) extends Decl | 
|         |     12  | 
|         |     13 // expressions | 
|         |     14 case class Call(name: String, args: List[Exp]) extends Exp | 
|         |     15 case class If(a: BExp, e1: Exp, e2: Exp) extends Exp | 
|         |     16 case class Write(e: Exp) extends Exp | 
|         |     17 case class Var(s: String) extends Exp | 
|         |     18 case class Num(i: Int) extends Exp | 
|         |     19 case class Aop(o: String, a1: Exp, a2: Exp) extends Exp | 
|         |     20 case class Sequ(e1: Exp, e2: Exp) extends Exp | 
|         |     21  | 
|         |     22 // boolean expressions | 
|         |     23 case class Bop(o: String, a1: Exp, a2: Exp) extends BExp | 
|         |     24  | 
|         |     25 // calculating the maximal needed stack size | 
|         |     26 def max_stack_exp(e: Exp): Int = e match { | 
|         |     27   case Call(_, args) => args.map(max_stack_exp).sum | 
|         |     28   case If(a, e1, e2) => max_stack_bexp(a) + (List(max_stack_exp(e1), max_stack_exp(e2)).max) | 
|         |     29   case Write(e) => max_stack_exp(e) + 1 | 
|         |     30   case Var(_) => 1 | 
|         |     31   case Num(_) => 1 | 
|         |     32   case Aop(_, a1, a2) => max_stack_exp(a1) + max_stack_exp(a2) | 
|         |     33   case Sequ(e1, e2) => List(max_stack_exp(e1), max_stack_exp(e2)).max | 
|         |     34 } | 
|         |     35 def max_stack_bexp(e: BExp): Int = e match { | 
|         |     36   case Bop(_, a1, a2) => max_stack_exp(a1) + max_stack_exp(a2) | 
|         |     37 } | 
|         |     38  | 
|         |     39 // compiler - built-in functions  | 
|         |     40 // copied from http://www.ceng.metu.edu.tr/courses/ceng444/link/jvm-cpm.html | 
|         |     41 // | 
|         |     42  | 
|         |     43 val library = """ | 
|         |     44 .class public XXX.XXX | 
|         |     45 .super java/lang/Object | 
|         |     46  | 
|         |     47 .method public static write(I)V  | 
|         |     48         .limit locals 5  | 
|         |     49         .limit stack 5  | 
|         |     50         iload 0  | 
|         |     51         getstatic java/lang/System/out Ljava/io/PrintStream;  | 
|         |     52         swap  | 
|         |     53         invokevirtual java/io/PrintStream/println(I)V  | 
|         |     54         return  | 
|         |     55 .end method | 
|         |     56  | 
|         |     57 """ | 
|         |     58  | 
|         |     59 // for generating new labels | 
|         |     60 var counter = -1 | 
|         |     61  | 
|         |     62 def Fresh(x: String) = { | 
|         |     63   counter += 1 | 
|         |     64   x ++ "_" ++ counter.toString() | 
|         |     65 } | 
|         |     66  | 
|         |     67 // convenient string interpolations for | 
|         |     68 // generating instructions, labels etc | 
|         |     69 import scala.language.implicitConversions | 
|         |     70 import scala.language.reflectiveCalls | 
|         |     71  | 
|         |     72 // convenience for code-generation (string interpolations) | 
|         |     73 implicit def sring_inters(sc: StringContext) = new { | 
|         |     74   def i(args: Any*): String = "   " ++ sc.s(args:_*) ++ "\n"  // instructions | 
|         |     75   def l(args: Any*): String = sc.s(args:_*) ++ ":\n"          // labels | 
|         |     76   def m(args: Any*): String = sc.s(args:_*) ++ "\n"           // methods | 
|         |     77 } | 
|         |     78  | 
|         |     79 // variable / index environments | 
|         |     80 type Env = Map[String, Int] | 
|         |     81  | 
|         |     82 // compile expressions | 
|         |     83 def compile_exp(a: Exp, env : Env) : String = a match { | 
|         |     84   case Num(i) => i"ldc $i" | 
|         |     85   case Var(s) => i"iload ${env(s)}" | 
|         |     86   case Aop("+", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"iadd" | 
|         |     87   case Aop("-", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"isub" | 
|         |     88   case Aop("*", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"imul" | 
|         |     89   case Aop("/", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"idiv" | 
|         |     90   case Aop("%", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"irem" | 
|         |     91   case If(b, a1, a2) => { | 
|         |     92     val if_else = Fresh("If_else") | 
|         |     93     val if_end = Fresh("If_end") | 
|         |     94     compile_bexp(b, env, if_else) ++ | 
|         |     95     compile_exp(a1, env) ++ | 
|         |     96     i"goto $if_end" ++ | 
|         |     97     l"$if_else" ++ | 
|         |     98     compile_exp(a2, env) ++ | 
|         |     99     l"$if_end" | 
|         |    100   } | 
|         |    101   case Call(name, args) => { | 
|         |    102     val is = "I" * args.length | 
|         |    103     args.map(a => compile_exp(a, env)).mkString ++ | 
|         |    104     i"invokestatic XXX/XXX/$name($is)I" | 
|         |    105   } | 
|         |    106   case Sequ(a1, a2) => { | 
|         |    107     compile_exp(a1, env) ++ i"pop" ++ compile_exp(a2, env) | 
|         |    108   } | 
|         |    109   case Write(a1) => { | 
|         |    110     compile_exp(a1, env) ++ | 
|         |    111     i"dup" ++ | 
|         |    112     i"invokestatic XXX/XXX/write(I)V" | 
|         |    113   } | 
|         |    114 } | 
|         |    115  | 
|         |    116 // compile boolean expressions | 
|         |    117 def compile_bexp(b: BExp, env : Env, jmp: String) : String = b match { | 
|         |    118   case Bop("==", a1, a2) =>  | 
|         |    119     compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"if_icmpne $jmp" | 
|         |    120   case Bop("!=", a1, a2) =>  | 
|         |    121     compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"if_icmpeq $jmp" | 
|         |    122   case Bop("<", a1, a2) =>  | 
|         |    123     compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"if_icmpge $jmp" | 
|         |    124   case Bop("<=", a1, a2) =>  | 
|         |    125     compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"if_icmpgt $jmp" | 
|         |    126 } | 
|         |    127  | 
|         |    128 // compile functions and declarations | 
|         |    129 def compile_decl(d: Decl) : String = d match { | 
|         |    130   case Def(name, args, a) => {  | 
|         |    131     val env = args.zipWithIndex.toMap | 
|         |    132     val is = "I" * args.length | 
|         |    133     m".method public static $name($is)I" ++ | 
|         |    134     m".limit locals ${args.length.toString}" ++ | 
|         |    135     m".limit stack ${1 + max_stack_exp(a)}" ++ | 
|         |    136     l"${name}_Start" ++    | 
|         |    137     compile_exp(a, env) ++ | 
|         |    138     i"ireturn" ++ | 
|         |    139     m".end method\n" | 
|         |    140   } | 
|         |    141   case Main(a) => { | 
|         |    142     m".method public static main([Ljava/lang/String;)V" ++ | 
|         |    143     m".limit locals 200" ++ | 
|         |    144     m".limit stack 200" ++ | 
|         |    145     compile_exp(a, Map()) ++ | 
|         |    146     i"invokestatic XXX/XXX/write(I)V" ++ | 
|         |    147     i"return" ++ | 
|         |    148     m".end method\n" | 
|         |    149   } | 
|         |    150 } | 
|         |    151  | 
|         |    152 // the main compilation function | 
|         |    153 def compile(prog: List[Decl], class_name: String) : String = { | 
|         |    154   val instructions = prog.map(compile_decl).mkString | 
|         |    155   (library + instructions).replaceAllLiterally("XXX", class_name) | 
|         |    156 } | 
|         |    157  | 
|         |    158  | 
|         |    159  | 
|         |    160  | 
|         |    161 // An example program (factorials) | 
|         |    162 // | 
|         |    163 // def fact(n) = if n == 0 then 1 else n * fact(n - 1); | 
|         |    164 // def facT(n, acc) = if n == 0 then acc else facT(n - 1, n * acc); | 
|         |    165 //  | 
|         |    166 // fact(10) ; facT(15, 1) | 
|         |    167 //  | 
|         |    168  | 
|         |    169  | 
|         |    170 val test_prog =  | 
|         |    171   List(Def("fact", List("n"), | 
|         |    172          If(Bop("==",Var("n"),Num(0)), | 
|         |    173             Num(1), | 
|         |    174             Aop("*",Var("n"), | 
|         |    175                     Call("fact",List(Aop("-",Var("n"),Num(1))))))), | 
|         |    176  | 
|         |    177        Def("facT",List("n", "acc"), | 
|         |    178          If(Bop("==",Var("n"),Num(0)), | 
|         |    179             Var("acc"), | 
|         |    180             Call("facT",List(Aop("-",Var("n"),Num(1)),  | 
|         |    181                              Aop("*",Var("n"),Var("acc")))))), | 
|         |    182  | 
|         |    183        Main(Sequ(Write(Call("fact",List(Num(10)))), | 
|         |    184                  Write(Call("facT",List(Num(10), Num(1))))))) | 
|         |    185  | 
|         |    186 // prints out the JVM instructions | 
|         |    187 println(compile(test_prog, "fact")) |