theory Fv
imports "Nominal2_Atoms" "Abs" "Perm" "Rsp"
begin

(* The bindings data structure:

  Bindings are a list of lists of lists of triples.

   The first list represents the datatypes defined.
   The second list represents the constructors.
   The internal list is a list of all the bndings that
   concern the constructor.

   Every triple consists of a function, the binding and
   the body.

  Eg:
nominal_datatype

   C1
 | C2 x y z bind x in z
 | C3 x y z bind f x in z bind g y in z

yields:
[
 [],
 [(NONE, 0, 2)],
 [(SOME (Const f), 0, 2), (Some (Const g), 1, 2)]]

A SOME binding has to have a function which takes an appropriate
argument and returns an atom set. A NONE binding has to be on an
argument that is an atom or an atom set.
*)

(*
An overview of the generation of free variables:

1) fv_bn functions are generated only for the non-recursive binds.

   An fv_bn for a constructor is a union of values for the arguments:

   For an argument x that is in the bn function
   - if it is a recursive argument bn' we return: fv_bn' x
   - otherwise empty

   For an argument x that is not in the bn function
   - for atom we return: {atom x}
   - for atom set we return: atom ` x
   - for a recursive call to type ty' we return: fv_ty' x
     with fv of the appropriate type
   - otherwise empty

2) fv_ty functions generated for all types being defined:

   fv_ty for a constructor is a union of values for the arguments.

   For an argument that is bound in a shallow binding we return empty.

   For an argument x that bound in a non-recursive deep binding
   we return: fv_bn x.

   Otherwise we return the free variables of the argument minus the
   bound variables of the argument.

   The free variables for an argument x are:
   - for an atom: {atom x}
   - for atom set: atom ` x
   - for recursive call to type ty' return: fv_ty' x
   - for nominal datatype ty' return: fv_ty' x

   The bound variables are a union of results of all bindings that
   involve the given argument. For a paricular binding:

   - for a binding function bn: bn x
   - for a recursive argument of type ty': fv_fy' x
   - for nominal datatype ty' return: fv_ty' x
*)

(*
An overview of the generation of alpha-equivalence:

1) alpha_bn relations are generated for binding functions.

   An alpha_bn for a constructor is true if a conjunction of
   propositions for each argument holds.

   For an argument a proposition is build as follows from
   th:

   - for a recursive argument in the bn function, we return: alpha_bn argl argr
   - for a recursive argument for type ty not in bn, we return: alpha_ty argl argr
   - for other arguments in the bn function we return: True
   - for other arguments not in the bn function we return: argl = argr

2) alpha_ty relations are generated for all the types being defined:

   !!! permutations !!!

   An alpha_ty for a constructor is true if a conjunction of
   propositions for each argument holds.

   For an argument we allow bindings where only one of the following
   holds:

   - Argument is bound in some shallow bindings: We return true
   - Argument is bound in some deep recursive bindings
     !!! still to describe !!!
   - Argument is bound in some deep non-recursive bindings.
     We return: alpha_bn argl argr
   - Argument has some shallow and/or non-recursive bindings.
     !!! still to describe !!!
   - Argument has some recursive bindings. The bindings were
     already treated in 2nd case so we return: True
   - Argument has no bindings and is not bound.
     If it is recursive for type ty, we return: alpha_ty argl argr
     Otherwise we return: argl = argr

*)

ML {*
fun is_atom thy typ =
  Sign.of_sort thy (typ, @{sort at})

fun is_atom_set thy (Type ("fun", [t, @{typ bool}])) = is_atom thy t
  | is_atom_set thy _ = false;
*}



(* Like map2, only if the second list is empty passes empty lists insted of error *)
ML {*
fun map2i _ [] [] = []
  | map2i f (x :: xs) (y :: ys) = f x y :: map2i f xs ys
  | map2i f (x :: xs) [] = f x [] :: map2i f xs []
  | map2i _ _ _ = raise UnequalLengths;
*}

(* Finds bindings with the same function and binding, and gathers all
   bodys for such pairs
 *)
ML {*
fun gather_binds binds =
let
  fun gather_binds_cons binds =
    let
      val common = map (fn (f, bi, _) => (f, bi)) binds
      val nodups = distinct (op =) common
      fun find_bodys (sf, sbi) =
        filter (fn (f, bi, _) => f = sf andalso bi = sbi) binds
      val bodys = map ((map (fn (_, _, bo) => bo)) o find_bodys) nodups
    in
      nodups ~~ bodys
    end
in
  map (map gather_binds_cons) binds
end
*}

ML {*
fun un_gather_binds_cons binds =
  flat (map (fn (((f, bi), bos), pi) => map (fn bo => ((f, bi, bo), pi)) bos) binds)
*}

