Nominal/Nominal2.thy
changeset 2957 01ff621599bc
parent 2950 0911cb7bf696
child 2959 9bd97ed202f7
--- a/Nominal/Nominal2.thy	Wed Jul 06 15:59:11 2011 +0200
+++ b/Nominal/Nominal2.thy	Thu Jul 07 16:16:42 2011 +0200
@@ -145,7 +145,7 @@
 ML {*
 (* definition of the raw datatype *)
 
-fun define_raw_dts dts bn_funs bn_eqs bclauses lthy =
+fun define_raw_dts dts cnstr_names cnstr_tys bn_funs bn_eqs bclauses lthy =
 let
   val thy = Local_Theory.exit_global lthy
   val thy_name = Context.theory_name thy
@@ -155,8 +155,6 @@
   val dt_full_names' = add_raws dt_full_names
   val dts_env = dt_full_names ~~ dt_full_names'
 
-  val cnstr_names = get_cnstr_strs dts
-  val cnstr_tys = get_typed_cnstrs dts
   val cnstr_full_names = map (Long_Name.qualify thy_name) cnstr_names
   val cnstr_full_names' = map (fn (x, y) => Long_Name.qualify thy_name 
     (Long_Name.qualify (add_raw x) (add_raw y))) cnstr_tys
@@ -168,16 +166,52 @@
   val bn_fun_full_env = map (pairself (Long_Name.qualify thy_name)) 
     (bn_fun_strs ~~ bn_fun_strs')
   
-  val (raw_dt_names, raw_dts) = rawify_dts dt_names dts dts_env
+  val (raw_full_dt_names, raw_dts) = rawify_dts dt_names dts dts_env
   val (raw_bn_funs, raw_bn_eqs) = rawify_bn_funs dts_env cnstrs_env bn_fun_env bn_funs bn_eqs 
   val raw_bclauses = rawify_bclauses dts_env cnstrs_env bn_fun_full_env bclauses 
 
-  val (raw_dt_full_names, thy1) = 
-    Datatype.add_datatype Datatype.default_config raw_dt_names raw_dts thy
+  val (raw_full_dt_names', thy1) = 
+    Datatype.add_datatype Datatype.default_config raw_full_dt_names raw_dts thy
 
   val lthy1 = Named_Target.theory_init thy1
+
+  val dtinfos = map (Datatype.the_info (ProofContext.theory_of lthy1)) raw_full_dt_names' 
+  val {descr, sorts, ...} = hd dtinfos
+
+  val raw_tys = Datatype_Aux.get_rec_types descr sorts
+  val raw_ty_args = hd raw_tys
+    |> snd o dest_Type
+    |> map dest_TFree 
+
+  val raw_cns_info = all_dtyp_constrs_types descr sorts
+  val raw_all_cns = (map o map) (fn (c, _, _, _) => c) raw_cns_info
+
+  val raw_inject_thms = flat (map #inject dtinfos)
+  val raw_distinct_thms = flat (map #distinct dtinfos)
+  val raw_induct_thm = #induct (hd dtinfos)
+  val raw_induct_thms = #inducts (hd dtinfos)
+  val raw_exhaust_thms = map #exhaust dtinfos
+  val raw_size_trms = map HOLogic.size_const raw_tys
+  val raw_size_thms = Size.size_thms (ProofContext.theory_of lthy1) (hd raw_full_dt_names')
+    |> `(fn thms => (length thms) div 2)
+    |> uncurry drop
+
+  val raw_result = RawDtInfo 
+    {raw_dt_names = raw_full_dt_names',
+     raw_dts = raw_dts,
+     raw_tys = raw_tys,
+     raw_ty_args = raw_ty_args,
+     raw_cns_info = raw_cns_info,
+     raw_all_cns = raw_all_cns,
+     raw_inject_thms = raw_inject_thms,
+     raw_distinct_thms = raw_distinct_thms,
+     raw_induct_thm = raw_induct_thm,
+     raw_induct_thms = raw_induct_thms,
+     raw_exhaust_thms = raw_exhaust_thms,
+     raw_size_trms = raw_size_trms,
+     raw_size_thms = raw_size_thms}
 in
-  (raw_dt_full_names, raw_dts, raw_bclauses, raw_bn_funs, raw_bn_eqs, cnstr_names, lthy1)
+  (raw_bclauses, raw_bn_funs, raw_bn_eqs, raw_result, lthy1)
 end
 *}
 
@@ -185,62 +219,52 @@
 ML {*
 fun nominal_datatype2 opt_thms_name dts bn_funs bn_eqs bclauses lthy =
 let
-  val _ = trace_msg (K "Defining raw datatypes...")
-  val (raw_dt_names, raw_dts, raw_bclauses, raw_bn_funs, raw_bn_eqs, cnstr_names, lthy0) =
-    define_raw_dts dts bn_funs bn_eqs bclauses lthy   
+  val cnstr_names = get_cnstr_strs dts
+  val cnstr_tys = get_typed_cnstrs dts
 
-  val dtinfos = map (Datatype.the_info (ProofContext.theory_of lthy0)) raw_dt_names 
-  val {descr, sorts, ...} = hd dtinfos
-
-  val raw_tys = Datatype_Aux.get_rec_types descr sorts
-  val tvs = hd raw_tys
-    |> snd o dest_Type
-    |> map dest_TFree  
+  val _ = trace_msg (K "Defining raw datatypes...")
+  val (raw_bclauses, raw_bn_funs, raw_bn_eqs, raw_dt_info, lthy0) =
+    define_raw_dts dts cnstr_names cnstr_tys bn_funs bn_eqs bclauses lthy   
 
-  val raw_cns_info = all_dtyp_constrs_types descr sorts
-  val raw_constrs = (map o map) (fn (c, _, _, _) => c) raw_cns_info
-
-  val raw_inject_thms = flat (map #inject dtinfos)
-  val raw_distinct_thms = flat (map #distinct dtinfos)
-  val raw_induct_thm = #induct (hd dtinfos)
-  val raw_induct_thms = #inducts (hd dtinfos)
-  val raw_exhaust_thms = map #exhaust dtinfos
-  val raw_size_trms = map HOLogic.size_const raw_tys
-  val raw_size_thms = Size.size_thms (ProofContext.theory_of lthy0) (hd raw_dt_names)
-    |> `(fn thms => (length thms) div 2)
-    |> uncurry drop
+  val RawDtInfo 
+    {raw_dt_names,
+     raw_tys,
+     raw_ty_args,
+     raw_all_cns,
+     raw_inject_thms,
+     raw_distinct_thms,
+     raw_induct_thm,
+     raw_induct_thms,
+     raw_exhaust_thms,
+     raw_size_trms,
+     raw_size_thms, ...} = raw_dt_info
   
   val _ = trace_msg (K "Defining raw permutations...")
-  val ((raw_perm_funs, raw_perm_simps, raw_perm_laws), lthy2a) =
-    define_raw_perms raw_dt_names raw_tys tvs (flat raw_constrs) raw_induct_thm lthy0
+  val ((raw_perm_funs, raw_perm_simps, raw_perm_laws), lthy2a) = define_raw_perms raw_dt_info lthy0
  
   (* noting the raw permutations as eqvt theorems *)
   val (_, lthy3) = Local_Theory.note ((Binding.empty, [eqvt_attr]), raw_perm_simps) lthy2a
 
   val _ = trace_msg (K "Defining raw fv- and bn-functions...")
   val (raw_bns, raw_bn_defs, raw_bn_info, raw_bn_inducts, lthy3a) =
-    define_raw_bns raw_dt_names raw_dts raw_bn_funs raw_bn_eqs 
-      (raw_inject_thms @ raw_distinct_thms) raw_size_thms lthy3
+    define_raw_bns raw_dt_info raw_bn_funs raw_bn_eqs lthy3
     
   (* defining the permute_bn functions *)
   val (raw_perm_bns, raw_perm_bn_simps, lthy3b) = 
-    define_raw_bn_perms raw_tys raw_bn_info raw_cns_info 
-      (raw_inject_thms @ raw_distinct_thms) raw_size_thms lthy3a
+    define_raw_bn_perms raw_dt_info raw_bn_info lthy3a
     
   val (raw_fvs, raw_fv_bns, raw_fv_defs, raw_fv_bns_induct, lthy3c) = 
-    define_raw_fvs raw_dt_names raw_tys raw_cns_info raw_bn_info raw_bclauses 
-      (raw_inject_thms @ raw_distinct_thms) raw_size_thms lthy3b
+    define_raw_fvs raw_dt_info raw_bn_info raw_bclauses lthy3b
     
   val _ = trace_msg (K "Defining alpha relations...")
   val (alpha_result, lthy4) =
-    define_raw_alpha raw_dt_names raw_tys raw_cns_info raw_bn_info raw_bclauses raw_fvs lthy3c
+    define_raw_alpha raw_dt_info raw_bn_info raw_bclauses raw_fvs lthy3c
     
   val _ = trace_msg (K "Proving distinct theorems...")
-  val alpha_distincts = 
-    raw_prove_alpha_distincts lthy4 alpha_result raw_distinct_thms raw_dt_names
+  val alpha_distincts = raw_prove_alpha_distincts lthy4 alpha_result raw_dt_info
 
   val _ = trace_msg (K "Proving eq-iff theorems...")
-  val alpha_eq_iff = raw_prove_alpha_eq_iff lthy4 alpha_result raw_distinct_thms raw_inject_thms
+  val alpha_eq_iff = raw_prove_alpha_eq_iff lthy4 alpha_result raw_dt_info
     
   val _ = trace_msg (K "Proving equivariance of bns, fvs, size and alpha...")
   val raw_bn_eqvt = 
@@ -253,11 +277,15 @@
     raw_prove_eqvt (raw_fvs @ raw_fv_bns) raw_fv_bns_induct (raw_fv_defs @ raw_perm_simps) 
       (Local_Theory.restore lthy_tmp)
     
-  val raw_size_eqvt = 
-    raw_prove_eqvt raw_size_trms raw_induct_thms (raw_size_thms @ raw_perm_simps) 
-      (Local_Theory.restore lthy_tmp)
-      |> map (rewrite_rule @{thms permute_nat_def[THEN eq_reflection]})
-      |> map (fn thm => thm RS @{thm sym})
+  val raw_size_eqvt =
+    let
+      val RawDtInfo {raw_size_trms, raw_size_thms, raw_induct_thms, ...} = raw_dt_info
+    in
+      raw_prove_eqvt raw_size_trms raw_induct_thms (raw_size_thms @ raw_perm_simps) 
+        (Local_Theory.restore lthy_tmp)
+        |> map (rewrite_rule @{thms permute_nat_def[THEN eq_reflection]})
+        |> map (fn thm => thm RS @{thm sym})
+    end 
      
   val lthy5 = snd (Local_Theory.note ((Binding.empty, [eqvt_attr]), raw_fv_eqvt) lthy_tmp)
 
@@ -293,7 +321,7 @@
       |> map mk_funs_rsp
 
   val raw_constrs_rsp = 
-    raw_constrs_rsp lthy5 alpha_result (flat raw_constrs) (alpha_bn_imp_thms @ raw_funs_rsp_aux) 
+    raw_constrs_rsp lthy5 alpha_result (flat raw_all_cns) (alpha_bn_imp_thms @ raw_funs_rsp_aux) 
     
   val alpha_permute_rsp = map mk_alpha_permute_rsp alpha_eqvt
 
@@ -324,7 +352,7 @@
 
   val _ = trace_msg (K "Defining the quotient constants...")
   val qconstrs_descrs =
-    (map2 o map2) (fn (b, _, mx) => fn t => (Variable.check_name b, t, mx)) (get_cnstrs dts) raw_constrs
+    (map2 o map2) (fn (b, _, mx) => fn t => (Variable.check_name b, t, mx)) (get_cnstrs dts) raw_all_cns
 
   val qbns_descr =
     map2 (fn (b, _, mx) => fn t => (Variable.check_name b, t, mx)) bn_funs raw_bns
@@ -362,10 +390,10 @@
       ||>> define_qconsts qtys qperm_bn_descr
 
   val lthy9 = 
-    define_qperms qtys qty_full_names tvs qperm_descr raw_perm_laws lthy8 
+    define_qperms qtys qty_full_names raw_ty_args qperm_descr raw_perm_laws lthy8 
   
   val lthy9a = 
-    define_qsizes qtys qty_full_names tvs qsize_descr lthy9
+    define_qsizes qtys qty_full_names raw_ty_args qsize_descr lthy9
 
   val qtrms = (map o map) #qconst qconstrs_infos
   val qbns = map #qconst qbns_info
@@ -408,7 +436,7 @@
   val qfsupp_thms = prove_fsupp lthyB qtys qinduct qsupports_thms
 
   (* fs instances *)
-  val lthyC = fs_instance qtys qty_full_names tvs qfsupp_thms lthyB
+  val lthyC = fs_instance qtys qty_full_names raw_ty_args qfsupp_thms lthyB
 
   val _ = trace_msg (K "Proving equality between fv and supp...")
   val qfv_supp_thms =