Nominal/nominal_dt_rawfuns.ML
changeset 2957 01ff621599bc
parent 2888 eda5aeb056a6
child 3045 d0ad264f8c4f
--- a/Nominal/nominal_dt_rawfuns.ML	Wed Jul 06 15:59:11 2011 +0200
+++ b/Nominal/nominal_dt_rawfuns.ML	Thu Jul 07 16:16:42 2011 +0200
@@ -7,59 +7,30 @@
 
 signature NOMINAL_DT_RAWFUNS =
 sig
-  (* info of raw datatypes *)
-  type dt_info = (string list * binding * mixfix * ((binding * typ list * mixfix) list)) list
-
-  (* info of raw binding functions *)
-  type bn_info = term * int * (int * term option) list list
-
-  (* binding modes and binding clauses *)
-  datatype bmode = Lst | Res | Set
-  datatype bclause = BC of bmode * (term option * int) list * int list
-
   val get_all_binders: bclause list -> (term option * int) list
   val is_recursive_binder: bclause -> bool
 
-  val define_raw_bns: string list -> dt_info -> (binding * typ option * mixfix) list ->
-    (Attrib.binding * term) list -> thm list -> thm list -> local_theory ->
+  val define_raw_bns: raw_dt_info -> (binding * typ option * mixfix) list -> 
+    (Attrib.binding * term) list -> local_theory ->
     (term list * thm list * bn_info list * thm list * local_theory) 
 
-  val define_raw_fvs: string list -> typ list -> cns_info list -> bn_info list -> bclause list list list -> 
-    thm list -> thm list -> Proof.context -> term list * term list * thm list * thm list * local_theory
+  val define_raw_fvs: raw_dt_info -> bn_info list -> bclause list list list -> 
+    Proof.context -> term list * term list * thm list * thm list * local_theory
 
-  val define_raw_bn_perms: typ list -> bn_info list -> cns_info list -> thm list -> thm list -> 
-    local_theory -> (term list * thm list * local_theory)
+  val define_raw_bn_perms: raw_dt_info -> bn_info list -> local_theory -> 
+    (term list * thm list * local_theory)
+ 
+  val define_raw_perms: raw_dt_info -> local_theory -> (term list * thm list * thm list) * local_theory
  
   val raw_prove_eqvt: term list -> thm list -> thm list -> Proof.context -> thm list
 
-  val define_raw_perms: string list -> typ list -> (string * sort) list -> term list -> thm -> 
-    local_theory -> (term list * thm list * thm list) * local_theory
 end
 
 
 structure Nominal_Dt_RawFuns: NOMINAL_DT_RAWFUNS =
 struct
 
-open Nominal_Permeq
-
-(* string list      - type variables of a datatype
-   binding          - name of the datatype
-   mixfix           - its mixfix
-   (binding * typ list * mixfix) list  - datatype constructors of the type
-*)  
-type dt_info = (string list * binding * mixfix * ((binding * typ list * mixfix) list)) list
-
-
-(* term              - is constant of the bn-function 
-   int               - is datatype number over which the bn-function is defined
-   int * term option - is number of the corresponding argument with possibly
-                       recursive call with bn-function term 
-*)  
-type bn_info = term * int * (int * term option) list list
-
-
-datatype bmode = Lst | Res | Set
-datatype bclause = BC of bmode * (term option * int) list * int list
+open Nominal_Permeq 
 
 fun get_all_binders bclauses = 
   bclauses
@@ -136,22 +107,25 @@
   |> order (op=) bn_funs    (* ordered according to bn_functions *)
 end
 
-fun define_raw_bns dt_names dts raw_bn_funs raw_bn_eqs constr_thms size_thms lthy =
+fun define_raw_bns raw_dt_info raw_bn_funs raw_bn_eqs lthy =
   if null raw_bn_funs 
   then ([], [], [], [], lthy)
   else 
     let
+      val RawDtInfo 
+        {raw_dt_names, raw_dts, raw_inject_thms, raw_distinct_thms, raw_size_thms, ...} = raw_dt_info 
+
       val (_, lthy1) = Function.add_function raw_bn_funs raw_bn_eqs
