(*  Nominal Mutual Functions
    Author:  Christian Urban

    heavily based on the code of Alexander Krauss
    (code forked on 14 January 2011)

    Joachim Breitner helped with the auxiliary graph
    definitions (7 August 2012)

Mutual recursive nominal function definitions.
*)


signature NOMINAL_FUNCTION_MUTUAL =
sig

  val prepare_nominal_function_mutual : Nominal_Function_Common.nominal_function_config
    -> string (* defname *)
    -> ((string * typ) * mixfix) list
    -> term list
    -> local_theory
    -> ((thm (* goalstate *)
        * (thm -> Nominal_Function_Common.nominal_function_result) (* proof continuation *)
       ) * local_theory)

end


structure Nominal_Function_Mutual: NOMINAL_FUNCTION_MUTUAL =
struct

open Function_Lib
open Function_Common
open Nominal_Function_Common

type qgar = string * (string * typ) list * term list * term list * term

datatype mutual_part = MutualPart of
 {i : int,
  i' : int,
  fvar : string * typ,
  cargTs: typ list,
  f_def: term,
  f: term option,
  f_defthm : thm option}

datatype mutual_info = Mutual of
 {n : int,
  n' : int,
  fsum_var : string * typ,

  ST: typ,
  RST: typ,

  parts: mutual_part list,
  fqgars: qgar list,
  qglrs: ((string * typ) list * term list * term * term) list,

  fsum : term option}

fun mutual_induct_Pnames n =
  if n < 5 then fst (chop n ["P","Q","R","S"])
  else map (fn i => "P" ^ string_of_int i) (1 upto n)

fun get_part fname =
  the o find_first (fn (MutualPart {fvar=(n,_), ...}) => n = fname)

(* FIXME *)
fun mk_prod_abs e (t1, t2) =
  let
    val bTs = rev (map snd e)
    val T1 = fastype_of1 (bTs, t1)
    val T2 = fastype_of1 (bTs, t2)
  in
    HOLogic.pair_const T1 T2 $ t1 $ t2
  end

