diff -r 54a483a33763 -r 02ef5c3abc51 solutions/cw5/fun_llvm.sc --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/solutions/cw5/fun_llvm.sc Fri Nov 04 12:07:40 2022 +0000 @@ -0,0 +1,412 @@ +// A Small LLVM Compiler for a Simple Functional Language +// (includes an external lexer and parser) +// +// +// call with -- prints out llvm code +// +// amm fun_llvm.sc main fact.fun +// amm fun_llvm.sc main defs.fun +// +// or -- writes llvm code to disk +// +// amm fun_llvm.sc write fact.fun +// amm fun_llvm.sc write defs.fun +// +// this will generate an .ll file. +// +// or -- runs the generated llvm code via lli +// +// amm fun_llvm.sc run fact.fun +// amm fun_llvm.sc run defs.fun +// +// +// You can interpret an .ll file using lli, for example +// +// lli fact.ll +// +// The optimiser can be invoked as +// +// opt -O1 -S in_file.ll > out_file.ll +// opt -O3 -S in_file.ll > out_file.ll +// +// The code produced for the various architectures can be obtain with +// +// llc -march=x86 -filetype=asm in_file.ll -o - +// llc -march=arm -filetype=asm in_file.ll -o - +// +// Producing an executable can be achieved by +// +// llc -filetype=obj in_file.ll +// gcc in_file.o -o a.out +// ./a.out + + +import $file.fun_tokens, fun_tokens._ +import $file.fun_parser, fun_parser._ + + +// for generating new labels +var counter = -1 + +def Fresh(x: String) = { + counter += 1 + x ++ "_" ++ counter.toString() +} + +// Internal CPS language for FUN +abstract class KExp +abstract class KVal + +type Ty = String +type TyEnv = Map[String, Ty] + +case class KVar(s: String, ty: Ty = "UNDEF") extends KVal +case class KLoad(v: KVal) extends KVal +case class KNum(i: Int) extends KVal +case class KFNum(i: Double) extends KVal +case class KChr(c: Int) extends KVal +case class Kop(o: String, v1: KVal, v2: KVal, ty: Ty = "UNDEF") extends KVal +case class KCall(o: String, vrs: List[KVal], ty: Ty = "UNDEF") extends KVal + +case class KIf(x1: String, e1: KExp, e2: KExp) extends KExp { + override def toString = s"KIf $x1\nIF\n$e1\nELSE\n$e2" +} +case class KLet(x: String, e1: KVal, e2: KExp) extends KExp { + override def toString = s"let $x = $e1 in \n$e2" +} +case class KReturn(v: KVal) extends KExp + +// typing K values +def typ_val(v: KVal, ts: TyEnv) : (KVal, Ty) = v match { + case KVar(s, _) => { + val ty = ts.getOrElse(s, "TUNDEF") + (KVar(s, ty), ty) + } + case Kop(op, v1, v2, _) => { + val (tv1, ty1) = typ_val(v1, ts) + val (tv2, ty2) = typ_val(v2, ts) + if (ty1 == ty2) (Kop(op, tv1, tv2, ty1), ty1) else (Kop(op, tv1, tv2, "TMISMATCH"), "TMISMATCH") + } + case KCall(fname, args, _) => { + val ty = ts.getOrElse(fname, "TCALLUNDEF" ++ fname) + (KCall(fname, args.map(typ_val(_, ts)._1), ty), ty) + } + case KLoad(v) => { + val (tv, ty) = typ_val(v, ts) + (KLoad(tv), ty) + } + case KNum(i) => (KNum(i), "Int") + case KFNum(i) => (KFNum(i), "Double") + case KChr(c) => (KChr(c), "Int") +} + +def typ_exp(a: KExp, ts: TyEnv) : KExp = a match { + case KReturn(v) => KReturn(typ_val(v, ts)._1) + case KLet(x: String, v: KVal, e: KExp) => { + val (tv, ty) = typ_val(v, ts) + KLet(x, tv, typ_exp(e, ts + (x -> ty))) + } + case KIf(b, e1, e2) => KIf(b, typ_exp(e1, ts), typ_exp(e2, ts)) +} + + + + +// CPS translation from Exps to KExps using a +// continuation k. +def CPS(e: Exp)(k: KVal => KExp) : KExp = e match { + case Var(s) if (s.head.isUpper) => { + val z = Fresh("tmp") + KLet(z, KLoad(KVar(s)), k(KVar(z))) + } + case Var(s) => k(KVar(s)) + case Num(i) => k(KNum(i)) + case ChConst(c) => k(KChr(c)) + case FNum(i) => k(KFNum(i)) + case Aop(o, e1, e2) => { + val z = Fresh("tmp") + CPS(e1)(y1 => + CPS(e2)(y2 => KLet(z, Kop(o, y1, y2), k(KVar(z))))) + } + case If(Bop(o, b1, b2), e1, e2) => { + val z = Fresh("tmp") + CPS(b1)(y1 => + CPS(b2)(y2 => + KLet(z, Kop(o, y1, y2), KIf(z, CPS(e1)(k), CPS(e2)(k))))) + } + case Call(name, args) => { + def aux(args: List[Exp], vs: List[KVal]) : KExp = args match { + case Nil => { + val z = Fresh("tmp") + KLet(z, KCall(name, vs), k(KVar(z))) + } + case e::es => CPS(e)(y => aux(es, vs ::: List(y))) + } + aux(args, Nil) + } + case Sequence(e1, e2) => + CPS(e1)(_ => CPS(e2)(y2 => k(y2))) +} + +//initial continuation +def CPSi(e: Exp) = CPS(e)(KReturn) + +// some testcases +val e1 = Aop("*", Var("a"), Num(3)) +CPSi(e1) + +val e2 = Aop("+", Aop("*", Var("a"), Num(3)), Num(4)) +CPSi(e2) + +val e3 = Aop("+", Num(2), Aop("*", Var("a"), Num(3))) +CPSi(e3) + +val e4 = Aop("+", Aop("-", Num(1), Num(2)), Aop("*", Var("a"), Num(3))) +CPSi(e4) + +val e5 = If(Bop("==", Num(1), Num(1)), Num(3), Num(4)) +CPSi(e5) + +val e6 = If(Bop("!=", Num(10), Num(10)), e5, Num(40)) +CPSi(e6) + +val e7 = Call("foo", List(Num(3))) +CPSi(e7) + +val e8 = Call("foo", List(Aop("*", Num(3), Num(1)), Num(4), Aop("+", Num(5), Num(6)))) +CPSi(e8) + +val e9 = Sequence(Aop("*", Var("a"), Num(3)), Aop("+", Var("b"), Num(6))) +CPSi(e9) + +val e = Aop("*", Aop("+", Num(1), Call("foo", List(Var("a"), Num(3)))), Num(4)) +CPSi(e) + + + + +// convenient string interpolations +// for instructions, labels and methods +import scala.language.implicitConversions +import scala.language.reflectiveCalls + + + + +implicit def sring_inters(sc: StringContext) = new { + def i(args: Any*): String = " " ++ sc.s(args:_*) ++ "\n" + def l(args: Any*): String = sc.s(args:_*) ++ ":\n" + def m(args: Any*): String = sc.s(args:_*) ++ "\n" +} + +def get_ty(s: String) = s match { + case "Double" => "double" + case "Void" => "void" + case "Int" => "i32" + case "Bool" => "i2" + case _ => s +} + +def compile_call_arg(a: KVal) = a match { + case KNum(i) => s"i32 $i" + case KFNum(i) => s"double $i" + case KChr(c) => s"i32 $c" + case KVar(s, ty) => s"${get_ty(ty)} %$s" +} + +def compile_arg(s: (String, String)) = s"${get_ty(s._2)} %${s._1}" + + +// mathematical and boolean operations +def compile_op(op: String) = op match { + case "+" => "add i32 " + case "*" => "mul i32 " + case "-" => "sub i32 " + case "/" => "sdiv i32 " + case "%" => "srem i32 " + case "==" => "icmp eq i32 " + case "!=" => "icmp ne i32 " // not equal + case "<=" => "icmp sle i32 " // signed less or equal + case "<" => "icmp slt i32 " // signed less than +} + +def compile_dop(op: String) = op match { + case "+" => "fadd double " + case "*" => "fmul double " + case "-" => "fsub double " + case "==" => "fcmp oeq double " + case "<=" => "fcmp ole double " + case "<" => "fcmp olt double " +} + +// compile K values +def compile_val(v: KVal) : String = v match { + case KNum(i) => s"$i" + case KFNum(i) => s"$i" + case KChr(c) => s"$c" + case KVar(s, ty) => s"%$s" + case KLoad(KVar(s, ty)) => s"load ${get_ty(ty)}, ${get_ty(ty)}* @$s" + case Kop(op, x1, x2, ty) => ty match { + case "Int" => s"${compile_op(op)} ${compile_val(x1)}, ${compile_val(x2)}" + case "Double" => s"${compile_dop(op)} ${compile_val(x1)}, ${compile_val(x2)}" + case _ => Kop(op, x1, x2, ty).toString + } + case KCall(fname, args, ty) => + s"call ${get_ty(ty)} @$fname (${args.map(compile_call_arg).mkString(", ")})" +} + +// compile K expressions +def compile_exp(a: KExp) : String = a match { + case KReturn(KVar("void", _)) => + i"ret void" + case KReturn(KVar(x, ty)) => + i"ret ${get_ty(ty)} %$x" + case KReturn(KNum(i)) => + i"ret i32 $i" + case KLet(x: String, KCall(o: String, vrs: List[KVal], "Void"), e: KExp) => + i"${compile_val(KCall(o: String, vrs: List[KVal], "Void"))}" ++ compile_exp(e) + case KLet(x: String, v: KVal, e: KExp) => + i"%$x = ${compile_val(v)}" ++ compile_exp(e) + case KIf(x, e1, e2) => { + val if_br = Fresh("if_branch") + val else_br = Fresh("else_branch") + i"br i1 %$x, label %$if_br, label %$else_br" ++ + l"\n$if_br" ++ + compile_exp(e1) ++ + l"\n$else_br" ++ + compile_exp(e2) + } +} + + +val prelude = """ +declare i32 @printf(i8*, ...) + +@.str_nl = private constant [2 x i8] c"\0A\00" +@.str_star = private constant [2 x i8] c"*\00" +@.str_space = private constant [2 x i8] c" \00" + +define void @new_line() #0 { + %t0 = getelementptr [2 x i8], [2 x i8]* @.str_nl, i32 0, i32 0 + %1 = call i32 (i8*, ...) @printf(i8* %t0) + ret void +} + +define void @print_star() #0 { + %t0 = getelementptr [2 x i8], [2 x i8]* @.str_star, i32 0, i32 0 + %1 = call i32 (i8*, ...) @printf(i8* %t0) + ret void +} + +define void @print_space() #0 { + %t0 = getelementptr [2 x i8], [2 x i8]* @.str_space, i32 0, i32 0 + %1 = call i32 (i8*, ...) @printf(i8* %t0) + ret void +} + +define void @skip() #0 { + ret void +} + +@.str_int = private constant [3 x i8] c"%d\00" + +define void @print_int(i32 %x) { + %t0 = getelementptr [3 x i8], [3 x i8]* @.str_int, i32 0, i32 0 + call i32 (i8*, ...) @printf(i8* %t0, i32 %x) + ret void +} + +@.str_char = private constant [3 x i8] c"%c\00" + +define void @print_char(i32 %x) { + %t0 = getelementptr [3 x i8], [3 x i8]* @.str_char, i32 0, i32 0 + call i32 (i8*, ...) @printf(i8* %t0, i32 %x) + ret void +} + +; END OF BUILD-IN FUNCTIONS (prelude) + +""" + +def get_cont(ty: Ty) = ty match { + case "Int" => KReturn + case "Double" => KReturn + case "Void" => { (_: KVal) => KReturn(KVar("void", "Void")) } +} + +// compile function for declarations and main +def compile_decl(d: Decl, ts: TyEnv) : (String, TyEnv) = d match { + case Def(name, args, ty, body) => { + val ts2 = ts + (name -> ty) + val tkbody = typ_exp(CPS(body)(get_cont(ty)), ts2 ++ args.toMap) + (m"define ${get_ty(ty)} @$name (${args.map(compile_arg).mkString(",")}) {" ++ + compile_exp(tkbody) ++ + m"}\n", ts2) + } + case Main(body) => { + val tbody = typ_exp(CPS(body)(_ => KReturn(KNum(0))), ts) + (m"define i32 @main() {" ++ + compile_exp(tbody) ++ + m"}\n", ts) + } + case Const(name, n) => { + (m"@$name = global i32 $n\n", ts + (name -> "Int")) + } + case FConst(name, x) => { + (m"@$name = global double $x\n", ts + (name -> "Double")) + } +} + +def compile_prog(prog: List[Decl], ty: TyEnv) : String = prog match { + case Nil => "" + case d::ds => { + val (s2, ty2) = compile_decl(d, ty) + s2 ++ compile_prog(ds, ty2) + } +} +// main compiler functions +def compile(prog: List[Decl]) : String = + prelude ++ compile_prog(prog, Map("new_line" -> "Void", "skip" -> "Void", + "print_star" -> "Void", "print_space" -> "Void", + "print_int" -> "Void", "print_char" -> "Void")) + + +//import ammonite.ops._ + + +@main +def main(fname: String) = { + val path = os.pwd / fname + val file = fname.stripSuffix("." ++ path.ext) + val tks = tokenise(os.read(path)) + val ast = parse_tks(tks) + val code = compile(ast) + println(code) +} + +@main +def write(fname: String) = { + val path = os.pwd / fname + val file = fname.stripSuffix("." ++ path.ext) + val tks = tokenise(os.read(path)) + val ast = parse_tks(tks) + val code = compile(ast) + //println(code) + os.write.over(os.pwd / (file ++ ".ll"), code) +} + +@main +def run(fname: String) = { + val path = os.pwd / fname + val file = fname.stripSuffix("." ++ path.ext) + write(fname) + os.proc("llc", "-filetype=obj", file ++ ".ll").call() + os.proc("gcc", file ++ ".o", "-o", file ++ ".bin").call() + os.proc(os.pwd / (file ++ ".bin")).call(stdout = os.Inherit) + println(s"done.") +} + + + + +