
datatype bit =
  Z | S | C of char

type bits = bit list

datatype rexp =
   ZERO
 | ONE 
 | CHAR of char 
 | ALTS of rexp list
 | SEQ  of rexp * rexp 
 | STAR of rexp 
 | RECD of string * rexp

fun alt r1 r2 = ALTS [r1, r2]

datatype arexp =
   AZERO
 | AONE  of bits
 | ACHAR of bits * char
 | AALTS of bits * (arexp list)
 | ASEQ  of bits * arexp * arexp 
 | ASTAR of bits * arexp 

datatype value =
   Empty
 | Chr of char
 | Sequ of value * value
 | Left of value
 | Right of value
 | Stars of value list
 | Rec of string * value

(* some helper functions for strings *)   
fun string_repeat s n = String.concat (List.tabulate (n, fn _ => s))

(* some helper functions for rexps *)
fun seq s = case s of
    [] => ONE
  | [c] => CHAR(c)
  | c::cs => SEQ(CHAR(c), seq cs)

fun chr c = CHAR(c)

fun str s = seq(explode s)

fun plus r = SEQ(r, STAR(r))

infix 9 ++
infix 9 --
infix 9 $

fun op ++ (r1, r2) = ALTS [r1, r2]

fun op -- (r1, r2) = SEQ(r1, r2)

fun op $ (x, r) = RECD(x, r)

fun alts rs = case rs of 
    [] => ZERO
  | [r] => r
  | r::rs => ALTS([r, alts rs])


fun sum (nil) = 0 
  | sum (head::tail) = head + sum(tail);

(* size of a regular expressions - for testing purposes *)
fun size r = case r of
    ZERO => 1
  | ONE => 1
  | CHAR(_) => 1
  | ALTS(rs) => 1 + sum (map size rs)
  | SEQ(r1, r2) => 1 + (size r1) + (size r2)
  | STAR(r) => 1 + (size r)
  | RECD(_, r) => 1 + (size r)


fun erase r = case r of
    AZERO => ZERO
  | AONE(_) => ONE
  | ACHAR(_, c) => CHAR(c)
  | AALTS(_, rs) => ALTS(map erase rs)
  | ASEQ(_, r1, r2) => SEQ(erase r1, erase r2)
  | ASTAR(_, r)=> STAR(erase r)


fun fuse bs r = case r of
    AZERO => AZERO
  | AONE(cs) => AONE(bs @ cs)
  | ACHAR(cs, c) => ACHAR(bs @ cs, c)
  | AALTS(cs, rs) => AALTS(bs @ cs, rs)
  | ASEQ(cs, r1, r2) => ASEQ(bs @ cs, r1, r2)
  | ASTAR(cs, r) => ASTAR(bs @ cs, r)

fun internalise r = case r of
  ZERO => AZERO
| ONE => AONE([])
| CHAR(c) => ACHAR([], c)
| ALTS([r1, r2]) => AALTS([], [fuse [Z] (internalise r1), fuse [S] (internalise r2)])
| SEQ(r1, r2) => ASEQ([], internalise r1, internalise r2)
| STAR(r) => ASTAR([], internalise r)
| RECD(x, r) => internalise r

fun decode_aux r bs = case (r, bs) of
  (ONE, bs) => (Empty, bs)
| (CHAR(c), bs) => (Chr(c), bs)
| (ALTS([r1]), bs) => decode_aux r1 bs
| (ALTS(rs), Z::bs1) => 
     let val (v, bs2) = decode_aux (hd rs) bs1
     in (Left(v), bs2) end
| (ALTS(rs), S::bs1) => 
     let val (v, bs2) = decode_aux (ALTS (tl rs)) bs1
     in (Right(v), bs2) end
| (SEQ(r1, r2), bs) => 
    let val (v1, bs1) = decode_aux r1 bs 
        val (v2, bs2) = decode_aux r2 bs1 
    in (Sequ(v1, v2), bs2) end
| (STAR(r1), Z::bs) => 
    let val (v, bs1) = decode_aux r1 bs 
        val (Stars(vs), bs2) = decode_aux (STAR r1) bs1
    in (Stars(v::vs), bs2) end
| (STAR(_), S::bs) => (Stars [], bs)
| (RECD(x, r1), bs) => 
    let val (v, bs1) = decode_aux r1 bs
    in (Rec(x, v), bs1) end

exception DecodeError

fun decode r bs = case (decode_aux r bs) of
  (v, []) => v
| _ => raise DecodeError

fun bnullable r = case r of
  AZERO => false
| AONE(_) => true
| ACHAR(_, _) => false
| AALTS(_, rs) => List.exists bnullable rs
| ASEQ(_, r1, r2) => bnullable(r1) andalso bnullable(r2)
| ASTAR(_, _) => true

fun bmkeps r = case r of
  AONE(bs) => bs
| AALTS(bs, rs) => 
    let 
      val SOME(r) = List.find bnullable rs
    in bs @ bmkeps(r) end
| ASEQ(bs, r1, r2) => bs @ bmkeps(r1) @ bmkeps(r2)
| ASTAR(bs, r) => bs @ [S]

fun bder c r = case r of
  AZERO => AZERO
| AONE(_) => AZERO
| ACHAR(bs, d) => if c = d then AONE(bs) else AZERO
| AALTS(bs, rs) => AALTS(bs, map (bder c) rs)
| ASEQ(bs, r1, r2) => 
    if (bnullable r1) 
    then AALTS(bs, [ASEQ([], bder c r1, r2), fuse (bmkeps r1) (bder c r2)])
    else ASEQ(bs, bder c r1, r2)
