progs/fun/fun_llvm.sc
changeset 789 f0696713177b
parent 734 5d860ff01938
child 813 059f970287d1
equal deleted inserted replaced
788:3b1136fb6bee 789:f0696713177b
       
     1 // A Small LLVM Compiler for a Simple Functional Language
       
     2 // (includes an external lexer and parser)
       
     3 //
       
     4 // call with 
       
     5 //
       
     6 //     amm fun_llvm.sc write fact.fun
       
     7 //
       
     8 //     amm fun_llvm.sc write defs.fun
       
     9 //
       
    10 // this will generate a .ll file. Other options are compile and run.
       
    11 //
       
    12 // You can interpret an .ll file using lli.
       
    13 //
       
    14 // The optimiser can be invoked as
       
    15 //
       
    16 //      opt -O1 -S in_file.ll > out_file.ll
       
    17 //      opt -O3 -S in_file.ll > out_file.ll
       
    18 //
       
    19 // The code produced for the various architectures can be obtains with
       
    20 //   
       
    21 //   llc -march=x86 -filetype=asm in_file.ll -o -
       
    22 //   llc -march=arm -filetype=asm in_file.ll -o -  
       
    23 //
       
    24 // Producing an executable can be achieved by
       
    25 //
       
    26 //    llc -filetype=obj in_file.ll
       
    27 //    gcc in_file.o -o a.out
       
    28 //    ./a.out
       
    29 
       
    30 
       
    31 import $file.fun_tokens, fun_tokens._
       
    32 import $file.fun_parser, fun_parser._ 
       
    33 import scala.util._
       
    34 
       
    35 
       
    36 // for generating new labels
       
    37 var counter = -1
       
    38 
       
    39 def Fresh(x: String) = {
       
    40   counter += 1
       
    41   x ++ "_" ++ counter.toString()
       
    42 }
       
    43 
       
    44 // Internal CPS language for FUN
       
    45 abstract class KExp
       
    46 abstract class KVal
       
    47 
       
    48 case class KVar(s: String) extends KVal
       
    49 case class KNum(i: Int) extends KVal
       
    50 case class Kop(o: String, v1: KVal, v2: KVal) extends KVal
       
    51 case class KCall(o: String, vrs: List[KVal]) extends KVal
       
    52 case class KWrite(v: KVal) extends KVal
       
    53 
       
    54 case class KIf(x1: String, e1: KExp, e2: KExp) extends KExp {
       
    55   override def toString = s"KIf $x1\nIF\n$e1\nELSE\n$e2"
       
    56 }
       
    57 case class KLet(x: String, e1: KVal, e2: KExp) extends KExp {
       
    58   override def toString = s"let $x = $e1 in \n$e2" 
       
    59 }
       
    60 case class KReturn(v: KVal) extends KExp
       
    61 
       
    62 
       
    63 // CPS translation from Exps to KExps using a
       
    64 // continuation k.
       
    65 def CPS(e: Exp)(k: KVal => KExp) : KExp = e match {
       
    66   case Var(s) => k(KVar(s)) 
       
    67   case Num(i) => k(KNum(i))
       
    68   case Aop(o, e1, e2) => {
       
    69     val z = Fresh("tmp")
       
    70     CPS(e1)(y1 => 
       
    71       CPS(e2)(y2 => KLet(z, Kop(o, y1, y2), k(KVar(z)))))
       
    72   }
       
    73   case If(Bop(o, b1, b2), e1, e2) => {
       
    74     val z = Fresh("tmp")
       
    75     CPS(b1)(y1 => 
       
    76       CPS(b2)(y2 => 
       
    77         KLet(z, Kop(o, y1, y2), KIf(z, CPS(e1)(k), CPS(e2)(k)))))
       
    78   }
       
    79   case Call(name, args) => {
       
    80     def aux(args: List[Exp], vs: List[KVal]) : KExp = args match {
       
    81       case Nil => {
       
    82           val z = Fresh("tmp")
       
    83           KLet(z, KCall(name, vs), k(KVar(z)))
       
    84       }
       
    85       case e::es => CPS(e)(y => aux(es, vs ::: List(y)))
       
    86     }
       
    87     aux(args, Nil)
       
    88   }
       
    89   case Sequence(e1, e2) => 
       
    90     CPS(e1)(_ => CPS(e2)(y2 => k(y2)))
       
    91   case Write(e) => {
       
    92     val z = Fresh("tmp")
       
    93     CPS(e)(y => KLet(z, KWrite(y), k(KVar(z))))
       
    94   }
       
    95 }   
       
    96 
       
    97 //initial continuation
       
    98 def CPSi(e: Exp) = CPS(e)(KReturn)
       
    99 
       
   100 // some testcases
       
   101 val e1 = Aop("*", Var("a"), Num(3))
       
   102 CPSi(e1)
       
   103 
       
   104 val e2 = Aop("+", Aop("*", Var("a"), Num(3)), Num(4))
       
   105 CPSi(e2)
       
   106 
       
   107 val e3 = Aop("+", Num(2), Aop("*", Var("a"), Num(3)))
       
   108 CPSi(e3)
       
   109 
       
   110 val e4 = Aop("+", Aop("-", Num(1), Num(2)), Aop("*", Var("a"), Num(3)))
       
   111 CPSi(e4)
       
   112 
       
   113 val e5 = If(Bop("==", Num(1), Num(1)), Num(3), Num(4))
       
   114 CPSi(e5)
       
   115 
       
   116 val e6 = If(Bop("!=", Num(10), Num(10)), e5, Num(40))
       
   117 CPSi(e6)
       
   118 
       
   119 val e7 = Call("foo", List(Num(3)))
       
   120 CPSi(e7)
       
   121 
       
   122 val e8 = Call("foo", List(Aop("*", Num(3), Num(1)), Num(4), Aop("+", Num(5), Num(6))))
       
   123 CPSi(e8)
       
   124 
       
   125 val e9 = Sequence(Aop("*", Var("a"), Num(3)), Aop("+", Var("b"), Num(6)))
       
   126 CPSi(e9)
       
   127 
       
   128 val e = Aop("*", Aop("+", Num(1), Call("foo", List(Var("a"), Num(3)))), Num(4))
       
   129 CPSi(e)
       
   130 
       
   131 
       
   132 
       
   133 
       
   134 // convenient string interpolations 
       
   135 // for instructions, labels and methods
       
   136 import scala.language.implicitConversions
       
   137 import scala.language.reflectiveCalls
       
   138 
       
   139 implicit def sring_inters(sc: StringContext) = new {
       
   140     def i(args: Any*): String = "   " ++ sc.s(args:_*) ++ "\n"
       
   141     def l(args: Any*): String = sc.s(args:_*) ++ ":\n"
       
   142     def m(args: Any*): String = sc.s(args:_*) ++ "\n"
       
   143 }
       
   144 
       
   145 // mathematical and boolean operations
       
   146 def compile_op(op: String) = op match {
       
   147   case "+" => "add i32 "
       
   148   case "*" => "mul i32 "
       
   149   case "-" => "sub i32 "
       
   150   case "/" => "sdiv i32 "
       
   151   case "%" => "srem i32 "
       
   152   case "==" => "icmp eq i32 "
       
   153   case "<=" => "icmp sle i32 "    // signed less or equal
       
   154   case "<" => "icmp slt i32 "     // signed less than
       
   155 }
       
   156 
       
   157 def compile_val(v: KVal) : String = v match {
       
   158   case KNum(i) => s"$i"
       
   159   case KVar(s) => s"%$s"
       
   160   case Kop(op, x1, x2) => 
       
   161     s"${compile_op(op)} ${compile_val(x1)}, ${compile_val(x2)}"
       
   162   case KCall(x1, args) => 
       
   163     s"call i32 @$x1 (${args.map(compile_val).mkString("i32 ", ", i32 ", "")})"
       
   164   case KWrite(x1) =>
       
   165     s"call i32 @printInt (i32 ${compile_val(x1)})"
       
   166 }
       
   167 
       
   168 // compile K expressions
       
   169 def compile_exp(a: KExp) : String = a match {
       
   170   case KReturn(v) =>
       
   171     i"ret i32 ${compile_val(v)}"
       
   172   case KLet(x: String, v: KVal, e: KExp) => 
       
   173     i"%$x = ${compile_val(v)}" ++ compile_exp(e)
       
   174   case KIf(x, e1, e2) => {
       
   175     val if_br = Fresh("if_branch")
       
   176     val else_br = Fresh("else_branch")
       
   177     i"br i1 %$x, label %$if_br, label %$else_br" ++
       
   178     l"\n$if_br" ++
       
   179     compile_exp(e1) ++
       
   180     l"\n$else_br" ++ 
       
   181     compile_exp(e2)
       
   182   }
       
   183 }
       
   184 
       
   185 
       
   186 val prelude = """
       
   187 @.str = private constant [4 x i8] c"%d\0A\00"
       
   188 
       
   189 declare i32 @printf(i8*, ...)
       
   190 
       
   191 define i32 @printInt(i32 %x) {
       
   192    %t0 = getelementptr [4 x i8], [4 x i8]* @.str, i32 0, i32 0
       
   193    call i32 (i8*, ...) @printf(i8* %t0, i32 %x) 
       
   194    ret i32 %x
       
   195 }
       
   196 
       
   197 """
       
   198 
       
   199 
       
   200 // compile function for declarations and main
       
   201 def compile_decl(d: Decl) : String = d match {
       
   202   case Def(name, args, body) => { 
       
   203     m"define i32 @$name (${args.mkString("i32 %", ", i32 %", "")}) {" ++
       
   204     compile_exp(CPSi(body)) ++
       
   205     m"}\n"
       
   206   }
       
   207   case Main(body) => {
       
   208     m"define i32 @main() {" ++
       
   209     compile_exp(CPS(body)(_ => KReturn(KNum(0)))) ++
       
   210     m"}\n"
       
   211   }
       
   212 }
       
   213 
       
   214 // main compiler functions
       
   215 
       
   216 def compile_prog(prog: List[Decl]) : String = 
       
   217   prelude ++ (prog.map(compile_decl).mkString)
       
   218 
       
   219 
       
   220 @main
       
   221 def compile(fname: String) = {
       
   222     val path = os.pwd / fname
       
   223     val file = fname.stripSuffix("." ++ path.ext)
       
   224     val tks = tokenise(os.read(path))
       
   225     val ast = parse_tks(tks)
       
   226     println(compile_prog(ast))
       
   227 }
       
   228 
       
   229 @main
       
   230 def write(fname: String) = {
       
   231     val path = os.pwd / fname
       
   232     val file = fname.stripSuffix("." ++ path.ext)
       
   233     val tks = tokenise(os.read(path))
       
   234     val ast = parse_tks(tks)
       
   235     val code = compile_prog(ast)
       
   236     os.write.over(os.pwd / (file ++ ".ll"), code)
       
   237 }
       
   238 
       
   239 @main
       
   240 def run(fname: String) = {
       
   241     val path = os.pwd / fname
       
   242     val file = fname.stripSuffix("." ++ path.ext)
       
   243     val tks = tokenise(os.read(path))
       
   244     val ast = parse_tks(tks)
       
   245     val code = compile_prog(ast)
       
   246     os.write.over(os.pwd / (file ++ ".ll"), code)
       
   247     os.proc("llc", "-filetype=obj", file ++ ".ll").call()
       
   248     os.proc("gcc", file ++ ".o", "-o", file).call()
       
   249     print(os.proc(os.pwd / file).call().out.string)
       
   250 }
       
   251 
       
   252 
       
   253 
       
   254 
       
   255