(*  Title:      nominal_dt_rawperm.ML
    Author:     Cezary Kaliszyk
    Author:     Christian Urban

  Definitions of the raw bn, fv and fv_bn
  functions
*)

signature NOMINAL_DT_RAWFUNS =
sig
  (* binding modes and binding clauses *)

  datatype bmode = Lst | Res | Set

  datatype bclause = BC of bmode * (term option * int) list * int list

  val setify: Proof.context -> term -> term
  val listify: Proof.context -> term -> term

  val define_raw_fvs: Datatype_Aux.descr -> (string * sort) list ->
    (term * 'a * 'b) list -> (term * int * (int * term option) list list) list ->
      bclause list list list -> Proof.context -> term list * term list * thm list * local_theory
end


structure Nominal_Dt_RawFuns: NOMINAL_DT_RAWFUNS =
struct

datatype bmode = Lst | Res | Set
datatype bclause = BC of bmode * (term option * int) list * int list

(* atom types *)
fun is_atom ctxt ty =
  Sign.of_sort (ProofContext.theory_of ctxt) (ty, @{sort at_base})

fun is_atom_set ctxt (Type ("fun", [t, @{typ bool}])) = is_atom ctxt t
  | is_atom_set _ _ = false;

fun is_atom_fset ctxt (Type (@{type_name "fset"}, [t])) = is_atom ctxt t
  | is_atom_fset _ _ = false;

fun is_atom_list ctxt (Type (@{type_name "list"}, [t])) = is_atom ctxt t
  | is_atom_list _ _ = false


(* functions for producing sets, fsets and lists *)
fun mk_atom_set t =
let
  val ty = fastype_of t;
  val atom_ty = HOLogic.dest_setT ty --> @{typ atom};
  val img_ty = atom_ty --> ty --> @{typ "atom set"};
in
  (Const (@{const_name image}, img_ty) $ mk_atom_ty atom_ty t)
end;

fun mk_atom_fset t =
let
  val ty = fastype_of t;
  val atom_ty = dest_fsetT ty --> @{typ atom};
  val fmap_ty = atom_ty --> ty --> @{typ "atom fset"};
  val fset_to_set = @{term "fset_to_set :: atom fset => atom set"}
in
  fset_to_set $ (Const (@{const_name fmap}, fmap_ty) $ Const (@{const_name atom}, atom_ty) $ t)
end;

fun mk_atom_list t =
let
  val ty = fastype_of t;
  val atom_ty = dest_listT ty --> @{typ atom};
  val map_ty = atom_ty --> ty --> @{typ "atom list"};
in
  (Const (@{const_name map}, map_ty) $ mk_atom_ty atom_ty t)
end;


(* functions that coerces atoms, sets and fsets into atom sets ? *)
fun setify ctxt t =
let
  val ty = fastype_of t;
in
  if is_atom ctxt ty
    then  HOLogic.mk_set @{typ atom} [mk_atom t]
  else if is_atom_set ctxt ty
    then mk_atom_set t
  else if is_atom_fset ctxt ty
    then mk_atom_fset t
  else raise TERM ("setify", [t])
end

(* functions that coerces atoms and lists into atom lists ? *)
fun listify ctxt t =
let
  val ty = fastype_of t;
in
  if is_atom ctxt ty
    then HOLogic.mk_list @{typ atom} [mk_atom t]
  else if is_atom_list ctxt ty
    then mk_atom_set t
  else raise TERM ("listify", [t])
end

(* coerces a list into a set *)
fun to_set x =
  if fastype_of x = @{typ "atom list"}
  then @{term "set::atom list => atom set"} $ x
  else x



fun make_body fv_map args i = 
let
  val arg = nth args i
  val ty = fastype_of arg
in
  case (AList.lookup (op=) fv_map ty) of
    NONE => mk_supp arg
  | SOME fv => fv $ arg
end  

fun make_binder lthy fv_bn_map args (bn_option, i) = 
let
  val arg = nth args i
in
  case bn_option of
    NONE => (setify lthy arg, @{term "{}::atom set"})
  | SOME bn => (to_set (bn $ arg), the (AList.lookup (op=) fv_bn_map bn) $ arg)
end  

fun make_fv_rhs lthy fv_map fv_bn_map args (BC (_, binders, bodies)) =
let
  val t1 = map (make_body fv_map args) bodies
  val (t2, t3) = split_list (map (make_binder lthy fv_bn_map args) binders)
in 
  fold_union (mk_diff (fold_union t1, fold_union t2)::t3)
end

fun make_fv_eq lthy fv_map fv_bn_map (constr, ty, arg_tys) bclauses = 
let
  val arg_names = Datatype_Prop.make_tnames arg_tys
  val args = map Free (arg_names ~~ arg_tys)
  val fv = the (AList.lookup (op=) fv_map ty)
  val lhs = fv $ list_comb (constr, args)
  val rhs_trms = map (make_fv_rhs lthy fv_map fv_bn_map args) bclauses
  val rhs = fold_union rhs_trms
in
  HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs, rhs))
end


