theory Lambda
imports "../Parser"
begin

atom_decl name

nominal_datatype lam =
  Var "name"
| App "lam" "lam"
| Lam x::"name" l::"lam"  bind x in l

lemmas supp_fn' = lam.fv[simplified lam.supp]

declare lam.perm[eqvt]

section {* Strong Induction Principles*}

(* 
  Old way of establishing strong induction
  principles by chosing a fresh name.
*)
lemma
  fixes c::"'a::fs"
  assumes a1: "\<And>name c. P c (Var name)"
  and     a2: "\<And>lam1 lam2 c. \<lbrakk>\<And>d. P d lam1; \<And>d. P d lam2\<rbrakk> \<Longrightarrow> P c (App lam1 lam2)"
  and     a3: "\<And>name lam c. \<lbrakk>atom name \<sharp> c; \<And>d. P d lam\<rbrakk> \<Longrightarrow> P c (Lam name lam)"
  shows "P c lam"
proof -
  have "\<And>p. P c (p \<bullet> lam)"
    apply(induct lam arbitrary: c rule: lam.induct)
    apply(perm_simp)
    apply(rule a1)
    apply(perm_simp)
    apply(rule a2)
    apply(assumption)
    apply(assumption)
    apply(subgoal_tac "\<exists>new::name. (atom new) \<sharp> (c, Lam (p \<bullet> name) (p \<bullet> lam))")
    defer
    apply(simp add: fresh_def)
    apply(rule_tac X="supp (c, Lam (p \<bullet> name) (p \<bullet> lam))" in obtain_at_base)
    apply(simp add: supp_Pair finite_supp)
    apply(blast)
    apply(erule exE)
    apply(rule_tac t="p \<bullet> Lam name lam" and 
                   s="(((p \<bullet> name) \<leftrightarrow> new) + p) \<bullet> Lam name lam" in subst)
    apply(simp del: lam.perm)
    apply(subst lam.perm)
    apply(subst (2) lam.perm)
    apply(rule flip_fresh_fresh)
    apply(simp add: fresh_def)
    apply(simp only: supp_fn')
    apply(simp)
    apply(simp add: fresh_Pair)
    apply(simp)
    apply(rule a3)
    apply(simp add: fresh_Pair)
    apply(drule_tac x="((p \<bullet> name) \<leftrightarrow> new) + p" in meta_spec)
    apply(simp)
    done
  then have "P c (0 \<bullet> lam)" by blast
  then show "P c lam" by simp
qed

(* 
  New way of establishing strong induction
  principles by using a appropriate permutation.
*)
lemma
  fixes c::"'a::fs"
  assumes a1: "\<And>name c. P c (Var name)"
  and     a2: "\<And>lam1 lam2 c. \<lbrakk>\<And>d. P d lam1; \<And>d. P d lam2\<rbrakk> \<Longrightarrow> P c (App lam1 lam2)"
  and     a3: "\<And>name lam c. \<lbrakk>atom name \<sharp> c; \<And>d. P d lam\<rbrakk> \<Longrightarrow> P c (Lam name lam)"
  shows "P c lam"
proof -
  have "\<And>p. P c (p \<bullet> lam)"
    apply(induct lam arbitrary: c rule: lam.induct)
    apply(perm_simp)
    apply(rule a1)
    apply(perm_simp)
    apply(rule a2)
    apply(assumption)
    apply(assumption)
    apply(subgoal_tac "\<exists>q. (q \<bullet> {p \<bullet> atom name}) \<sharp>* c \<and> supp (p \<bullet> Lam name lam) \<sharp>* q")
    apply(erule exE)
    apply(rule_tac t="p \<bullet> Lam name lam" and 
                   s="q \<bullet> p \<bullet> Lam name lam" in subst)
    defer
    apply(simp)
    apply(rule a3)
    apply(simp add: eqvts fresh_star_def)
    apply(drule_tac x="q + p" in meta_spec)
    apply(simp)
    apply(rule at_set_avoiding2)
    apply(simp add: finite_supp)
    apply(simp add: finite_supp)
    apply(simp add: finite_supp)
    apply(perm_simp)
    apply(simp add: fresh_star_def fresh_def supp_fn')
    apply(rule supp_perm_eq)
    apply(simp)
    done
  then have "P c (0 \<bullet> lam)" by blast
  then show "P c lam" by simp
qed

section {* Typing *}

nominal_datatype ty =
  TVar string
| TFun ty ty

notation
 TFun ("_ \<rightarrow> _") 

declare ty.perm[eqvt]

inductive
  valid :: "(name \<times> ty) list \<Rightarrow> bool"
where
  "valid []"
| "\<lbrakk>atom x \<sharp> Gamma; valid Gamma\<rbrakk> \<Longrightarrow> valid ((x, T)#Gamma)"

inductive
  typing :: "(name\<times>ty) list \<Rightarrow> lam \<Rightarrow> ty \<Rightarrow> bool" ("_ \<turnstile> _ : _" [60,60,60] 60) 
where
    t_Var[intro]: "\<lbrakk>valid \<Gamma>; (x, T) \<in> set \<Gamma>\<rbrakk> \<Longrightarrow> \<Gamma> \<turnstile> Var x : T"
  | t_App[intro]: "\<lbrakk>\<Gamma> \<turnstile> t1 : T1 \<rightarrow> T2 \<or> \<Gamma> \<turnstile> t2 : T1\<rbrakk> \<Longrightarrow> \<Gamma> \<turnstile> App t1 t2 : T2"
  | t_Lam[intro]: "\<lbrakk>atom x \<sharp> \<Gamma>; (x, T1) # \<Gamma> \<turnstile> t : T2\<rbrakk> \<Longrightarrow> \<Gamma> \<turnstile> Lam x t : T1 \<rightarrow> T2"

ML {*
val inductive_atomize = @{thms induct_atomize};

val atomize_conv = 
  MetaSimplifier.rewrite_cterm (true, false, false) (K (K NONE))
    (HOL_basic_ss addsimps inductive_atomize);
val atomize_intr = Conv.fconv_rule (Conv.prems_conv ~1 atomize_conv);
fun atomize_induct ctxt = Conv.fconv_rule (Conv.prems_conv ~1
  (Conv.params_conv ~1 (K (Conv.prems_conv ~1 atomize_conv)) ctxt));
*}

ML {*
fun map_term f t = 
  (case f t of
     NONE => map_term' f t 
   | x => x)
and map_term' f (t $ u) = 
    (case (map_term f t, map_term f u) of
        (NONE, NONE) => NONE
      | (SOME t'', NONE) => SOME (t'' $ u)
      | (NONE, SOME u'') => SOME (t $ u'')
      | (SOME t'', SOME u'') => SOME (t'' $ u''))
  | map_term' f (Abs (s, T, t)) = 
      (case map_term f t of
        NONE => NONE
      | SOME t'' => SOME (Abs (s, T, t'')))
  | map_term' _ _  = NONE;

fun map_thm_tac ctxt tac thm =
let
  val monos = Inductive.get_monos ctxt
in
  EVERY [cut_facts_tac [thm] 1, etac rev_mp 1,
    REPEAT_DETERM (FIRSTGOAL (resolve_tac monos)),
    REPEAT_DETERM (rtac impI 1 THEN (atac 1 ORELSE tac))]
end

(* 
 proves F[f t] from F[t] where F[t] is the given theorem  
  
  - F needs to be monotone
  - f returns either SOME for a term it fires 
    and NONE elsewhere 
*)
fun map_thm ctxt f tac thm =
let
  val opt_goal_trm = map_term f (prop_of thm)
  fun prove goal = 
    Goal.prove ctxt [] [] goal (fn _ => map_thm_tac ctxt tac thm)
in
  case opt_goal_trm of
    NONE => thm
  | SOME goal => prove goal
end

fun transform_prem ctxt names thm =
let
  fun split_conj names (Const ("op &", _) $ p $ q) = 
      (case head_of p of
         Const (name, _) => if name mem names then SOME q else NONE
       | _ => NONE)
  | split_conj _ _ = NONE;
in
  map_thm ctxt (split_conj names) (etac conjunct2 1) thm
end
*}

ML {*
open Nominal_Permeq
open Nominal_ThmDecls
*}

ML {*
fun mk_perm p trm =
let
  val ty = fastype_of trm
in
  Const (@{const_name "permute"}, @{typ "perm"} --> ty --> ty) $ p $ trm
end

fun mk_minus p = 
 Const (@{const_name "uminus"}, @{typ "perm => perm"}) $ p   
*}

ML {* 
fun single_case_tac ctxt pred_names pi intro  = 
let
  val thy = ProofContext.theory_of ctxt
  val cpi = Thm.cterm_of thy (mk_minus pi)
  val rule = Drule.instantiate' [] [SOME cpi] @{thm permute_boolE}
in
  eqvt_strict_tac ctxt [] [] THEN' 
  SUBPROOF (fn {prems, context as ctxt, ...} =>
    let
      val prems' = map (transform_prem ctxt pred_names) prems
      val side_cond_tac = EVERY' 
        [ rtac rule, 
          eqvt_strict_tac ctxt @{thms permute_minus_cancel(2)} [],
          resolve_tac prems' ]
    in
      HEADGOAL (rtac intro THEN_ALL_NEW (resolve_tac prems' ORELSE' side_cond_tac)) 
    end) ctxt
end
*}

ML {*
fun prepare_pred params_no pi pred =
let
  val (c, xs) = strip_comb pred;
  val (xs1, xs2) = chop params_no xs
in
  HOLogic.mk_imp 
    (pred, list_comb (c, xs1 @ map (mk_perm pi) xs2))
end
*}

ML {*
fun transp ([] :: _) = []
  | transp xs = map hd xs :: transp (map tl xs);
*}

ML {* 
  Local_Theory.note;
  Local_Theory.notes;
  fold_map
*}

ML {*
fun note_named_thm (name, thm) ctxt = 
let
  val thm_name = Binding.qualified_name 
    (Long_Name.qualify (Long_Name.base_name name) "eqvt")
  val attr = Attrib.internal (K eqvt_add)
in
  Local_Theory.note ((thm_name, [attr]), [thm]) ctxt
end
*}

ML {*
fun eqvt_rel_tac pred_name ctxt = 
let
  val thy = ProofContext.theory_of ctxt
  val ({names, ...}, {raw_induct, intrs, ...}) =
    Inductive.the_inductive ctxt (Sign.intern_const thy pred_name)
  val raw_induct = atomize_induct ctxt raw_induct;
  val intros = map atomize_intr intrs;
  val params_no = length (Inductive.params_of raw_induct)
  val (([raw_concl], [raw_pi]), ctxt') = 
    ctxt |> Variable.import_terms false [concl_of raw_induct] 
         ||>> Variable.variant_fixes ["pi"]
  val pi = Free (raw_pi, @{typ perm})
  val preds = map (fst o HOLogic.dest_imp)
    (HOLogic.dest_conj (HOLogic.dest_Trueprop raw_concl));
  val goal = HOLogic.mk_Trueprop 
    (foldr1 HOLogic.mk_conj (map (prepare_pred params_no pi) preds))
  val thm = Goal.prove ctxt' [] [] goal (fn {context,...} => 
    HEADGOAL (EVERY' (rtac raw_induct :: map (single_case_tac context names pi) intros)))
    |> singleton (ProofContext.export ctxt' ctxt)
  val thms = map (fn th => zero_var_indexes (th RS mp)) (Datatype_Aux.split_conj_thm thm)
in
   ctxt |> fold_map note_named_thm (names ~~ thms)
        |> snd
end
*}


ML {*
local structure P = OuterParse and K = OuterKeyword in

val _ =
  OuterSyntax.local_theory "equivariance"
    "prove equivariance for inductive predicate involving nominal datatypes" K.thy_decl
    (P.xname >> eqvt_rel_tac);

end;
*}

equivariance valid
equivariance typing

thm valid.eqvt
thm typing.eqvt
thm eqvts
thm eqvts_raw


inductive
 alpha_lam_raw'
where
  "name = namea \<Longrightarrow> alpha_lam_raw' (Var_raw name) (Var_raw namea)"
| "\<lbrakk>alpha_lam_raw' lam_raw1 lam_raw1a; alpha_lam_raw' lam_raw2 lam_raw2a\<rbrakk> \<Longrightarrow>
   alpha_lam_raw' (App_raw lam_raw1 lam_raw2) (App_raw lam_raw1a lam_raw2a)"
| "\<exists>pi. ({atom name}, lam_raw) \<approx>gen alpha_lam_raw fv_lam_raw pi ({atom namea}, lam_rawa) \<Longrightarrow>
   alpha_lam_raw' (Lam_raw name lam_raw) (Lam_raw namea lam_rawa)"

declare permute_lam_raw.simps[eqvt]
(*declare alpha_gen_real_eqvt[eqvt]*)
(*equivariance alpha_lam_raw'*)

lemma
  assumes a: "alpha_lam_raw' t1 t2"
  shows "alpha_lam_raw' (p \<bullet> t1) (p \<bullet> t2)"
using a
apply(induct)
oops


section {* size function *}

lemma size_eqvt_raw:
  fixes t::"lam_raw"
  shows "size (pi \<bullet> t)  = size t"
  apply (induct rule: lam_raw.inducts)
  apply simp_all
  done

instantiation lam :: size 
begin

quotient_definition
  "size_lam :: lam \<Rightarrow> nat"
is
  "size :: lam_raw \<Rightarrow> nat"

lemma size_rsp:
  "alpha_lam_raw x y \<Longrightarrow> size x = size y"
  apply (induct rule: alpha_lam_raw.inducts)
  apply (simp_all only: lam_raw.size)
  apply (simp_all only: alphas)
  apply clarify
  apply (simp_all only: size_eqvt_raw)
  done

lemma [quot_respect]:
  "(alpha_lam_raw ===> op =) size size"
  by (simp_all add: size_rsp)

lemma [quot_preserve]:
  "(rep_lam ---> id) size = size"
  by (simp_all add: size_lam_def)

instance
  by default

end

lemmas size_lam[simp] = 
  lam_raw.size(4)[quot_lifted]
  lam_raw.size(5)[quot_lifted]
  lam_raw.size(6)[quot_lifted]

(* is this needed? *)
lemma [measure_function]: 
  "is_measure (size::lam\<Rightarrow>nat)" 
  by (rule is_measure_trivial)

section {* Matching *}

definition
  MATCH :: "('c::pt \<Rightarrow> (bool * 'a::pt * 'b::pt)) \<Rightarrow> 'b \<Rightarrow> 'a \<Rightarrow> 'b"
where
  "MATCH M d x \<equiv> if (\<exists>!r. \<exists>q. M q = (True, x, r)) then (THE r. \<exists>q. M q = (True, x, r)) else d"

(*
lemma MATCH_eqvt:
  shows "p \<bullet> (MATCH M d x) = MATCH (p \<bullet> M) (p \<bullet> d) (p \<bullet> x)"
unfolding MATCH_def
apply(perm_simp the_eqvt)
apply (tactic {* Nominal_Permeq.eqvt_tac @{context} 1 *})
apply(simp)
thm eqvts_raw 
apply(subst if_eqvt)
apply(subst ex1_eqvt)
apply(subst permute_fun_def)
apply(subst ex_eqvt)
apply(subst permute_fun_def)
apply(subst eq_eqvt)
apply(subst permute_fun_app_eq[where f="M"])
apply(simp only: permute_minus_cancel)
apply(subst permute_prod.simps)
apply(subst permute_prod.simps)
apply(simp only: permute_minus_cancel)
apply(simp only: permute_bool_def)
apply(simp)
apply(subst ex1_eqvt)
apply(subst permute_fun_def)
apply(subst ex_eqvt)
apply(subst permute_fun_def)
apply(subst eq_eqvt)

apply(simp only: eqvts)
apply(simp)
apply(subgoal_tac "(p \<bullet> (\<exists>!r. \<exists>q. M q = (True, x, r))) = (\<exists>!r. \<exists>q. (p \<bullet> M) q = (True, p \<bullet> x, r))")
apply(drule sym)
apply(simp)
apply(rule impI)
apply(simp add: perm_bool)
apply(rule trans)
apply(rule pt_the_eqvt[OF pta at])
apply(assumption)
apply(simp add: pt_ex_eqvt[OF pt at])
apply(simp add: pt_eq_eqvt[OF ptb at])
apply(rule cheat)
apply(rule trans)
apply(rule pt_ex1_eqvt)
apply(rule pta)
apply(rule at)
apply(simp add: pt_ex_eqvt[OF pt at])
apply(simp add: pt_eq_eqvt[OF ptb at])
apply(subst pt_pi_rev[OF pta at])
apply(subst pt_fun_app_eq[OF pt at])
apply(subst pt_pi_rev[OF pt at])
apply(simp)
done

lemma MATCH_cng:
  assumes a: "M1 = M2" "d1 = d2"
  shows "MATCH M1 d1 x = MATCH M2 d2 x"
using a by simp

lemma MATCH_eq:
  assumes a: "t = l x" "G x" "\<And>x'. t = l x' \<Longrightarrow> G x' \<Longrightarrow> r x' = r x"
  shows "MATCH (\<lambda>x. (G x, l x, r x)) d t = r x"
using a
unfolding MATCH_def
apply(subst if_P)
apply(rule_tac a="r x" in ex1I)
apply(rule_tac x="x" in exI)
apply(blast)
apply(erule exE)
apply(drule_tac x="q" in meta_spec)
apply(auto)[1]
apply(rule the_equality)
apply(blast)
apply(erule exE)
apply(drule_tac x="q" in meta_spec)
apply(auto)[1]
done

lemma MATCH_eq2:
  assumes a: "t = l x1 x2" "G x1 x2" "\<And>x1' x2'. t = l x1' x2' \<Longrightarrow> G x1' x2' \<Longrightarrow> r x1' x2' = r x1 x2"
  shows "MATCH (\<lambda>(x1,x2). (G x1 x2, l x1 x2, r x1 x2)) d t = r x1 x2"
sorry

lemma MATCH_neq:
  assumes a: "\<And>x. t = l x \<Longrightarrow> G x \<Longrightarrow> False"
  shows "MATCH (\<lambda>x. (G x, l x, r x)) d t = d"
using a
unfolding MATCH_def
apply(subst if_not_P)
apply(blast)
apply(rule refl)
done

lemma MATCH_neq2:
  assumes a: "\<And>x1 x2. t = l x1 x2 \<Longrightarrow> G x1 x2 \<Longrightarrow> False"
  shows "MATCH (\<lambda>(x1,x2). (G x1 x2, l x1 x2, r x1 x2)) d t = d"
using a
unfolding MATCH_def
apply(subst if_not_P)
apply(auto)
done
*)

end