-        Function_Common.default_config (pat_completeness_simp constr_thms) lthy
+        Function_Common.default_config (pat_completeness_simp (raw_inject_thms @ raw_distinct_thms)) lthy
 
-      val (info, lthy2) = prove_termination_fun size_thms (Local_Theory.restore lthy1)
+      val (info, lthy2) = prove_termination_fun raw_size_thms (Local_Theory.restore lthy1)
       val {fs, simps, inducts, ...} = info
 
       val raw_bn_induct = (the inducts)
       val raw_bn_eqs = the simps
 
       val raw_bn_info = 
-        prep_bn_info lthy dt_names dts fs (map prop_of raw_bn_eqs)
+        prep_bn_info lthy raw_dt_names raw_dts fs (map prop_of raw_bn_eqs)
     in
       (fs, raw_bn_eqs, raw_bn_info, raw_bn_induct, lthy2)
     end
@@ -255,9 +229,13 @@
     map2 (mk_fv_bn_eq lthy bn_trm fv_map fv_bn_map) (bn_argss ~~ nth_constrs_info) nth_bclausess
   end
 
-fun define_raw_fvs raw_full_ty_names raw_tys cns_info bn_info bclausesss constr_thms size_simps lthy =
+fun define_raw_fvs raw_dt_info bn_info bclausesss lthy =
   let
-    val fv_names = map (prefix "fv_" o Long_Name.base_name) raw_full_ty_names
+    val RawDtInfo 
+      {raw_dt_names, raw_tys, raw_cns_info, raw_inject_thms, raw_distinct_thms, raw_size_thms, ...} = 
+        raw_dt_info
+
+    val fv_names = map (prefix "fv_" o Long_Name.base_name) raw_dt_names
     val fv_tys = map (fn ty => ty --> @{typ "atom set"}) raw_tys
     val fv_frees = map Free (fv_names ~~ fv_tys);
     val fv_map = raw_tys ~~ fv_frees
@@ -270,16 +248,16 @@
     val fv_bn_frees = map Free (fv_bn_names ~~ fv_bn_tys)
     val fv_bn_map = bns ~~ fv_bn_frees
 
