progs/fun/simple-cps.sc
author Christian Urban <christian.urban@kcl.ac.uk>
Fri, 29 Nov 2024 18:59:32 +0000
changeset 976 e9eac62928f5
parent 911 df8660143051
permissions -rw-r--r--
updated

// Source language: arithmetic expressions with function calls 
enum Expr {
    case Num(n: Int)
    case Aop(op: String, e1: Expr, e2: Expr)
    case Call(fname: String, args: List[Expr])
}
import Expr._

// Target language 
// "trivial" KValues
enum KVal {
    case KVar(s: String)
    case KNum(n: Int)
    case KAop(op: String, v1: KVal, v2: KVal)
    case KCall(fname: String, args: List[KVal])
}
import KVal._

// KExpressions 
enum KExp {
    case KReturn(v: KVal)
    case KLet(x: String, v: KVal, e: KExp)
}
import KExp._

def pexp(e: KExp): String = e match {
    case KReturn(v) => s"KReturn($v)"
    case KLet(x,e1,e2) => s"KLet($x = ${e1} \n in ${pexp(e2)})"
}

var cnt = -1
def Fresh(s: String) = {
    cnt = cnt + 1
    s"${s}_${cnt}"
}

def CPS(e: Expr)(k: KVal => KExp): KExp = e match { 
    case Num(i) => k(KNum(i))
    case Aop(op, l, r) => {
        val z = Fresh("z")
        CPS(l)(l => 
          CPS(r)(r => KLet(z, KAop(op, l, r), k(KVar(z)))))
    }
    case Call(fname, args) => {
        def aux(args: List[Expr], vs: List[KVal]) : KExp = args match {
            case Nil => {
                val z = Fresh("tmp")
                KLet(z, KCall(fname, vs), k(KVar(z)))
            }
            case a::as => CPS(a)(r => aux(as, vs ::: List(r)))
        }
        aux(args, Nil)  
    }
}

def CPSi(e: Expr) : KExp = CPS(e)(KReturn(_))


//1 + foo(bar(4 * -7), 3, id(12))
val etest = 
    Aop("+", Num(1),
             Call("foo", 
                List(Call("bar", 
                             List(Aop("*", Num(4), Num(-7)))), 
                     Num(3), 
                     Call("id", List(Num(12))))))

println(pexp(CPSi(etest)))

// Constant Folding
def opt(v: KVal, env: Map[String, Int]) : KVal = v match {
    case KVar(s) => if (env.isDefinedAt(s)) KNum(env(s)) else KVar(s)
    case KNum(n) => KNum(n)
    case KAop(op, v1, v2) => (op, opt(v1, env), opt(v2, env)) match {
        case ("+", KNum(n1), KNum(n2)) => KNum(n1 + n2)
        case ("*", KNum(n1), KNum(n2)) => KNum(n1 * n2)
        case (_, v1o, v2o) => KAop(op, v1o, v2o)
    }
    case KCall(fname, args) => KCall(fname, args.map(opt(_, env)))
}
    
def Koptimise(ke: KExp, env: Map[String, Int] = Map()) : KExp = ke match {
    case KReturn(v) => KReturn(opt(v, env)) 
    case KLet(x, v, e) => opt(v, env) match {
        case KNum(n) => Koptimise(e, env + (x -> n))
        case vo => KLet(x, vo, Koptimise(e, env))
    }
}    

println("\n" ++ pexp(Koptimise(CPSi(etest))))