fun make_bn_body fv_map fv_bn_map bn_args args i = 
let
  val arg = nth args i
  val ty = fastype_of arg
in
  case AList.lookup (op=) bn_args i of
    NONE => (case (AList.lookup (op=) fv_map ty) of
              NONE => mk_supp arg
            | SOME fv => fv $ arg)
  | SOME (NONE) => @{term "{}::atom set"}
  | SOME (SOME bn) => the (AList.lookup (op=) fv_bn_map bn) $ arg
end  

fun make_fv_bn_rhs lthy fv_map fv_bn_map bn_args args bclause =
  case bclause of
    BC (_, [], bodies) => fold_union (map (make_bn_body fv_map fv_bn_map bn_args args) bodies)
  | BC (_, binders, bodies) => 
      let
        val t1 = map (make_body fv_map args) bodies
        val (t2, t3) = split_list (map (make_binder lthy fv_bn_map args) binders)
      in 
        fold_union (mk_diff (fold_union t1, fold_union t2)::t3)
      end

fun make_fv_bn_eq lthy bn_trm fv_map fv_bn_map (bn_args, (constr, ty, arg_tys)) bclauses =
let
  val arg_names = Datatype_Prop.make_tnames arg_tys
  val args = map Free (arg_names ~~ arg_tys)
  val fv_bn = the (AList.lookup (op=) fv_bn_map bn_trm)
  val lhs = fv_bn $ list_comb (constr, args)
  val rhs_trms = map (make_fv_bn_rhs lthy fv_map fv_bn_map bn_args args) bclauses
  val rhs = fold_union rhs_trms
in
  HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs, rhs))
end

fun make_fv_bn_eqs lthy fv_map fv_bn_map constrs_info bclausesss (bn_trm, bn_n, bn_argss) = 
let
  val nth_constrs_info = nth constrs_info bn_n
  val nth_bclausess = nth bclausesss bn_n
in
  map2 (make_fv_bn_eq lthy bn_trm fv_map fv_bn_map) (bn_argss ~~ nth_constrs_info) nth_bclausess
end

fun define_raw_fvs dt_descr sorts bn_funs bn_funs2 bclausesss lthy =
let

  val fv_names = prefix_dt_names dt_descr sorts "fv_"
  val fv_arg_tys = map (fn (i, _) => nth_dtyp dt_descr sorts i) dt_descr;
  val fv_tys = map (fn ty => ty --> @{typ "atom set"}) fv_arg_tys;
  val fv_frees = map Free (fv_names ~~ fv_tys);
  val fv_map = fv_arg_tys ~~ fv_frees

  val (bns, bn_tys) = split_list (map (fn (bn, i, _) => (bn, i)) bn_funs)
  val (bns2, bn_tys2) = split_list (map (fn (bn, i, _) => (bn, i)) bn_funs2)
  val bn_args2 = map (fn (_, _, arg) => arg) bn_funs2
  val fv_bn_names2 = map (fn bn => "fv_" ^ (fst (dest_Free bn))) bns2
  val fv_bn_arg_tys2 = map (fn i => nth_dtyp dt_descr sorts i) bn_tys2
  val fv_bn_tys2 = map (fn ty => ty --> @{typ "atom set"}) fv_bn_arg_tys2
  val fv_bn_frees2 = map Free (fv_bn_names2 ~~ fv_bn_tys2)
  val fv_bn_map2 = bns ~~ fv_bn_frees2
  val fv_bn_map3 = bns2 ~~ fv_bn_frees2
 
  val constrs_info = all_dtyp_constrs_types dt_descr sorts

  val fv_eqs2 = map2 (map2 (make_fv_eq lthy fv_map fv_bn_map2)) constrs_info bclausesss 
  val fv_bn_eqs2 = map (make_fv_bn_eqs lthy fv_map fv_bn_map3 constrs_info bclausesss) bn_funs2
  
  val all_fv_names = map (fn s => (Binding.name s, NONE, NoSyn)) (fv_names @ fv_bn_names2)
  val all_fv_eqs = map (pair Attrib.empty_binding) (flat fv_eqs2 @ flat fv_bn_eqs2)

  fun pat_completeness_auto lthy =
    Pat_Completeness.pat_completeness_tac lthy 1
      THEN auto_tac (clasimpset_of lthy)

  fun prove_termination lthy =
    Function.prove_termination NONE
      (Lexicographic_Order.lexicographic_order_tac true lthy) lthy

  val (_, lthy') = Function.add_function all_fv_names all_fv_eqs
    Function_Common.default_config pat_completeness_auto lthy

  val (info, lthy'') = prove_termination (Local_Theory.restore lthy')

  val {fs, simps, ...} = info;

  val morphism = ProofContext.export_morphism lthy'' lthy
  val fs_exp = map (Morphism.term morphism) fs

  val (fv_frees_exp, fv_bns_exp) = chop (length fv_frees) fs_exp
  val simps_exp = Morphism.fact morphism (the simps)
in
  (fv_frees_exp, fv_bns_exp, simps_exp, lthy'')
end

end (* structure *)

