solutions/cw5/fun_llvm.sc
author Christian Urban <christian.urban@kcl.ac.uk>
Tue, 21 Oct 2025 08:41:46 +0200
changeset 1014 8400bbdef1b7
parent 921 bb54e7aa1a3f
permissions -rw-r--r--
updated

// Author: Zhuo Ying Jiang Li
// Starting code by Dr Christian Urban

// 
// Use amm compiler.sc XXX.fun
// ./XXX
// This will generate XXX.ll, XXX.o as well as the binary program.
//

// lexer + parser

import $file.fun_tokens, fun_tokens._
import $file.fun_parser, fun_parser._ 

// for generating new labels
var counter = -1

def Fresh(x: String) = {
  counter += 1
  x ++ "_" ++ counter.toString()
}

// typing
type Ty = String
type TyEnv = Map[String, Ty]

// initial typing environment
val initialEnv = Map[String, Ty]("skip" -> "Void", "print_int" -> "Void", "print_char" -> "Void",
                                "print_space" -> "Void", "print_star" -> "Void", "new_line" -> "Void")

val typeConversion = Map("Int" -> "i32", "Double" -> "double", "Void" -> "void")

// Internal CPS language for FUN
abstract class KExp
abstract class KVal

case class KVar(s: String, ty: Ty = "UNDEF") extends KVal
case class KConst(s: String, ty: Ty = "UNDEF") extends KVal
case class KNum(i: Int) extends KVal  // known type
case class KFNum(d: Float) extends KVal  // known type
case class KChConst(c: Int) extends KVal  // known type
case class Kop(o: String, v1: KVal, v2: KVal, ty: Ty = "UNDEF") extends KVal
case class KCall(o: String, vrs: List[KVal], ty: Ty = "UNDEF") extends KVal

case class KLet(x: String, e1: KVal, e2: KExp) extends KExp {
  override def toString = s"LET $x = $e1 in \n$e2" 
}
case class KIf(x1: String, e1: KExp, e2: KExp) extends KExp {
  def pad(e: KExp) = e.toString.replaceAll("(?m)^", "  ")

  override def toString = 
     s"IF $x1\nTHEN\n${pad(e1)}\nELSE\n${pad(e2)}"
}
case class KReturn(v: KVal) extends KExp

// CPS translation from Exps to KExps using a
// continuation k.
def CPS(e: Exp)(k: KVal => KExp) : KExp = e match {
  case Var(s) => {
    if (s.head.isUpper) {  // if this variable is a global
      val z = Fresh("tmp")
      KLet(z, KConst(s), k(KVar(z)))
    } else k(KVar(s))
  }
  case Num(i) => k(KNum(i))
  case FNum(d) => k(KFNum(d))
  case ChConst(c) => k(KChConst(c))
  case Aop(o, e1, e2) => {
    val z = Fresh("tmp")
    CPS(e1)(y1 => 
      CPS(e2)(y2 => KLet(z, Kop(o, y1, y2), k(KVar(z)))))
  }
  case If(Bop(o, b1, b2), e1, e2) => {
    val z = Fresh("tmp")
    CPS(b1)(y1 => 
      CPS(b2)(y2 => 
        KLet(z, Kop(o, y1, y2), KIf(z, CPS(e1)(k), CPS(e2)(k)))))
  }
  case Call(name, args) => {
    def aux(args: List[Exp], vs: List[KVal]) : KExp = args match {
      case Nil => {
          val z = Fresh("tmp")
          KLet(z, KCall(name, vs), k(KVar(z)))
      }
      case e::es => CPS(e)(y => aux(es, vs ::: List(y)))
    }
    aux(args, Nil)
  }
  case Sequence(e1, e2) => 
    CPS(e1)(_ => CPS(e2)(y2 => k(y2)))
}

// initial continuation
def CPSi(e: Exp) = CPS(e)(KReturn)


// get type of KVal
def get_typ_val(v: KVal) : Ty = v match {
  case KNum(i) => "Int"
  case KFNum(d) => "Double"
  case KChConst(i) => "Int"
  case KVar(name, ty) => ty
  case KConst(name, ty) => ty
  case Kop(o, v1, v2, ty) => ty
  case KCall(o, vrs, ty) => ty
}

