|
1 // A Small Compiler for the WHILE Language |
|
2 // (it does not use a parser and lexer) |
|
3 |
|
4 // the abstract syntax trees |
|
5 abstract class Stmt |
|
6 abstract class AExp |
|
7 abstract class BExp |
|
8 type Block = List[Stmt] |
|
9 |
|
10 // statements |
|
11 case object Skip extends Stmt |
|
12 case class If(a: BExp, bl1: Block, bl2: Block) extends Stmt |
|
13 case class While(b: BExp, bl: Block) extends Stmt |
|
14 case class Assign(s: String, a: AExp) extends Stmt |
|
15 case class Write(s: String) extends Stmt |
|
16 case class Read(s: String) extends Stmt |
|
17 |
|
18 // arithmetic expressions |
|
19 case class Var(s: String) extends AExp |
|
20 case class Num(i: Int) extends AExp |
|
21 case class Aop(o: String, a1: AExp, a2: AExp) extends AExp |
|
22 |
|
23 // boolean expressions |
|
24 case object True extends BExp |
|
25 case object False extends BExp |
|
26 case class Bop(o: String, a1: AExp, a2: AExp) extends BExp |
|
27 |
|
28 |
|
29 // compiler headers needed for the JVM |
|
30 // (contains an init method, as well as methods for read and write) |
|
31 val beginning = """ |
|
32 .class public XXX.XXX |
|
33 .super java/lang/Object |
|
34 |
|
35 .method public <init>()V |
|
36 aload_0 |
|
37 invokenonvirtual java/lang/Object/<init>()V |
|
38 return |
|
39 .end method |
|
40 |
|
41 .method public static write(I)V |
|
42 .limit locals 1 |
|
43 .limit stack 2 |
|
44 getstatic java/lang/System/out Ljava/io/PrintStream; |
|
45 iload 0 |
|
46 invokevirtual java/io/PrintStream/println(I)V |
|
47 return |
|
48 .end method |
|
49 |
|
50 .method public static read()I |
|
51 .limit locals 10 |
|
52 .limit stack 10 |
|
53 |
|
54 ldc 0 |
|
55 istore 1 ; this will hold our final integer |
|
56 Label1: |
|
57 getstatic java/lang/System/in Ljava/io/InputStream; |
|
58 invokevirtual java/io/InputStream/read()I |
|
59 istore 2 |
|
60 iload 2 |
|
61 ldc 10 ; the newline delimiter |
|
62 isub |
|
63 ifeq Label2 |
|
64 iload 2 |
|
65 ldc 32 ; the space delimiter |
|
66 isub |
|
67 ifeq Label2 |
|
68 |
|
69 iload 2 |
|
70 ldc 48 ; we have our digit in ASCII, have to subtract it from 48 |
|
71 isub |
|
72 ldc 10 |
|
73 iload 1 |
|
74 imul |
|
75 iadd |
|
76 istore 1 |
|
77 goto Label1 |
|
78 Label2: |
|
79 ;when we come here we have our integer computed in local variable 1 |
|
80 iload 1 |
|
81 ireturn |
|
82 .end method |
|
83 |
|
84 .method public static main([Ljava/lang/String;)V |
|
85 .limit locals 200 |
|
86 .limit stack 200 |
|
87 |
|
88 """ |
|
89 |
|
90 val ending = """ |
|
91 |
|
92 return |
|
93 |
|
94 .end method |
|
95 """ |
|
96 |
|
97 println("Start compilation") |
|
98 |
|
99 |
|
100 // for generating new labels |
|
101 var counter = -1 |
|
102 |
|
103 def Fresh(x: String) = { |
|
104 counter += 1 |
|
105 x ++ "_" ++ counter.toString() |
|
106 } |
|
107 |
|
108 // environments and instructions |
|
109 type Env = Map[String, String] |
|
110 type Instrs = List[String] |
|
111 |
|
112 // arithmetic expression compilation |
|
113 def compile_aexp(a: AExp, env : Env) : Instrs = a match { |
|
114 case Num(i) => List("ldc " + i.toString + "\n") |
|
115 case Var(s) => List("iload " + env(s) + "\n") |
|
116 case Aop("+", a1, a2) => |
|
117 compile_aexp(a1, env) ++ |
|
118 compile_aexp(a2, env) ++ List("iadd\n") |
|
119 case Aop("-", a1, a2) => |
|
120 compile_aexp(a1, env) ++ compile_aexp(a2, env) ++ List("isub\n") |
|
121 case Aop("*", a1, a2) => |
|
122 compile_aexp(a1, env) ++ compile_aexp(a2, env) ++ List("imul\n") |
|
123 } |
|
124 |
|
125 // boolean expression compilation |
|
126 def compile_bexp(b: BExp, env : Env, jmp: String) : Instrs = b match { |
|
127 case True => Nil |
|
128 case False => List("goto " + jmp + "\n") |
|
129 case Bop("=", a1, a2) => |
|
130 compile_aexp(a1, env) ++ compile_aexp(a2, env) ++ |
|
131 List("if_icmpne " + jmp + "\n") |
|
132 case Bop("!=", a1, a2) => |
|
133 compile_aexp(a1, env) ++ compile_aexp(a2, env) ++ |
|
134 List("if_icmpeq " + jmp + "\n") |
|
135 case Bop("<", a1, a2) => |
|
136 compile_aexp(a1, env) ++ compile_aexp(a2, env) ++ |
|
137 List("if_icmpge " + jmp + "\n") |
|
138 } |
|
139 |
|
140 // statement compilation |
|
141 def compile_stmt(s: Stmt, env: Env) : (Instrs, Env) = s match { |
|
142 case Skip => (Nil, env) |
|
143 case Assign(x, a) => { |
|
144 val index = if (env.isDefinedAt(x)) env(x) else |
|
145 env.keys.size.toString |
|
146 (compile_aexp(a, env) ++ |
|
147 List("istore " + index + "\n"), env + (x -> index)) |
|
148 } |
|
149 case If(b, bl1, bl2) => { |
|
150 val if_else = Fresh("If_else") |
|
151 val if_end = Fresh("If_end") |
|
152 val (instrs1, env1) = compile_block(bl1, env) |
|
153 val (instrs2, env2) = compile_block(bl2, env1) |
|
154 (compile_bexp(b, env, if_else) ++ |
|
155 instrs1 ++ |
|
156 List("goto " + if_end + "\n") ++ |
|
157 List("\n" + if_else + ":\n\n") ++ |
|
158 instrs2 ++ |
|
159 List("\n" + if_end + ":\n\n"), env2) |
|
160 } |
|
161 case While(b, bl) => { |
|
162 val loop_begin = Fresh("Loop_begin") |
|
163 val loop_end = Fresh("Loop_end") |
|
164 val (instrs1, env1) = compile_block(bl, env) |
|
165 (List("\n" + loop_begin + ":\n\n") ++ |
|
166 compile_bexp(b, env, loop_end) ++ |
|
167 instrs1 ++ |
|
168 List("goto " + loop_begin + "\n") ++ |
|
169 List("\n" + loop_end + ":\n\n"), env1) |
|
170 } |
|
171 case Write(x) => |
|
172 (List("iload " + env(x) + "\n" + |
|
173 "invokestatic XXX/XXX/write(I)V\n"), env) |
|
174 case Read(x) => { |
|
175 val index = if (env.isDefinedAt(x)) env(x) else |
|
176 env.keys.size.toString |
|
177 (List("invokestatic XXX/XXX/read()I\n" + |
|
178 "istore " + index + "\n"), env + (x -> index)) |
|
179 } |
|
180 } |
|
181 |
|
182 // compilation of a block (i.e. list of instructions) |
|
183 def compile_block(bl: Block, env: Env) : (Instrs, Env) = bl match { |
|
184 case Nil => (Nil, env) |
|
185 case s::bl => { |
|
186 val (instrs1, env1) = compile_stmt(s, env) |
|
187 val (instrs2, env2) = compile_block(bl, env1) |
|
188 (instrs1 ++ instrs2, env2) |
|
189 } |
|
190 } |
|
191 |
|
192 // main compilation function for blocks |
|
193 def compile(bl: Block, class_name: String) : String = { |
|
194 val instructions = compile_block(bl, Map.empty)._1 |
|
195 (beginning ++ instructions.mkString ++ ending).replaceAllLiterally("XXX", class_name) |
|
196 } |
|
197 |
|
198 |
|
199 // compiling and running files |
|
200 // |
|
201 // JVM files can be assembled with |
|
202 // |
|
203 // java -jar jvm/jasmin-2.4/jasmin.jar fib.j |
|
204 // |
|
205 // and started with |
|
206 // |
|
207 // java fib/fib |
|
208 |
|
209 |
|
210 |
|
211 import scala.util._ |
|
212 import scala.sys.process._ |
|
213 import scala.io |
|
214 |
|
215 def compile_tofile(bl: Block, class_name: String) = { |
|
216 val output = compile(bl, class_name) |
|
217 val fw = new java.io.FileWriter(class_name + ".j") |
|
218 fw.write(output) |
|
219 fw.close() |
|
220 } |
|
221 |
|
222 def compile_all(bl: Block, class_name: String) : Unit = { |
|
223 compile_tofile(bl, class_name) |
|
224 println("compiled ") |
|
225 val test = ("java -jar jvm/jasmin-2.4/jasmin.jar " + class_name + ".j").!! |
|
226 println("assembled ") |
|
227 } |
|
228 |
|
229 def time_needed[T](i: Int, code: => T) = { |
|
230 val start = System.nanoTime() |
|
231 for (j <- 1 to i) code |
|
232 val end = System.nanoTime() |
|
233 (end - start)/(i * 1.0e9) |
|
234 } |
|
235 |
|
236 |
|
237 def compile_run(bl: Block, class_name: String) : Unit = { |
|
238 println("Start compilation") |
|
239 compile_all(bl, class_name) |
|
240 println("Time: " + time_needed(1, ("java " + class_name + "/" + class_name).!)) |
|
241 } |
|
242 |
|
243 |
|
244 // Fibonacci numbers as a test-case |
|
245 val fib_test = |
|
246 List(Assign("n", Num(10)), // n := 10; |
|
247 Assign("minus1",Num(0)), // minus1 := 0; |
|
248 Assign("minus2",Num(1)), // minus2 := 1; |
|
249 Assign("temp",Num(0)), // temp := 0; |
|
250 While(Bop("<",Num(0),Var("n")), // while n > 0 do { |
|
251 List(Assign("temp",Var("minus2")), // temp := minus2; |
|
252 Assign("minus2",Aop("+",Var("minus1"),Var("minus2"))), |
|
253 // minus2 := minus1 + minus2; |
|
254 Assign("minus1",Var("temp")), // minus1 := temp; |
|
255 Assign("n",Aop("-",Var("n"),Num(1))))), // n := n - 1 }; |
|
256 Write("minus1")) // write minus1 |
|
257 |
|
258 |
|
259 compile_run(fib_test, "fib") |
|
260 |
|
261 |