(*  Title:      Nominal2_Eqvt
    Author:     Brian Huffman, 
    Author:     Christian Urban

    Equivariance, Supp and Fresh Lemmas for Operators. 
    (Contains many, but not all such lemmas.)
*)
theory Nominal2_Eqvt
imports Nominal2_Base Nominal2_Atoms
uses ("nominal_thmdecls.ML")
     ("nominal_permeq.ML")
begin


section {* Logical Operators *}

lemma eq_eqvt:
  shows "p \<bullet> (x = y) \<longleftrightarrow> (p \<bullet> x) = (p \<bullet> y)"
  unfolding permute_eq_iff permute_bool_def ..

lemma if_eqvt:
  shows "p \<bullet> (if b then x else y) = (if p \<bullet> b then p \<bullet> x else p \<bullet> y)"
  by (simp add: permute_fun_def permute_bool_def)

lemma True_eqvt:
  shows "p \<bullet> True = True"
  unfolding permute_bool_def ..

lemma False_eqvt:
  shows "p \<bullet> False = False"
  unfolding permute_bool_def ..

lemma imp_eqvt:
  shows "p \<bullet> (A \<longrightarrow> B) = ((p \<bullet> A) \<longrightarrow> (p \<bullet> B))"
  by (simp add: permute_bool_def)

lemma conj_eqvt:
  shows "p \<bullet> (A \<and> B) = ((p \<bullet> A) \<and> (p \<bullet> B))"
  by (simp add: permute_bool_def)

lemma disj_eqvt:
  shows "p \<bullet> (A \<or> B) = ((p \<bullet> A) \<or> (p \<bullet> B))"
  by (simp add: permute_bool_def)

lemma Not_eqvt:
  shows "p \<bullet> (\<not> A) = (\<not> (p \<bullet> A))"
  by (simp add: permute_bool_def)

lemma all_eqvt:
  shows "p \<bullet> (\<forall>x. P x) = (\<forall>x. (p \<bullet> P) x)"
  unfolding permute_fun_def permute_bool_def
  by (auto, drule_tac x="p \<bullet> x" in spec, simp)

lemma all_eqvt2:
  shows "p \<bullet> (\<forall>x. P x) = (\<forall>x. p \<bullet> P (- p \<bullet> x))"
  unfolding permute_fun_def permute_bool_def
  by (auto, drule_tac x="p \<bullet> x" in spec, simp)

lemma ex_eqvt:
  shows "p \<bullet> (\<exists>x. P x) = (\<exists>x. (p \<bullet> P) x)"
  unfolding permute_fun_def permute_bool_def
  by (auto, rule_tac x="p \<bullet> x" in exI, simp)

lemma ex_eqvt2:
  shows "p \<bullet> (\<exists>x. P x) = (\<exists>x. p \<bullet> P (- p \<bullet> x))"
  unfolding permute_fun_def permute_bool_def
  by (auto, rule_tac x="p \<bullet> x" in exI, simp)

lemma ex1_eqvt:
  shows "p \<bullet> (\<exists>!x. P x) = (\<exists>!x. (p \<bullet> P) x)"
  unfolding Ex1_def 
  by (simp add: ex_eqvt permute_fun_def conj_eqvt all_eqvt imp_eqvt eq_eqvt)

lemma ex1_eqvt2:
  shows "p \<bullet> (\<exists>!x. P x) = (\<exists>!x. p \<bullet> P (- p \<bullet> x))"
  unfolding Ex1_def ex_eqvt2 conj_eqvt all_eqvt2 imp_eqvt eq_eqvt
  by simp

