progs/fun-bare.scala
changeset 323 4ce07c4abdb4
child 624 8d0af38389bc
equal deleted inserted replaced
322:698ed1c96cd0 323:4ce07c4abdb4
       
     1 // A Small Compiler for a simple functional language
       
     2 
       
     3 // Abstract syntax trees
       
     4 abstract class Exp
       
     5 abstract class BExp 
       
     6 abstract class Decl
       
     7 
       
     8 // function declarations
       
     9 case class Def(name: String, args: List[String], body: Exp) extends Decl
       
    10 case class Main(e: Exp) extends Decl
       
    11 
       
    12 // expressions
       
    13 case class Call(name: String, args: List[Exp]) extends Exp
       
    14 case class If(a: BExp, e1: Exp, e2: Exp) extends Exp
       
    15 case class Write(e: Exp) extends Exp
       
    16 case class Var(s: String) extends Exp
       
    17 case class Num(i: Int) extends Exp
       
    18 case class Aop(o: String, a1: Exp, a2: Exp) extends Exp
       
    19 case class Sequ(e1: Exp, e2: Exp) extends Exp
       
    20 
       
    21 // boolean expressions
       
    22 case class Bop(o: String, a1: Exp, a2: Exp) extends BExp
       
    23 
       
    24 // calculating the maximal needed stack size
       
    25 def max_stack_exp(e: Exp): Int = e match {
       
    26   case Call(_, args) => args.map(max_stack_exp).sum
       
    27   case If(a, e1, e2) => max_stack_bexp(a) + (List(max_stack_exp(e1), max_stack_exp(e2)).max)
       
    28   case Write(e) => max_stack_exp(e) + 1
       
    29   case Var(_) => 1
       
    30   case Num(_) => 1
       
    31   case Aop(_, a1, a2) => max_stack_exp(a1) + max_stack_exp(a2)
       
    32   case Sequ(e1, e2) => List(max_stack_exp(e1), max_stack_exp(e2)).max
       
    33 }
       
    34 def max_stack_bexp(e: BExp): Int = e match {
       
    35   case Bop(_, a1, a2) => max_stack_exp(a1) + max_stack_exp(a2)
       
    36 }
       
    37 
       
    38 // compiler - built-in functions 
       
    39 // copied from http://www.ceng.metu.edu.tr/courses/ceng444/link/jvm-cpm.html
       
    40 //
       
    41 
       
    42 val library = """
       
    43 .class public XXX.XXX
       
    44 .super java/lang/Object
       
    45 
       
    46 .method public <init>()V
       
    47         aload_0
       
    48         invokenonvirtual java/lang/Object/<init>()V
       
    49         return
       
    50 .end method
       
    51 
       
    52 .method public static write(I)V 
       
    53         .limit locals 5 
       
    54         .limit stack 5 
       
    55         iload 0 
       
    56         getstatic java/lang/System/out Ljava/io/PrintStream; 
       
    57         swap 
       
    58         invokevirtual java/io/PrintStream/println(I)V 
       
    59         return 
       
    60 .end method
       
    61 
       
    62 """
       
    63 
       
    64 // for generating new labels
       
    65 var counter = -1
       
    66 
       
    67 def Fresh(x: String) = {
       
    68   counter += 1
       
    69   x ++ "_" ++ counter.toString()
       
    70 }
       
    71 
       
    72 
       
    73 type Env = Map[String, Int]
       
    74 type Instrs = List[String]
       
    75 
       
    76 // compile expressions
       
    77 def compile_exp(a: Exp, env : Env) : Instrs = a match {
       
    78   case Num(i) => List("ldc " + i.toString + "\n")
       
    79   case Var(s) => List("iload " + env(s).toString + "\n")
       
    80   case Aop("+", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ List("iadd\n")
       
    81   case Aop("-", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ List("isub\n")
       
    82   case Aop("*", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ List("imul\n")
       
    83   case Aop("/", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ List("idiv\n")
       
    84   case Aop("%", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ List("irem\n")
       
    85   case If(b, a1, a2) => {
       
    86     val if_else = Fresh("If_else")
       
    87     val if_end = Fresh("If_end")
       
    88     compile_bexp(b, env, if_else) ++
       
    89     compile_exp(a1, env) ++
       
    90     List("goto " + if_end + "\n") ++
       
    91     List("\n" + if_else + ":\n\n") ++
       
    92     compile_exp(a2, env) ++
       
    93     List("\n" + if_end + ":\n\n")
       
    94   }
       
    95   case Call(n, args) => {
       
    96     val is = "I" * args.length
       
    97     args.flatMap(a => compile_exp(a, env)) ++
       
    98     List ("invokestatic XXX/XXX/" + n + "(" + is + ")I\n")
       
    99   }
       
   100   case Sequ(a1, a2) => {
       
   101     compile_exp(a1, env) ++ List("pop\n") ++ compile_exp(a2, env)
       
   102   }
       
   103   case Write(a1) => {
       
   104     compile_exp(a1, env) ++
       
   105     List("dup\n",
       
   106          "invokestatic XXX/XXX/write(I)V\n")
       
   107   }
       
   108 }
       
   109 
       
   110 // compile boolean expressions
       
   111 def compile_bexp(b: BExp, env : Env, jmp: String) : Instrs = b match {
       
   112   case Bop("==", a1, a2) => 
       
   113     compile_exp(a1, env) ++ compile_exp(a2, env) ++ List("if_icmpne " + jmp + "\n")
       
   114   case Bop("!=", a1, a2) => 
       
   115     compile_exp(a1, env) ++ compile_exp(a2, env) ++ List("if_icmpeq " + jmp + "\n")
       
   116   case Bop("<", a1, a2) => 
       
   117     compile_exp(a1, env) ++ compile_exp(a2, env) ++ List("if_icmpge " + jmp + "\n")
       
   118   case Bop("<=", a1, a2) => 
       
   119     compile_exp(a1, env) ++ compile_exp(a2, env) ++ List("if_icmpgt " + jmp + "\n")
       
   120 }
       
   121 
       
   122 // compile function for declarations and main
       
   123 def compile_decl(d: Decl) : Instrs = d match {
       
   124   case Def(name, args, a) => { 
       
   125     val env = args.zipWithIndex.toMap
       
   126     val is = "I" * args.length
       
   127     List(".method public static " + name + "(" + is + ")I \n",
       
   128          ".limit locals " + args.length.toString + "\n",
       
   129          ".limit stack " + (1 + max_stack_exp(a)).toString + "\n",
       
   130          name + "_Start:\n") ++   
       
   131     compile_exp(a, env) ++
       
   132     List("ireturn\n",
       
   133          ".end method \n\n")
       
   134   }
       
   135   case Main(a) => {
       
   136     List(".method public static main([Ljava/lang/String;)V\n",
       
   137          ".limit locals 200\n",
       
   138          ".limit stack 200\n") ++
       
   139     compile_exp(a, Map()) ++
       
   140     List("invokestatic XXX/XXX/write(I)V\n",
       
   141          "return\n",
       
   142          ".end method\n")
       
   143   }
       
   144 }
       
   145 
       
   146 // main compilation function
       
   147 def compile(prog: List[Decl], class_name: String) : String = {
       
   148   val instructions = prog.flatMap(compile_decl).mkString
       
   149   (library + instructions).replaceAllLiterally("XXX", class_name)
       
   150 }
       
   151 
       
   152 
       
   153 
       
   154 
       
   155 // example program (factorials)
       
   156 
       
   157 val test_prog = 
       
   158   List(Def("fact", List("n"),
       
   159          If(Bop("==",Var("n"),Num(0)),
       
   160             Num(1),
       
   161             Aop("*",Var("n"),
       
   162                     Call("fact",List(Aop("-",Var("n"),Num(1))))))),
       
   163 
       
   164        Def("facT",List("n", "acc"),
       
   165          If(Bop("==",Var("n"),Num(0)),
       
   166             Var("acc"),
       
   167             Call("facT",List(Aop("-",Var("n"),Num(1)), 
       
   168                              Aop("*",Var("n"),Var("acc")))))),
       
   169 
       
   170        Main(Sequ(Write(Call("fact",List(Num(10)))),
       
   171                  Write(Call("facT",List(Num(10), Num(1)))))))
       
   172 
       
   173 // prints out the JVM instructions
       
   174 println(compile(test, "fact"))