-    val fv_eqs = map2 (map2 (mk_fv_eq lthy fv_map fv_bn_map)) cns_info bclausesss 
-    val fv_bn_eqs = map (mk_fv_bn_eqs lthy fv_map fv_bn_map cns_info bclausesss) bn_info
+    val fv_eqs = map2 (map2 (mk_fv_eq lthy fv_map fv_bn_map)) raw_cns_info bclausesss 
+    val fv_bn_eqs = map (mk_fv_bn_eqs lthy fv_map fv_bn_map raw_cns_info bclausesss) bn_info
   
     val all_fun_names = map (fn s => (Binding.name s, NONE, NoSyn)) (fv_names @ fv_bn_names)
     val all_fun_eqs = map (pair Attrib.empty_binding) (flat fv_eqs @ flat fv_bn_eqs)
 
     val (_, lthy') = Function.add_function all_fun_names all_fun_eqs
-      Function_Common.default_config (pat_completeness_simp constr_thms) lthy
+      Function_Common.default_config (pat_completeness_simp (raw_inject_thms @ raw_distinct_thms)) lthy
 
-    val (info, lthy'') = prove_termination_fun size_simps (Local_Theory.restore lthy')
+    val (info, lthy'') = prove_termination_fun raw_size_thms (Local_Theory.restore lthy')
  
     val {fs, simps, inducts, ...} = info; 
 
@@ -325,11 +303,14 @@
     map2 (mk_perm_bn_eq lthy bn_trm perm_bn_map) bn_argss nth_cns_info
   end
 
-fun define_raw_bn_perms raw_tys bn_info cns_info cns_thms size_thms lthy =
+fun define_raw_bn_perms raw_dt_info bn_info lthy =
   if null bn_info
   then ([], [], lthy)
   else
     let
+      val RawDtInfo
+        {raw_tys, raw_cns_info, raw_inject_thms, raw_distinct_thms, raw_size_thms, ...} = raw_dt_info
+
       val (bns, bn_tys) = split_list (map (fn (bn, i, _) => (bn, i)) bn_info)
       val bn_names = map (fn bn => Long_Name.base_name (fst (dest_Const bn))) bns
       val perm_bn_names = map (prefix "permute_") bn_names
@@ -338,7 +319,7 @@
       val perm_bn_frees = map Free (perm_bn_names ~~ perm_bn_tys)
       val perm_bn_map = bns ~~ perm_bn_frees
 
-      val perm_bn_eqs = map (mk_perm_bn_eqs lthy perm_bn_map cns_info) bn_info
+      val perm_bn_eqs = map (mk_perm_bn_eqs lthy perm_bn_map raw_cns_info) bn_info
 
       val all_fun_names = map (fn s => (Binding.name s, NONE, NoSyn)) perm_bn_names
       val all_fun_eqs = map (pair Attrib.empty_binding) (flat perm_bn_eqs)
@@ -346,9 +327,10 @@
       val prod_simps = @{thms prod.inject HOL.simp_thms}
 
       val (_, lthy') = Function.add_function all_fun_names all_fun_eqs
-        Function_Common.default_config (pat_completeness_simp (prod_simps @ cns_thms)) lthy
+        Function_Common.default_config 
+          (pat_completeness_simp (prod_simps @ raw_inject_thms @ raw_distinct_thms)) lthy
     
-      val (info, lthy'') = prove_termination_fun size_thms (Local_Theory.restore lthy')
+      val (info, lthy'') = prove_termination_fun raw_size_thms (Local_Theory.restore lthy')
 
       val {fs, simps, ...} = info;
 
@@ -358,57 +340,6 @@
       (fs, simps_exp, lthy'')
     end
 
-
-(** equivarance proofs **)
-
-val eqvt_apply_sym = @{thm eqvt_apply[symmetric]}
-
-fun subproof_tac const_names simps = 
-  SUBPROOF (fn {prems, context, ...} => 
-    HEADGOAL 
-      (simp_tac (HOL_basic_ss addsimps simps)
-       THEN' eqvt_tac context (eqvt_relaxed_config addexcls const_names)
-       THEN' simp_tac (HOL_basic_ss addsimps (prems @ [eqvt_apply_sym]))))
-
-fun prove_eqvt_tac insts ind_thms const_names simps ctxt = 
-  HEADGOAL
-    (Object_Logic.full_atomize_tac
-     THEN' (DETERM o (InductTacs.induct_rules_tac ctxt insts ind_thms))  
-     THEN_ALL_NEW  subproof_tac const_names simps ctxt)
-
-fun mk_eqvt_goal pi const arg =
-  let
-    val lhs = mk_perm pi (const $ arg)
-    val rhs = const $ (mk_perm pi arg)  
-  in
-    HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs, rhs))
-  end
-
-
-fun raw_prove_eqvt consts ind_thms simps ctxt =
-  if null consts then []
-  else
-    let 
-      val ([p], ctxt') = Variable.variant_fixes ["p"] ctxt
-      val p = Free (p, @{typ perm})
-      val arg_tys = 
-        consts
-        |> map fastype_of
-        |> map domain_type 
-      val (arg_names, ctxt'') = 
-        Variable.variant_fixes (Datatype_Prop.make_tnames arg_tys) ctxt'
-      val args = map Free (arg_names ~~ arg_tys)
-      val goals = map2 (mk_eqvt_goal p) consts args
-      val insts = map (single o SOME) arg_names
-      val const_names = map (fst o dest_Const) consts      
-    in
-      Goal.prove_multi ctxt'' [] [] goals (fn {context, ...} => 
-        prove_eqvt_tac insts ind_thms const_names simps context)
-      |> ProofContext.export ctxt'' ctxt
-    end
-
-
-
 (*** raw permutation functions ***)
 
 (** proves the two pt-type class properties **)
@@ -483,17 +414,20 @@
   end
 
 
-fun define_raw_perms full_ty_names tys tvs constrs induct_thm lthy =
+fun define_raw_perms raw_dt_info lthy =
   let
-    val perm_fn_names = full_ty_names
+    val RawDtInfo 
+      {raw_dt_names, raw_tys, raw_ty_args, raw_all_cns, raw_induct_thm, ...} = raw_dt_info
+
+    val perm_fn_names = raw_dt_names
       |> map Long_Name.base_name
       |> map (prefix "permute_")
 
-    val perm_fn_types = map perm_ty tys
+    val perm_fn_types = map perm_ty raw_tys
     val perm_fn_frees = map Free (perm_fn_names ~~ perm_fn_types)
     val perm_fn_binds = map (fn s => (Binding.name s, NONE, NoSyn)) perm_fn_names
 
-    val perm_eqs = map (mk_perm_eq (tys ~~ perm_fn_frees)) constrs
+    val perm_eqs = map (mk_perm_eq (raw_tys ~~ perm_fn_frees)) (flat raw_all_cns)
 
     fun tac _ (_, _, simps) =
       Class.intro_classes_tac [] THEN ALLGOALS (resolve_tac simps)
@@ -506,11 +440,11 @@
     val ((perm_funs, perm_eq_thms), lthy') =
       lthy
       |> Local_Theory.exit_global
-      |> Class.instantiation (full_ty_names, tvs, @{sort pt}) 
+      |> Class.instantiation (raw_dt_names, raw_ty_args, @{sort pt}) 
       |> Primrec.add_primrec perm_fn_binds perm_eqs
     
-    val perm_zero_thms = prove_permute_zero induct_thm perm_eq_thms perm_funs lthy'
-    val perm_plus_thms = prove_permute_plus induct_thm perm_eq_thms perm_funs lthy'  
+    val perm_zero_thms = prove_permute_zero raw_induct_thm perm_eq_thms perm_funs lthy'
+    val perm_plus_thms = prove_permute_plus raw_induct_thm perm_eq_thms perm_funs lthy'  
   in
     lthy'
     |> Class.prove_instantiation_exit_result morphism tac 
@@ -519,5 +453,54 @@
   end
 
 
+(** equivarance proofs **)
+
+val eqvt_apply_sym = @{thm eqvt_apply[symmetric]}
+
+fun subproof_tac const_names simps = 
+  SUBPROOF (fn {prems, context, ...} => 
+    HEADGOAL 
+      (simp_tac (HOL_basic_ss addsimps simps)
+       THEN' eqvt_tac context (eqvt_relaxed_config addexcls const_names)
+       THEN' simp_tac (HOL_basic_ss addsimps (prems @ [eqvt_apply_sym]))))
+
+fun prove_eqvt_tac insts ind_thms const_names simps ctxt = 
+  HEADGOAL
+    (Object_Logic.full_atomize_tac
+     THEN' (DETERM o (InductTacs.induct_rules_tac ctxt insts ind_thms))  
+     THEN_ALL_NEW  subproof_tac const_names simps ctxt)
+
+fun mk_eqvt_goal pi const arg =
+  let
+    val lhs = mk_perm pi (const $ arg)
+    val rhs = const $ (mk_perm pi arg)  
+  in
+    HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs, rhs))
+  end
+
+
+fun raw_prove_eqvt consts ind_thms simps ctxt =
+  if null consts then []
+  else
+    let 
+      val ([p], ctxt') = Variable.variant_fixes ["p"] ctxt
+      val p = Free (p, @{typ perm})
+      val arg_tys = 
+        consts
+        |> map fastype_of
+        |> map domain_type 
+      val (arg_names, ctxt'') = 
+        Variable.variant_fixes (Datatype_Prop.make_tnames arg_tys) ctxt'
+      val args = map Free (arg_names ~~ arg_tys)
+      val goals = map2 (mk_eqvt_goal p) consts args
+      val insts = map (single o SOME) arg_names
+      val const_names = map (fst o dest_Const) consts      
+    in
+      Goal.prove_multi ctxt'' [] [] goals (fn {context, ...} => 
+        prove_eqvt_tac insts ind_thms const_names simps context)
+      |> ProofContext.export ctxt'' ctxt
+    end
+
+
 end (* structure *)