progs/fun_llvm.scala
changeset 653 9d7843934d30
parent 650 3031e3379ea3
child 654 fb6192488b91
--- a/progs/fun_llvm.scala	Thu Oct 10 21:04:38 2019 +0100
+++ b/progs/fun_llvm.scala	Sat Oct 12 14:11:10 2019 +0100
@@ -56,60 +56,84 @@
 
 // Abstract syntax trees for the Fun language
 abstract class KExp
+abstract class KVal
 
-case class KVar(s: String) extends KExp
-case class KNum(i: Int) extends KExp
-case class KAop(o: String, x1: String, x2: String) extends KExp
-case class KIfeq(x1: String, x2: String, e1: KExp, e2: KExp) extends KExp {
-  override def toString = s"KIf $x1 == $x2 \nIF\n$e1\nELSE\n$e2"
+case class KVar(s: String) extends KVal
+case class KNum(i: Int) extends KVal
+case class KAop(o: String, v1: KVal, v2: KVal) extends KVal
+case class KBop(o: String, v1: KVal, v2: KVal) extends KVal
+case class KCall(o: String, vrs: List[KVal]) extends KVal
 
+case class KIf(x1: String, e1: KExp, e2: KExp) extends KExp {
+  override def toString = s"KIf $x1\nIF\n$e1\nELSE\n$e2"
 }
-case class KCall(o: String, vrs: List[String]) extends KExp
-case class KLet(x: String, e1: KExp, e2: KExp) extends KExp {
+case class KLet(x: String, e1: KVal, e2: KExp) extends KExp {
   override def toString = s"let $x = $e1 in \n$e2" 
 }
+case class KReturn(v: KVal) extends KExp
 
-def K(e: Exp) : KExp = e match {
-  case Var(s) => KVar(s) 
-  case Num(i) => KNum(i)
-  case Aop(o, a1, a2) => {
-    val x1 = Fresh("tmp")
-    val x2 = Fresh("tmp") 
-    KLet(x1, K(a1), KLet(x2, K(a2), KAop(o, x1, x2)))
-  } 
-  case Call(name: String, args: List[Exp]) => {
-    val args_new = args.map{a => (Fresh("tmp"), K(a))}
-    def aux(as: List[(String, KExp)]) : KExp = as match {
-      case Nil => KCall(name, args_new.map(_._1))
-      case (x, a)::rest => KLet(x, a, aux(rest))
+def CPS(e: Exp)(k: KVal => KExp) : KExp = e match {
+  case Var(s) => k(KVar(s)) 
+  case Num(i) => k(KNum(i))
+  case Aop(o, e1, e2) => {
+    val z = Fresh("tmp")
+    CPS(e1)(y1 => 
+      CPS(e2)(y2 => KLet(z, KAop(o, y1, y2), k(KVar(z)))))
+  }
+  case If(Bop(o, b1, b2), e1, e2) => {
+    val z = Fresh("tmp")
+    CPS(b1)(y1 => 
+      CPS(b2)(y2 => KLet(z, KBop(o, y1, y2), KIf(z, CPS(e1)(k), CPS(e2)(k)))))
+  }
+  case Call(name, args) => {
+    def aux(args: List[Exp], vs: List[KVal]) : KExp = args match {
+      case Nil => {
+          val z = Fresh("tmp")
+          KLet(z, KCall(name, vs), k(KVar(z)))
+      }
+      case e::es => CPS(e)(y => aux(es, vs ::: List(y)))
     }
-    aux(args_new)
+    aux(args, Nil)
+  }
+  case Sequence(e1, e2) => {
+    val z = Fresh("tmp")
+    CPS(e1)(y1 => 
+      CPS(e2)(y2 => KLet("_", y1, KLet(z, y2, k(KVar(z))))))
   }
-  case If(Bop("==", b1, b2), e1, e2) => {
-    val x1 = Fresh("tmp")
-    val x2 = Fresh("tmp") 
-    KLet(x1, K(b1), KLet(x2, K(b2), KIfeq(x1, x2, K(e1), K(e2))))
-  }
-}
+}   
+
+def CPSi(e: Exp) = CPS(e)(KReturn)
+
+val e1 = Aop("*", Var("a"), Num(3))
+CPS(e1)(KReturn)
+
+val e2 = Aop("+", Aop("*", Var("a"), Num(3)), Num(4))
+CPS(e2)(KReturn)
+
+val e3 = Aop("+", Num(2), Aop("*", Var("a"), Num(3)))
+CPS(e3)(KReturn)
 
-def Denest(e: KExp) : KExp = e match {
-  case KLet(xt, e1, e2) => {
-    def insert(e: KExp) : KExp = e match {
-      case KLet(yt, e3, e4) => KLet(yt, e3, insert(e4))
-      case e => KLet(xt, e, Denest(e2))
-    }
-    insert(Denest(e1))  
-  }
-  case KIfeq(x1, x2, e1, e2) =>  KIfeq(x1, x2, Denest(e1), Denest(e2))
-  case _ => e
-}
+val e4 = Aop("+", Aop("-", Num(1), Num(2)), Aop("*", Var("a"), Num(3)))
+CPS(e4)(KReturn)
+
+val e5 = If(Bop("==", Num(1), Num(1)), Num(3), Num(4))
+CPS(e5)(KReturn)
+
+val e6 = If(Bop("!=", Num(10), Num(10)), e5, Num(40))
+CPS(e6)(KReturn)
 
+val e7 = Call("foo", List(Num(3)))
+CPS(e7)(KReturn)
+
+val e8 = Call("foo", List(Num(3), Num(4), Aop("+", Num(5), Num(6))))
+CPS(e8)(KReturn)
+
+val e9 = Sequence(Aop("*", Var("a"), Num(3)), Aop("+", Var("b"), Num(6)))
+CPS(e9)(KReturn)
 
 val e = Aop("*", Aop("+", Num(1), Call("foo", List(Var("a"), Num(3)))), Num(4))
-println(K(e))
-println()
-println(Denest(K(e)))
-println()
+CPS(e)(KReturn)
+
 
 
 
@@ -124,88 +148,64 @@
     def m(args: Any*): String = sc.s(args:_*) ++ "\n"
 }
 
+def compile_op(op: String) = op match {
+  case "+" => "add i32 "
+  case "*" => "mul i32 "
+  case "-" => "sub i32 "
+  case "==" => "icmp eq i32 "
+}
 
-type Env = Map[String, Int]
-
-
+def compile_val(v: KVal) : String = v match {
+  case KNum(i) => s"$i"
+  case KVar(s) => s"%$s"
+  case KAop(op, x1, x2) => 
+    s"${compile_op(op)} ${compile_val(x1)}, ${compile_val(x2)}"
+  case KBop(op, x1, x2) => 
+    s"${compile_op(op)} ${compile_val(x1)}, ${compile_val(x2)}"
+  case KCall(x1, args) => 
+    s"call i32 @$x1 (${args.map(compile_val).mkString("i32 ", ", i32 ", "")})"
+}
 
 // compile K expressions
 def compile_exp(a: KExp) : String = a match {
-  case KNum(i) => s"?$i?"
-  case KVar(s) => s"?$s?"
-  case KAop("+", x1, x2) => s"add i32 %$x1, i32 %$x2"
-  case KAop("-", x1, x2) => s"sub i32 %$x1, i32 %$x2"
-  case KAop("*", x1, x2) => s"mul i32 %$x1, i32 %$x2"
-  case KLet(x: String, e1: KExp, e2: KExp) => {
-    val is1 = compile_exp(e1)
-    val is2 = compile_exp(e2)
-    i"%$x = $is1" ++ is2
-  }
-  case KLet(x: String, e1: KExp, e2: KExp) => {
-    val is1 = compile_exp(e1)
-    val is2 = compile_exp(e2)
-    i"%$x = $is1" ++ is2
-  }
-  case KIfeq(x1, x2, a1, a2) => {
+  case KReturn(v) =>
+    i"ret i32 ${compile_val(v)}"
+  case KLet(x: String, v: KVal, e: KExp) => 
+    i"%$x = ${compile_val(v)}" ++ compile_exp(e)
+  case KIf(x, e1, e2) => {
     val if_br = Fresh("if_br")
     val else_br = Fresh("else_br")
-    val x = Fresh("tmp")
-    i"%$x = icmp eq i32 %$x1, i32 %$x2" ++
     i"br i1 %$x, label %$if_br, label %$else_br" ++
     l"\n$if_br" ++
-    compile_exp(a1) ++
+    compile_exp(e1) ++
     l"\n$else_br" ++ 
-    compile_exp(a2)
-  }
-  case KCall(x1, args) => {
-    s"Call $x1 ($args)"
+    compile_exp(e2)
   }
-/*
-  case Call(name, args) => {
-    val is = "I" * args.length
-    args.map(a => compile_exp(a, env)).mkString ++
-    i"invokestatic XXX/XXX/$name($is)I"
-  }
-  case Sequence(a1, a2) => {
-    compile_exp(a1, env) ++ i"pop" ++ compile_exp(a2, env)
-  }
-  case Write(a1) => {
+}
+
+/*  case Write(a1) => {
     compile_exp(a1, env) ++
     i"dup" ++
     i"invokestatic XXX/XXX/write(I)V"
   }
-  */
-}
+*/
 
-/*
-// compile boolean expressions
-def compile_bexp(b: BExp, env : Env, jmp: String) : String = b match {
-  case Bop("==", a1, a2) => 
-    compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"if_icmpne $jmp"
-  case Bop("!=", a1, a2) => 
-    compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"if_icmpeq $jmp"
-  case Bop("<", a1, a2) => 
-    compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"if_icmpge $jmp"
-  case Bop("<=", a1, a2) => 
-    compile_exp(a1, env) ++ compile_exp(a2, env) ++ i"if_icmpgt $jmp"
-}
-*/
+
 
 // compile function for declarations and main
 def compile_decl(d: Decl) : String = d match {
   case Def(name, args, body) => { 
-    println(s"DEF\n $name ($args) = \nBODY:")
-    println(Denest(K(body)))
-    println()
-    counter = -1
+    //println(s"DEF\n $name ($args) = \nBODY:")
+    //println(CPSi(body))
+    //println()
+    //counter = -1
     m"define i32 @$name (${args.mkString("i32 %", ", i32 %", "")}) {" ++
-    compile_exp(Denest(K(body))) ++
+    compile_exp(CPSi(body)) ++
     m"}\n"
   }
   case Main(body) => {
     m"define i32 @main() {" ++
-    compile_exp(Denest(K(body))) ++
-    i"ret i32 0" ++
+    compile_exp(CPSi(body)) ++
     m"}\n"
   }
 }
@@ -228,9 +228,8 @@
 
 def compile(class_name: String) : String = {
   val ast = deserialise[List[Decl]](class_name ++ ".prs").getOrElse(Nil) 
-  println(ast(0).toString ++ "\n")
-  val instructions = List(ast(0), ast(2)).map(compile_decl).mkString
-  instructions
+  //println(ast(0).toString ++ "\n")
+  ast.map(compile_decl).mkString
 }
 
 /*