| 
     1 // A Small Compiler for a Simple Functional Language  | 
         | 
     2 // (it does not include a parser and lexer)  | 
         | 
     3 //  | 
         | 
     4 // call with  | 
         | 
     5 //  | 
         | 
     6 //    amm fun.sc  | 
         | 
     7 //  | 
         | 
     8 // this will print out the JVM instructions for two  | 
         | 
     9 // factorial functions  | 
         | 
    10   | 
         | 
    11   | 
         | 
    12 // abstract syntax trees  | 
         | 
    13 abstract class Exp  | 
         | 
    14 abstract class BExp   | 
         | 
    15 abstract class Decl  | 
         | 
    16   | 
         | 
    17 // functions and declarations  | 
         | 
    18 case class Def(name: String, args: List[String], body: Exp) extends Decl  | 
         | 
    19 case class Main(e: Exp) extends Decl  | 
         | 
    20   | 
         | 
    21 // expressions  | 
         | 
    22 case class Call(name: String, args: List[Exp]) extends Exp  | 
         | 
    23 case class If(a: BExp, e1: Exp, e2: Exp) extends Exp  | 
         | 
    24 case class Write(e: Exp) extends Exp  | 
         | 
    25 case class Var(s: String) extends Exp  | 
         | 
    26 case class Num(i: Int) extends Exp  | 
         | 
    27 case class Aop(o: String, a1: Exp, a2: Exp) extends Exp  | 
         | 
    28 case class Sequ(e1: Exp, e2: Exp) extends Exp  | 
         | 
    29   | 
         | 
    30 // boolean expressions  | 
         | 
    31 case class Bop(o: String, a1: Exp, a2: Exp) extends BExp  | 
         | 
    32   | 
         | 
    33 // calculating the maximal needed stack size  | 
         | 
    34 def max_stack_exp(e: Exp): Int = e match { | 
         | 
    35   case Call(_, args) => args.map(max_stack_exp).sum  | 
         | 
    36   case If(a, e1, e2) =>   | 
         | 
    37     max_stack_bexp(a) + (List(max_stack_exp(e1), max_stack_exp(e2)).max)  | 
         | 
    38   case Write(e) => max_stack_exp(e) + 1  | 
         | 
    39   case Var(_) => 1  | 
         | 
    40   case Num(_) => 1  | 
         | 
    41   case Aop(_, a1, a2) => max_stack_exp(a1) + max_stack_exp(a2)  | 
         | 
    42   case Sequ(e1, e2) => List(max_stack_exp(e1), max_stack_exp(e2)).max  | 
         | 
    43 }  | 
         | 
    44 def max_stack_bexp(e: BExp): Int = e match { | 
         | 
    45   case Bop(_, a1, a2) => max_stack_exp(a1) + max_stack_exp(a2)  | 
         | 
    46 }  | 
         | 
    47   | 
         | 
    48 // compiler - built-in functions   | 
         | 
    49 // copied from http://www.ceng.metu.edu.tr/courses/ceng444/link/jvm-cpm.html  | 
         | 
    50 //  | 
         | 
    51   | 
         | 
    52 val library = """  | 
         | 
    53 .class public XXX.XXX  | 
         | 
    54 .super java/lang/Object  | 
         | 
    55   | 
         | 
    56 .method public static write(I)V   | 
         | 
    57         .limit locals 5   | 
         | 
    58         .limit stack 5   | 
         | 
    59         iload 0   | 
         | 
    60         getstatic java/lang/System/out Ljava/io/PrintStream;   | 
         | 
    61         swap   | 
         | 
    62         invokevirtual java/io/PrintStream/println(I)V   | 
         | 
    63         return   | 
         | 
    64 .end method  | 
         | 
    65   | 
         | 
    66 """  | 
         | 
    67   | 
         | 
    68 // for generating new labels  | 
         | 
    69 var counter = -1  | 
         | 
    70   | 
         | 
    71 def Fresh(x: String) = { | 
         | 
    72   counter += 1  | 
         | 
    73   x ++ "_" ++ counter.toString()  | 
         | 
    74 }  | 
         | 
    75   | 
         | 
    76 // convenient string interpolations for  | 
         | 
    77 // generating instructions, labels etc  | 
         | 
    78 import scala.language.implicitConversions  | 
         | 
    79 import scala.language.reflectiveCalls  | 
         | 
    80   | 
         | 
    81 // convenience for code-generation (string interpolations)  | 
         | 
    82 implicit def sring_inters(sc: StringContext) = new { | 
         | 
    83   def i(args: Any*): String = "   " ++ sc.s(args:_*) ++ "\n"  // instructions  | 
         | 
    84   def l(args: Any*): String = sc.s(args:_*) ++ ":\n"          // labels  | 
         | 
    85   def m(args: Any*): String = sc.s(args:_*) ++ "\n"           // methods  | 
         | 
    86 }  | 
         | 
    87   | 
         | 
    88 // variable / index environments  | 
         | 
    89 type Env = Map[String, Int]  | 
         | 
    90   | 
         | 
    91 // compile expressions  | 
         | 
    92 def compile_exp(a: Exp, env : Env) : String = a match { | 
         | 
    93   case Num(i) => i"ldc $i"  | 
         | 
    94   case Var(s) => i"iload ${env(s)}" | 
         | 
    95   case Aop("+", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"iadd" | 
         | 
    96   case Aop("-", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"isub" | 
         | 
    97   case Aop("*", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"imul" | 
         | 
    98   case Aop("/", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"idiv" | 
         | 
    99   case Aop("%", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"irem" | 
         | 
   100   case If(b, a1, a2) => { | 
         | 
   101     val if_else = Fresh("If_else") | 
         | 
   102     val if_end = Fresh("If_end") | 
         | 
   103     compile_bexp(b, env, if_else) ++  | 
         | 
   104     compile_exp(a1, env) ++  | 
         | 
   105     i"goto $if_end" ++  | 
         | 
   106     l"$if_else" ++  | 
         | 
   107     compile_exp(a2, env) ++  | 
         | 
   108     l"$if_end"  | 
         | 
   109   }  | 
         | 
   110   case Call(name, args) => { | 
         | 
   111     val is = "I" * args.length  | 
         | 
   112     args.map(a => compile_exp(a, env)).mkString ++  | 
         | 
   113     i"invokestatic XXX/XXX/$name($is)I"  | 
         | 
   114   }  | 
         | 
   115   case Sequ(a1, a2) => { | 
         | 
   116     compile_exp(a1, env) ++ i"pop" ++ compile_exp(a2, env)  | 
         | 
   117   }  | 
         | 
   118   case Write(a1) => { | 
         | 
   119     compile_exp(a1, env) ++  | 
         | 
   120     i"dup" ++  | 
         | 
   121     i"invokestatic XXX/XXX/write(I)V"  | 
         | 
   122   }  | 
         | 
   123 }  | 
         | 
   124   | 
         | 
   125 // compile boolean expressions  | 
         | 
   126 def compile_bexp(b: BExp, env : Env, jmp: String) : String = b match { | 
         | 
   127   case Bop("==", a1, a2) =>  | 
         | 
   128     compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"if_icmpne $jmp"  | 
         | 
   129   case Bop("!=", a1, a2) =>  | 
         | 
   130     compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"if_icmpeq $jmp"  | 
         | 
   131   case Bop("<", a1, a2) =>  | 
         | 
   132     compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"if_icmpge $jmp"  | 
         | 
   133   case Bop("<=", a1, a2) =>  | 
         | 
   134     compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"if_icmpgt $jmp"  | 
         | 
   135 }  | 
         | 
   136   | 
         | 
   137 // compile functions and declarations  | 
         | 
   138 def compile_decl(d: Decl) : String = d match { | 
         | 
   139   case Def(name, args, a) => {  | 
         | 
   140     val env = args.zipWithIndex.toMap  | 
         | 
   141     val is = "I" * args.length  | 
         | 
   142     m".method public static $name($is)I" ++  | 
         | 
   143     m".limit locals ${args.length.toString}" ++ | 
         | 
   144     m".limit stack ${1 + max_stack_exp(a)}" ++ | 
         | 
   145     l"${name}_Start" ++    | 
         | 
   146     compile_exp(a, env) ++  | 
         | 
   147     i"ireturn" ++  | 
         | 
   148     m".end method\n"  | 
         | 
   149   }  | 
         | 
   150   case Main(a) => { | 
         | 
   151     m".method public static main([Ljava/lang/String;)V" ++  | 
         | 
   152     m".limit locals 200" ++  | 
         | 
   153     m".limit stack 200" ++  | 
         | 
   154     compile_exp(a, Map()) ++  | 
         | 
   155     i"invokestatic XXX/XXX/write(I)V" ++  | 
         | 
   156     i"return" ++  | 
         | 
   157     m".end method\n"  | 
         | 
   158   }  | 
         | 
   159 }  | 
         | 
   160   | 
         | 
   161 // the main compilation function  | 
         | 
   162 def compile(prog: List[Decl], class_name: String) : String = { | 
         | 
   163   val instructions = prog.map(compile_decl).mkString  | 
         | 
   164   (library + instructions).replaceAllLiterally("XXX", class_name) | 
         | 
   165 }  | 
         | 
   166   | 
         | 
   167   | 
         | 
   168   | 
         | 
   169   | 
         | 
   170 // An example program (two versions of factorial)  | 
         | 
   171 //  | 
         | 
   172 // def fact(n) =   | 
         | 
   173 //    if n == 0 then 1 else n * fact(n - 1);  | 
         | 
   174 //  | 
         | 
   175 // def facT(n, acc) =   | 
         | 
   176 //    if n == 0 then acc else facT(n - 1, n * acc);  | 
         | 
   177 //   | 
         | 
   178 // fact(10) ; facT(10, 1)  | 
         | 
   179 //   | 
         | 
   180   | 
         | 
   181   | 
         | 
   182 val test_prog =   | 
         | 
   183   List(Def("fact", List("n"), | 
         | 
   184          If(Bop("==",Var("n"),Num(0)), | 
         | 
   185             Num(1),  | 
         | 
   186             Aop("*",Var("n"), | 
         | 
   187                     Call("fact",List(Aop("-",Var("n"),Num(1))))))), | 
         | 
   188   | 
         | 
   189        Def("facT",List("n", "acc"), | 
         | 
   190          If(Bop("==",Var("n"),Num(0)), | 
         | 
   191             Var("acc"), | 
         | 
   192             Call("facT",List(Aop("-",Var("n"),Num(1)),  | 
         | 
   193                              Aop("*",Var("n"),Var("acc")))))), | 
         | 
   194   | 
         | 
   195        Main(Sequ(Write(Call("fact",List(Num(10)))), | 
         | 
   196                  Write(Call("facT",List(Num(10), Num(1))))))) | 
         | 
   197   | 
         | 
   198 // prints out the JVM instructions  | 
         | 
   199 @main  | 
         | 
   200 def test() =   | 
         | 
   201   println(compile(test_prog, "fact"))  |