// update type information for KValues
def typ_val(v: KVal, ts: TyEnv) : KVal = v match {
  case KVar(name, ty) => {
    if (ts.contains(name)) {
      KVar(name, ts(name))
    } else throw new Exception(s"Compile error: unknown type for $name")
  }
  case KConst(name, ty) => {
    if (ts.contains(name)) {
      KConst(name, ts(name))
    } else throw new Exception(s"Compile error: unknown type for $name")
  }
  case Kop(o, v1, v2, ty) => {
    val tv1 = typ_val(v1, ts)
    val tv2 = typ_val(v2, ts)
    val t1 = get_typ_val(tv1)
    val t2 = get_typ_val(tv2)
    if (t1 != t2) throw new Exception(s"Compile error: cannot compare $t1 with $t2")
    Kop(o, tv1, tv2, t1)
  }
  case KCall(o, vrs, ty) => {
    val new_vrs = vrs.map(vr => typ_val(vr, ts))
    if (ts.contains(o)) {
      KCall(o, new_vrs, ts(o))
    } else throw new Exception(s"Compile error: unknown type for $o")
  }
  case x => x  // no changes: KNum, KFNum, KChConst
}

// update type information for KExpressions
def typ_exp(a: KExp, ts: TyEnv) : KExp = a match {
  case KLet(x, e1, e2) => {
    val te1 = typ_val(e1, ts)
    val env1 = ts + (x -> get_typ_val(te1))
    val te2 = typ_exp(e2, env1)
    KLet(x, te1, te2)
  }
  case KIf(x1, e1, e2) => KIf(x1, typ_exp(e1, ts), typ_exp(e2, ts))
  case KReturn(v) => KReturn(typ_val(v, ts))
}

// prelude
val prelude = """
declare i32 @printf(i8*, ...)

@.str_nl = private constant [2 x i8] c"\0A\00"
@.str_star = private constant [2 x i8] c"*\00"
@.str_space = private constant [2 x i8] c" \00"
@.str_int = private constant [3 x i8] c"%d\00"
@.str_c = private constant [3 x i8] c"%c\00"

define void @new_line() #0 {
  %t0 = getelementptr [2 x i8], [2 x i8]* @.str_nl, i32 0, i32 0
  call i32 (i8*, ...) @printf(i8* %t0)
  ret void
}

define void @print_star() #0 {
  %t0 = getelementptr [2 x i8], [2 x i8]* @.str_star, i32 0, i32 0
  call i32 (i8*, ...) @printf(i8* %t0)
  ret void
}

define void @print_space() #0 {
  %t0 = getelementptr [2 x i8], [2 x i8]* @.str_space, i32 0, i32 0
  call i32 (i8*, ...) @printf(i8* %t0)
  ret void
}

define void @print_int(i32 %x) {
  %t0 = getelementptr [3 x i8], [3 x i8]* @.str_int, i32 0, i32 0
  call i32 (i8*, ...) @printf(i8* %t0, i32 %x) 
  ret void
}

define void @print_char(i32 %x) {
  %t0 = getelementptr [3 x i8], [3 x i8]* @.str_c, i32 0, i32 0
  call i32 (i8*, ...) @printf(i8* %t0, i32 %x)
  ret void
}

define void @skip() #0 {
  ret void
}

; END OF BUILT-IN FUNCTIONS (prelude)
"""

// convenient string interpolations 
// for instructions, labels and methods


extension (sc: StringContext) {
    def i(args: Any*): String = "   " ++ sc.s(args:_*) ++ "\n"
    def l(args: Any*): String = sc.s(args:_*) ++ ":\n"
    def m(args: Any*): String = sc.s(args:_*) ++ "\n"
}

// mathematical and boolean operations
def compile_op(op: String) = op match {
  case "+" => "add i32 "
  case "*" => "mul i32 "
  case "-" => "sub i32 "
  case "/" => "sdiv i32 "
  case "%" => "srem i32 "
  case "==" => "icmp eq i32 "
  case "!=" => "icmp ne i32 "
  case "<=" => "icmp sle i32 "
  case "<"  => "icmp slt i32 "
  case ">=" => "icmp sge i32 "
  case ">" => "icmp sgt i32 "
}

def compile_dop(op: String) = op match {
  case "+" => "fadd double "
  case "*" => "fmul double "
  case "-" => "fsub double "
  case "/" => "fdiv double "
  case "%" => "frem double "
  case "==" => "fcmp oeq double "
  case "!=" => "fcmp one double "
  case "<=" => "fcmp ole double "
  case "<" => "fcmp olt double "
  case ">=" => "icmp sge double "
  case ">" => "icmp sgt double "
}

def compile_args(vrs: List[KVal]) : List[String] = vrs match {
  case Nil => Nil
  case x::xs => s"${typeConversion(get_typ_val(x))} ${compile_val(x)}" :: compile_args(xs)
}

