// package zre 
//Zre5: eliminated mems table



import scala.collection.mutable.{Map => MMap}
import scala.collection.mutable.{ArrayBuffer => MList}
//import pprint._

import scala.util.Try



abstract class Val
case object Empty extends Val
case class Chr(c: Char) extends Val
case class Sequ(v1: Val, v2: Val) extends Val
case class Left(v: Val) extends Val
case class Right(v: Val) extends Val
case class Stars(vs: List[Val]) extends Val
case object DummyFilling extends Val


// abstract class Rexp {
//      def equals(other: Rexp) : Boolean = this.eq(other)
// }
abstract class Rexp
case object ZERO extends Rexp                    // matches nothing
case object ONE extends Rexp                     // matches an empty string
case class CHAR(c: Char) extends Rexp            // matches a character c
case class ALT(r1: Rexp, r2: Rexp) extends Rexp  // alternative
case class AL1(r1: Rexp) extends Rexp
case class SEQ(r1: Rexp, r2: Rexp) extends Rexp  // sequence
case class STAR(r: Rexp) extends Rexp
case object RTOP extends Rexp


//Seq a b --> Seq Seqa Seqb
//Seq a b --> Sequ chra chrb
//ALT r1 r2 --> mALT
//         AltC L   AltC R
var cyclicPreventionList : Set[Int]= Set()
abstract class Ctx   
case object RootC extends Ctx
case class SeqC( mForMyself:  Mem, processedSibling: List[Val], unpSibling: List[Rexp]) extends Ctx
case class AltC( mForMyself:  Mem, wrapper: Val => Val) extends Ctx
case class StarC(mForMyself:  Mem, vs: List[Val], inside: Rexp) extends Ctx

case class Mem(var parents: List[Ctx], var result : MList[(Int, Val)])

//AltC(Mem(RootC::Nil, Map()))



type Zipper = (Val, Mem)

var mems : MMap[(Int, Rexp), Mem] = MMap()
        //start pos, original regex --> result entry


var pos : Int = 0



//input ..................
//          ^       ^
//          p       q
//          r

//parse r[p...q] --> v

//(a+aa)*
//aaa
//[R(Sequ(a, a)), vs]
//[L(a), L(a), vs]
def check_before_down(c: Ctx, r: Rexp, d: Int = 0) : List[Zipper] = {
    mems.get((pos, r)) match {
        case Some(m) => 
            //m.parents = c::m.parents
            m.result.find(tup2 => tup2._1 == pos) match {
                // case Some((i, v)) => 
                //   original_up(v, c, d)
                case None => 
                  List()
            }
        case None => 
            val m = Mem(c::Nil, MList.empty[(Int, Val)])
            mems = mems + ((pos, r) -> m)
            original_down(r, m, d)
    }
}

//mems  pstart r  --> m parents [(pend, vres), ...]
//aaa
//012
//seq a a 
//0 a~a --> m ... [(2, Sequ a a)]


def mem_up(vres: Val, m: Mem, rec_depth : Int = 0) : List[Zipper] = {
    m.result += (pos -> vres)
    m.parents.flatMap((c: Ctx) =>
        original_up(vres, c, rec_depth)
    )
}

def original_down(r: Rexp, m: Mem, d: Int = 0) : List[Zipper] = (r, m) match {
    case (CHAR(b), m) => {
        if (input(pos) == b) {
            List((Chr(b), m)) 
        }
        else 
            Nil
    }
    case (ONE, m) => mem_up(Empty, m, d + 1)
    case (SEQ(r1, r2), m) =>
          
        val mprime = Mem(AltC(m, x => x )::Nil, MList.empty[(Int, Val)])
         
        check_before_down(SeqC(mprime, Nil, List(r2)), r1, d)
    case (ALT(r1, r2), m) => 
         
        check_before_down(AltC(m, Left(_)), r1, d) ::: 
        check_before_down(AltC(m, Right(_)), r2, d)
    case (STAR(r0), m) =>
         
        check_before_down(StarC(m, Nil, r0), r0, d)
    case (_, _) => throw new Exception("original down unexpected r or m")
}

