|
1 // A Small Compiler for a Simple Functional Language |
|
2 // (it does not include a parser and lexer) |
|
3 // |
|
4 // call with |
|
5 // |
|
6 // amm fun.sc |
|
7 // |
|
8 // this will print out the JVM instructions for two |
|
9 // factorial functions |
|
10 |
|
11 |
|
12 // abstract syntax trees |
|
13 abstract class Exp |
|
14 abstract class BExp |
|
15 abstract class Decl |
|
16 |
|
17 // functions and declarations |
|
18 case class Def(name: String, args: List[String], body: Exp) extends Decl |
|
19 case class Main(e: Exp) extends Decl |
|
20 |
|
21 // expressions |
|
22 case class Call(name: String, args: List[Exp]) extends Exp |
|
23 case class If(a: BExp, e1: Exp, e2: Exp) extends Exp |
|
24 case class Write(e: Exp) extends Exp |
|
25 case class Var(s: String) extends Exp |
|
26 case class Num(i: Int) extends Exp |
|
27 case class Aop(o: String, a1: Exp, a2: Exp) extends Exp |
|
28 case class Sequ(e1: Exp, e2: Exp) extends Exp |
|
29 |
|
30 // boolean expressions |
|
31 case class Bop(o: String, a1: Exp, a2: Exp) extends BExp |
|
32 |
|
33 // calculating the maximal needed stack size |
|
34 def max_stack_exp(e: Exp): Int = e match { |
|
35 case Call(_, args) => args.map(max_stack_exp).sum |
|
36 case If(a, e1, e2) => |
|
37 max_stack_bexp(a) + (List(max_stack_exp(e1), max_stack_exp(e2)).max) |
|
38 case Write(e) => max_stack_exp(e) + 1 |
|
39 case Var(_) => 1 |
|
40 case Num(_) => 1 |
|
41 case Aop(_, a1, a2) => max_stack_exp(a1) + max_stack_exp(a2) |
|
42 case Sequ(e1, e2) => List(max_stack_exp(e1), max_stack_exp(e2)).max |
|
43 } |
|
44 def max_stack_bexp(e: BExp): Int = e match { |
|
45 case Bop(_, a1, a2) => max_stack_exp(a1) + max_stack_exp(a2) |
|
46 } |
|
47 |
|
48 // compiler - built-in functions |
|
49 // copied from http://www.ceng.metu.edu.tr/courses/ceng444/link/jvm-cpm.html |
|
50 // |
|
51 |
|
52 val library = """ |
|
53 .class public XXX.XXX |
|
54 .super java/lang/Object |
|
55 |
|
56 .method public static write(I)V |
|
57 .limit locals 5 |
|
58 .limit stack 5 |
|
59 iload 0 |
|
60 getstatic java/lang/System/out Ljava/io/PrintStream; |
|
61 swap |
|
62 invokevirtual java/io/PrintStream/println(I)V |
|
63 return |
|
64 .end method |
|
65 |
|
66 """ |
|
67 |
|
68 // for generating new labels |
|
69 var counter = -1 |
|
70 |
|
71 def Fresh(x: String) = { |
|
72 counter += 1 |
|
73 x ++ "_" ++ counter.toString() |
|
74 } |
|
75 |
|
76 // convenient string interpolations for |
|
77 // generating instructions, labels etc |
|
78 import scala.language.implicitConversions |
|
79 import scala.language.reflectiveCalls |
|
80 |
|
81 // convenience for code-generation (string interpolations) |
|
82 implicit def sring_inters(sc: StringContext) = new { |
|
83 def i(args: Any*): String = " " ++ sc.s(args:_*) ++ "\n" // instructions |
|
84 def l(args: Any*): String = sc.s(args:_*) ++ ":\n" // labels |
|
85 def m(args: Any*): String = sc.s(args:_*) ++ "\n" // methods |
|
86 } |
|
87 |
|
88 // variable / index environments |
|
89 type Env = Map[String, Int] |
|
90 |
|
91 // compile expressions |
|
92 def compile_exp(a: Exp, env : Env) : String = a match { |
|
93 case Num(i) => i"ldc $i" |
|
94 case Var(s) => i"iload ${env(s)}" |
|
95 case Aop("+", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"iadd" |
|
96 case Aop("-", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"isub" |
|
97 case Aop("*", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"imul" |
|
98 case Aop("/", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"idiv" |
|
99 case Aop("%", a1, a2) => compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"irem" |
|
100 case If(b, a1, a2) => { |
|
101 val if_else = Fresh("If_else") |
|
102 val if_end = Fresh("If_end") |
|
103 compile_bexp(b, env, if_else) ++ |
|
104 compile_exp(a1, env) ++ |
|
105 i"goto $if_end" ++ |
|
106 l"$if_else" ++ |
|
107 compile_exp(a2, env) ++ |
|
108 l"$if_end" |
|
109 } |
|
110 case Call(name, args) => { |
|
111 val is = "I" * args.length |
|
112 args.map(a => compile_exp(a, env)).mkString ++ |
|
113 i"invokestatic XXX/XXX/$name($is)I" |
|
114 } |
|
115 case Sequ(a1, a2) => { |
|
116 compile_exp(a1, env) ++ i"pop" ++ compile_exp(a2, env) |
|
117 } |
|
118 case Write(a1) => { |
|
119 compile_exp(a1, env) ++ |
|
120 i"dup" ++ |
|
121 i"invokestatic XXX/XXX/write(I)V" |
|
122 } |
|
123 } |
|
124 |
|
125 // compile boolean expressions |
|
126 def compile_bexp(b: BExp, env : Env, jmp: String) : String = b match { |
|
127 case Bop("==", a1, a2) => |
|
128 compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"if_icmpne $jmp" |
|
129 case Bop("!=", a1, a2) => |
|
130 compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"if_icmpeq $jmp" |
|
131 case Bop("<", a1, a2) => |
|
132 compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"if_icmpge $jmp" |
|
133 case Bop("<=", a1, a2) => |
|
134 compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"if_icmpgt $jmp" |
|
135 } |
|
136 |
|
137 // compile functions and declarations |
|
138 def compile_decl(d: Decl) : String = d match { |
|
139 case Def(name, args, a) => { |
|
140 val env = args.zipWithIndex.toMap |
|
141 val is = "I" * args.length |
|
142 m".method public static $name($is)I" ++ |
|
143 m".limit locals ${args.length.toString}" ++ |
|
144 m".limit stack ${1 + max_stack_exp(a)}" ++ |
|
145 l"${name}_Start" ++ |
|
146 compile_exp(a, env) ++ |
|
147 i"ireturn" ++ |
|
148 m".end method\n" |
|
149 } |
|
150 case Main(a) => { |
|
151 m".method public static main([Ljava/lang/String;)V" ++ |
|
152 m".limit locals 200" ++ |
|
153 m".limit stack 200" ++ |
|
154 compile_exp(a, Map()) ++ |
|
155 i"invokestatic XXX/XXX/write(I)V" ++ |
|
156 i"return" ++ |
|
157 m".end method\n" |
|
158 } |
|
159 } |
|
160 |
|
161 // the main compilation function |
|
162 def compile(prog: List[Decl], class_name: String) : String = { |
|
163 val instructions = prog.map(compile_decl).mkString |
|
164 (library + instructions).replaceAllLiterally("XXX", class_name) |
|
165 } |
|
166 |
|
167 |
|
168 |
|
169 |
|
170 // An example program (two versions of factorial) |
|
171 // |
|
172 // def fact(n) = |
|
173 // if n == 0 then 1 else n * fact(n - 1); |
|
174 // |
|
175 // def facT(n, acc) = |
|
176 // if n == 0 then acc else facT(n - 1, n * acc); |
|
177 // |
|
178 // fact(10) ; facT(10, 1) |
|
179 // |
|
180 |
|
181 |
|
182 val test_prog = |
|
183 List(Def("fact", List("n"), |
|
184 If(Bop("==",Var("n"),Num(0)), |
|
185 Num(1), |
|
186 Aop("*",Var("n"), |
|
187 Call("fact",List(Aop("-",Var("n"),Num(1))))))), |
|
188 |
|
189 Def("facT",List("n", "acc"), |
|
190 If(Bop("==",Var("n"),Num(0)), |
|
191 Var("acc"), |
|
192 Call("facT",List(Aop("-",Var("n"),Num(1)), |
|
193 Aop("*",Var("n"),Var("acc")))))), |
|
194 |
|
195 Main(Sequ(Write(Call("fact",List(Num(10)))), |
|
196 Write(Call("facT",List(Num(10), Num(1))))))) |
|
197 |
|
198 // prints out the JVM instructions |
|
199 @main |
|
200 def test() = |
|
201 println(compile(test_prog, "fact")) |