lemma the_eqvt:
  assumes unique: "\<exists>!x. P x"
  shows "(p \<bullet> (THE x. P x)) = (THE x. p \<bullet> P (- p \<bullet> x))"
  apply(rule the1_equality [symmetric])
  apply(simp add: ex1_eqvt2[symmetric])
  apply(simp add: permute_bool_def unique)
  apply(simp add: permute_bool_def)
  apply(rule theI'[OF unique])
  done

section {* Set Operations *}

lemma mem_permute_iff:
  shows "(p \<bullet> x) \<in> (p \<bullet> X) \<longleftrightarrow> x \<in> X"
unfolding mem_def permute_fun_def permute_bool_def
by simp

lemma mem_eqvt:
  shows "p \<bullet> (x \<in> A) \<longleftrightarrow> (p \<bullet> x) \<in> (p \<bullet> A)"
  unfolding mem_permute_iff permute_bool_def by simp

lemma not_mem_eqvt:
  shows "p \<bullet> (x \<notin> A) \<longleftrightarrow> (p \<bullet> x) \<notin> (p \<bullet> A)"
  unfolding mem_def permute_fun_def by (simp add: Not_eqvt)

lemma Collect_eqvt:
  shows "p \<bullet> {x. P x} = {x. (p \<bullet> P) x}"
  unfolding Collect_def permute_fun_def ..

lemma Collect_eqvt2:
  shows "p \<bullet> {x. P x} = {x. p \<bullet> (P (-p \<bullet> x))}"
  unfolding Collect_def permute_fun_def ..

lemma empty_eqvt:
  shows "p \<bullet> {} = {}"
  unfolding empty_def Collect_eqvt2 False_eqvt ..

lemma supp_set_empty:
  shows "supp {} = {}"
  by (simp add: supp_def empty_eqvt)

lemma fresh_set_empty:
  shows "a \<sharp> {}"
  by (simp add: fresh_def supp_set_empty)

lemma UNIV_eqvt:
  shows "p \<bullet> UNIV = UNIV"
  unfolding UNIV_def Collect_eqvt2 True_eqvt ..

lemma union_eqvt:
  shows "p \<bullet> (A \<union> B) = (p \<bullet> A) \<union> (p \<bullet> B)"
  unfolding Un_def Collect_eqvt2 disj_eqvt mem_eqvt by simp

lemma inter_eqvt:
  shows "p \<bullet> (A \<inter> B) = (p \<bullet> A) \<inter> (p \<bullet> B)"
  unfolding Int_def Collect_eqvt2 conj_eqvt mem_eqvt by simp

lemma Diff_eqvt:
  fixes A B :: "'a::pt set"
  shows "p \<bullet> (A - B) = p \<bullet> A - p \<bullet> B"
  unfolding set_diff_eq Collect_eqvt2 conj_eqvt Not_eqvt mem_eqvt by simp

lemma Compl_eqvt:
  fixes A :: "'a::pt set"
  shows "p \<bullet> (- A) = - (p \<bullet> A)"
  unfolding Compl_eq_Diff_UNIV Diff_eqvt UNIV_eqvt ..

lemma insert_eqvt:
  shows "p \<bullet> (insert x A) = insert (p \<bullet> x) (p \<bullet> A)"
  unfolding permute_set_eq_image image_insert ..

lemma vimage_eqvt:
  shows "p \<bullet> (f -` A) = (p \<bullet> f) -` (p \<bullet> A)"
  unfolding vimage_def permute_fun_def [where f=f]
  unfolding Collect_eqvt2 mem_eqvt ..

lemma image_eqvt:
  shows "p \<bullet> (f ` A) = (p \<bullet> f) ` (p \<bullet> A)"
  unfolding permute_set_eq_image
  unfolding permute_fun_def [where f=f]
  by (simp add: image_image)

lemma finite_permute_iff:
  shows "finite (p \<bullet> A) \<longleftrightarrow> finite A"
  unfolding permute_set_eq_vimage
  using bij_permute by (rule finite_vimage_iff)

lemma finite_eqvt:
  shows "p \<bullet> finite A = finite (p \<bullet> A)"
  unfolding finite_permute_iff permute_bool_def ..


section {* List Operations *}

lemma append_eqvt:
  shows "p \<bullet> (xs @ ys) = (p \<bullet> xs) @ (p \<bullet> ys)"
  by (induct xs) auto

lemma supp_append:
  shows "supp (xs @ ys) = supp xs \<union> supp ys"
  by (induct xs) (auto simp add: supp_Nil supp_Cons)

lemma fresh_append:
  shows "a \<sharp> (xs @ ys) \<longleftrightarrow> a \<sharp> xs \<and> a \<sharp> ys"
  by (induct xs) (simp_all add: fresh_Nil fresh_Cons)

lemma rev_eqvt:
  shows "p \<bullet> (rev xs) = rev (p \<bullet> xs)"
  by (induct xs) (simp_all add: append_eqvt)

lemma supp_rev:
  shows "supp (rev xs) = supp xs"
  by (induct xs) (auto simp add: supp_append supp_Cons supp_Nil)

lemma fresh_rev:
  shows "a \<sharp> rev xs \<longleftrightarrow> a \<sharp> xs"
  by (induct xs) (auto simp add: fresh_append fresh_Cons fresh_Nil)

lemma set_eqvt:
  shows "p \<bullet> (set xs) = set (p \<bullet> xs)"
  by (induct xs) (simp_all add: empty_eqvt insert_eqvt)

(* needs finite support premise
lemma supp_set:
  fixes x :: "'a::pt"
  shows "supp (set xs) = supp xs"
*)

lemma map_eqvt: 
  shows "p \<bullet> (map f xs) = map (p \<bullet> f) (p \<bullet> xs)"
  by (induct xs) (simp_all, simp only: permute_fun_app_eq)

section {* Product Operations *}

lemma fst_eqvt:
  "p \<bullet> (fst x) = fst (p \<bullet> x)"
 by (cases x) simp

lemma snd_eqvt:
  "p \<bullet> (snd x) = snd (p \<bullet> x)"
 by (cases x) simp

section {* Units *}

lemma supp_unit:
  shows "supp () = {}"
  by (simp add: supp_def)

lemma fresh_unit:
  shows "a \<sharp> ()"
  by (simp add: fresh_def supp_unit)

lemma permute_eqvt_raw:
  shows "p \<bullet> permute = permute"
apply(simp add: expand_fun_eq permute_fun_def)
apply(subst permute_eqvt)
apply(simp)
done

section {* Equivariance automation *}

text {* Setup of the theorem attributes @{text eqvt} and @{text eqvt_force} *}

use "nominal_thmdecls.ML"
setup "Nominal_ThmDecls.setup"

lemmas [eqvt] = 
  (* connectives *)
  eq_eqvt if_eqvt imp_eqvt disj_eqvt conj_eqvt Not_eqvt 
  True_eqvt False_eqvt ex_eqvt all_eqvt ex1_eqvt
  imp_eqvt [folded induct_implies_def]

  (* nominal *)
  supp_eqvt fresh_eqvt

  (* datatypes *)
  permute_prod.simps append_eqvt rev_eqvt set_eqvt
  map_eqvt fst_eqvt snd_eqvt Pair_eqvt permute_list.simps

  (* sets *)
  empty_eqvt UNIV_eqvt union_eqvt inter_eqvt mem_eqvt
  Diff_eqvt Compl_eqvt insert_eqvt Collect_eqvt image_eqvt

  atom_eqvt add_perm_eqvt

lemmas [eqvt_raw] =
  permute_eqvt_raw[THEN eq_reflection] (* the normal version of this lemma loops *)

text {* helper lemmas for the eqvt_tac *}

definition
  "unpermute p = permute (- p)"

lemma eqvt_apply:
  fixes f :: "'a::pt \<Rightarrow> 'b::pt" 
  and x :: "'a::pt"
  shows "p \<bullet> (f x) \<equiv> (p \<bullet> f) (p \<bullet> x)"
  unfolding permute_fun_def by simp

lemma eqvt_lambda:
  fixes f :: "'a::pt \<Rightarrow> 'b::pt"
  shows "p \<bullet> (\<lambda>x. f x) \<equiv> (\<lambda>x. p \<bullet> (f (unpermute p x)))"
  unfolding permute_fun_def unpermute_def by simp

lemma eqvt_bound:
  shows "p \<bullet> unpermute p x \<equiv> x"
  unfolding unpermute_def by simp

use "nominal_permeq.ML"
setup Nominal_Permeq.setup

method_setup perm_simp =
 {* Attrib.thms >> 
    (fn thms => fn ctxt => SIMPLE_METHOD (HEADGOAL (Nominal_Permeq.eqvt_tac ctxt thms ["The"]))) *}
 {* pushes permutations inside *}

method_setup perm_strict_simp =
 {* Attrib.thms >> 
    (fn thms => fn ctxt => SIMPLE_METHOD (HEADGOAL (Nominal_Permeq.eqvt_strict_tac ctxt thms ["The"]))) *}
 {* pushes permutations inside, raises an error if it cannot solve all permutations *}

declare [[trace_eqvt = true]]

lemma 
  fixes B::"'a::pt"
  shows "p \<bullet> (B = C)"
apply(perm_simp)
oops

lemma 
  fixes B::"bool"
  shows "p \<bullet> (B = C)"
apply(perm_simp)
oops

lemma 
  fixes B::"bool"
  shows "p \<bullet> (A \<longrightarrow> B = C)"
apply (perm_simp) 
oops

lemma 
  shows "p \<bullet> (\<lambda>(x::'a::pt). A \<longrightarrow> (B::'a \<Rightarrow> bool) x = C) = foo"
apply(perm_simp)
oops

lemma 
  shows "p \<bullet> (\<lambda>B::bool. A \<longrightarrow> (B = C)) = foo"
apply (perm_simp)
oops

lemma 
  shows "p \<bullet> (\<lambda>x y. \<exists>z. x = z \<and> x = y \<longrightarrow> z \<noteq> x) = foo"
apply (perm_simp)
oops

lemma 
  shows "p \<bullet> (\<lambda>f x. f (g (f x))) = foo"
apply (perm_simp)
oops

lemma 
  fixes p q::"perm"
  and   x::"'a::pt"
  shows "p \<bullet> (q \<bullet> x) = foo"
apply(perm_simp)
oops

lemma 
  fixes p q r::"perm"
  and   x::"'a::pt"
  shows "p \<bullet> (q \<bullet> r \<bullet> x) = foo"
apply(perm_simp)
oops

lemma 
  fixes p r::"perm"
  shows "p \<bullet> (\<lambda>q::perm. q \<bullet> (r \<bullet> x)) = foo"
apply (perm_simp)
oops

lemma 
  fixes C D::"bool"
  shows "B (p \<bullet> (C = D))"
apply(perm_simp)
oops

declare [[trace_eqvt = false]]

text {* Problem: there is no raw eqvt-rule for The *}
lemma "p \<bullet> (THE x. P x) = foo"
apply(perm_simp)
(* apply(perm_strict_simp) *)
oops

atom_decl var

ML {*
val inductive_atomize = @{thms induct_atomize};

val atomize_conv =
  MetaSimplifier.rewrite_cterm (true, false, false) (K (K NONE))
    (HOL_basic_ss addsimps inductive_atomize);
val atomize_intr = Conv.fconv_rule (Conv.prems_conv ~1 atomize_conv);
fun atomize_induct ctxt = Conv.fconv_rule (Conv.prems_conv ~1
  (Conv.params_conv ~1 (K (Conv.prems_conv ~1 atomize_conv)) ctxt));

fun map_term f t u = (case f t u of
      NONE => map_term' f t u | x => x)
and map_term' f (t $ u) (t' $ u') = (case (map_term f t t', map_term f u u') of
      (NONE, NONE) => NONE
    | (SOME t'', NONE) => SOME (t'' $ u)
    | (NONE, SOME u'') => SOME (t $ u'')
    | (SOME t'', SOME u'') => SOME (t'' $ u''))
  | map_term' f (Abs (s, T, t)) (Abs (s', T', t')) = (case map_term f t t' of
      NONE => NONE
    | SOME t'' => SOME (Abs (s, T, t'')))
  | map_term' _ _ _ = NONE;