def original_up(v: Val, c: Ctx, d: Int = 0) : List[Zipper] = 
{

(v, c) match {
    case (v, SeqC(m, v1::Nil, Nil)) => 
        mem_up(Sequ(v1, v), m, d + 1)
    case (v, SeqC(m, Nil, u1::Nil)) => 
        check_before_down(SeqC(m, v::Nil, Nil), u1, d)
    case (v, AltC(m, wrap)) => m.result.find(tup2 => tup2._1 == pos) match {
        case Some( (i, vPrime)  ) => 
            m.result += (i -> wrap(v))
            Nil
        case None => 
            mem_up(wrap(v), m, d + 1)
    } //mem_up(AL1(r), par)
    //case (v, StarC(m, vs, r0)) => throw new Exception("why not hit starC")

    case (v, RootC) => 
        Nil
    case (v, StarC(m, vs, r0) ) => mem_up(Stars(v::vs), m, d + 1) ::: 
        check_before_down(StarC(m, v::vs, r0), r0, d)
    case (_, _) => throw new Exception("hit unexpected context")
}

}


def derive(p: Int, z: Zipper) : List[Zipper] = {
    pos = p
    z match {
        case (v, m) => mem_up(v, m)
        case _ => throw new Exception("error")
    }
}
//let e' = Seq([]) in 
//
def init_zipper(r: Rexp) : Zipper = {
    val m_top = Mem(RootC::Nil, MList.empty[(Int, Val)])
    val c_top = SeqC(m_top, Nil, r::Nil)
    val m_r = Mem(c_top::Nil, MList.empty[(Int, Val)])
    println(s"initial zipper is (ZERO, $m_r)")
    (Empty, m_r)//TODO: which val should we start with? Maybe Empty, maybe doesn't matter
    // val dummyRexp = ONE
    // val dummyMem = Mem()

}


def plug_convert(v: Val, c: Ctx) : List[Val] = 
{

c match {
    case RootC => List(v)
    //TODO: non-binary Seq requires ps.rev
    case SeqC(m, ps::Nil, Nil) => 
        plug_mem(Sequ(ps, v), m)

    //TODO: un not nullable--partial values?
    case SeqC(m, Nil, un::Nil) => 
        if(nullable(un))
            plug_mem(Sequ(v, mkeps(un)), m)
        else
            Nil

    //TODO: when multiple results stored in m, which one to choose?
    case AltC(m, wrap) => 
        plug_mem(wrap(v), m)
    case StarC(m, vs, r0) => plug_mem(Stars(v::vs), m)
}

}


var cnt = 0;
def plug_mem(v: Val, m: Mem) : List[Val] = {
    m.result += (pos -> v)
    m.parents.flatMap({c =>
        plug_convert(v, c)
    }
    )
}

def plug_all(zs: List[Zipper]) : List[Val] = {
    zs.flatMap(z => plug_mem(z._1, z._2))
}


def mkeps(r: Rexp) : Val = r match {
  case ONE => Empty
  case ALT(r1, r2) => 
    if (nullable(r1)) Left(mkeps(r1)) else Right(mkeps(r2))
  case SEQ(r1, r2) => Sequ(mkeps(r1), mkeps(r2))
  case _ => DummyFilling
}

def nullable(r: Rexp) : Boolean = r match {
  case ZERO => false
  case ONE => true
  case CHAR(_) => false
  case ALT(r1, r2) => nullable(r1) || nullable(r2)
  case SEQ(r1, r2) => nullable(r1) && nullable(r2)
  case _ => false
}


val tokList : List[Char] = "aab".toList
var input : List[Char] = tokList







def lexRecurse(zs: List[Zipper], index: Int) : List[Zipper] = {
    if(index <  input.length )
        lexRecurse(zs.flatMap(z => derive(index, z) ), index + 1)
    else 
        zs
}

def lex(r: Rexp, s: String) : List[Zipper] = {
    input = s.toList
    
    lexRecurse(init_zipper(r)::Nil,  0)
}



implicit def charlist2rexp(s: List[Char]): Rexp = s match {
    case Nil => ONE
    case c::Nil => CHAR(c)
    case c::cs => SEQ(CHAR(c), charlist2rexp(cs))
}
implicit def string2Rexp(s: String) : Rexp = charlist2rexp(s.toList)

implicit def RexpOps(r: Rexp) = new {
    def | (s: Rexp) = ALT(r, s)
    def ~ (s: Rexp) = SEQ(r, s)
    def % = STAR(r)
}

