// Author: Zhuo Ying Jiang Li
// Starting code by Dr Christian Urban
//
// Use amm compiler.sc XXX.fun
// ./XXX
// This will generate XXX.ll, XXX.o as well as the binary program.
//
// lexer + parser
import $file.fun_tokens, fun_tokens._
import $file.fun_parser, fun_parser._
// for generating new labels
var counter = -1
def Fresh(x: String) = {
counter += 1
x ++ "_" ++ counter.toString()
}
// typing
type Ty = String
type TyEnv = Map[String, Ty]
// initial typing environment
val initialEnv = Map[String, Ty]("skip" -> "Void", "print_int" -> "Void", "print_char" -> "Void",
"print_space" -> "Void", "print_star" -> "Void", "new_line" -> "Void")
val typeConversion = Map("Int" -> "i32", "Double" -> "double", "Void" -> "void")
// Internal CPS language for FUN
abstract class KExp
abstract class KVal
case class KVar(s: String, ty: Ty = "UNDEF") extends KVal
case class KConst(s: String, ty: Ty = "UNDEF") extends KVal
case class KNum(i: Int) extends KVal // known type
case class KFNum(d: Float) extends KVal // known type
case class KChConst(c: Int) extends KVal // known type
case class Kop(o: String, v1: KVal, v2: KVal, ty: Ty = "UNDEF") extends KVal
case class KCall(o: String, vrs: List[KVal], ty: Ty = "UNDEF") extends KVal
case class KLet(x: String, e1: KVal, e2: KExp) extends KExp {
override def toString = s"LET $x = $e1 in \n$e2"
}
case class KIf(x1: String, e1: KExp, e2: KExp) extends KExp {
def pad(e: KExp) = e.toString.replaceAll("(?m)^", " ")
override def toString =
s"IF $x1\nTHEN\n${pad(e1)}\nELSE\n${pad(e2)}"
}
case class KReturn(v: KVal) extends KExp
// CPS translation from Exps to KExps using a
// continuation k.
def CPS(e: Exp)(k: KVal => KExp) : KExp = e match {
case Var(s) => {
if (s.head.isUpper) { // if this variable is a global
val z = Fresh("tmp")
KLet(z, KConst(s), k(KVar(z)))
} else k(KVar(s))
}
case Num(i) => k(KNum(i))
case FNum(d) => k(KFNum(d))
case ChConst(c) => k(KChConst(c))
case Aop(o, e1, e2) => {
val z = Fresh("tmp")
CPS(e1)(y1 =>
CPS(e2)(y2 => KLet(z, Kop(o, y1, y2), k(KVar(z)))))
}
case If(Bop(o, b1, b2), e1, e2) => {
val z = Fresh("tmp")
CPS(b1)(y1 =>
CPS(b2)(y2 =>
KLet(z, Kop(o, y1, y2), KIf(z, CPS(e1)(k), CPS(e2)(k)))))
}
case Call(name, args) => {
def aux(args: List[Exp], vs: List[KVal]) : KExp = args match {
case Nil => {
val z = Fresh("tmp")
KLet(z, KCall(name, vs), k(KVar(z)))
}
case e::es => CPS(e)(y => aux(es, vs ::: List(y)))
}
aux(args, Nil)
}
case Sequence(e1, e2) =>
CPS(e1)(_ => CPS(e2)(y2 => k(y2)))
}
// initial continuation
def CPSi(e: Exp) = CPS(e)(KReturn)
// get type of KVal
def get_typ_val(v: KVal) : Ty = v match {
case KNum(i) => "Int"
case KFNum(d) => "Double"
case KChConst(i) => "Int"
case KVar(name, ty) => ty
case KConst(name, ty) => ty
case Kop(o, v1, v2, ty) => ty
case KCall(o, vrs, ty) => ty
}
// update type information for KValues
def typ_val(v: KVal, ts: TyEnv) : KVal = v match {
case KVar(name, ty) => {
if (ts.contains(name)) {
KVar(name, ts(name))
} else throw new Exception(s"Compile error: unknown type for $name")
}
case KConst(name, ty) => {
if (ts.contains(name)) {
KConst(name, ts(name))
} else throw new Exception(s"Compile error: unknown type for $name")
}
case Kop(o, v1, v2, ty) => {
val tv1 = typ_val(v1, ts)
val tv2 = typ_val(v2, ts)
val t1 = get_typ_val(tv1)
val t2 = get_typ_val(tv2)
if (t1 != t2) throw new Exception(s"Compile error: cannot compare $t1 with $t2")
Kop(o, tv1, tv2, t1)
}
case KCall(o, vrs, ty) => {
val new_vrs = vrs.map(vr => typ_val(vr, ts))
if (ts.contains(o)) {
KCall(o, new_vrs, ts(o))
} else throw new Exception(s"Compile error: unknown type for $o")
}
case x => x // no changes: KNum, KFNum, KChConst
}
// update type information for KExpressions
def typ_exp(a: KExp, ts: TyEnv) : KExp = a match {
case KLet(x, e1, e2) => {
val te1 = typ_val(e1, ts)
val env1 = ts + (x -> get_typ_val(te1))
val te2 = typ_exp(e2, env1)
KLet(x, te1, te2)
}
case KIf(x1, e1, e2) => KIf(x1, typ_exp(e1, ts), typ_exp(e2, ts))
case KReturn(v) => KReturn(typ_val(v, ts))
}
// prelude
val prelude = """
declare i32 @printf(i8*, ...)
@.str_nl = private constant [2 x i8] c"\0A\00"
@.str_star = private constant [2 x i8] c"*\00"
@.str_space = private constant [2 x i8] c" \00"
@.str_int = private constant [3 x i8] c"%d\00"
@.str_c = private constant [3 x i8] c"%c\00"
define void @new_line() #0 {
%t0 = getelementptr [2 x i8], [2 x i8]* @.str_nl, i32 0, i32 0
call i32 (i8*, ...) @printf(i8* %t0)
ret void
}
define void @print_star() #0 {
%t0 = getelementptr [2 x i8], [2 x i8]* @.str_star, i32 0, i32 0
call i32 (i8*, ...) @printf(i8* %t0)
ret void
}
define void @print_space() #0 {
%t0 = getelementptr [2 x i8], [2 x i8]* @.str_space, i32 0, i32 0
call i32 (i8*, ...) @printf(i8* %t0)
ret void
}
define void @print_int(i32 %x) {
%t0 = getelementptr [3 x i8], [3 x i8]* @.str_int, i32 0, i32 0
call i32 (i8*, ...) @printf(i8* %t0, i32 %x)
ret void
}
define void @print_char(i32 %x) {
%t0 = getelementptr [3 x i8], [3 x i8]* @.str_c, i32 0, i32 0
call i32 (i8*, ...) @printf(i8* %t0, i32 %x)
ret void
}
define void @skip() #0 {
ret void
}
; END OF BUILT-IN FUNCTIONS (prelude)
"""
// convenient string interpolations
// for instructions, labels and methods
extension (sc: StringContext) {
def i(args: Any*): String = " " ++ sc.s(args:_*) ++ "\n"
def l(args: Any*): String = sc.s(args:_*) ++ ":\n"
def m(args: Any*): String = sc.s(args:_*) ++ "\n"
}
// mathematical and boolean operations
def compile_op(op: String) = op match {
case "+" => "add i32 "
case "*" => "mul i32 "
case "-" => "sub i32 "
case "/" => "sdiv i32 "
case "%" => "srem i32 "
case "==" => "icmp eq i32 "
case "!=" => "icmp ne i32 "
case "<=" => "icmp sle i32 "
case "<" => "icmp slt i32 "
case ">=" => "icmp sge i32 "
case ">" => "icmp sgt i32 "
}
def compile_dop(op: String) = op match {
case "+" => "fadd double "
case "*" => "fmul double "
case "-" => "fsub double "
case "/" => "fdiv double "
case "%" => "frem double "
case "==" => "fcmp oeq double "
case "!=" => "fcmp one double "
case "<=" => "fcmp ole double "
case "<" => "fcmp olt double "
case ">=" => "icmp sge double "
case ">" => "icmp sgt double "
}
def compile_args(vrs: List[KVal]) : List[String] = vrs match {
case Nil => Nil
case x::xs => s"${typeConversion(get_typ_val(x))} ${compile_val(x)}" :: compile_args(xs)
}
// compile K values
def compile_val(v: KVal) : String = v match {
case KNum(i) => s"$i"
case KFNum(d) => s"$d"
case KChConst(i) => s"$i" // as integer
case KVar(s, ty) => s"%$s"
case KConst(s, ty) => {
val t = typeConversion(ty)
s"load $t, $t* @$s"
}
case Kop(op, x1, x2, ty) => {
if (ty == "Double") {
s"${compile_dop(op)} ${compile_val(x1)}, ${compile_val(x2)}"
} else if (ty == "Int") {
s"${compile_op(op)} ${compile_val(x1)}, ${compile_val(x2)}"
} else throw new Exception("Compile error: unknown type for comparison")
}
case KCall(x1, args, ty) => {
s"call ${typeConversion(ty)} @$x1 (${compile_args(args).mkString(", ")})"
}
}
// compile K expressions
def compile_exp(a: KExp) : String = a match {
case KReturn(v) => {
val ty = get_typ_val(v)
if (ty == "Void") {
i"ret void"
} else {
i"ret ${typeConversion(ty)} ${compile_val(v)}"
}
}
case KLet(x: String, v: KVal, e: KExp) => {
val tv = get_typ_val(v)
if (tv == "Void") {
i"${compile_val(v)}" ++ compile_exp(e)
} else i"%$x = ${compile_val(v)}" ++ compile_exp(e)
}
case KIf(x, e1, e2) => {
val if_br = Fresh("if_branch")
val else_br = Fresh("else_branch")
i"br i1 %$x, label %$if_br, label %$else_br" ++
l"\n$if_br" ++
compile_exp(e1) ++
l"\n$else_br" ++
compile_exp(e2)
}
}
def compile_def_args(args: List[(String, String)], ts: TyEnv) : (List[String], TyEnv) = args match {
case Nil => (Nil, ts)
case (n, t)::xs => {
if (t == "Void") throw new Exception("Compile error: argument of type void is invalid")
val (rest, env) = compile_def_args(xs, ts + (n -> t))
(s"${typeConversion(t)} %$n" :: rest, env)
}
}
def compile_decl(d: Decl, ts: TyEnv) : (String, TyEnv) = d match {
case Const(name, value) => {
(m"@$name = global i32 $value\n", ts + (name -> "Int"))
}
case FConst(name, value) => {
(m"@$name = global double $value\n", ts + (name -> "Double"))
}
case Def(name, args, ty, body) => {
val (argList, env1) = compile_def_args(args, ts + (name -> ty))
(m"define ${typeConversion(ty)} @$name (${argList.mkString(", ")}) {" ++
compile_exp(typ_exp(CPSi(body), env1)) ++
m"}\n", ts + (name -> ty)) // don't preserve local variables in environment
}
case Main(body) => {
(m"define i32 @main() {" ++
compile_exp(typ_exp(CPS(body)(_ => KReturn(KNum(0))), ts + ("main" -> "Int"))) ++
m"}\n", ts + ("main" -> "Int"))
}
}
// recursively update the typing environment while compiling
def compile_block(prog: List[Decl], ts: TyEnv) : (String, TyEnv) = prog match {
case Nil => ("", ts)
case x::xs => {
val (compiled, env) = compile_decl(x, ts)
val (compiled_block, env1) = compile_block(xs, env)
(compiled ++ compiled_block, env1)
}
}
def fun_compile(prog: List[Decl]) : String = {
val tyenv = initialEnv
val (compiled, _) = compile_block(prog, tyenv)
prelude ++ compiled
}
@main
def main(fname: String) = {
val path = os.pwd / fname
val file = fname.stripSuffix("." ++ path.ext)
val tks = tokenise(os.read(path))
val ast = parse_tks(tks).head
val code = fun_compile(ast)
println(code)
}
@main
def write(fname: String) = {
val path = os.pwd / fname
val file = fname.stripSuffix("." ++ path.ext)
val tks = tokenise(os.read(path))
val ast = parse_tks(tks).head
val code = fun_compile(ast)
//println(code)
os.write.over(os.pwd / (file ++ ".ll"), code)
}
@main
def run(fname: String) = {
val path = os.pwd / fname
val file = fname.stripSuffix("." ++ path.ext)
write(fname)
os.proc("llc", "-filetype=obj", file ++ ".ll").call()
os.proc("gcc", file ++ ".o", "-o", file ++ ".bin").call()
os.proc(os.pwd / (file ++ ".bin")).call(stdout = os.Inherit)
println(s"done.")
}
// for automated testing
@main
def test(fname: String) = write(fname)