ML {*
  open Datatype_Aux; (* typ_of_dtyp, DtRec, ... *);
  (* TODO: It is the same as one in 'nominal_atoms' *)
  fun mk_atom ty = Const (@{const_name atom}, ty --> @{typ atom});
  val noatoms = @{term "{} :: atom set"};
  fun mk_single_atom x = HOLogic.mk_set @{typ atom} [mk_atom (type_of x) $ x];
  fun mk_union sets =
    fold (fn a => fn b =>
      if a = noatoms then b else
      if b = noatoms then a else
      if a = b then a else
      HOLogic.mk_binop @{const_name sup} (a, b)) (rev sets) noatoms;
  val mk_inter = foldr1 (HOLogic.mk_binop @{const_name inf})
  fun mk_conjl props =
    fold (fn a => fn b =>
      if a = @{term True} then b else
      if b = @{term True} then a else
      HOLogic.mk_conj (a, b)) (rev props) @{term True};
  fun mk_diff a b =
    if b = noatoms then a else
    if b = a then noatoms else
    HOLogic.mk_binop @{const_name minus} (a, b);
  fun mk_atoms t =
    let
      val ty = fastype_of t;
      val atom_ty = HOLogic.dest_setT ty --> @{typ atom};
      val img_ty = atom_ty --> ty --> @{typ "atom set"};
    in
      (Const (@{const_name image}, img_ty) $ Const (@{const_name atom}, atom_ty) $ t)
    end;
  (* Similar to one in USyntax *)
  fun mk_pair (fst, snd) =
    let val ty1 = fastype_of fst
      val ty2 = fastype_of snd
      val c = HOLogic.pair_const ty1 ty2
    in c $ fst $ snd
    end;
*}

(* Given [fv1, fv2, fv3] creates %(x, y, z). fv1 x u fv2 y u fv3 z *)
ML {*
fun mk_compound_fv fvs =
let
  val nos = (length fvs - 1) downto 0;
  val fvs_applied = map (fn (fv, no) => fv $ Bound no) (fvs ~~ nos);
  val fvs_union = mk_union fvs_applied;
  val (tyh :: tys) = rev (map (domain_type o fastype_of) fvs);
  fun fold_fun ty t = HOLogic.mk_split (Abs ("", ty, t))
in
  fold fold_fun tys (Abs ("", tyh, fvs_union))
end;
*}

ML {* @{term "\<lambda>(x, y, z). \<lambda>(x', y', z'). R x x' \<and> R2 y y' \<and> R3 z z'"} *}