implicit def stringOps(s: String) = new {
    def | (r: Rexp) = ALT(s, r)
    def | (r: String) = ALT(s, r)
    def ~ (r: Rexp) = SEQ(s, r)
    def ~ (r: String) = SEQ(s, r)
    def % = STAR(s)

}

//derive(0, init_zipper(re0))

// println(re1s.length)
// mems.foreach(a => println(a))
// val re1sPlugged = plug_all(re1s)
// re1sPlugged.foreach(zipper => {
//                         println(zipper); 
//                         println("delimit") 
//                         })
                
// mems.clear()
// println(mems)
// println(re0)
// val re2s = lex(re0, "aab")
// val re2sPlugged = plug_all(re2s)
// re2sPlugged.foreach(v => {
//         val Sequ(Empty, vp) = v
//         println(vp)
//     }
// )
// val re0 = SEQ(ALT(CHAR('a'), SEQ(CHAR('a'),CHAR('a'))), 
// ALT(SEQ(CHAR('a'), CHAR('b')), SEQ(CHAR('b'), CHAR('c')) )
// )

// val (rgraph, re0root) = makeGraphFromObject(re0)
// val asciir = GraphLayout.renderGraph(rgraph)
// println("printing out re0")
// println(asciir)
// val re1s = lex(re0, "aabc")

def actualZipperSize(zs: List[Zipper]) : Int = zs match {
    case Nil => 0
    case z::zs1 => countParents(z._2) + actualZipperSize(zs1)
}

def countParents(m: Mem) : Int = {
    m.parents.map(c => countGrandParents(c)).sum
}

def countGrandParents(c: Ctx) : Int = {
    c match {
        case RootC => 1
        case SeqC(m, pr, unp) => countParents(m)
        case AltC(m, w) => countParents(m)
        case StarC(m, _, _) => countParents(m)
    }
}

def zipperSimp(zs: List[Zipper]) : List[Zipper] = {
    zs.distinctBy(z => zipBackMem(z._2))
}

def zipBackToRegex(c: Ctx, r: Rexp = ONE) : List[Rexp] = {
    c match {
        case RootC => r::Nil
        case SeqC(m, pr, Nil) => zipBackMem(m, r)
        case SeqC(m, pr, unp::Nil) => zipBackMem(m, SEQ(r, unp))
        case AltC(m, w) => zipBackMem(m, r)
        case StarC(m, vs, r0) => zipBackMem(m, SEQ(r, STAR(r0)))
    }
}

def zipBackMem(m: Mem, r: Rexp = ONE) : List[Rexp] = {
    m.parents.flatMap(c => zipBackToRegex(c, r))
}

//def crystalizeZipper

mems.clear()
val re1 = ("a" | "aa").%
val re1ss = lex(re1, "aaaaa")

//drawZippers(re1ss)
println(actualZipperSize(re1ss))
//println(re1ss)
val re1S = zipperSimp(re1ss)
//println(actualZipperSize(re1S))


// val re2 = SEQ(ONE, "a")
// val re2res = lex(re2, "a")
// //lex(1~a, "a") --> lexRecurse((1v, m  (SeqC(m (RootC, Nil), Nil, [1~a] ) )))


// println(re2res)

// val re2resPlugged = plug_all(re2res)
// re2resPlugged.foreach(v => {
//         val Sequ(Empty, vp) = v
//         println(vp)
// }
// )

// println("remaining regex")
// println(re1ss.flatMap(z => zipBackMem(z._2)))


// val re1ssPlugged = plug_all(re1ss)
// println("each of the values")
// re1ssPlugged.foreach(v => {
//         //val Sequ(Empty, vp) = v
//         //println(vp)
//         println(v)
//     }
// )
// println(mems.size)
//println(mems)
//mems.map({case (ir, m) => if (ir._1 == 1 && ir._2 == CHAR('b')) println(printMem(m)) })
// println("Mkeps + inj:")
// println(
//     mems.get((0, re1)) match {
//         case Some(m) => printMem(m)
//         case None => ""
//     }
//     )

// println(re1sPlugged)
//drawZippers(re1s, plugOrNot = false)
// re1s.foreach{
//   re1 => 
//   {

//     drawZippers(derive(1, re1), plugOrNot = true)

//   }
// }