fun analyze_eqs ctxt defname fs eqs =
  let
    val num = length fs
    val fqgars = map (split_def ctxt (K true)) eqs
    val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars
      |> AList.lookup (op =) #> the

    fun curried_types (fname, fT) =
      let
        val (caTs, uaTs) = chop (arity_of fname) (binder_types fT)
      in
        (caTs, uaTs ---> body_type fT)
      end

    val (caTss, resultTs) = split_list (map curried_types fs)
    val argTs = map (foldr1 HOLogic.mk_prodT) caTss

    val dresultTs = distinct (op =) resultTs
    val n' = length dresultTs

    val RST = Balanced_Tree.make (uncurry SumTree.mk_sumT) dresultTs
    val ST = Balanced_Tree.make (uncurry SumTree.mk_sumT) argTs

    val fsum_type = ST --> RST

    val ([fsum_var_name], _) = Variable.add_fixes [ defname ^ "_sum" ] ctxt
    val fsum_var = (fsum_var_name, fsum_type)

    fun define (fvar as (n, _)) caTs resultT i =
      let
        val vars = map_index (fn (j,T) => Free ("x" ^ string_of_int j, T)) caTs (* FIXME: Bind xs properly *)
        val i' = find_index (fn Ta => Ta = resultT) dresultTs + 1

        val f_exp = SumTree.mk_proj RST n' i' (Free fsum_var $ SumTree.mk_inj ST num i (foldr1 HOLogic.mk_prod vars))
        val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp)

        val rew = (n, fold_rev lambda vars f_exp)
      in
        (MutualPart {i=i, i'=i', fvar=fvar,cargTs=caTs,f_def=def,f=NONE,f_defthm=NONE}, rew)
      end

    val (parts, rews) = split_list (map4 define fs caTss resultTs (1 upto num))

    fun convert_eqs (f, qs, gs, args, rhs) =
      let
        val MutualPart {i, i', ...} = get_part f parts
        val rhs' = rhs
             |> map_aterms (fn t as Free (n, _) => the_default t (AList.lookup (op =) rews n) | t => t)
      in
        (qs, gs, SumTree.mk_inj ST num i (foldr1 (mk_prod_abs qs) args),
         Envir.beta_norm (SumTree.mk_inj RST n' i' rhs'))
      end

    val qglrs = map convert_eqs fqgars
  in
    Mutual {n=num, n'=n', fsum_var=fsum_var, ST=ST, RST=RST,
      parts=parts, fqgars=fqgars, qglrs=qglrs, fsum=NONE}
  end

fun define_projections fixes mutual fsum lthy =
  let
    fun def ((MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs, f_def, ...}), (_, mixfix)) lthy =
      let
        val ((f, (_, f_defthm)), lthy') =
          Local_Theory.define
            ((Binding.name fname, mixfix),
              ((Binding.conceal (Binding.name (fname ^ "_def")), []),
              Term.subst_bound (fsum, f_def))) lthy
      in
        (MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs=cargTs, f_def=f_def,
           f=SOME f, f_defthm=SOME f_defthm },
         lthy')
      end

    val Mutual { n, n', fsum_var, ST, RST, parts, fqgars, qglrs, ... } = mutual
    val (parts', lthy') = fold_map def (parts ~~ fixes) lthy
  in
    (Mutual { n=n, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, parts=parts',
       fqgars=fqgars, qglrs=qglrs, fsum=SOME fsum },
     lthy')
  end

fun in_context ctxt (f, pre_qs, pre_gs, pre_args, pre_rhs) F =
  let
    val thy = Proof_Context.theory_of ctxt

    val oqnames = map fst pre_qs
    val (qs, _) = Variable.variant_fixes oqnames ctxt
      |>> map2 (fn (_, T) => fn n => Free (n, T)) pre_qs

    fun inst t = subst_bounds (rev qs, t)
    val gs = map inst pre_gs
    val args = map inst pre_args
    val rhs = inst pre_rhs

    val cqs = map (cterm_of thy) qs
    val ags = map (Thm.assume o cterm_of thy) gs

    val import = fold Thm.forall_elim cqs
      #> fold Thm.elim_implies ags

    val export = fold_rev (Thm.implies_intr o cprop_of) ags
      #> fold_rev forall_intr_rename (oqnames ~~ cqs)
  in
    F ctxt (f, qs, gs, args, rhs) import export
  end

fun recover_mutual_psimp all_orig_fdefs parts ctxt (fname, _, _, args, rhs)
  import (export : thm -> thm) sum_psimp_eq =
  let
    val (MutualPart {f=SOME f, ...}) = get_part fname parts
 
    val psimp = import sum_psimp_eq
    val (simp, restore_cond) =
      case cprems_of psimp of
        [] => (psimp, I)
      | [cond] => (Thm.implies_elim psimp (Thm.assume cond), Thm.implies_intr cond)
      | _ => raise General.Fail "Too many conditions"
  in
    Goal.prove ctxt [] []
      (HOLogic.Trueprop $ HOLogic.mk_eq (list_comb (f, args), rhs))
      (fn _ => (Local_Defs.unfold_tac ctxt all_orig_fdefs)
         THEN EqSubst.eqsubst_tac ctxt [0] [simp] 1
         THEN (simp_tac (simpset_of ctxt)) 1) (* FIXME: global simpset?!! *)
    |> restore_cond
    |> export
  end

val inl_perm = @{lemma "x = Inl y ==> Sum_Type.Projl (permute p x) = permute p (Sum_Type.Projl x)" by simp}
val inr_perm = @{lemma "x = Inr y ==> Sum_Type.Projr (permute p x) = permute p (Sum_Type.Projr x)" by simp}

fun recover_mutual_eqvt eqvt_thm all_orig_fdefs parts ctxt (fname, _, _, args, _)
  import (export : thm -> thm) sum_psimp_eq =
  let
    val (MutualPart {f=SOME f, ...}) = get_part fname parts
    
    val psimp = import sum_psimp_eq
    val (cond, simp, restore_cond) =
      case cprems_of psimp of
        [] => ([], psimp, I)
      | [cond] => ([Thm.assume cond], Thm.implies_elim psimp (Thm.assume cond), Thm.implies_intr cond)
      | _ => raise General.Fail "Too many conditions"

    val ([p], ctxt') = Variable.variant_fixes ["p"] ctxt		   
    val p = Free (p, @{typ perm})
    val ss = HOL_basic_ss addsimps 
      @{thms permute_sum.simps[symmetric] Pair_eqvt[symmetric]} @
      @{thms Projr.simps Projl.simps} @
      [(cond MRS eqvt_thm) RS @{thm sym}] @ 
      [inl_perm, inr_perm, simp] 
    val goal_lhs = mk_perm p (list_comb (f, args))
    val goal_rhs = list_comb (f, map (mk_perm p) args)
  in
    Goal.prove ctxt' [] [] (HOLogic.Trueprop $ HOLogic.mk_eq (goal_lhs, goal_rhs))
      (fn _ => (Local_Defs.unfold_tac ctxt all_orig_fdefs)
         THEN (asm_full_simp_tac ss 1))
    |> singleton (Proof_Context.export ctxt' ctxt)
    |> restore_cond
    |> export
  end

fun mk_applied_form ctxt caTs thm =
  let
    val thy = Proof_Context.theory_of ctxt
    val xs = map_index (fn (i,T) => cterm_of thy (Free ("x" ^ string_of_int i, T))) caTs (* FIXME: Bind xs properly *)
  in
    fold (fn x => fn thm => Thm.combination thm (Thm.reflexive x)) xs thm
    |> Conv.fconv_rule (Thm.beta_conversion true)
    |> fold_rev Thm.forall_intr xs
    |> Thm.forall_elim_vars 0
  end

fun mutual_induct_rules lthy induct all_f_defs (Mutual {n, ST, parts, ...}) =
  let
    val cert = cterm_of (Proof_Context.theory_of lthy)
    val newPs =
      map2 (fn Pname => fn MutualPart {cargTs, ...} =>
          Free (Pname, cargTs ---> HOLogic.boolT))
        (mutual_induct_Pnames (length parts)) parts

    fun mk_P (MutualPart {cargTs, ...}) P =
      let
        val avars = map_index (fn (i,T) => Var (("a", i), T)) cargTs
        val atup = foldr1 HOLogic.mk_prod avars
      in
        HOLogic.tupled_lambda atup (list_comb (P, avars))
      end

    val Ps = map2 mk_P parts newPs
    val case_exp = SumTree.mk_sumcases HOLogic.boolT Ps

    val induct_inst =
      Thm.forall_elim (cert case_exp) induct
      |> full_simplify SumTree.sumcase_split_ss
      |> full_simplify (HOL_basic_ss addsimps all_f_defs)

    fun project rule (MutualPart {cargTs, i, ...}) k =
      let
        val afs = map_index (fn (j,T) => Free ("a" ^ string_of_int (j + k), T)) cargTs (* FIXME! *)
        val inj = SumTree.mk_inj ST n i (foldr1 HOLogic.mk_prod afs)
      in
        (rule
         |> Thm.forall_elim (cert inj)
         |> full_simplify SumTree.sumcase_split_ss
         |> fold_rev (Thm.forall_intr o cert) (afs @ newPs),
         k + length cargTs)
      end
  in
    fst (fold_map (project induct_inst) parts 0)
  end


fun forall_elim s (Const ("all", _) $ Abs (_, _, t)) = subst_bound (s, t)
  | forall_elim _ t = t

val forall_elim_list = fold forall_elim

fun split_conj_thm th =
  (split_conj_thm (th RS conjunct1)) @ (split_conj_thm (th RS conjunct2)) handle THM _ => [th];

fun prove_eqvt ctxt fs argTss eqvts_thms induct_thms =
  let
    fun aux argTs s = argTs
      |> map (pair s)
      |> Variable.variant_frees ctxt fs
    val argss' = map2 aux argTss (Name.invent (Variable.names_of ctxt) "" (length fs)) 
    val argss = (map o map) Free argss'
    val arg_namess = (map o map) fst argss'
    val insts = (map o map) SOME arg_namess 
   
    val ([p_name], ctxt') = Variable.variant_fixes ["p"] ctxt
    val p = Free (p_name, @{typ perm})

    (* extracting the acc-premises from the induction theorems *)
    val acc_prems = 
     map prop_of induct_thms
     |> map2 forall_elim_list argss 
     |> map (strip_qnt_body "all")
     |> map (curry Logic.nth_prem 1)
     |> map HOLogic.dest_Trueprop

    fun mk_goal acc_prem (f, args) = 
      let
        val goal_lhs = mk_perm p (list_comb (f, args))
        val goal_rhs = list_comb (f, map (mk_perm p) args)
      in
        HOLogic.mk_imp (acc_prem, HOLogic.mk_eq (goal_lhs, goal_rhs))
      end

    val goal = fold_conj_balanced (map2 mk_goal acc_prems (fs ~~ argss))
      |> HOLogic.mk_Trueprop

    val induct_thm = case induct_thms of
        [thm] => thm
          |> Drule.gen_all 
          |> Thm.permute_prems 0 1
          |> (fn thm => atomize_rule (length (prems_of thm) - 1) thm)
      | thms => thms
          |> map Drule.gen_all 
          |> map (Rule_Cases.add_consumes 1)
          |> snd o Rule_Cases.strict_mutual_rule ctxt'
          |> atomize_concl

    fun tac thm = rtac (Drule.gen_all thm) THEN_ALL_NEW atac
  in
    Goal.prove ctxt' (flat arg_namess) [] goal
      (fn {context, ...} => HEADGOAL (DETERM o (rtac induct_thm) THEN' RANGE (map tac eqvts_thms)))
    |> singleton (Proof_Context.export ctxt' ctxt)
    |> split_conj_thm
    |> map (fn th => th RS mp)
  end

fun mk_partial_rules_mutual lthy inner_cont (m as Mutual {parts, fqgars, ...}) proof =
  let
    val result = inner_cont proof
    val NominalFunctionResult {G, R, cases, psimps, simple_pinducts=[simple_pinduct],
      termination, domintros, eqvts=[eqvt],...} = result

    val (all_f_defs, fs) =
      map (fn MutualPart {f_defthm = SOME f_def, f = SOME f, cargTs, ...} =>
          (mk_applied_form lthy cargTs (Thm.symmetric f_def), f))
      parts
      |> split_list

    val all_orig_fdefs =
      map (fn MutualPart {f_defthm = SOME f_def, ...} => f_def) parts

    val cargTss =
      map (fn MutualPart {f = SOME f, cargTs, ...} => cargTs) parts

    fun mk_mpsimp fqgar sum_psimp =
      in_context lthy fqgar (recover_mutual_psimp all_orig_fdefs parts) sum_psimp

    fun mk_meqvts fqgar sum_psimp =
      in_context lthy fqgar (recover_mutual_eqvt eqvt all_orig_fdefs parts) sum_psimp

    val rew_ss = HOL_basic_ss addsimps all_f_defs
    val mpsimps = map2 mk_mpsimp fqgars psimps
    val minducts = mutual_induct_rules lthy simple_pinduct all_f_defs m
    val mtermination = full_simplify rew_ss termination
    val mdomintros = Option.map (map (full_simplify rew_ss)) domintros
    val meqvts = map2 mk_meqvts fqgars psimps
    val meqvt_funs = prove_eqvt lthy fs cargTss meqvts minducts
 in
    NominalFunctionResult { fs=fs, G=G, R=R,
      psimps=mpsimps, simple_pinducts=minducts,
      cases=cases, termination=mtermination,
      domintros=mdomintros, eqvts=meqvt_funs }
  end

(* nominal *)
fun subst_all s (Q $ Abs(_, _, t)) = 
  let
    val vs = map Free (Term.add_frees s [])
  in
    fold Logic.all vs (subst_bound (s, t))
  end

fun mk_comp_dummy t s = Const (@{const_name comp}, dummyT) $ t $ s

fun all v t = 
  let
    val T = Term.fastype_of v
  in
    Logic.all_const T $ absdummy T (abstract_over (v, t)) 
  end

(* nominal *)
fun prepare_nominal_function_mutual config defname fixes eqss lthy =
  let
    val mutual as Mutual {fsum_var=(n, T), qglrs, ...} =
      analyze_eqs lthy defname (map fst fixes) (map Envir.beta_eta_contract eqss)

    val ((fsum, G, GIntro_thms, G_induct, goalstate, cont), lthy') =
      Nominal_Function_Core.prepare_nominal_function config defname [((n, T), NoSyn)] qglrs lthy

    val (mutual' as Mutual {n', parts, ST, RST, ...}, lthy'') = define_projections fixes mutual fsum lthy'

    val mutual_cont = mk_partial_rules_mutual lthy'' cont mutual'

    (* XXX *)

    (* defining the auxiliary graph *)
    fun mk_cases (MutualPart {i', fvar as (n, T), ...}) =
      let
        val (tys, ty) = strip_type T
        val fun_var = Free (n ^ "_aux", HOLogic.mk_tupleT tys --> ty)
        val inj_fun = absdummy dummyT (SumTree.mk_inj RST n' i' (Bound 0))
      in
        Syntax.check_term lthy'' (mk_comp_dummy inj_fun fun_var)
      end

    val sum_case_exp = map mk_cases parts
      |> SumTree.mk_sumcases RST 
   
    val (G_name, G_type) = dest_Free G 
    val G_name_aux = G_name ^ "_aux"
    val subst = [(G, Free (G_name_aux, G_type))]
    val GIntros_aux = GIntro_thms
      |> map prop_of
      |> map (Term.subst_free subst)
      |> map (subst_all sum_case_exp)

    val ((G_aux, GIntro_aux_thms, _, G_aux_induct), lthy''') = 
      Nominal_Function_Core.inductive_def ((Binding.name G_name_aux, G_type), NoSyn) GIntros_aux lthy''

    (* proof of equivalence between graph and auxiliary graph *)
    val x = Var(("x", 0), ST)
    val y = Var(("y", 1), RST)
    val G_aux_prem = HOLogic.mk_Trueprop (G_aux $ x $ y)
    val G_prem = HOLogic.mk_Trueprop (G $ x $ y)

    fun mk_inj_goal  (MutualPart {i', ...}) =
      let
        val injs = SumTree.mk_inj ST n' i' (Bound 0)
        val projs = y
          |> SumTree.mk_proj RST n' i'
          |> SumTree.mk_inj RST n' i'
      in
        Const (@{const_name "All"}, dummyT) $ absdummy dummyT
          (HOLogic.mk_imp (HOLogic.mk_eq(x, injs), HOLogic.mk_eq(projs, y)))
      end

    val goal_inj = Logic.mk_implies (G_aux_prem, 
      HOLogic.mk_Trueprop (fold_conj (map mk_inj_goal parts)))
      |> all x |> all y
      |> Syntax.check_term lthy'''
    val goal_iff1 = Logic.mk_implies (G_aux_prem, G_prem)
      |> all x |> all y
    val goal_iff2 = Logic.mk_implies (G_prem, G_aux_prem)
      |> all x |> all y

    val simp_thms = @{thms Projl.simps Projr.simps sum.inject sum.cases sum.distinct o_apply}
    val ss0 = HOL_basic_ss addsimps simp_thms
    val ss1 = HOL_ss addsimps simp_thms

    val inj_thm = Goal.prove lthy''' [] [] goal_inj 
      (K (HEADGOAL (DETERM o etac G_aux_induct THEN_ALL_NEW asm_simp_tac ss1)))

    fun aux_tac thm = 
      rtac (Drule.gen_all thm) THEN_ALL_NEW (asm_full_simp_tac (ss1 addsimps [inj_thm]))
    
    val iff1_thm = Goal.prove lthy''' [] [] goal_iff1 
      (K (HEADGOAL (DETERM o etac G_aux_induct THEN' RANGE (map aux_tac GIntro_thms))))
      |> Drule.gen_all
    val iff2_thm = Goal.prove lthy''' [] [] goal_iff2 
      (K (HEADGOAL (DETERM o etac G_induct THEN' RANGE (map (aux_tac o simplify ss0) GIntro_aux_thms))))
      |> Drule.gen_all

    val iff_thm = Goal.prove lthy''' [] [] (HOLogic.mk_Trueprop (HOLogic.mk_eq (G, G_aux)))
      (K (HEADGOAL (EVERY' ((map rtac @{thms ext ext iffI}) @ [etac iff2_thm, etac iff1_thm]))))
 
    val tac = HEADGOAL (simp_tac (HOL_basic_ss addsimps [iff_thm]))
    val goalstate' = 
      case (SINGLE tac) goalstate of
        NONE => error "auxiliary equivalence proof failed"
      | SOME st => st
  in
    ((goalstate', mutual_cont), lthy''')
  end

end