(* Given [R1, R2, R3] creates %(x,x'). %(y,y'). %(z,z'). R x x' \<and> R y y' \<and> R z z' *)
ML {*
fun mk_compound_alpha Rs =
let
  val nos = (length Rs - 1) downto 0;
  val nos2 = (2 * length Rs - 1) downto length Rs;
  val Rs_applied = map (fn (R, (no2, no)) => R $ Bound no2 $ Bound no) (Rs ~~ (nos2 ~~ nos));
  val Rs_conj = mk_conjl Rs_applied;
  val (tyh :: tys) = rev (map (domain_type o fastype_of) Rs);
  fun fold_fun ty t = HOLogic.mk_split (Abs ("", ty, t))
  val abs_rhs = fold fold_fun tys (Abs ("", tyh, Rs_conj))
in
  fold fold_fun tys (Abs ("", tyh, abs_rhs))
end;
*}

ML {* cterm_of @{theory} (mk_compound_alpha [@{term "R :: 'a \<Rightarrow> 'a \<Rightarrow> bool"}, @{term "R2 :: 'b \<Rightarrow> 'b \<Rightarrow> bool"}, @{term "R3 :: 'b \<Rightarrow> 'b \<Rightarrow> bool"}]) *}

ML {* fun add_perm (p1, p2) = Const(@{const_name plus}, @{typ "perm \<Rightarrow> perm \<Rightarrow> perm"}) $ p1 $ p2 *}

ML {*
fun non_rec_binds l =
let
  fun is_non_rec (SOME (f, false), _, _) = SOME f
    | is_non_rec _ = NONE
in
  distinct (op =) (map_filter is_non_rec (flat (flat l)))
end
*}

(* We assume no bindings in the type on which bn is defined *)
(* TODO: currently works only with current fv_bn function *)
ML {*
fun fv_bn thy (dt_info : Datatype_Aux.info) fv_frees (bn, ith_dtyp, args_in_bns) =
let
  val {descr, sorts, ...} = dt_info;
  fun nth_dtyp i = typ_of_dtyp descr sorts (DtRec i);
  val fvbn_name = "fv_" ^ (Long_Name.base_name (fst (dest_Const bn)));
  val fvbn = Free (fvbn_name, fastype_of (nth fv_frees ith_dtyp));
  fun fv_bn_constr (cname, dts) args_in_bn =
  let
    val Ts = map (typ_of_dtyp descr sorts) dts;
    val names = Datatype_Prop.make_tnames Ts;
    val args = map Free (names ~~ Ts);
    val c = Const (cname, Ts ---> (nth_dtyp ith_dtyp));
    fun fv_arg ((dt, x), arg_no) =
      let
        val ty = fastype_of x
      in
        if arg_no mem args_in_bn then 
          (if is_rec_type dt then
            (if body_index dt = ith_dtyp then fvbn $ x else error "fv_bn: recursive argument, but wrong datatype.")
          else @{term "{} :: atom set"}) else
        if is_atom thy ty then mk_single_atom x else
        if is_atom_set thy ty then mk_atoms x else
        if is_rec_type dt then nth fv_frees (body_index dt) $ x else
        @{term "{} :: atom set"}
      end;
    val arg_nos = 0 upto (length dts - 1)
  in
    HOLogic.mk_Trueprop (HOLogic.mk_eq
      (fvbn $ list_comb (c, args), mk_union (map fv_arg (dts ~~ args ~~ arg_nos))))
  end;
  val (_, (_, _, constrs)) = nth descr ith_dtyp;
  val eqs = map2i fv_bn_constr constrs args_in_bns
in
  ((bn, fvbn), (fvbn_name, eqs))
end
*}

ML {*
fun alpha_bn thy (dt_info : Datatype_Aux.info) alpha_frees ((bn, ith_dtyp, args_in_bns), is_rec) =
let
  val {descr, sorts, ...} = dt_info;
  fun nth_dtyp i = typ_of_dtyp descr sorts (DtRec i);
  val alpha_bn_name = "alpha_" ^ (Long_Name.base_name (fst (dest_Const bn)));
  val alpha_bn_type = 
    (*if is_rec then @{typ perm} --> nth_dtyp ith_dtyp --> nth_dtyp ith_dtyp --> @{typ bool} else*)
    nth_dtyp ith_dtyp --> nth_dtyp ith_dtyp --> @{typ bool};
  val alpha_bn_free = Free(alpha_bn_name, alpha_bn_type);
  val pi = Free("pi", @{typ perm})
  fun alpha_bn_constr (cname, dts) args_in_bn =
  let
    val Ts = map (typ_of_dtyp descr sorts) dts;
    val names = Name.variant_list ["pi"] (Datatype_Prop.make_tnames Ts);
    val names2 = Name.variant_list ("pi" :: names) (Datatype_Prop.make_tnames Ts);
    val args = map Free (names ~~ Ts);
    val args2 = map Free (names2 ~~ Ts);
    val c = Const (cname, Ts ---> (nth_dtyp ith_dtyp));
    val rhs = HOLogic.mk_Trueprop
      (alpha_bn_free $ (list_comb (c, args)) $ (list_comb (c, args2)));
    fun lhs_arg ((dt, arg_no), (arg, arg2)) =
      let
        val argty = fastype_of arg;
        val permute = Const (@{const_name permute}, @{typ perm} --> argty --> argty);
      in
      if is_rec_type dt then
        if arg_no mem args_in_bn then alpha_bn_free $ arg $ arg2
        else (nth alpha_frees (body_index dt)) $ arg $ arg2
      else
        if arg_no mem args_in_bn then @{term True}
        else HOLogic.mk_eq (arg, arg2)
      end
    val arg_nos = 0 upto (length dts - 1)
    val lhss = mk_conjl (map lhs_arg (dts ~~ arg_nos ~~ (args ~~ args2)))
    val eq = Logic.mk_implies (HOLogic.mk_Trueprop lhss, rhs)
  in
    eq
  end
  val (_, (_, _, constrs)) = nth descr ith_dtyp;
  val eqs = map2i alpha_bn_constr constrs args_in_bns
in
  ((bn, alpha_bn_free), (alpha_bn_name, eqs))
end
*}

(* Checks that a list of bindings contains only compatible ones *)
ML {*
fun bns_same l =
  length (distinct (op =) (map (fn ((b, _, _), _) => b) l)) = 1
*}

(* TODO: Notice datatypes without bindings and replace alpha with equality *)
ML {*
fun define_fv_alpha (dt_info : Datatype_Aux.info) bindsall bns lthy =
let
  val thy = ProofContext.theory_of lthy;
  val {descr, sorts, ...} = dt_info;
  fun nth_dtyp i = typ_of_dtyp descr sorts (DtRec i);
  val fv_names = Datatype_Prop.indexify_names (map (fn (i, _) =>
    "fv_" ^ name_of_typ (nth_dtyp i)) descr);
  val fv_types = map (fn (i, _) => nth_dtyp i --> @{typ "atom set"}) descr;
  val fv_frees = map Free (fv_names ~~ fv_types);
  val nr_bns = non_rec_binds bindsall;
  val rel_bns = filter (fn (bn, _, _) => bn mem nr_bns) bns;
  val (bn_fv_bns, fv_bn_names_eqs) = split_list (map (fv_bn thy dt_info fv_frees) rel_bns);
  val (fv_bn_names, fv_bn_eqs) = split_list fv_bn_names_eqs;
  val alpha_names = Datatype_Prop.indexify_names (map (fn (i, _) =>
    "alpha_" ^ name_of_typ (nth_dtyp i)) descr);
  val alpha_types = map (fn (i, _) => nth_dtyp i --> nth_dtyp i --> @{typ bool}) descr;
  val alpha_frees = map Free (alpha_names ~~ alpha_types);
  (* We assume that a bn is either recursive or not *)
  val bns_rec = map (fn (bn, _, _) => not (bn mem nr_bns)) bns;
  val (bn_alpha_bns, alpha_bn_names_eqs) = split_list (map (alpha_bn thy dt_info alpha_frees) (bns ~~ bns_rec))
  val (alpha_bn_names, alpha_bn_eqs) = split_list alpha_bn_names_eqs;
  val alpha_bn_frees = map snd bn_alpha_bns;
  val alpha_bn_types = map fastype_of alpha_bn_frees;
  fun fv_alpha_constr ith_dtyp (cname, dts) bindcs =
    let
      val Ts = map (typ_of_dtyp descr sorts) dts;
      val bindslen = length bindcs
      val pi_strs_same = replicate bindslen "pi"
      val pi_strs = Name.variant_list [] pi_strs_same;
      val pis = map (fn ps => Free (ps, @{typ perm})) pi_strs;
      val bind_pis_gath = bindcs ~~ pis;
      val bind_pis = un_gather_binds_cons bind_pis_gath;
      val bindcs = map fst bind_pis;
      val names = Name.variant_list pi_strs (Datatype_Prop.make_tnames Ts);
      val args = map Free (names ~~ Ts);
      val names2 = Name.variant_list (pi_strs @ names) (Datatype_Prop.make_tnames Ts);
      val args2 = map Free (names2 ~~ Ts);
      val c = Const (cname, Ts ---> (nth_dtyp ith_dtyp));
      val fv_c = nth fv_frees ith_dtyp;
      val alpha = nth alpha_frees ith_dtyp;
      val arg_nos = 0 upto (length dts - 1)
      fun fv_bind args (NONE, i, _) =
            if is_rec_type (nth dts i) then (nth fv_frees (body_index (nth dts i))) $ (nth args i) else
            if ((is_atom thy) o fastype_of) (nth args i) then mk_single_atom (nth args i) else
            if ((is_atom_set thy) o fastype_of) (nth args i) then mk_atoms (nth args i) else
            (* TODO we do not know what to do with non-atomizable things *)
            @{term "{} :: atom set"}
        | fv_bind args (SOME (f, _), i, _) = f $ (nth args i);
      fun fv_binds args relevant = mk_union (map (fv_bind args) relevant)
      fun find_nonrec_binder j (SOME (f, false), i, _) = if i = j then SOME f else NONE
        | find_nonrec_binder _ _ = NONE
      fun fv_arg ((dt, x), arg_no) =
        case get_first (find_nonrec_binder arg_no) bindcs of
          SOME f =>
            (case get_first (fn (x, y) => if x = f then SOME y else NONE) bn_fv_bns of
                SOME fv_bn => fv_bn $ x
              | NONE => error "bn specified in a non-rec binding but not in bn list")
        | NONE =>
            let
              val arg =
                if is_rec_type dt then nth fv_frees (body_index dt) $ x else
                if ((is_atom thy) o fastype_of) x then mk_single_atom x else
                if ((is_atom_set thy) o fastype_of) x then mk_atoms x else
                (* TODO we do not know what to do with non-atomizable things *)
                @{term "{} :: atom set"};
              (* If i = j then we generate it only once *)
              val relevant = filter (fn (_, i, j) => ((i = arg_no) orelse (j = arg_no))) bindcs;
              val sub = fv_binds args relevant
            in
              mk_diff arg sub
            end;
      val fv_eq = HOLogic.mk_Trueprop (HOLogic.mk_eq
        (fv_c $ list_comb (c, args), mk_union (map fv_arg  (dts ~~ args ~~ arg_nos))))
      val alpha_rhs =
        HOLogic.mk_Trueprop (alpha $ (list_comb (c, args)) $ (list_comb (c, args2)));
      fun alpha_arg ((dt, arg_no), (arg, arg2)) =
        let
          val rel_in_simp_binds = filter (fn ((NONE, i, _), _) => i = arg_no | _ => false) bind_pis;
          val rel_in_comp_binds = filter (fn ((SOME _, i, _), _) => i = arg_no | _ => false) bind_pis;
          val rel_has_binds = filter (fn ((NONE, _, j), _) => j = arg_no
                                       | ((SOME (_, false), _, j), _) => j = arg_no
                                       | _ => false) bind_pis;
          val rel_has_rec_binds = filter 
            (fn ((SOME (_, true), _, j), _) => j = arg_no | _ => false) bind_pis;
        in
          case (rel_in_simp_binds, rel_in_comp_binds, rel_has_binds, rel_has_rec_binds) of
            ([], [], [], []) =>
              if is_rec_type dt then (nth alpha_frees (body_index dt) $ arg $ arg2)
              else (HOLogic.mk_eq (arg, arg2))
          | (_, [], [], []) => @{term True}
          | ([], [], [], _) => @{term True}
          | ([], ((((SOME (bn, is_rec)), _, _), pi) :: _), [], []) =>
            if not (bns_same rel_in_comp_binds) then error "incompatible bindings for an argument" else
            if is_rec then
              let
                val (rbinds, rpis) = split_list rel_in_comp_binds
                val bound_in_nos = map (fn (_, _, i) => i) rbinds
                val bound_in_ty_nos = map (fn i => body_index (nth dts i)) bound_in_nos;
                val bound_args = arg :: map (nth args) bound_in_nos;
                val bound_args2 = arg2 :: map (nth args2) bound_in_nos;
                fun bound_in args (_, _, i) = nth args i;
                val lhs_binds = fv_binds args rbinds
                val lhs_arg = foldr1 HOLogic.mk_prod bound_args
                val lhs = mk_pair (lhs_binds, lhs_arg);
                val rhs_binds = fv_binds args2 rbinds;
                val rhs_arg = foldr1 HOLogic.mk_prod bound_args2;
                val rhs = mk_pair (rhs_binds, rhs_arg);
                val fvs = map (nth fv_frees) ((body_index dt) :: bound_in_ty_nos);
                val fv = mk_compound_fv fvs;
                val alphas = map (nth alpha_frees) ((body_index dt) :: bound_in_ty_nos);
                val alpha = mk_compound_alpha alphas;
                val pi = foldr1 add_perm (distinct (op =) rpis);
                val alpha_gen_pre = Const (@{const_name alpha_gen}, dummyT) $ lhs $ alpha $ fv $ pi $ rhs;
                val alpha_gen = Syntax.check_term lthy alpha_gen_pre
              in
                alpha_gen
              end
            else
              let
                val alpha_bn_const =
                  nth alpha_bn_frees (find_index (fn (b, _, _) => b = bn) bns)
                val ty = fastype_of (bn $ arg)
                val permute = Const(@{const_name permute}, @{typ perm} --> ty --> ty)
              in
                alpha_bn_const $ arg $ arg2
              end
          | ([], [], relevant, []) =>
            let
              val (rbinds, rpis) = split_list relevant
              val lhs_binds = fv_binds args rbinds
              val lhs = mk_pair (lhs_binds, arg);
              val rhs_binds = fv_binds args2 rbinds;
              val rhs = mk_pair (rhs_binds, arg2);
              val alpha = nth alpha_frees (body_index dt);
              val fv = nth fv_frees (body_index dt);
              val pi = foldr1 add_perm (distinct (op =) rpis);
              val alpha_gen_pre = Const (@{const_name alpha_gen}, dummyT) $ lhs $ alpha $ fv $ pi $ rhs;
              val alpha_gen = Syntax.check_term lthy alpha_gen_pre
            in
              alpha_gen
            end
          | _ => error "Fv.alpha: not supported binding structure"
        end
      val alphas = map alpha_arg (dts ~~ arg_nos ~~ (args ~~ args2))
      val alpha_lhss = mk_conjl alphas
      val alpha_lhss_ex =
        fold (fn pi_str => fn t => HOLogic.mk_exists (pi_str, @{typ perm}, t)) pi_strs alpha_lhss
      val alpha_eq = Logic.mk_implies (HOLogic.mk_Trueprop alpha_lhss_ex, alpha_rhs)
    in
      (fv_eq, alpha_eq)
    end;
  fun fv_alpha_eq (i, (_, _, constrs)) binds = map2i (fv_alpha_constr i) constrs binds;
  val fveqs_alphaeqs = map2i fv_alpha_eq descr (gather_binds bindsall)
  val (fv_eqs_perfv, alpha_eqs) = apsnd flat (split_list (map split_list fveqs_alphaeqs))
  val rel_bns_nos = map (fn (_, i, _) => i) rel_bns;
  fun filter_fun (_, b) = b mem rel_bns_nos;
  val all_fvs = (fv_names ~~ fv_eqs_perfv) ~~ (0 upto (length fv_names - 1))
  val (fv_names_fst, fv_eqs_fst) = apsnd flat (split_list (map fst (filter_out filter_fun all_fvs)))
  val (fv_names_snd, fv_eqs_snd) = apsnd flat (split_list (map fst (filter filter_fun all_fvs)))
  val fv_eqs_all = fv_eqs_fst @ (flat fv_bn_eqs);
  val fv_names_all = fv_names_fst @ fv_bn_names;
  val add_binds = map (fn x => (Attrib.empty_binding, x))
(* Function_Fun.add_fun Function_Common.default_config ... true *)
  val (fvs, lthy') = (Primrec.add_primrec
    (map (fn s => (Binding.name s, NONE, NoSyn)) fv_names_all) (add_binds fv_eqs_all) lthy)
  val (fvs2, lthy'') =
    if fv_eqs_snd = [] then (([], []), lthy') else
   (Primrec.add_primrec
    (map (fn s => (Binding.name s, NONE, NoSyn)) fv_names_snd) (add_binds fv_eqs_snd) lthy')
  val (alphas, lthy''') = (Inductive.add_inductive_i
     {quiet_mode = true, verbose = false, alt_name = Binding.empty,
      coind = false, no_elim = false, no_ind = false, skip_mono = true, fork_mono = false}
     (map2 (fn x => fn y => ((Binding.name x, y), NoSyn)) (alpha_names @ alpha_bn_names)
     (alpha_types @ alpha_bn_types)) []
     (add_binds (alpha_eqs @ flat alpha_bn_eqs)) [] lthy'')
  val all_fvs = (fst fvs @ fst fvs2, snd fvs @ snd fvs2)
in
  ((all_fvs, alphas), lthy''')
end
*}

(*
atom_decl name
datatype lam =
  VAR "name"
| APP "lam" "lam"
| LET "bp" "lam"
and bp =
  BP "name" "lam"
primrec
  bi::"bp \<Rightarrow> atom set"
where
  "bi (BP x t) = {atom x}"
setup {* snd o define_raw_perms (Datatype.the_info @{theory} "Fv.lam") 2 *}
local_setup {*
  snd o define_fv_alpha (Datatype.the_info @{theory} "Fv.lam")
  [[[], [], [(SOME (@{term bi}, true), 0, 1)]], [[]]] [(@{term bi}, 1, [[0]])] *}
print_theorems
*)

(*atom_decl name
datatype rtrm1 =
  rVr1 "name"
| rAp1 "rtrm1" "rtrm1"
| rLm1 "name" "rtrm1"        --"name is bound in trm1"
| rLt1 "bp" "rtrm1" "rtrm1"   --"all variables in bp are bound in the 2nd trm1"
and bp =
  BUnit
| BVr "name"
| BPr "bp" "bp"
primrec
  bv1
where
  "bv1 (BUnit) = {}"
| "bv1 (BVr x) = {atom x}"
| "bv1 (BPr bp1 bp2) = (bv1 bp1) \<union> (bv1 bp1)"
setup {* snd o define_raw_perms (Datatype.the_info @{theory} "Fv.rtrm1") 2 *}
local_setup {*
  snd o define_fv_alpha (Datatype.the_info @{theory} "Fv.rtrm1")
  [[[], [], [(NONE, 0, 1)], [(SOME (@{term bv1}, false), 0, 2)]],
  [[], [], []]] [(@{term bv1}, 1, [[], [0], [0, 1]])] *}
print_theorems
*)

(*
atom_decl name
datatype rtrm5 =
  rVr5 "name"
| rLt5 "rlts" "rtrm5" --"bind (bv5 lts) in (rtrm5)"
and rlts =
  rLnil
| rLcons "name" "rtrm5" "rlts"
primrec
  rbv5
where
  "rbv5 rLnil = {}"
| "rbv5 (rLcons n t ltl) = {atom n} \<union> (rbv5 ltl)"
setup {* snd o define_raw_perms (Datatype.the_info @{theory} "Fv.rtrm5") 2 *}
local_setup {* snd o define_fv_alpha (Datatype.the_info @{theory} "Fv.rtrm5")
  [[[], [(SOME (@{term rbv5}, false), 0, 1)]], [[], []]] [(@{term rbv5}, 1, [[], [0, 2]])] *}
print_theorems
*)

ML {*
fun alpha_inj_tac dist_inj intrs elims =
  SOLVED' (asm_full_simp_tac (HOL_ss addsimps intrs)) ORELSE'
  (rtac @{thm iffI} THEN' RANGE [
     (eresolve_tac elims THEN_ALL_NEW
       asm_full_simp_tac (HOL_ss addsimps dist_inj)
     ),
     asm_full_simp_tac (HOL_ss addsimps intrs)])
*}

ML {*
fun build_alpha_inj_gl thm =
  let
    val prop = prop_of thm;
    val concl = HOLogic.dest_Trueprop (Logic.strip_imp_concl prop);
    val hyps = map HOLogic.dest_Trueprop (Logic.strip_imp_prems prop);
    fun list_conj l = foldr1 HOLogic.mk_conj l;
  in
    if hyps = [] then concl
    else HOLogic.mk_eq (concl, list_conj hyps)
  end;
*}

ML {*
fun build_alpha_inj intrs dist_inj elims ctxt =
let
  val ((_, thms_imp), ctxt') = Variable.import false intrs ctxt;
  val gls = map (HOLogic.mk_Trueprop o build_alpha_inj_gl) thms_imp;
  fun tac _ = alpha_inj_tac dist_inj intrs elims 1;
  val thms = map (fn gl => Goal.prove ctxt' [] [] gl tac) gls;
in
  Variable.export ctxt' ctxt thms
end
*}

ML {*
fun build_alpha_refl_gl alphas (x, y, z) =
let
  fun build_alpha alpha =
    let
      val ty = domain_type (fastype_of alpha);
      val var = Free(x, ty);
      val var2 = Free(y, ty);
      val var3 = Free(z, ty);
      val symp = HOLogic.mk_imp (alpha $ var $ var2, alpha $ var2 $ var);
      val transp = HOLogic.mk_imp (alpha $ var $ var2,
        HOLogic.mk_all (z, ty,
          HOLogic.mk_imp (alpha $ var2 $ var3, alpha $ var $ var3)))
    in
      ((alpha $ var $ var), (symp, transp))
    end;
  val (refl_eqs, eqs) = split_list (map build_alpha alphas)
  val (sym_eqs, trans_eqs) = split_list eqs
  fun conj l = @{term Trueprop} $ foldr1 HOLogic.mk_conj l
in
  (conj refl_eqs, (conj sym_eqs, conj trans_eqs))
end
*}

ML {*
fun reflp_tac induct inj ctxt =
  rtac induct THEN_ALL_NEW
  simp_tac ((mk_minimal_ss ctxt) addsimps inj) THEN_ALL_NEW
  split_conjs THEN_ALL_NEW REPEAT o rtac @{thm exI[of _ "0 :: perm"]}
  THEN_ALL_NEW split_conjs THEN_ALL_NEW asm_full_simp_tac (HOL_ss addsimps
     @{thms alpha_gen fresh_star_def fresh_zero_perm permute_zero ball_triv
       add_0_left supp_zero_perm Int_empty_left split_conv})
*}


lemma exi_neg: "\<exists>(pi :: perm). P pi \<Longrightarrow> (\<And>(p :: perm). P p \<Longrightarrow> Q (- p)) \<Longrightarrow> \<exists>pi. Q pi"
apply (erule exE)
apply (rule_tac x="-pi" in exI)
by auto

ML {*
fun symp_tac induct inj eqvt ctxt =
  ind_tac induct THEN_ALL_NEW
  simp_tac ((mk_minimal_ss ctxt) addsimps inj) THEN_ALL_NEW split_conjs
  THEN_ALL_NEW
  REPEAT o etac @{thm exi_neg}
  THEN_ALL_NEW
  split_conjs THEN_ALL_NEW
  asm_full_simp_tac (HOL_ss addsimps @{thms supp_minus_perm minus_add[symmetric]}) THEN_ALL_NEW
  TRY o (rtac @{thm alpha_gen_compose_sym2} ORELSE' rtac @{thm alpha_gen_compose_sym}) THEN_ALL_NEW
  (asm_full_simp_tac (HOL_ss addsimps (eqvt @ all_eqvts ctxt)))
*}

ML {*
fun imp_elim_tac case_rules =
  Subgoal.FOCUS (fn {concl, context, ...} =>
    case term_of concl of
      _ $ (_ $ asm $ _) =>
        let
          fun filter_fn case_rule = (
            case Logic.strip_assums_hyp (prop_of case_rule) of
              ((_ $ asmc) :: _) =>
                let
                  val thy = ProofContext.theory_of context
                in
                  Pattern.matches thy (asmc, asm)
                end
            | _ => false)
          val matching_rules = filter filter_fn case_rules
        in
         (rtac impI THEN' rotate_tac (~1) THEN' eresolve_tac matching_rules) 1
        end
    | _ => no_tac
  )
*}


lemma exi_sum: "\<exists>(pi :: perm). P pi \<Longrightarrow> \<exists>(pi :: perm). Q pi \<Longrightarrow> (\<And>(p :: perm) (pi :: perm). P p \<Longrightarrow> Q pi \<Longrightarrow> R (pi + p)) \<Longrightarrow> \<exists>pi. R pi"
apply (erule exE)+
apply (rule_tac x="pia + pi" in exI)
by auto

ML {*
fun is_ex (Const ("Ex", _) $ Abs _) = true
  | is_ex _ = false;
*}

ML {*
fun eetac rule = Subgoal.FOCUS_PARAMS 
  (fn (focus) =>
     let
       val concl = #concl focus
       val prems = Logic.strip_imp_prems (term_of concl)
       val exs = filter (fn x => is_ex (HOLogic.dest_Trueprop x)) prems
       val cexs = map (SOME o (cterm_of (ProofContext.theory_of (#context focus)))) exs
       val thins = map (fn cex => Drule.instantiate' [] [cex] Drule.thin_rl) cexs
     in
     (etac rule THEN' RANGE[
        atac,
        eresolve_tac thins
     ]) 1
     end
  )
*}

ML {*
fun transp_tac ctxt induct alpha_inj term_inj distinct cases eqvt =
  ind_tac induct THEN_ALL_NEW
  (TRY o rtac allI THEN' imp_elim_tac cases ctxt) THEN_ALL_NEW
  asm_full_simp_tac ((mk_minimal_ss ctxt) addsimps alpha_inj) THEN_ALL_NEW
  split_conjs THEN_ALL_NEW REPEAT o (eetac @{thm exi_sum} ctxt) THEN_ALL_NEW split_conjs
  THEN_ALL_NEW (asm_full_simp_tac (HOL_ss addsimps (term_inj @ distinct)))
  THEN_ALL_NEW split_conjs THEN_ALL_NEW
  TRY o (etac @{thm alpha_gen_compose_trans} THEN' RANGE[atac]) THEN_ALL_NEW
  (asm_full_simp_tac (HOL_ss addsimps (all_eqvts ctxt @ eqvt @ term_inj @ distinct)))
*}

lemma transp_aux:
  "(\<And>xa ya. R xa ya \<longrightarrow> (\<forall>z. R ya z \<longrightarrow> R xa z)) \<Longrightarrow> transp R"
  unfolding transp_def
  by blast

ML {*
fun equivp_tac reflps symps transps =
  simp_tac (HOL_ss addsimps @{thms equivp_reflp_symp_transp reflp_def symp_def})
  THEN' rtac conjI THEN' rtac allI THEN'
  resolve_tac reflps THEN'
  rtac conjI THEN' rtac allI THEN' rtac allI THEN'
  resolve_tac symps THEN'
  rtac @{thm transp_aux} THEN' resolve_tac transps
*}

ML {*
fun build_equivps alphas term_induct alpha_induct term_inj alpha_inj distinct cases eqvt ctxt =
let
  val ([x, y, z], ctxt') = Variable.variant_fixes ["x","y","z"] ctxt;
  val (reflg, (symg, transg)) = build_alpha_refl_gl alphas (x, y, z)
  fun reflp_tac' _ = reflp_tac term_induct alpha_inj ctxt 1;
  fun symp_tac' _ = symp_tac alpha_induct alpha_inj eqvt ctxt 1;
  fun transp_tac' _ = transp_tac ctxt alpha_induct alpha_inj term_inj distinct cases eqvt 1;
  val reflt = Goal.prove ctxt' [] [] reflg reflp_tac';
  val symt = Goal.prove ctxt' [] [] symg symp_tac';
  val transt = Goal.prove ctxt' [] [] transg transp_tac';
  val [refltg, symtg, transtg] = Variable.export ctxt' ctxt [reflt, symt, transt]
  val reflts = HOLogic.conj_elims refltg
  val symts = HOLogic.conj_elims symtg
  val transts = HOLogic.conj_elims transtg
  fun equivp alpha =
    let
      val equivp = Const (@{const_name equivp}, fastype_of alpha --> @{typ bool})
      val goal = @{term Trueprop} $ (equivp $ alpha)
      fun tac _ = equivp_tac reflts symts transts 1
    in
      Goal.prove ctxt [] [] goal tac
    end
in
  map equivp alphas
end
*}

(*
Tests:
prove alpha1_reflp_aux: {* fst (build_alpha_refl_gl [@{term alpha_rtrm1}, @{term alpha_bp}] ("x","y","z")) *}
by (tactic {* reflp_tac @{thm rtrm1_bp.induct} @{thms alpha1_inj} 1 *})

prove alpha1_symp_aux: {* (fst o snd) (build_alpha_refl_gl [@{term alpha_rtrm1}, @{term alpha_bp}] ("x","y","z")) *}
by (tactic {* symp_tac @{thm alpha_rtrm1_alpha_bp.induct} @{thms alpha1_inj} @{thms alpha1_eqvt} 1 *})

prove alpha1_transp_aux: {* (snd o snd) (build_alpha_refl_gl [@{term alpha_rtrm1}, @{term alpha_bp}] ("x","y","z")) *}
by (tactic {* transp_tac @{context} @{thm alpha_rtrm1_alpha_bp.induct} @{thms alpha1_inj} @{thms rtrm1.inject bp.inject} @{thms rtrm1.distinct bp.distinct} @{thms alpha_rtrm1.cases alpha_bp.cases} @{thms alpha1_eqvt} 1 *})

lemma alpha1_equivp:
  "equivp alpha_rtrm1"
  "equivp alpha_bp"
apply (tactic {*
  (simp_tac (HOL_ss addsimps @{thms equivp_reflp_symp_transp reflp_def symp_def})
  THEN' rtac @{thm conjI} THEN' rtac @{thm allI} THEN'
  resolve_tac (HOLogic.conj_elims @{thm alpha1_reflp_aux})
  THEN' rtac @{thm conjI} THEN' rtac @{thm allI} THEN' rtac @{thm allI} THEN'
  resolve_tac (HOLogic.conj_elims @{thm alpha1_symp_aux}) THEN' rtac @{thm transp_aux}
  THEN' resolve_tac (HOLogic.conj_elims @{thm alpha1_transp_aux})
)
1 *})
done*)

ML {*
fun dtyp_no_of_typ _ (TFree (n, _)) = error "dtyp_no_of_typ: Illegal free"
  | dtyp_no_of_typ _ (TVar _) = error "dtyp_no_of_typ: Illegal schematic"
  | dtyp_no_of_typ dts (Type (tname, Ts)) =
      case try (find_index (curry op = tname o fst)) dts of
        NONE => error "dtyp_no_of_typ: Illegal recursion"
      | SOME i => i
*}

lemma not_in_union: "c \<notin> a \<union> b \<equiv> (c \<notin> a \<and> c \<notin> b)"
by auto

ML {*
fun supports_tac perm =
  simp_tac (HOL_ss addsimps @{thms supports_def not_in_union} @ perm) THEN_ALL_NEW (
    REPEAT o rtac allI THEN' REPEAT o rtac impI THEN' split_conjs THEN'
    asm_full_simp_tac (HOL_ss addsimps @{thms fresh_def[symmetric]
      swap_fresh_fresh fresh_atom swap_at_base_simps(3) swap_atom_image_fresh}))
*}

ML {*
fun mk_supp ty x =
  Const (@{const_name supp}, ty --> @{typ "atom set"}) $ x
*}

ML {*
fun mk_supports_eq thy cnstr =
let
  val (tys, ty) = (strip_type o fastype_of) cnstr
  val names = Datatype_Prop.make_tnames tys
  val frees = map Free (names ~~ tys)
  val rhs = list_comb (cnstr, frees)

  fun mk_supp_arg (x, ty) =
    if is_atom thy ty then mk_supp @{typ atom} (mk_atom ty $ x) else
    if is_atom_set thy ty then mk_supp @{typ "atom set"} (mk_atoms x)
    else mk_supp ty x
  val lhss = map mk_supp_arg (frees ~~ tys)
  val supports = Const(@{const_name "supports"}, @{typ "atom set"} --> ty --> @{typ bool})
  val eq = HOLogic.mk_Trueprop (supports $ mk_union lhss $ rhs)
in
  (names, eq)
end
*}

ML {*
fun prove_supports ctxt perms cnst =
let
  val (names, eq) = mk_supports_eq (ProofContext.theory_of ctxt) cnst
in
  Goal.prove ctxt names [] eq (fn _ => supports_tac perms 1)
end
*}

ML {*
fun mk_fs tys =
let
  val names = Datatype_Prop.make_tnames tys
  val frees = map Free (names ~~ tys)
  val supps = map2 mk_supp tys frees
  val fin_supps = map (fn x => @{term "finite :: atom set \<Rightarrow> bool"} $ x) supps
in
  (names, HOLogic.mk_Trueprop (mk_conjl fin_supps))
end
*}

ML {*
fun fs_tac induct supports = ind_tac induct THEN_ALL_NEW (
  rtac @{thm supports_finite} THEN' resolve_tac supports) THEN_ALL_NEW
  asm_full_simp_tac (HOL_ss addsimps @{thms supp_atom supp_atom_image finite_insert finite.emptyI finite_Un})
*}

ML {*
fun prove_fs ctxt induct supports tys =
let
  val (names, eq) = mk_fs tys
in
  Goal.prove ctxt names [] eq (fn _ => fs_tac induct supports 1)
end
*}

end