// compile K values
def compile_val(v: KVal) : String = v match {
  case KNum(i) => s"$i"
  case KFNum(d) => s"$d"
  case KChConst(i) => s"$i"  // as integer
  case KVar(s, ty) => s"%$s"
  case KConst(s, ty) => {
    val t = typeConversion(ty)
    s"load $t, $t* @$s"
  }
  case Kop(op, x1, x2, ty) => {
    if (ty == "Double") {
      s"${compile_dop(op)} ${compile_val(x1)}, ${compile_val(x2)}"
    } else if (ty == "Int") {
      s"${compile_op(op)} ${compile_val(x1)}, ${compile_val(x2)}"
    } else throw new Exception("Compile error: unknown type for comparison")
  }
  case KCall(x1, args, ty) => {
    s"call ${typeConversion(ty)} @$x1 (${compile_args(args).mkString(", ")})"
  }
}

// compile K expressions
def compile_exp(a: KExp) : String = a match {
  case KReturn(v) => {
    val ty = get_typ_val(v)
    if (ty == "Void") {
      i"ret void"
    } else {
      i"ret ${typeConversion(ty)} ${compile_val(v)}"
    }
  }
  case KLet(x: String, v: KVal, e: KExp) => {
    val tv = get_typ_val(v)
    if (tv == "Void") {
      i"${compile_val(v)}" ++ compile_exp(e)
    } else i"%$x = ${compile_val(v)}" ++ compile_exp(e)
  }
  case KIf(x, e1, e2) => {
    val if_br = Fresh("if_branch")
    val else_br = Fresh("else_branch")
    i"br i1 %$x, label %$if_br, label %$else_br" ++
    l"\n$if_br" ++
    compile_exp(e1) ++
    l"\n$else_br" ++ 
    compile_exp(e2)
  }
}

def compile_def_args(args: List[(String, String)], ts: TyEnv) : (List[String], TyEnv) = args match {
  case Nil => (Nil, ts)
  case (n, t)::xs => {
    if (t == "Void") throw new Exception("Compile error: argument of type void is invalid")
    val (rest, env) = compile_def_args(xs, ts + (n -> t))
    (s"${typeConversion(t)} %$n" :: rest, env)
  }
}

def compile_decl(d: Decl, ts: TyEnv) : (String, TyEnv) = d match {
  case Const(name, value) => {
    (m"@$name = global i32 $value\n", ts + (name -> "Int"))
  }
  case FConst(name, value) => {
    (m"@$name = global double $value\n", ts + (name -> "Double"))
  }
  case Def(name, args, ty, body) => {
    val (argList, env1) = compile_def_args(args, ts + (name -> ty))
    (m"define ${typeConversion(ty)} @$name (${argList.mkString(", ")}) {" ++
    compile_exp(typ_exp(CPSi(body), env1)) ++
    m"}\n", ts + (name -> ty))  // don't preserve local variables in environment
  }
  case Main(body) => {
    (m"define i32 @main() {" ++
    compile_exp(typ_exp(CPS(body)(_ => KReturn(KNum(0))), ts + ("main" -> "Int"))) ++
    m"}\n", ts + ("main" -> "Int"))
  }
}

// recursively update the typing environment while compiling
def compile_block(prog: List[Decl], ts: TyEnv) : (String, TyEnv) = prog match {
  case Nil => ("", ts)
  case x::xs => {
    val (compiled, env) = compile_decl(x, ts)
    val (compiled_block, env1) = compile_block(xs, env)
    (compiled ++ compiled_block, env1)
  }
}

def fun_compile(prog: List[Decl]) : String = {
  val tyenv = initialEnv
  val (compiled, _) = compile_block(prog, tyenv)
  prelude ++ compiled
}


@main
def main(fname: String) = {
    val path = os.pwd / fname
    val file = fname.stripSuffix("." ++ path.ext)
    val tks = tokenise(os.read(path))
    val ast = parse_tks(tks).head
    val code = fun_compile(ast)
    println(code)
}

@main
def write(fname: String) = {
    val path = os.pwd / fname
    val file = fname.stripSuffix("." ++ path.ext)
    val tks = tokenise(os.read(path))
    val ast = parse_tks(tks).head
    val code = fun_compile(ast)
    //println(code)
    os.write.over(os.pwd / (file ++ ".ll"), code)
}

@main
def run(fname: String) = {
    val path = os.pwd / fname
    val file = fname.stripSuffix("." ++ path.ext)
    write(fname)  
    os.proc("llc", "-filetype=obj", file ++ ".ll").call()
    os.proc("gcc", file ++ ".o", "-o", file ++ ".bin").call()
    os.proc(os.pwd / (file ++ ".bin")).call(stdout = os.Inherit)
    println(s"done.")
}


// for automated testing 

@main
def test(fname: String) = write(fname)