|
1 // A Small Compiler for a simple functional language |
|
2 |
|
3 // Abstract syntax trees |
|
4 abstract class Exp |
|
5 abstract class BExp |
|
6 abstract class Decl |
|
7 |
|
8 // function declarations |
|
9 case class Def(name: String, args: List[String], body: Exp) extends Decl |
|
10 case class Main(e: Exp) extends Decl |
|
11 |
|
12 // expressions |
|
13 case class Call(name: String, args: List[Exp]) extends Exp |
|
14 case class If(a: BExp, e1: Exp, e2: Exp) extends Exp |
|
15 case class Write(e: Exp) extends Exp |
|
16 case class Var(s: String) extends Exp |
|
17 case class Num(i: Int) extends Exp |
|
18 case class Aop(o: String, a1: Exp, a2: Exp) extends Exp |
|
19 case class Sequ(e1: Exp, e2: Exp) extends Exp |
|
20 |
|
21 // boolean expressions |
|
22 case class Bop(o: String, a1: Exp, a2: Exp) extends BExp |
|
23 |
|
24 // calculating the maximal needed stack size |
|
25 def max_stack_exp(e: Exp): Int = e match { |
|
26 case Call(_, args) => args.map(max_stack_exp).sum |
|
27 case If(a, e1, e2) => max_stack_bexp(a) + (List(max_stack_exp(e1), max_stack_exp(e2)).max) |
|
28 case Write(e) => max_stack_exp(e) + 1 |
|
29 case Var(_) => 1 |
|
30 case Num(_) => 1 |
|
31 case Aop(_, a1, a2) => max_stack_exp(a1) + max_stack_exp(a2) |
|
32 case Sequ(e1, e2) => List(max_stack_exp(e1), max_stack_exp(e2)).max |
|
33 } |
|
34 def max_stack_bexp(e: BExp): Int = e match { |
|
35 case Bop(_, a1, a2) => max_stack_exp(a1) + max_stack_exp(a2) |
|
36 } |
|
37 |
|
38 // compiler - built-in functions |
|
39 // copied from http://www.ceng.metu.edu.tr/courses/ceng444/link/jvm-cpm.html |
|
40 // |
|
41 |
|
42 val library = """ |
|
43 .class public XXX.XXX |
|
44 .super java/lang/Object |
|
45 |
|
46 .method public <init>()V |
|
47 aload_0 |
|
48 invokenonvirtual java/lang/Object/<init>()V |
|
49 return |
|
50 .end method |
|
51 |
|
52 .method public static write(I)V |
|
53 .limit locals 5 |
|
54 .limit stack 5 |
|
55 iload 0 |
|
56 getstatic java/lang/System/out Ljava/io/PrintStream; |
|
57 swap |
|
58 invokevirtual java/io/PrintStream/println(I)V |
|
59 return |
|
60 .end method |
|
61 |
|
62 """ |
|
63 |
|
64 // for generating new labels |
|
65 var counter = -1 |
|
66 |
|
67 def Fresh(x: String) = { |
|
68 counter += 1 |
|
69 x ++ "_" ++ counter.toString() |
|
70 } |
|
71 |
|
72 |
|
73 type Env = Map[String, Int] |
|
74 type Instrs = List[String] |
|
75 |
|
76 // compile expressions |
|
77 def compile_exp(a: Exp, env : Env) : Instrs = a match { |
|
78 case Num(i) => List("ldc " + i.toString + "\n") |
|
79 case Var(s) => List("iload " + env(s).toString + "\n") |
|
80 case Aop("+", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ List("iadd\n") |
|
81 case Aop("-", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ List("isub\n") |
|
82 case Aop("*", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ List("imul\n") |
|
83 case Aop("/", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ List("idiv\n") |
|
84 case Aop("%", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ List("irem\n") |
|
85 case If(b, a1, a2) => { |
|
86 val if_else = Fresh("If_else") |
|
87 val if_end = Fresh("If_end") |
|
88 compile_bexp(b, env, if_else) ++ |
|
89 compile_exp(a1, env) ++ |
|
90 List("goto " + if_end + "\n") ++ |
|
91 List("\n" + if_else + ":\n\n") ++ |
|
92 compile_exp(a2, env) ++ |
|
93 List("\n" + if_end + ":\n\n") |
|
94 } |
|
95 case Call(n, args) => { |
|
96 val is = "I" * args.length |
|
97 args.flatMap(a => compile_exp(a, env)) ++ |
|
98 List ("invokestatic XXX/XXX/" + n + "(" + is + ")I\n") |
|
99 } |
|
100 case Sequ(a1, a2) => { |
|
101 compile_exp(a1, env) ++ List("pop\n") ++ compile_exp(a2, env) |
|
102 } |
|
103 case Write(a1) => { |
|
104 compile_exp(a1, env) ++ |
|
105 List("dup\n", |
|
106 "invokestatic XXX/XXX/write(I)V\n") |
|
107 } |
|
108 } |
|
109 |
|
110 // compile boolean expressions |
|
111 def compile_bexp(b: BExp, env : Env, jmp: String) : Instrs = b match { |
|
112 case Bop("==", a1, a2) => |
|
113 compile_exp(a1, env) ++ compile_exp(a2, env) ++ List("if_icmpne " + jmp + "\n") |
|
114 case Bop("!=", a1, a2) => |
|
115 compile_exp(a1, env) ++ compile_exp(a2, env) ++ List("if_icmpeq " + jmp + "\n") |
|
116 case Bop("<", a1, a2) => |
|
117 compile_exp(a1, env) ++ compile_exp(a2, env) ++ List("if_icmpge " + jmp + "\n") |
|
118 case Bop("<=", a1, a2) => |
|
119 compile_exp(a1, env) ++ compile_exp(a2, env) ++ List("if_icmpgt " + jmp + "\n") |
|
120 } |
|
121 |
|
122 // compile function for declarations and main |
|
123 def compile_decl(d: Decl) : Instrs = d match { |
|
124 case Def(name, args, a) => { |
|
125 val env = args.zipWithIndex.toMap |
|
126 val is = "I" * args.length |
|
127 List(".method public static " + name + "(" + is + ")I \n", |
|
128 ".limit locals " + args.length.toString + "\n", |
|
129 ".limit stack " + (1 + max_stack_exp(a)).toString + "\n", |
|
130 name + "_Start:\n") ++ |
|
131 compile_exp(a, env) ++ |
|
132 List("ireturn\n", |
|
133 ".end method \n\n") |
|
134 } |
|
135 case Main(a) => { |
|
136 List(".method public static main([Ljava/lang/String;)V\n", |
|
137 ".limit locals 200\n", |
|
138 ".limit stack 200\n") ++ |
|
139 compile_exp(a, Map()) ++ |
|
140 List("invokestatic XXX/XXX/write(I)V\n", |
|
141 "return\n", |
|
142 ".end method\n") |
|
143 } |
|
144 } |
|
145 |
|
146 // main compilation function |
|
147 def compile(prog: List[Decl], class_name: String) : String = { |
|
148 val instructions = prog.flatMap(compile_decl).mkString |
|
149 (library + instructions).replaceAllLiterally("XXX", class_name) |
|
150 } |
|
151 |
|
152 |
|
153 |
|
154 |
|
155 // example program (factorials) |
|
156 |
|
157 val test_prog = |
|
158 List(Def("fact", List("n"), |
|
159 If(Bop("==",Var("n"),Num(0)), |
|
160 Num(1), |
|
161 Aop("*",Var("n"), |
|
162 Call("fact",List(Aop("-",Var("n"),Num(1))))))), |
|
163 |
|
164 Def("facT",List("n", "acc"), |
|
165 If(Bop("==",Var("n"),Num(0)), |
|
166 Var("acc"), |
|
167 Call("facT",List(Aop("-",Var("n"),Num(1)), |
|
168 Aop("*",Var("n"),Var("acc")))))), |
|
169 |
|
170 Main(Sequ(Write(Call("fact",List(Num(10)))), |
|
171 Write(Call("facT",List(Num(10), Num(1))))))) |
|
172 |
|
173 // prints out the JVM instructions |
|
174 println(compile(test, "fact")) |