fun map_thm ctxt f tac monos opt th =
  let
    val prop = prop_of th;
    fun prove t =
      Goal.prove ctxt [] [] t (fn _ =>
        EVERY [cut_facts_tac [th] 1, etac rev_mp 1,
          REPEAT_DETERM (FIRSTGOAL (resolve_tac monos)),
          REPEAT_DETERM (rtac impI 1 THEN (atac 1 ORELSE tac))])
  in Option.map prove (map_term f prop (the_default prop opt)) end;

fun split_conj f names (Const ("op &", _) $ p $ q) _ = (case head_of p of
      Const (name, _) =>
        if name mem names then SOME (f p q) else NONE
    | _ => NONE)
  | split_conj _ _ _ _ = NONE;
*}

ML {*
val perm_bool = @{thm "permute_bool_def"};
val perm_boolI = @{thm "permute_boolI"};
val (_, [perm_boolI_pi, _]) = Drule.strip_comb (snd (Thm.dest_comb
  (Drule.strip_imp_concl (cprop_of perm_boolI))));

fun mk_perm_bool pi th = th RS Drule.cterm_instantiate
  [(perm_boolI_pi, pi)] perm_boolI;

*}

ML {*
fun transp ([] :: _) = []
  | transp xs = map hd xs :: transp (map tl xs);

fun prove_eqvt s xatoms ctxt =
  let
    val thy = ProofContext.theory_of ctxt;
    val ({names, ...}, {raw_induct, intrs, elims, ...}) =
      Inductive.the_inductive ctxt (Sign.intern_const thy s);
    val raw_induct = atomize_induct ctxt raw_induct;
    val elims = map (atomize_induct ctxt) elims;
    val intrs = map atomize_intr intrs;
    val monos = Inductive.get_monos ctxt;
    val intrs' = Inductive.unpartition_rules intrs
      (map (fn (((s, ths), (_, k)), th) =>
           (s, ths ~~ Inductive.infer_intro_vars th k ths))
         (Inductive.partition_rules raw_induct intrs ~~
          Inductive.arities_of raw_induct ~~ elims));
    val k = length (Inductive.params_of raw_induct);
    val atoms' = ["var"];
    val atoms =
      if null xatoms then atoms' else
      let val atoms = map (Sign.intern_type thy) xatoms
      in
        (case duplicates op = atoms of
             [] => ()
           | xs => error ("Duplicate atoms: " ^ commas xs);
         case subtract (op =) atoms' atoms of
             [] => ()
           | xs => error ("No such atoms: " ^ commas xs);
         atoms)
      end;
    val perm_pi_simp = PureThy.get_thms thy "perm_pi_simp";
    val eqvt_ss = Simplifier.global_context thy HOL_basic_ss addsimps
      (Nominal_ThmDecls.get_eqvts_thms ctxt @ perm_pi_simp);
    val (([t], [pi]), ctxt') = ctxt |>
      Variable.import_terms false [concl_of raw_induct] ||>>
      Variable.variant_fixes ["pi"];
    val ps = map (fst o HOLogic.dest_imp)
      (HOLogic.dest_conj (HOLogic.dest_Trueprop t));
    fun eqvt_tac ctxt'' pi (intr, vs) st =
      let
        fun eqvt_err s =
          let val ([t], ctxt''') = Variable.import_terms true [prop_of intr] ctxt
          in error ("Could not prove equivariance for introduction rule\n" ^
            Syntax.string_of_term ctxt''' t ^ "\n" ^ s)
          end;
        val res = SUBPROOF (fn {prems, params, ...} =>
          let
            val prems' = map (fn th => the_default th (map_thm ctxt'
              (split_conj (K I) names) (etac conjunct2 1) monos NONE th)) prems;
            val prems'' = map (fn th => Simplifier.simplify eqvt_ss
              (mk_perm_bool (cterm_of thy pi) th)) prems';
            val intr' = intr 
          in (rtac intr' THEN_ALL_NEW (TRY o resolve_tac prems'')) 1
          end) ctxt' 1 st
      in
        case (Seq.pull res handle THM (s, _, _) => eqvt_err s) of
          NONE => eqvt_err ("Rule does not match goal\n" ^
            Syntax.string_of_term ctxt'' (hd (prems_of st)))
        | SOME (th, _) => Seq.single th
      end;
    val thss = map (fn atom =>
      let val pi' = Free (pi, @{typ perm})
      in map (fn th => zero_var_indexes (th RS mp))
        (Datatype_Aux.split_conj_thm (Goal.prove ctxt' [] []
          (HOLogic.mk_Trueprop (foldr1 HOLogic.mk_conj (map (fn p =>
            let
              val (h, ts) = strip_comb p;
              val (ts1, ts2) = chop k ts
            in
              HOLogic.mk_imp (p, list_comb (h, ts1))
            end) ps)))
          (fn {context, ...} => EVERY (rtac raw_induct 1 :: map (fn intr_vs =>
              full_simp_tac eqvt_ss 1 THEN
              eqvt_tac context pi' intr_vs) intrs')) |>
          singleton (ProofContext.export ctxt' ctxt)))
      end) atoms
  in
    ctxt |>
    Local_Theory.notes (map (fn (name, ths) =>
        ((Binding.qualified_name (Long_Name.qualify (Long_Name.base_name name) "eqvt"),
          [Attrib.internal (K Nominal_ThmDecls.eqvt_add)]), [(ths, [])]))
      (names ~~ transp thss)) |> snd
  end;
*}

end
