progs/fun/fun_llvm.scala
changeset 734 5d860ff01938
parent 705 bfc8703b1527
equal deleted inserted replaced
733:022e2cb1668d 734:5d860ff01938
       
     1 // A Small LLVM Compiler for a Simple Functional Language
       
     2 // (includes an external lexer and parser)
       
     3 //
       
     4 // call with 
       
     5 //
       
     6 //     scala fun_llvm.scala fact
       
     7 //
       
     8 //     scala fun_llvm.scala defs
       
     9 //
       
    10 // this will generate a .ll file. You can interpret this file
       
    11 // using lli.
       
    12 //
       
    13 // The optimiser can be invoked as
       
    14 //
       
    15 //      opt -O1 -S in_file.ll > out_file.ll
       
    16 //      opt -O3 -S in_file.ll > out_file.ll
       
    17 //
       
    18 // The code produced for the various architectures can be obtains with
       
    19 //   
       
    20 //   llc -march=x86 -filetype=asm in_file.ll -o -
       
    21 //   llc -march=arm -filetype=asm in_file.ll -o -  
       
    22 //
       
    23 // Producing an executable can be achieved by
       
    24 //
       
    25 //    llc -filetype=obj in_file.ll
       
    26 //    gcc in_file.o -o a.out
       
    27 //    ./a.out
       
    28 
       
    29 
       
    30 
       
    31 object Compiler {
       
    32 
       
    33 import java.io._  
       
    34 import scala.util._
       
    35 import scala.sys.process._
       
    36 
       
    37 // Abstract syntax trees for the Fun language
       
    38 abstract class Exp extends Serializable 
       
    39 abstract class BExp extends Serializable 
       
    40 abstract class Decl extends Serializable 
       
    41 
       
    42 case class Def(name: String, args: List[String], body: Exp) extends Decl
       
    43 case class Main(e: Exp) extends Decl
       
    44 
       
    45 case class Call(name: String, args: List[Exp]) extends Exp
       
    46 case class If(a: BExp, e1: Exp, e2: Exp) extends Exp
       
    47 case class Write(e: Exp) extends Exp
       
    48 case class Var(s: String) extends Exp
       
    49 case class Num(i: Int) extends Exp
       
    50 case class Aop(o: String, a1: Exp, a2: Exp) extends Exp
       
    51 case class Sequence(e1: Exp, e2: Exp) extends Exp
       
    52 case class Bop(o: String, a1: Exp, a2: Exp) extends BExp
       
    53 
       
    54 
       
    55 // for generating new labels
       
    56 var counter = -1
       
    57 
       
    58 def Fresh(x: String) = {
       
    59   counter += 1
       
    60   x ++ "_" ++ counter.toString()
       
    61 }
       
    62 
       
    63 // Internal CPS language for FUN
       
    64 abstract class KExp
       
    65 abstract class KVal
       
    66 
       
    67 case class KVar(s: String) extends KVal
       
    68 case class KNum(i: Int) extends KVal
       
    69 case class Kop(o: String, v1: KVal, v2: KVal) extends KVal
       
    70 case class KCall(o: String, vrs: List[KVal]) extends KVal
       
    71 case class KWrite(v: KVal) extends KVal
       
    72 
       
    73 case class KIf(x1: String, e1: KExp, e2: KExp) extends KExp {
       
    74   override def toString = s"KIf $x1\nIF\n$e1\nELSE\n$e2"
       
    75 }
       
    76 case class KLet(x: String, e1: KVal, e2: KExp) extends KExp {
       
    77   override def toString = s"let $x = $e1 in \n$e2" 
       
    78 }
       
    79 case class KReturn(v: KVal) extends KExp
       
    80 
       
    81 
       
    82 // CPS translation from Exps to KExps using a
       
    83 // continuation k.
       
    84 def CPS(e: Exp)(k: KVal => KExp) : KExp = e match {
       
    85   case Var(s) => k(KVar(s)) 
       
    86   case Num(i) => k(KNum(i))
       
    87   case Aop(o, e1, e2) => {
       
    88     val z = Fresh("tmp")
       
    89     CPS(e1)(y1 => 
       
    90       CPS(e2)(y2 => KLet(z, Kop(o, y1, y2), k(KVar(z)))))
       
    91   }
       
    92   case If(Bop(o, b1, b2), e1, e2) => {
       
    93     val z = Fresh("tmp")
       
    94     CPS(b1)(y1 => 
       
    95       CPS(b2)(y2 => 
       
    96         KLet(z, Kop(o, y1, y2), KIf(z, CPS(e1)(k), CPS(e2)(k)))))
       
    97   }
       
    98   case Call(name, args) => {
       
    99     def aux(args: List[Exp], vs: List[KVal]) : KExp = args match {
       
   100       case Nil => {
       
   101           val z = Fresh("tmp")
       
   102           KLet(z, KCall(name, vs), k(KVar(z)))
       
   103       }
       
   104       case e::es => CPS(e)(y => aux(es, vs ::: List(y)))
       
   105     }
       
   106     aux(args, Nil)
       
   107   }
       
   108   case Sequence(e1, e2) => 
       
   109     CPS(e1)(_ => CPS(e2)(y2 => k(y2)))
       
   110   case Write(e) => {
       
   111     val z = Fresh("tmp")
       
   112     CPS(e)(y => KLet(z, KWrite(y), k(KVar(z))))
       
   113   }
       
   114 }   
       
   115 
       
   116 //initial continuation
       
   117 def CPSi(e: Exp) = CPS(e)(KReturn)
       
   118 
       
   119 // some testcases
       
   120 val e1 = Aop("*", Var("a"), Num(3))
       
   121 CPSi(e1)
       
   122 
       
   123 val e2 = Aop("+", Aop("*", Var("a"), Num(3)), Num(4))
       
   124 CPSi(e2)
       
   125 
       
   126 val e3 = Aop("+", Num(2), Aop("*", Var("a"), Num(3)))
       
   127 CPSi(e3)
       
   128 
       
   129 val e4 = Aop("+", Aop("-", Num(1), Num(2)), Aop("*", Var("a"), Num(3)))
       
   130 CPSi(e4)
       
   131 
       
   132 val e5 = If(Bop("==", Num(1), Num(1)), Num(3), Num(4))
       
   133 CPSi(e5)
       
   134 
       
   135 val e6 = If(Bop("!=", Num(10), Num(10)), e5, Num(40))
       
   136 CPSi(e6)
       
   137 
       
   138 val e7 = Call("foo", List(Num(3)))
       
   139 CPSi(e7)
       
   140 
       
   141 val e8 = Call("foo", List(Aop("*", Num(3), Num(1)), Num(4), Aop("+", Num(5), Num(6))))
       
   142 CPSi(e8)
       
   143 
       
   144 val e9 = Sequence(Aop("*", Var("a"), Num(3)), Aop("+", Var("b"), Num(6)))
       
   145 CPSi(e9)
       
   146 
       
   147 val e = Aop("*", Aop("+", Num(1), Call("foo", List(Var("a"), Num(3)))), Num(4))
       
   148 CPSi(e)
       
   149 
       
   150 
       
   151 
       
   152 
       
   153 // convenient string interpolations 
       
   154 // for instructions, labels and methods
       
   155 import scala.language.implicitConversions
       
   156 import scala.language.reflectiveCalls
       
   157 
       
   158 implicit def sring_inters(sc: StringContext) = new {
       
   159     def i(args: Any*): String = "   " ++ sc.s(args:_*) ++ "\n"
       
   160     def l(args: Any*): String = sc.s(args:_*) ++ ":\n"
       
   161     def m(args: Any*): String = sc.s(args:_*) ++ "\n"
       
   162 }
       
   163 
       
   164 // mathematical and boolean operations
       
   165 def compile_op(op: String) = op match {
       
   166   case "+" => "add i32 "
       
   167   case "*" => "mul i32 "
       
   168   case "-" => "sub i32 "
       
   169   case "/" => "sdiv i32 "
       
   170   case "%" => "srem i32 "
       
   171   case "==" => "icmp eq i32 "
       
   172   case "<=" => "icmp sle i32 "    // signed less or equal
       
   173   case "<" => "icmp slt i32 "     // signed less than
       
   174 }
       
   175 
       
   176 def compile_val(v: KVal) : String = v match {
       
   177   case KNum(i) => s"$i"
       
   178   case KVar(s) => s"%$s"
       
   179   case Kop(op, x1, x2) => 
       
   180     s"${compile_op(op)} ${compile_val(x1)}, ${compile_val(x2)}"
       
   181   case KCall(x1, args) => 
       
   182     s"call i32 @$x1 (${args.map(compile_val).mkString("i32 ", ", i32 ", "")})"
       
   183   case KWrite(x1) =>
       
   184     s"call i32 @printInt (i32 ${compile_val(x1)})"
       
   185 }
       
   186 
       
   187 // compile K expressions
       
   188 def compile_exp(a: KExp) : String = a match {
       
   189   case KReturn(v) =>
       
   190     i"ret i32 ${compile_val(v)}"
       
   191   case KLet(x: String, v: KVal, e: KExp) => 
       
   192     i"%$x = ${compile_val(v)}" ++ compile_exp(e)
       
   193   case KIf(x, e1, e2) => {
       
   194     val if_br = Fresh("if_branch")
       
   195     val else_br = Fresh("else_branch")
       
   196     i"br i1 %$x, label %$if_br, label %$else_br" ++
       
   197     l"\n$if_br" ++
       
   198     compile_exp(e1) ++
       
   199     l"\n$else_br" ++ 
       
   200     compile_exp(e2)
       
   201   }
       
   202 }
       
   203 
       
   204 
       
   205 val prelude = """
       
   206 @.str = private constant [4 x i8] c"%d\0A\00"
       
   207 
       
   208 declare i32 @printf(i8*, ...)
       
   209 
       
   210 define i32 @printInt(i32 %x) {
       
   211    %t0 = getelementptr [4 x i8], [4 x i8]* @.str, i32 0, i32 0
       
   212    call i32 (i8*, ...) @printf(i8* %t0, i32 %x) 
       
   213    ret i32 %x
       
   214 }
       
   215 
       
   216 """
       
   217 
       
   218 
       
   219 // compile function for declarations and main
       
   220 def compile_decl(d: Decl) : String = d match {
       
   221   case Def(name, args, body) => { 
       
   222     m"define i32 @$name (${args.mkString("i32 %", ", i32 %", "")}) {" ++
       
   223     compile_exp(CPSi(body)) ++
       
   224     m"}\n"
       
   225   }
       
   226   case Main(body) => {
       
   227     m"define i32 @main() {" ++
       
   228     compile_exp(CPSi(body)) ++
       
   229     m"}\n"
       
   230   }
       
   231 }
       
   232 
       
   233 // main compiler functions
       
   234 
       
   235 def time_needed[T](i: Int, code: => T) = {
       
   236   val start = System.nanoTime()
       
   237   for (j <- 1 to i) code
       
   238   val end = System.nanoTime()
       
   239   (end - start)/(i * 1.0e9)
       
   240 }
       
   241 
       
   242 // for Scala 2.12
       
   243 /*
       
   244 def deserialise[T](file: String) : Try[T] = {
       
   245     val in = new ObjectInputStream(new FileInputStream(new File(file)))
       
   246     val obj = Try(in.readObject().asInstanceOf[T])
       
   247     in.close()
       
   248     obj
       
   249 }
       
   250 */
       
   251 
       
   252 def deserialise[T](fname: String) : Try[T] = {
       
   253   import scala.util.Using
       
   254   Using(new ObjectInputStream(new FileInputStream(fname))) {
       
   255     in => in.readObject.asInstanceOf[T]
       
   256   }
       
   257 }
       
   258 
       
   259 def compile(fname: String) : String = {
       
   260   val ast = deserialise[List[Decl]](fname ++ ".prs").getOrElse(Nil) 
       
   261   prelude ++ (ast.map(compile_decl).mkString)
       
   262 }
       
   263 
       
   264 def compile_to_file(fname: String) = {
       
   265   val output = compile(fname)
       
   266   scala.tools.nsc.io.File(s"${fname}.ll").writeAll(output)
       
   267 }
       
   268 
       
   269 def compile_and_run(fname: String) : Unit = {
       
   270   compile_to_file(fname)
       
   271   (s"llc -filetype=obj ${fname}.ll").!!
       
   272   (s"gcc ${fname}.o -o a.out").!!
       
   273   println("Time: " + time_needed(2, (s"./a.out").!))
       
   274 }
       
   275 
       
   276 // some examples of .fun files
       
   277 //compile_to_file("fact")
       
   278 //compile_and_run("fact")
       
   279 //compile_and_run("defs")
       
   280 
       
   281 
       
   282 def main(args: Array[String]) : Unit = 
       
   283    //println(compile(args(0)))
       
   284    compile_and_run(args(0))
       
   285 }
       
   286 
       
   287 
       
   288 
       
   289 
       
   290