solutions/cw5/fun_llvm.sc
changeset 894 02ef5c3abc51
parent 868 8fb3b6d3be70
child 903 2f86ebda3629
equal deleted inserted replaced
893:54a483a33763 894:02ef5c3abc51
       
     1 // A Small LLVM Compiler for a Simple Functional Language
       
     2 // (includes an external lexer and parser)
       
     3 //
       
     4 //
       
     5 // call with                 -- prints out llvm code
       
     6 //
       
     7 //     amm fun_llvm.sc main fact.fun
       
     8 //     amm fun_llvm.sc main defs.fun
       
     9 //
       
    10 // or                        -- writes llvm code to disk
       
    11 //
       
    12 //     amm fun_llvm.sc write fact.fun
       
    13 //     amm fun_llvm.sc write defs.fun
       
    14 //
       
    15 //       this will generate an .ll file. 
       
    16 //
       
    17 // or                       -- runs the generated llvm code via lli
       
    18 //
       
    19 //     amm fun_llvm.sc run fact.fun
       
    20 //     amm fun_llvm.sc run defs.fun
       
    21 //
       
    22 //
       
    23 // You can interpret an .ll file using lli, for example
       
    24 //
       
    25 //      lli fact.ll
       
    26 //
       
    27 // The optimiser can be invoked as
       
    28 //
       
    29 //      opt -O1 -S in_file.ll > out_file.ll
       
    30 //      opt -O3 -S in_file.ll > out_file.ll
       
    31 //
       
    32 // The code produced for the various architectures can be obtain with
       
    33 //   
       
    34 //   llc -march=x86 -filetype=asm in_file.ll -o -
       
    35 //   llc -march=arm -filetype=asm in_file.ll -o -  
       
    36 //
       
    37 // Producing an executable can be achieved by
       
    38 //
       
    39 //    llc -filetype=obj in_file.ll
       
    40 //    gcc in_file.o -o a.out
       
    41 //    ./a.out
       
    42 
       
    43 
       
    44 import $file.fun_tokens, fun_tokens._
       
    45 import $file.fun_parser, fun_parser._ 
       
    46 
       
    47 
       
    48 // for generating new labels
       
    49 var counter = -1
       
    50 
       
    51 def Fresh(x: String) = {
       
    52   counter += 1
       
    53   x ++ "_" ++ counter.toString()
       
    54 }
       
    55 
       
    56 // Internal CPS language for FUN
       
    57 abstract class KExp
       
    58 abstract class KVal
       
    59 
       
    60 type Ty = String
       
    61 type TyEnv = Map[String, Ty]
       
    62 
       
    63 case class KVar(s: String, ty: Ty = "UNDEF") extends KVal
       
    64 case class KLoad(v: KVal) extends KVal
       
    65 case class KNum(i: Int) extends KVal
       
    66 case class KFNum(i: Double) extends KVal
       
    67 case class KChr(c: Int) extends KVal
       
    68 case class Kop(o: String, v1: KVal, v2: KVal, ty: Ty = "UNDEF") extends KVal
       
    69 case class KCall(o: String, vrs: List[KVal], ty: Ty = "UNDEF") extends KVal
       
    70 
       
    71 case class KIf(x1: String, e1: KExp, e2: KExp) extends KExp {
       
    72   override def toString = s"KIf $x1\nIF\n$e1\nELSE\n$e2"
       
    73 }
       
    74 case class KLet(x: String, e1: KVal, e2: KExp) extends KExp {
       
    75   override def toString = s"let $x = $e1 in \n$e2" 
       
    76 }
       
    77 case class KReturn(v: KVal) extends KExp
       
    78 
       
    79 // typing K values
       
    80 def typ_val(v: KVal, ts: TyEnv) : (KVal, Ty) = v match {
       
    81   case KVar(s, _) => {
       
    82     val ty = ts.getOrElse(s, "TUNDEF")
       
    83     (KVar(s, ty), ty)  
       
    84   }
       
    85   case Kop(op, v1, v2, _) => {
       
    86     val (tv1, ty1) = typ_val(v1, ts)
       
    87     val (tv2, ty2) = typ_val(v2, ts)
       
    88     if (ty1 == ty2) (Kop(op, tv1, tv2, ty1), ty1) else (Kop(op, tv1, tv2, "TMISMATCH"), "TMISMATCH") 
       
    89   }
       
    90   case KCall(fname, args, _) => {
       
    91     val ty = ts.getOrElse(fname, "TCALLUNDEF" ++ fname)
       
    92     (KCall(fname, args.map(typ_val(_, ts)._1), ty), ty)
       
    93   }  
       
    94   case KLoad(v) => {
       
    95     val (tv, ty) = typ_val(v, ts)
       
    96     (KLoad(tv), ty)
       
    97   }
       
    98   case KNum(i) => (KNum(i), "Int")
       
    99   case KFNum(i) => (KFNum(i), "Double")
       
   100   case KChr(c) => (KChr(c), "Int")
       
   101 }
       
   102 
       
   103 def typ_exp(a: KExp, ts: TyEnv) : KExp = a match {
       
   104   case KReturn(v) => KReturn(typ_val(v, ts)._1)
       
   105   case KLet(x: String, v: KVal, e: KExp) => {
       
   106     val (tv, ty) = typ_val(v, ts)
       
   107     KLet(x, tv, typ_exp(e, ts + (x -> ty)))
       
   108   }
       
   109   case KIf(b, e1, e2) => KIf(b, typ_exp(e1, ts), typ_exp(e2, ts))
       
   110 }
       
   111 
       
   112 
       
   113 
       
   114 
       
   115 // CPS translation from Exps to KExps using a
       
   116 // continuation k.
       
   117 def CPS(e: Exp)(k: KVal => KExp) : KExp = e match {
       
   118   case Var(s) if (s.head.isUpper) => {
       
   119       val z = Fresh("tmp")
       
   120       KLet(z, KLoad(KVar(s)), k(KVar(z)))
       
   121   }
       
   122   case Var(s) => k(KVar(s))
       
   123   case Num(i) => k(KNum(i))
       
   124   case ChConst(c) => k(KChr(c))
       
   125   case FNum(i) => k(KFNum(i))
       
   126   case Aop(o, e1, e2) => {
       
   127     val z = Fresh("tmp")
       
   128     CPS(e1)(y1 => 
       
   129       CPS(e2)(y2 => KLet(z, Kop(o, y1, y2), k(KVar(z)))))
       
   130   }
       
   131   case If(Bop(o, b1, b2), e1, e2) => {
       
   132     val z = Fresh("tmp")
       
   133     CPS(b1)(y1 => 
       
   134       CPS(b2)(y2 => 
       
   135         KLet(z, Kop(o, y1, y2), KIf(z, CPS(e1)(k), CPS(e2)(k)))))
       
   136   }
       
   137   case Call(name, args) => {
       
   138     def aux(args: List[Exp], vs: List[KVal]) : KExp = args match {
       
   139       case Nil => {
       
   140           val z = Fresh("tmp")
       
   141           KLet(z, KCall(name, vs), k(KVar(z)))
       
   142       }
       
   143       case e::es => CPS(e)(y => aux(es, vs ::: List(y)))
       
   144     }
       
   145     aux(args, Nil)
       
   146   }
       
   147   case Sequence(e1, e2) => 
       
   148     CPS(e1)(_ => CPS(e2)(y2 => k(y2)))
       
   149 }   
       
   150 
       
   151 //initial continuation
       
   152 def CPSi(e: Exp) = CPS(e)(KReturn)
       
   153 
       
   154 // some testcases
       
   155 val e1 = Aop("*", Var("a"), Num(3))
       
   156 CPSi(e1)
       
   157 
       
   158 val e2 = Aop("+", Aop("*", Var("a"), Num(3)), Num(4))
       
   159 CPSi(e2)
       
   160 
       
   161 val e3 = Aop("+", Num(2), Aop("*", Var("a"), Num(3)))
       
   162 CPSi(e3)
       
   163 
       
   164 val e4 = Aop("+", Aop("-", Num(1), Num(2)), Aop("*", Var("a"), Num(3)))
       
   165 CPSi(e4)
       
   166 
       
   167 val e5 = If(Bop("==", Num(1), Num(1)), Num(3), Num(4))
       
   168 CPSi(e5)
       
   169 
       
   170 val e6 = If(Bop("!=", Num(10), Num(10)), e5, Num(40))
       
   171 CPSi(e6)
       
   172 
       
   173 val e7 = Call("foo", List(Num(3)))
       
   174 CPSi(e7)
       
   175 
       
   176 val e8 = Call("foo", List(Aop("*", Num(3), Num(1)), Num(4), Aop("+", Num(5), Num(6))))
       
   177 CPSi(e8)
       
   178 
       
   179 val e9 = Sequence(Aop("*", Var("a"), Num(3)), Aop("+", Var("b"), Num(6)))
       
   180 CPSi(e9)
       
   181 
       
   182 val e = Aop("*", Aop("+", Num(1), Call("foo", List(Var("a"), Num(3)))), Num(4))
       
   183 CPSi(e)
       
   184 
       
   185 
       
   186 
       
   187 
       
   188 // convenient string interpolations 
       
   189 // for instructions, labels and methods
       
   190 import scala.language.implicitConversions
       
   191 import scala.language.reflectiveCalls
       
   192 
       
   193 
       
   194 
       
   195 
       
   196 implicit def sring_inters(sc: StringContext) = new {
       
   197     def i(args: Any*): String = "   " ++ sc.s(args:_*) ++ "\n"
       
   198     def l(args: Any*): String = sc.s(args:_*) ++ ":\n"
       
   199     def m(args: Any*): String = sc.s(args:_*) ++ "\n"
       
   200 }
       
   201 
       
   202 def get_ty(s: String) = s match {
       
   203   case "Double" => "double"
       
   204   case "Void" => "void"
       
   205   case "Int" => "i32"
       
   206   case "Bool" => "i2"
       
   207   case _ => s
       
   208 }
       
   209 
       
   210 def compile_call_arg(a: KVal) = a match {
       
   211   case KNum(i) => s"i32 $i"
       
   212   case KFNum(i) => s"double $i"
       
   213   case KChr(c) => s"i32 $c"
       
   214   case KVar(s, ty) => s"${get_ty(ty)} %$s" 
       
   215 }
       
   216 
       
   217 def compile_arg(s: (String, String)) = s"${get_ty(s._2)} %${s._1}" 
       
   218 
       
   219 
       
   220 // mathematical and boolean operations
       
   221 def compile_op(op: String) = op match {
       
   222   case "+" => "add i32 "
       
   223   case "*" => "mul i32 "
       
   224   case "-" => "sub i32 "
       
   225   case "/" => "sdiv i32 "
       
   226   case "%" => "srem i32 "
       
   227   case "==" => "icmp eq i32 "
       
   228   case "!=" => "icmp ne i32 "      // not equal 
       
   229   case "<=" => "icmp sle i32 "     // signed less or equal
       
   230   case "<"  => "icmp slt i32 "     // signed less than
       
   231 }
       
   232 
       
   233 def compile_dop(op: String) = op match {
       
   234   case "+" => "fadd double "
       
   235   case "*" => "fmul double "
       
   236   case "-" => "fsub double "
       
   237   case "==" => "fcmp oeq double "
       
   238   case "<=" => "fcmp ole double "   
       
   239   case "<"  => "fcmp olt double "   
       
   240 }
       
   241 
       
   242 // compile K values
       
   243 def compile_val(v: KVal) : String = v match {
       
   244   case KNum(i) => s"$i"
       
   245   case KFNum(i) => s"$i"
       
   246   case KChr(c) => s"$c"
       
   247   case KVar(s, ty) => s"%$s" 
       
   248   case KLoad(KVar(s, ty)) => s"load ${get_ty(ty)}, ${get_ty(ty)}* @$s"
       
   249   case Kop(op, x1, x2, ty) => ty match { 
       
   250     case "Int" => s"${compile_op(op)} ${compile_val(x1)}, ${compile_val(x2)}"
       
   251     case "Double" => s"${compile_dop(op)} ${compile_val(x1)}, ${compile_val(x2)}"
       
   252     case _ => Kop(op, x1, x2, ty).toString
       
   253   }
       
   254   case KCall(fname, args, ty) => 
       
   255     s"call ${get_ty(ty)} @$fname (${args.map(compile_call_arg).mkString(", ")})"
       
   256 }
       
   257 
       
   258 // compile K expressions
       
   259 def compile_exp(a: KExp) : String = a match {
       
   260   case KReturn(KVar("void", _)) =>
       
   261     i"ret void"
       
   262   case KReturn(KVar(x, ty)) =>
       
   263     i"ret ${get_ty(ty)} %$x"
       
   264   case KReturn(KNum(i)) =>
       
   265     i"ret i32 $i"
       
   266   case KLet(x: String, KCall(o: String, vrs: List[KVal], "Void"), e: KExp) => 
       
   267     i"${compile_val(KCall(o: String, vrs: List[KVal], "Void"))}" ++ compile_exp(e)
       
   268   case KLet(x: String, v: KVal, e: KExp) => 
       
   269     i"%$x = ${compile_val(v)}" ++ compile_exp(e)
       
   270   case KIf(x, e1, e2) => {
       
   271     val if_br = Fresh("if_branch")
       
   272     val else_br = Fresh("else_branch")
       
   273     i"br i1 %$x, label %$if_br, label %$else_br" ++
       
   274     l"\n$if_br" ++
       
   275     compile_exp(e1) ++
       
   276     l"\n$else_br" ++ 
       
   277     compile_exp(e2)
       
   278   }
       
   279 }
       
   280 
       
   281 
       
   282 val prelude = """
       
   283 declare i32 @printf(i8*, ...)
       
   284 
       
   285 @.str_nl = private constant [2 x i8] c"\0A\00"
       
   286 @.str_star = private constant [2 x i8] c"*\00"
       
   287 @.str_space = private constant [2 x i8] c" \00"
       
   288 
       
   289 define void @new_line() #0 {
       
   290   %t0 = getelementptr [2 x i8], [2 x i8]* @.str_nl, i32 0, i32 0
       
   291   %1 = call i32 (i8*, ...) @printf(i8* %t0)
       
   292   ret void
       
   293 }
       
   294 
       
   295 define void @print_star() #0 {
       
   296   %t0 = getelementptr [2 x i8], [2 x i8]* @.str_star, i32 0, i32 0
       
   297   %1 = call i32 (i8*, ...) @printf(i8* %t0)
       
   298   ret void
       
   299 }
       
   300 
       
   301 define void @print_space() #0 {
       
   302   %t0 = getelementptr [2 x i8], [2 x i8]* @.str_space, i32 0, i32 0
       
   303   %1 = call i32 (i8*, ...) @printf(i8* %t0)
       
   304   ret void
       
   305 }
       
   306 
       
   307 define void @skip() #0 {
       
   308   ret void
       
   309 }
       
   310 
       
   311 @.str_int = private constant [3 x i8] c"%d\00"
       
   312 
       
   313 define void @print_int(i32 %x) {
       
   314    %t0 = getelementptr [3 x i8], [3 x i8]* @.str_int, i32 0, i32 0
       
   315    call i32 (i8*, ...) @printf(i8* %t0, i32 %x) 
       
   316    ret void
       
   317 }
       
   318 
       
   319 @.str_char = private constant [3 x i8] c"%c\00"
       
   320 
       
   321 define void @print_char(i32 %x) {
       
   322    %t0 = getelementptr [3 x i8], [3 x i8]* @.str_char, i32 0, i32 0
       
   323    call i32 (i8*, ...) @printf(i8* %t0, i32 %x) 
       
   324    ret void
       
   325 }
       
   326 
       
   327 ; END OF BUILD-IN FUNCTIONS (prelude)
       
   328 
       
   329 """
       
   330 
       
   331 def get_cont(ty: Ty) = ty match {
       
   332   case "Int" =>    KReturn
       
   333   case "Double" => KReturn
       
   334   case "Void" =>   { (_: KVal) => KReturn(KVar("void", "Void")) }
       
   335 } 
       
   336 
       
   337 // compile function for declarations and main
       
   338 def compile_decl(d: Decl, ts: TyEnv) : (String, TyEnv) = d match {
       
   339   case Def(name, args, ty, body) => { 
       
   340     val ts2 = ts + (name -> ty)
       
   341     val tkbody = typ_exp(CPS(body)(get_cont(ty)), ts2 ++ args.toMap)
       
   342     (m"define ${get_ty(ty)} @$name (${args.map(compile_arg).mkString(",")}) {" ++
       
   343      compile_exp(tkbody) ++
       
   344      m"}\n", ts2)
       
   345   }
       
   346   case Main(body) => {
       
   347     val tbody = typ_exp(CPS(body)(_ => KReturn(KNum(0))), ts)
       
   348     (m"define i32 @main() {" ++
       
   349      compile_exp(tbody) ++
       
   350      m"}\n", ts)
       
   351   }
       
   352   case Const(name, n) => {
       
   353     (m"@$name = global i32 $n\n", ts + (name -> "Int"))
       
   354   }
       
   355   case FConst(name, x) => {
       
   356     (m"@$name = global double $x\n", ts + (name -> "Double"))
       
   357   }
       
   358 }
       
   359 
       
   360 def compile_prog(prog: List[Decl], ty: TyEnv) : String = prog match {
       
   361   case Nil => ""
       
   362   case d::ds => {
       
   363     val (s2, ty2) = compile_decl(d, ty)
       
   364     s2 ++ compile_prog(ds, ty2)
       
   365   }
       
   366 }
       
   367 // main compiler functions
       
   368 def compile(prog: List[Decl]) : String = 
       
   369   prelude ++ compile_prog(prog, Map("new_line" -> "Void", "skip" -> "Void", 
       
   370 				    "print_star" -> "Void", "print_space" -> "Void",
       
   371                                     "print_int" -> "Void", "print_char" -> "Void"))
       
   372 
       
   373 
       
   374 //import ammonite.ops._
       
   375 
       
   376 
       
   377 @main
       
   378 def main(fname: String) = {
       
   379     val path = os.pwd / fname
       
   380     val file = fname.stripSuffix("." ++ path.ext)
       
   381     val tks = tokenise(os.read(path))
       
   382     val ast = parse_tks(tks)
       
   383     val code = compile(ast)
       
   384     println(code)
       
   385 }
       
   386 
       
   387 @main
       
   388 def write(fname: String) = {
       
   389     val path = os.pwd / fname
       
   390     val file = fname.stripSuffix("." ++ path.ext)
       
   391     val tks = tokenise(os.read(path))
       
   392     val ast = parse_tks(tks)
       
   393     val code = compile(ast)
       
   394     //println(code)
       
   395     os.write.over(os.pwd / (file ++ ".ll"), code)
       
   396 }
       
   397 
       
   398 @main
       
   399 def run(fname: String) = {
       
   400     val path = os.pwd / fname
       
   401     val file = fname.stripSuffix("." ++ path.ext)
       
   402     write(fname)  
       
   403     os.proc("llc", "-filetype=obj", file ++ ".ll").call()
       
   404     os.proc("gcc", file ++ ".o", "-o", file ++ ".bin").call()
       
   405     os.proc(os.pwd / (file ++ ".bin")).call(stdout = os.Inherit)
       
   406     println(s"done.")
       
   407 }
       
   408 
       
   409 
       
   410 
       
   411 
       
   412