1 // A Small LLVM Compiler for a Simple Functional Language  | 
         | 
     2 // (includes an external lexer and parser)  | 
         | 
     3 //  | 
         | 
     4 // call with   | 
         | 
     5 //  | 
         | 
     6 //     scala fun_llvm.scala fact  | 
         | 
     7 //  | 
         | 
     8 //     scala fun_llvm.scala defs  | 
         | 
     9 //  | 
         | 
    10 // this will generate a .ll file. You can interpret this file  | 
         | 
    11 // using lli.  | 
         | 
    12 //  | 
         | 
    13 // The optimiser can be invoked as  | 
         | 
    14 //  | 
         | 
    15 //      opt -O1 -S in_file.ll > out_file.ll  | 
         | 
    16 //      opt -O3 -S in_file.ll > out_file.ll  | 
         | 
    17 //  | 
         | 
    18 // The code produced for the various architectures can be obtains with  | 
         | 
    19 //     | 
         | 
    20 //   llc -march=x86 -filetype=asm in_file.ll -o -  | 
         | 
    21 //   llc -march=arm -filetype=asm in_file.ll -o -    | 
         | 
    22 //  | 
         | 
    23 // Producing an executable can be achieved by  | 
         | 
    24 //  | 
         | 
    25 //    llc -filetype=obj in_file.ll  | 
         | 
    26 //    gcc in_file.o -o a.out  | 
         | 
    27 //    ./a.out  | 
         | 
    28   | 
         | 
    29   | 
         | 
    30   | 
         | 
    31 object Compiler { | 
         | 
    32   | 
         | 
    33 import java.io._    | 
         | 
    34 import scala.util._  | 
         | 
    35 import scala.sys.process._  | 
         | 
    36   | 
         | 
    37 // Abstract syntax trees for the Fun language  | 
         | 
    38 abstract class Exp extends Serializable   | 
         | 
    39 abstract class BExp extends Serializable   | 
         | 
    40 abstract class Decl extends Serializable   | 
         | 
    41   | 
         | 
    42 case class Def(name: String, args: List[String], body: Exp) extends Decl  | 
         | 
    43 case class Main(e: Exp) extends Decl  | 
         | 
    44   | 
         | 
    45 case class Call(name: String, args: List[Exp]) extends Exp  | 
         | 
    46 case class If(a: BExp, e1: Exp, e2: Exp) extends Exp  | 
         | 
    47 case class Write(e: Exp) extends Exp  | 
         | 
    48 case class Var(s: String) extends Exp  | 
         | 
    49 case class Num(i: Int) extends Exp  | 
         | 
    50 case class Aop(o: String, a1: Exp, a2: Exp) extends Exp  | 
         | 
    51 case class Sequence(e1: Exp, e2: Exp) extends Exp  | 
         | 
    52 case class Bop(o: String, a1: Exp, a2: Exp) extends BExp  | 
         | 
    53   | 
         | 
    54   | 
         | 
    55 // for generating new labels  | 
         | 
    56 var counter = -1  | 
         | 
    57   | 
         | 
    58 def Fresh(x: String) = { | 
         | 
    59   counter += 1  | 
         | 
    60   x ++ "_" ++ counter.toString()  | 
         | 
    61 }  | 
         | 
    62   | 
         | 
    63 // Internal CPS language for FUN  | 
         | 
    64 abstract class KExp  | 
         | 
    65 abstract class KVal  | 
         | 
    66   | 
         | 
    67 case class KVar(s: String) extends KVal  | 
         | 
    68 case class KNum(i: Int) extends KVal  | 
         | 
    69 case class Kop(o: String, v1: KVal, v2: KVal) extends KVal  | 
         | 
    70 case class KCall(o: String, vrs: List[KVal]) extends KVal  | 
         | 
    71 case class KWrite(v: KVal) extends KVal  | 
         | 
    72   | 
         | 
    73 case class KIf(x1: String, e1: KExp, e2: KExp) extends KExp { | 
         | 
    74   override def toString = s"KIf $x1\nIF\n$e1\nELSE\n$e2"  | 
         | 
    75 }  | 
         | 
    76 case class KLet(x: String, e1: KVal, e2: KExp) extends KExp { | 
         | 
    77   override def toString = s"let $x = $e1 in \n$e2"   | 
         | 
    78 }  | 
         | 
    79 case class KReturn(v: KVal) extends KExp  | 
         | 
    80   | 
         | 
    81   | 
         | 
    82 // CPS translation from Exps to KExps using a  | 
         | 
    83 // continuation k.  | 
         | 
    84 def CPS(e: Exp)(k: KVal => KExp) : KExp = e match { | 
         | 
    85   case Var(s) => k(KVar(s))   | 
         | 
    86   case Num(i) => k(KNum(i))  | 
         | 
    87   case Aop(o, e1, e2) => { | 
         | 
    88     val z = Fresh("tmp") | 
         | 
    89     CPS(e1)(y1 =>   | 
         | 
    90       CPS(e2)(y2 => KLet(z, Kop(o, y1, y2), k(KVar(z)))))  | 
         | 
    91   }  | 
         | 
    92   case If(Bop(o, b1, b2), e1, e2) => { | 
         | 
    93     val z = Fresh("tmp") | 
         | 
    94     CPS(b1)(y1 =>   | 
         | 
    95       CPS(b2)(y2 =>   | 
         | 
    96         KLet(z, Kop(o, y1, y2), KIf(z, CPS(e1)(k), CPS(e2)(k)))))  | 
         | 
    97   }  | 
         | 
    98   case Call(name, args) => { | 
         | 
    99     def aux(args: List[Exp], vs: List[KVal]) : KExp = args match { | 
         | 
   100       case Nil => { | 
         | 
   101           val z = Fresh("tmp") | 
         | 
   102           KLet(z, KCall(name, vs), k(KVar(z)))  | 
         | 
   103       }  | 
         | 
   104       case e::es => CPS(e)(y => aux(es, vs ::: List(y)))  | 
         | 
   105     }  | 
         | 
   106     aux(args, Nil)  | 
         | 
   107   }  | 
         | 
   108   case Sequence(e1, e2) =>   | 
         | 
   109     CPS(e1)(_ => CPS(e2)(y2 => k(y2)))  | 
         | 
   110   case Write(e) => { | 
         | 
   111     val z = Fresh("tmp") | 
         | 
   112     CPS(e)(y => KLet(z, KWrite(y), k(KVar(z))))  | 
         | 
   113   }  | 
         | 
   114 }     | 
         | 
   115   | 
         | 
   116 //initial continuation  | 
         | 
   117 def CPSi(e: Exp) = CPS(e)(KReturn)  | 
         | 
   118   | 
         | 
   119 // some testcases  | 
         | 
   120 val e1 = Aop("*", Var("a"), Num(3)) | 
         | 
   121 CPSi(e1)  | 
         | 
   122   | 
         | 
   123 val e2 = Aop("+", Aop("*", Var("a"), Num(3)), Num(4)) | 
         | 
   124 CPSi(e2)  | 
         | 
   125   | 
         | 
   126 val e3 = Aop("+", Num(2), Aop("*", Var("a"), Num(3))) | 
         | 
   127 CPSi(e3)  | 
         | 
   128   | 
         | 
   129 val e4 = Aop("+", Aop("-", Num(1), Num(2)), Aop("*", Var("a"), Num(3))) | 
         | 
   130 CPSi(e4)  | 
         | 
   131   | 
         | 
   132 val e5 = If(Bop("==", Num(1), Num(1)), Num(3), Num(4)) | 
         | 
   133 CPSi(e5)  | 
         | 
   134   | 
         | 
   135 val e6 = If(Bop("!=", Num(10), Num(10)), e5, Num(40)) | 
         | 
   136 CPSi(e6)  | 
         | 
   137   | 
         | 
   138 val e7 = Call("foo", List(Num(3))) | 
         | 
   139 CPSi(e7)  | 
         | 
   140   | 
         | 
   141 val e8 = Call("foo", List(Aop("*", Num(3), Num(1)), Num(4), Aop("+", Num(5), Num(6)))) | 
         | 
   142 CPSi(e8)  | 
         | 
   143   | 
         | 
   144 val e9 = Sequence(Aop("*", Var("a"), Num(3)), Aop("+", Var("b"), Num(6))) | 
         | 
   145 CPSi(e9)  | 
         | 
   146   | 
         | 
   147 val e = Aop("*", Aop("+", Num(1), Call("foo", List(Var("a"), Num(3)))), Num(4)) | 
         | 
   148 CPSi(e)  | 
         | 
   149   | 
         | 
   150   | 
         | 
   151   | 
         | 
   152   | 
         | 
   153 // convenient string interpolations   | 
         | 
   154 // for instructions, labels and methods  | 
         | 
   155 import scala.language.implicitConversions  | 
         | 
   156 import scala.language.reflectiveCalls  | 
         | 
   157   | 
         | 
   158 implicit def sring_inters(sc: StringContext) = new { | 
         | 
   159     def i(args: Any*): String = "   " ++ sc.s(args:_*) ++ "\n"  | 
         | 
   160     def l(args: Any*): String = sc.s(args:_*) ++ ":\n"  | 
         | 
   161     def m(args: Any*): String = sc.s(args:_*) ++ "\n"  | 
         | 
   162 }  | 
         | 
   163   | 
         | 
   164 // mathematical and boolean operations  | 
         | 
   165 def compile_op(op: String) = op match { | 
         | 
   166   case "+" => "add i32 "  | 
         | 
   167   case "*" => "mul i32 "  | 
         | 
   168   case "-" => "sub i32 "  | 
         | 
   169   case "/" => "sdiv i32 "  | 
         | 
   170   case "%" => "srem i32 "  | 
         | 
   171   case "==" => "icmp eq i32 "  | 
         | 
   172   case "<=" => "icmp sle i32 "    // signed less or equal  | 
         | 
   173   case "<" => "icmp slt i32 "     // signed less than  | 
         | 
   174 }  | 
         | 
   175   | 
         | 
   176 def compile_val(v: KVal) : String = v match { | 
         | 
   177   case KNum(i) => s"$i"  | 
         | 
   178   case KVar(s) => s"%$s"  | 
         | 
   179   case Kop(op, x1, x2) =>   | 
         | 
   180     s"${compile_op(op)} ${compile_val(x1)}, ${compile_val(x2)}" | 
         | 
   181   case KCall(x1, args) =>   | 
         | 
   182     s"call i32 @$x1 (${args.map(compile_val).mkString("i32 ", ", i32 ", "")})" | 
         | 
   183   case KWrite(x1) =>  | 
         | 
   184     s"call i32 @printInt (i32 ${compile_val(x1)})" | 
         | 
   185 }  | 
         | 
   186   | 
         | 
   187 // compile K expressions  | 
         | 
   188 def compile_exp(a: KExp) : String = a match { | 
         | 
   189   case KReturn(v) =>  | 
         | 
   190     i"ret i32 ${compile_val(v)}" | 
         | 
   191   case KLet(x: String, v: KVal, e: KExp) =>   | 
         | 
   192     i"%$x = ${compile_val(v)}" ++ compile_exp(e) | 
         | 
   193   case KIf(x, e1, e2) => { | 
         | 
   194     val if_br = Fresh("if_branch") | 
         | 
   195     val else_br = Fresh("else_branch") | 
         | 
   196     i"br i1 %$x, label %$if_br, label %$else_br" ++  | 
         | 
   197     l"\n$if_br" ++  | 
         | 
   198     compile_exp(e1) ++  | 
         | 
   199     l"\n$else_br" ++   | 
         | 
   200     compile_exp(e2)  | 
         | 
   201   }  | 
         | 
   202 }  | 
         | 
   203   | 
         | 
   204   | 
         | 
   205 val prelude = """  | 
         | 
   206 @.str = private constant [4 x i8] c"%d\0A\00"  | 
         | 
   207   | 
         | 
   208 declare i32 @printf(i8*, ...)  | 
         | 
   209   | 
         | 
   210 define i32 @printInt(i32 %x) { | 
         | 
   211    %t0 = getelementptr [4 x i8], [4 x i8]* @.str, i32 0, i32 0  | 
         | 
   212    call i32 (i8*, ...) @printf(i8* %t0, i32 %x)   | 
         | 
   213    ret i32 %x  | 
         | 
   214 }  | 
         | 
   215   | 
         | 
   216 """  | 
         | 
   217   | 
         | 
   218   | 
         | 
   219 // compile function for declarations and main  | 
         | 
   220 def compile_decl(d: Decl) : String = d match { | 
         | 
   221   case Def(name, args, body) => {  | 
         | 
   222     m"define i32 @$name (${args.mkString("i32 %", ", i32 %", "")}) {" ++ | 
         | 
   223     compile_exp(CPSi(body)) ++  | 
         | 
   224     m"}\n"  | 
         | 
   225   }  | 
         | 
   226   case Main(body) => { | 
         | 
   227     m"define i32 @main() {" ++ | 
         | 
   228     compile_exp(CPSi(body)) ++  | 
         | 
   229     m"}\n"  | 
         | 
   230   }  | 
         | 
   231 }  | 
         | 
   232   | 
         | 
   233 // main compiler functions  | 
         | 
   234   | 
         | 
   235 def time_needed[T](i: Int, code: => T) = { | 
         | 
   236   val start = System.nanoTime()  | 
         | 
   237   for (j <- 1 to i) code  | 
         | 
   238   val end = System.nanoTime()  | 
         | 
   239   (end - start)/(i * 1.0e9)  | 
         | 
   240 }  | 
         | 
   241   | 
         | 
   242 // for Scala 2.12  | 
         | 
   243 /*  | 
         | 
   244 def deserialise[T](file: String) : Try[T] = { | 
         | 
   245     val in = new ObjectInputStream(new FileInputStream(new File(file)))  | 
         | 
   246     val obj = Try(in.readObject().asInstanceOf[T])  | 
         | 
   247     in.close()  | 
         | 
   248     obj  | 
         | 
   249 }  | 
         | 
   250 */  | 
         | 
   251   | 
         | 
   252 def deserialise[T](fname: String) : Try[T] = { | 
         | 
   253   import scala.util.Using  | 
         | 
   254   Using(new ObjectInputStream(new FileInputStream(fname))) { | 
         | 
   255     in => in.readObject.asInstanceOf[T]  | 
         | 
   256   }  | 
         | 
   257 }  | 
         | 
   258   | 
         | 
   259 def compile(fname: String) : String = { | 
         | 
   260   val ast = deserialise[List[Decl]](fname ++ ".prs").getOrElse(Nil)   | 
         | 
   261   prelude ++ (ast.map(compile_decl).mkString)  | 
         | 
   262 }  | 
         | 
   263   | 
         | 
   264 def compile_to_file(fname: String) = { | 
         | 
   265   val output = compile(fname)  | 
         | 
   266   scala.tools.nsc.io.File(s"${fname}.ll").writeAll(output) | 
         | 
   267 }  | 
         | 
   268   | 
         | 
   269 def compile_and_run(fname: String) : Unit = { | 
         | 
   270   compile_to_file(fname)  | 
         | 
   271   (s"llc -filetype=obj ${fname}.ll").!! | 
         | 
   272   (s"gcc ${fname}.o -o a.out").!! | 
         | 
   273   println("Time: " + time_needed(2, (s"./a.out").!)) | 
         | 
   274 }  | 
         | 
   275   | 
         | 
   276 // some examples of .fun files  | 
         | 
   277 //compile_to_file("fact") | 
         | 
   278 //compile_and_run("fact") | 
         | 
   279 //compile_and_run("defs") | 
         | 
   280   | 
         | 
   281   | 
         | 
   282 def main(args: Array[String]) : Unit =   | 
         | 
   283    //println(compile(args(0)))  | 
         | 
   284    compile_and_run(args(0))  | 
         | 
   285 }  | 
         | 
   286   | 
         | 
   287   | 
         | 
   288   | 
         | 
   289   | 
         | 
   290   | 
         |