(*  Title:      nominal_library.ML
    Author:     Christian Urban

  Basic functions for nominal.
*)

signature NOMINAL_LIBRARY =
sig
  val last2: 'a list -> 'a * 'a
  val order: ('a * 'a -> bool) -> 'a list -> ('a * 'b) list -> 'b list
  val remove_dups: ('a * 'a -> bool) -> 'a list -> 'a list
  val partitions: 'a list -> int list -> 'a list list
  val split_filter: ('a -> bool) -> 'a list -> 'a list * 'a list

  val is_true: term -> bool
 
  val dest_listT: typ -> typ
  val dest_fsetT: typ -> typ

  val mk_id: term -> term
  val mk_all: (string * typ) -> term -> term

  val sum_case_const: typ -> typ -> typ -> term
  val mk_sum_case: term -> term -> term
 
  val mk_minus: term -> term
  val mk_plus: term -> term -> term

  val perm_ty: typ -> typ 
  val mk_perm_ty: typ -> term -> term -> term
  val mk_perm: term -> term -> term
  val dest_perm: term -> term * term

  val mk_sort_of: term -> term
  val atom_ty: typ -> typ
  val atom_const: typ -> term
  val mk_atom_ty: typ -> term -> term
  val mk_atom: term -> term

  val mk_atom_set_ty: typ -> term -> term
  val mk_atom_set: term -> term
  val mk_atom_fset_ty: typ -> term -> term
  val mk_atom_fset: term -> term
  val mk_atom_list_ty: typ -> term -> term
  val mk_atom_list: term -> term

  val is_atom: Proof.context -> typ -> bool
  val is_atom_set: Proof.context -> typ -> bool
  val is_atom_fset: Proof.context -> typ -> bool
  val is_atom_list: Proof.context -> typ -> bool

  val to_set_ty: typ -> term -> term
  val to_set: term -> term
  
  val atomify_ty: Proof.context -> typ -> term -> term
  val atomify: Proof.context -> term -> term
  val setify_ty: Proof.context -> typ -> term -> term
  val setify: Proof.context -> term -> term
  val listify_ty: Proof.context -> typ -> term -> term
  val listify: Proof.context -> term -> term

  val fresh_star_ty: typ -> typ
  val fresh_star_const: typ -> term
  val mk_fresh_star_ty: typ -> term -> term -> term
  val mk_fresh_star: term -> term -> term

  val supp_ty: typ -> typ
  val supp_const: typ -> term
  val mk_supp_ty: typ -> term -> term
  val mk_supp: term -> term

  val supp_rel_ty: typ -> typ
  val supp_rel_const: typ -> term
  val mk_supp_rel_ty: typ -> term -> term -> term
  val mk_supp_rel: term -> term -> term		       

  val supports_const: typ -> term
  val mk_supports_ty: typ -> term -> term -> term
  val mk_supports: term -> term -> term

  val finite_const: typ -> term
  val mk_finite_ty: typ -> term -> term
  val mk_finite: term -> term

  val mk_equiv: thm -> thm
  val safe_mk_equiv: thm -> thm

  val mk_diff: term * term -> term
  val mk_append: term * term -> term
  val mk_union: term * term -> term
  val fold_union: term list -> term
  val fold_append: term list -> term
  val mk_conj: term * term -> term
  val fold_conj: term list -> term

  (* fresh arguments for a term *)
  val fresh_args: Proof.context -> term -> term list

  (* datatype operations *)
  type cns_info = (term * typ * typ list * bool list) list

  val all_dtyps: Datatype_Aux.descr -> (string * sort) list -> typ list
  val nth_dtyp: Datatype_Aux.descr -> (string * sort) list -> int -> typ
  val all_dtyp_constrs_types: Datatype_Aux.descr -> (string * sort) list -> cns_info list
  val nth_dtyp_constrs_types: Datatype_Aux.descr -> (string * sort) list -> int -> cns_info
  val prefix_dt_names: Datatype_Aux.descr -> (string * sort) list -> string -> string list

  (* tactics for function package *)
  val pat_completeness_simp: thm list -> Proof.context -> tactic
  val prove_termination: thm list -> Proof.context -> Function.info * local_theory

  (* transformations of premises in inductions *)
  val transform_prem1: Proof.context -> string list -> thm -> thm
  val transform_prem2: Proof.context -> string list -> thm -> thm

  (* transformation into the object logic *)
  val atomize: thm -> thm

end


structure Nominal_Library: NOMINAL_LIBRARY =
struct

(* orders an AList according to keys *)
fun order eq keys list = 
  map (the o AList.lookup eq list) keys

(* remove duplicates *)
fun remove_dups eq [] = []
  | remove_dups eq (x :: xs) = 
      if member eq xs x 
      then remove_dups eq xs 
      else x :: remove_dups eq xs

fun last2 [] = raise Empty
  | last2 [_] = raise Empty
  | last2 [x, y] = (x, y)
  | last2 (_ :: xs) = last2 xs

(* partitions a set according to the numbers in the int list *)
fun partitions [] [] = []
  | partitions xs (i :: js) = 
      let
        val (head, tail) = chop i xs
      in
        head :: partitions tail js
      end

fun split_filter f [] = ([], [])
  | split_filter f (x :: xs) =
      let 
        val (r, l) = split_filter f xs 
      in 
        if f x 
        then (x :: r, l) 
        else (r, x :: l) 
      end


fun is_true @{term "Trueprop True"} = true
  | is_true _ = false 

fun dest_listT (Type (@{type_name list}, [T])) = T
  | dest_listT T = raise TYPE ("dest_listT: list type expected", [T], [])

fun dest_fsetT (Type (@{type_name fset}, [T])) = T
  | dest_fsetT T = raise TYPE ("dest_fsetT: fset type expected", [T], []);



fun mk_id trm =
  let 
    val ty = fastype_of trm
  in
    Const (@{const_name id}, ty --> ty) $ trm
  end

fun mk_all (a, T) t =  Term.all T $ Abs (a, T, t)

fun sum_case_const ty1 ty2 ty3 = 
  Const (@{const_name sum_case}, [ty1 --> ty3, ty2 --> ty3, Type (@{type_name sum}, [ty1, ty2])] ---> ty3)

fun mk_sum_case trm1 trm2 =
  let
    val ([ty1], ty3) = strip_type (fastype_of trm1)
    val ty2 = domain_type (fastype_of trm2)
  in
    sum_case_const ty1 ty2 ty3 $ trm1 $ trm2
  end 



fun mk_minus p = @{term "uminus::perm => perm"} $ p

fun mk_plus p q = @{term "plus::perm => perm => perm"} $ p $ q

fun perm_ty ty = @{typ "perm"} --> ty --> ty
fun mk_perm_ty ty p trm = Const (@{const_name "permute"}, perm_ty ty) $ p $ trm
fun mk_perm p trm = mk_perm_ty (fastype_of trm) p trm

fun dest_perm (Const (@{const_name "permute"}, _) $ p $ t) = (p, t)
  | dest_perm t = raise TERM ("dest_perm", [t]);

fun mk_sort_of t = @{term "sort_of"} $ t;

fun atom_ty ty = ty --> @{typ "atom"};
fun atom_const ty = Const (@{const_name "atom"}, atom_ty ty)
fun mk_atom_ty ty t = atom_const ty $ t;
fun mk_atom t = mk_atom_ty (fastype_of t) t;

fun mk_atom_set_ty ty t =
  let
    val atom_ty = HOLogic.dest_setT ty 
    val img_ty = (atom_ty --> @{typ atom}) --> ty --> @{typ "atom set"};
  in
    Const (@{const_name image}, img_ty) $ atom_const atom_ty $ t
  end

fun mk_atom_fset_ty ty t =
  let
    val atom_ty = dest_fsetT ty
    val fmap_ty = (atom_ty --> @{typ atom}) --> ty --> @{typ "atom fset"};
  in
    Const (@{const_name map_fset}, fmap_ty) $ atom_const atom_ty $ t
  end

fun mk_atom_list_ty ty t =
  let
    val atom_ty = dest_listT ty
    val map_ty = (atom_ty --> @{typ atom}) --> ty --> @{typ "atom list"}
  in
    Const (@{const_name map}, map_ty) $ atom_const atom_ty $ t
  end

fun mk_atom_set t = mk_atom_set_ty (fastype_of t) t
fun mk_atom_fset t = mk_atom_fset_ty (fastype_of t) t
fun mk_atom_list t = mk_atom_list_ty (fastype_of t) t

(* coerces a list into a set *)
  
fun to_set_ty ty t =
  case ty of
    @{typ "atom list"} => @{term "set :: atom list => atom set"} $ t
  | @{typ "atom fset"} => @{term "fset :: atom fset => atom set"} $ t
  | _ => t

fun to_set t = to_set_ty (fastype_of t) t


(* testing for concrete atom types *)
fun is_atom ctxt ty =
  Sign.of_sort (ProofContext.theory_of ctxt) (ty, @{sort at_base})

fun is_atom_set ctxt (Type ("fun", [ty, @{typ bool}])) = is_atom ctxt ty
  | is_atom_set _ _ = false;

fun is_atom_fset ctxt (Type (@{type_name "fset"}, [ty])) = is_atom ctxt ty
  | is_atom_fset _ _ = false;

fun is_atom_list ctxt (Type (@{type_name "list"}, [ty])) = is_atom ctxt ty
  | is_atom_list _ _ = false


(* functions that coerce singletons, sets, fsets and lists of concrete 
   atoms into general atoms sets / lists *)
fun atomify_ty ctxt ty t =
  if is_atom ctxt ty
    then  mk_atom_ty ty t
  else if is_atom_set ctxt ty
    then mk_atom_set_ty ty t
  else if is_atom_fset ctxt ty
    then mk_atom_fset_ty ty t
  else if is_atom_list ctxt ty
    then mk_atom_list_ty ty t
  else raise TERM ("atomify", [t])

fun setify_ty ctxt ty t =
  if is_atom ctxt ty
    then  HOLogic.mk_set @{typ atom} [mk_atom_ty ty t]
  else if is_atom_set ctxt ty
    then mk_atom_set_ty ty t
  else if is_atom_fset ctxt ty
    then @{term "fset :: atom fset => atom set"} $ mk_atom_fset_ty ty t
  else if is_atom_list ctxt ty
    then @{term "set :: atom list => atom set"} $ mk_atom_list_ty ty t
  else raise TERM ("setify", [t])

fun listify_ty ctxt ty t =
  if is_atom ctxt ty
    then HOLogic.mk_list @{typ atom} [mk_atom_ty ty t]
  else if is_atom_list ctxt ty
    then mk_atom_list_ty ty t
  else raise TERM ("listify", [t])

fun atomify ctxt t = atomify_ty ctxt (fastype_of t) t
fun setify ctxt t  = setify_ty ctxt (fastype_of t) t
fun listify ctxt t = listify_ty ctxt (fastype_of t) t

fun fresh_star_ty ty = [@{typ "atom set"}, ty] ---> @{typ bool}
fun fresh_star_const ty = Const (@{const_name fresh_star}, fresh_star_ty ty)
fun mk_fresh_star_ty ty t1 t2 = fresh_star_const ty $ t1 $ t2
fun mk_fresh_star t1 t2 = mk_fresh_star_ty (fastype_of t2) t1 t2

fun supp_ty ty = ty --> @{typ "atom set"};
fun supp_const ty = Const (@{const_name supp}, supp_ty ty)
fun mk_supp_ty ty t = supp_const ty $ t
fun mk_supp t = mk_supp_ty (fastype_of t) t

fun supp_rel_ty ty = ([ty, ty] ---> @{typ bool}) --> ty --> @{typ "atom set"};
fun supp_rel_const ty = Const (@{const_name supp_rel}, supp_rel_ty ty)
fun mk_supp_rel_ty ty r t = supp_rel_const ty $ r $ t
fun mk_supp_rel r t = mk_supp_rel_ty (fastype_of t) r t

fun supports_const ty = Const (@{const_name supports}, [@{typ "atom set"}, ty] ---> @{typ bool});
fun mk_supports_ty ty t1 t2 = supports_const ty $ t1 $ t2;
fun mk_supports t1 t2 = mk_supports_ty (fastype_of t2) t1 t2;

fun finite_const ty = Const (@{const_name finite}, ty --> @{typ bool})
fun mk_finite_ty ty t = finite_const ty $ t
fun mk_finite t = mk_finite_ty (fastype_of t) t


fun mk_equiv r = r RS @{thm eq_reflection};
fun safe_mk_equiv r = mk_equiv r handle Thm.THM _ => r;


(* functions that construct differences, appends and unions
   but avoid producing empty atom sets or empty atom lists *)

fun mk_diff (@{term "{}::atom set"}, _) = @{term "{}::atom set"}
  | mk_diff (t1, @{term "{}::atom set"}) = t1
  | mk_diff (@{term "set ([]::atom list)"}, _) = @{term "set ([]::atom list)"}
  | mk_diff (t1, @{term "set ([]::atom list)"}) = t1
  | mk_diff (t1, t2) = HOLogic.mk_binop @{const_name minus} (t1, t2)

fun mk_append (t1, @{term "[]::atom list"}) = t1
  | mk_append (@{term "[]::atom list"}, t2) = t2
  | mk_append (t1, t2) = HOLogic.mk_binop @{const_name "append"} (t1, t2) 

fun mk_union (t1, @{term "{}::atom set"}) = t1
  | mk_union (@{term "{}::atom set"}, t2) = t2
  | mk_union (t1, @{term "set ([]::atom list)"}) = t1
  | mk_union (@{term "set ([]::atom list)"}, t2) = t2
  | mk_union (t1, t2) = HOLogic.mk_binop @{const_name "sup"} (t1, t2)  
 
fun fold_union trms = fold_rev (curry mk_union) trms @{term "{}::atom set"}
fun fold_append trms = fold_rev (curry mk_append) trms @{term "[]::atom list"}

fun mk_conj (t1, @{term "True"}) = t1
  | mk_conj (@{term "True"}, t2) = t2
  | mk_conj (t1, t2) = HOLogic.mk_conj (t1, t2)

fun fold_conj trms = fold_rev (curry mk_conj) trms @{term "True"}


(* produces fresh arguments for a term *)

fun fresh_args ctxt f =
    f |> fastype_of
      |> binder_types
      |> map (pair "z")
      |> Variable.variant_frees ctxt [f]
      |> map Free



(** datatypes **)

(* constructor infos *)
type cns_info = (term * typ * typ list * bool list) list

(*  - term for constructor constant
    - type of the constructor
    - types of the arguments
    - flags indicating whether the argument is recursive
*)

(* returns the type of the nth datatype *)
fun all_dtyps descr sorts = 
  map (fn n => Datatype_Aux.typ_of_dtyp descr sorts (Datatype_Aux.DtRec n)) (0 upto (length descr - 1))

fun nth_dtyp descr sorts n = 
  Datatype_Aux.typ_of_dtyp descr sorts (Datatype_Aux.DtRec n);

(* returns info about constructors in a datatype *)
fun all_dtyp_constrs_info descr = 
  map (fn (_, (ty, vs, constrs)) => map (pair (ty, vs)) constrs) descr

(* returns the constants of the constructors plus the 
   corresponding type and types of arguments *)
fun all_dtyp_constrs_types descr sorts = 
  let
    fun aux ((ty_name, vs), (cname, args)) =
      let
        val vs_tys = map (Datatype_Aux.typ_of_dtyp descr sorts) vs
        val ty = Type (ty_name, vs_tys)
        val arg_tys = map (Datatype_Aux.typ_of_dtyp descr sorts) args
        val is_rec = map Datatype_Aux.is_rec_type args
      in
        (Const (cname, arg_tys ---> ty), ty, arg_tys, is_rec)
      end
  in
    map (map aux) (all_dtyp_constrs_info descr)
  end

fun nth_dtyp_constrs_types descr sorts n =
  nth (all_dtyp_constrs_types descr sorts) n


(* generates for every datatype a name str ^ dt_name 
   plus and index for multiple occurences of a string *)
fun prefix_dt_names descr sorts str = 
  let
    fun get_nth_name (i, _) = 
      Datatype_Aux.name_of_typ (nth_dtyp descr sorts i) 
  in
    Datatype_Prop.indexify_names 
      (map (prefix str o get_nth_name) descr)
  end



(** function package tactics **)

fun pat_completeness_simp simps lthy =
  let
    val simp_set = HOL_basic_ss addsimps (@{thms sum.inject sum.distinct} @ simps)
  in
    Pat_Completeness.pat_completeness_tac lthy 1
      THEN ALLGOALS (asm_full_simp_tac simp_set)
  end


fun prove_termination_tac size_simps ctxt =
  let
    val natT = @{typ nat}
    fun prod_size_const fT sT = 
      let
        val fT_fun = fT --> natT
        val sT_fun = sT --> natT
        val prodT = Type (@{type_name prod}, [fT, sT])
      in
        Const (@{const_name prod_size}, [fT_fun, sT_fun, prodT] ---> natT)
      end

    fun mk_size_measure T =
      case T of    
        (Type (@{type_name Sum_Type.sum}, [fT, sT])) =>
           SumTree.mk_sumcase fT sT natT (mk_size_measure fT) (mk_size_measure sT)
      | (Type (@{type_name Product_Type.prod}, [fT, sT])) =>
           prod_size_const fT sT $ (mk_size_measure fT) $ (mk_size_measure sT)
      | _ => HOLogic.size_const T

    fun mk_measure_trm T = 
      HOLogic.dest_setT T
      |> fst o HOLogic.dest_prodT
      |> mk_size_measure 
      |> curry (op $) (Const (@{const_name "measure"}, dummyT))
      |> Syntax.check_term ctxt
      
    val ss = HOL_ss addsimps @{thms in_measure wf_measure sum.cases add_Suc_right add.right_neutral 
      zero_less_Suc prod.size(1) mult_Suc_right} @ size_simps 
    val ss' = ss addsimprocs Nat_Numeral_Simprocs.cancel_numerals
  in
    Function_Relation.relation_tac ctxt mk_measure_trm
    THEN_ALL_NEW simp_tac ss'
  end

fun prove_termination size_simps ctxt = 
  Function.prove_termination NONE 
    (HEADGOAL (prove_termination_tac size_simps ctxt)) ctxt

(** transformations of premises (in inductive proofs) **)

(* 
 given the theorem F[t]; proves the theorem F[f t] 

  - F needs to be monotone
  - f returns either SOME for a term it fires on 
    and NONE elsewhere 
*)
fun map_term f t = 
  (case f t of
     NONE => map_term' f t 
   | x => x)
and map_term' f (t $ u) = 
    (case (map_term f t, map_term f u) of
        (NONE, NONE) => NONE
      | (SOME t'', NONE) => SOME (t'' $ u)
      | (NONE, SOME u'') => SOME (t $ u'')
      | (SOME t'', SOME u'') => SOME (t'' $ u''))
  | map_term' f (Abs (s, T, t)) = 
      (case map_term f t of
        NONE => NONE
      | SOME t'' => SOME (Abs (s, T, t'')))
  | map_term' _ _  = NONE;

fun map_thm_tac ctxt tac thm =
  let
    val monos = Inductive.get_monos ctxt
    val simps = HOL_basic_ss addsimps @{thms split_def}
  in
    EVERY [cut_facts_tac [thm] 1, etac rev_mp 1, 
      REPEAT_DETERM (FIRSTGOAL (simp_tac simps THEN' resolve_tac monos)),
      REPEAT_DETERM (rtac impI 1 THEN (atac 1 ORELSE tac))]
  end

fun map_thm ctxt f tac thm =
  let
    val opt_goal_trm = map_term f (prop_of thm)
  in
    case opt_goal_trm of
      NONE => thm
    | SOME goal =>
        Goal.prove ctxt [] [] goal (fn _ => map_thm_tac ctxt tac thm) 
  end

(*
 inductive premises can be of the form
 R ... /\ P ...; split_conj_i picks out
 the part R or P part
*)
fun split_conj1 names (Const (@{const_name "conj"}, _) $ f1 $ _) = 
  (case head_of f1 of
     Const (name, _) => if member (op =) names name then SOME f1 else NONE
   | _ => NONE)
| split_conj1 _ _ = NONE;

fun split_conj2 names (Const (@{const_name "conj"}, _) $ f1 $ f2) = 
  (case head_of f1 of
     Const (name, _) => if member (op =) names name then SOME f2 else NONE
   | _ => NONE)
| split_conj2 _ _ = NONE;

fun transform_prem1 ctxt names thm =
  map_thm ctxt (split_conj1 names) (etac conjunct1 1) thm

fun transform_prem2 ctxt names thm =
  map_thm ctxt (split_conj2 names) (etac conjunct2 1) thm


(* transformes a theorem into one of the object logic *)
val atomize = Conv.fconv_rule Object_Logic.atomize o forall_intr_vars

end (* structure *)

open Nominal_Library;