| ASTAR(bs, r) => ASEQ(bs, fuse [Z] (bder c r), ASTAR([], r))

fun bders s r = case s of 
  [] => r
| c::s => bders s (bder c r)


exception LexError

fun blex r s = case s of 
    [] => if (bnullable r) then bmkeps r else raise LexError
  | c::cs => blex (bder c r) cs

fun blexing r s = decode r (blex (internalise r) (explode s))

(* Simplification *)

fun distinctBy xs f acc = case xs of
   [] => []
 | x::xs =>
    let   
      val res = f x
    in  if (List.exists (fn x => x = res) acc)
        then distinctBy xs f acc  
        else x::distinctBy xs f (res::acc)
    end

fun flats rs = case rs of
    [] => []
  | AZERO::rs1 => flats rs1
  | AALTS(bs, rs1)::rs2 => map (fuse bs) rs1 @ flats rs2
  | r1::rs2 => r1::flats rs2


fun stack r1 r2 = case r1 of
    AONE(bs2) => fuse bs2 r2
  | _ => ASEQ([], r1, r2)


fun bsimp r = case r of
    ASEQ(bs1, r1, r2) => (case (bsimp r1, bsimp r2) of
        (AZERO, _) => AZERO
      | (_, AZERO) => AZERO
      | (AONE(bs2), r2s) => fuse (bs1 @ bs2) r2s
      | (AALTS(bs2, rs), r2s) =>  
           AALTS(bs1 @ bs2, map (fn r => stack r r2s) rs)
      | (r1s, r2s) => ASEQ(bs1, r1s, r2s)) 
  | AALTS(bs1, rs) => (case distinctBy (flats (map bsimp rs)) erase [] of
        [] => AZERO
      | [r] => fuse bs1 r
      | rs2 => AALTS(bs1, rs2))  
  | r => r

fun bders_simp r s = case s of 
  [] => r
| c::s => bders_simp (bsimp (bder c r)) s

fun blex_simp r s = case s of 
    [] => if (bnullable r) then bmkeps r else raise LexError
  | c::cs => blex_simp (bsimp (bder c r)) cs

fun blexing_simp r s = 
    decode r (blex_simp (internalise r) (explode s))


(* Lexing rules for a small WHILE language *)
val sym = alts (List.map chr (explode "abcdefghijklmnopqrstuvwxyz"))
val digit = alts (List.map chr (explode "0123456789"))
val idents =  sym -- STAR(sym ++ digit)
val nums = plus(digit)
val keywords = alts (List.map str ["skip", "while", "do", "if", "then", "else", "read", "write", "true", "false"])
val semicolon = str ";"
val ops = alts (List.map str [":=", "==", "-", "+", "*", "!=", "<", ">", "<=", ">=", "%", "/"])
val whitespace = plus(str " " ++ str "\n" ++ str "\t")
val rparen = str ")"
val lparen = str "("
val begin_paren = str "{"
val end_paren = str "}"


val while_regs = STAR(("k" $ keywords) ++
                      ("i" $ idents) ++
                      ("o" $ ops) ++ 
                      ("n" $ nums) ++ 
                      ("s" $ semicolon) ++ 
                      ("p" $ (lparen ++ rparen)) ++ 
                      ("b" $ (begin_paren ++ end_paren)) ++ 
                      ("w" $ whitespace))



(* Some Tests
  ============ *)

fun time f x =
  let
  val t_start = Timer.startCPUTimer()
  val f_x = (f x; f x; f x; f x; f x)
  val t_end = Time.toReal(#usr(Timer.checkCPUTimer(t_start))) / 5.0
in
  (print ((Real.toString t_end) ^ "\n"); f_x)
end


val prog2 = String.concatWith "\n" 
  ["i := 2;",
   "max := 100;",
   "while i < max do {",
   "  isprime := 1;",
   "  j := 2;",
   "  while (j * j) <= i + 1  do {",
   "    if i % j == 0 then isprime := 0  else skip;",
   "    j := j + 1",
   "  };",
   " if isprime == 1 then write i else skip;",
   " i := i + 1",
   "}"];


(* loops in ML *)
datatype for = to of int * int
infix to 

val for =
  fn lo to up =>
    (fn f => 
       let fun loop lo = 
         if lo > up then () else (f lo; loop (lo + 1))
       in loop lo end)

fun forby n =
  fn lo to up =>
    (fn f => 
       let fun loop lo = 
         if lo > up then () else (f lo; loop (lo + n))
       in loop lo end)


fun step_simp i = 
  (print ((Int.toString i) ^ ": ") ;
   time (blexing_simp while_regs) (string_repeat prog2 i)); 

(*
val main1 = forby 1000 (1000 to 5000) step_simp;
print "\n";
val main2 = forby 1000 (1000 to 5000) step_simp2;
print "\n";
val main3 = forby 1000 (1000 to 5000) step_acc;
print "\n";
val main4 = forby 1000 (1000 to 5000) step_acc2; 
*)

print "\n";
val main5 = forby 10 (10 to 50) step_simp; 

print("Size after 50: " ^ 
  PolyML.makestring(size (erase (bders_simp (internalise while_regs) (explode (string_repeat prog2 50))))) ^ "\n");

print("Size after 100: " ^ 
  PolyML.makestring(size (erase (bders_simp (internalise while_regs) (explode (string_repeat prog2 100))))) ^ "\n");

