Nominal/nominal_dt_rawperm.ML
changeset 2398 1e6160690546
parent 2396 f2f611daf480
child 2401 7645e18e8b19
--- a/Nominal/nominal_dt_rawperm.ML	Sat Aug 14 16:54:41 2010 +0800
+++ b/Nominal/nominal_dt_rawperm.ML	Sat Aug 14 23:33:23 2010 +0800
@@ -9,7 +9,7 @@
 
 signature NOMINAL_DT_RAWPERM =
 sig
-  val define_raw_perms: Datatype.descr -> (string * sort) list -> thm -> int -> theory -> 
+  val define_raw_perms: string list -> typ list -> term list -> thm -> theory -> 
     (term list * thm list * thm list) * theory
 end
 
@@ -18,39 +18,6 @@
 struct
 
 
-(* permutation function for one argument 
-   
-    - in case the argument is recursive it returns 
-
-         permute_fn p arg
-
-    - in case the argument is non-recursive it will return
-
-         p o arg
-*)
-fun perm_arg permute_fn_frees p (arg_dty, arg) =
-  if Datatype_Aux.is_rec_type arg_dty 
-  then (nth permute_fn_frees (Datatype_Aux.body_index arg_dty)) $ p $ arg
-  else mk_perm p arg
-
-
-(* generates the equation for the permutation function for one constructor;
-   i is the index of the corresponding datatype *)
-fun perm_eq_constr dt_descr sorts permute_fn_frees i (cnstr_name, dts) =
-let
-  val p = Free ("p", @{typ perm})
-  val arg_tys = map (Datatype_Aux.typ_of_dtyp dt_descr sorts) dts
-  val arg_names = Name.variant_list ["p"] (Datatype_Prop.make_tnames arg_tys)
-  val args = map Free (arg_names ~~ arg_tys)
-  val cnstr = Const (cnstr_name, arg_tys ---> (nth_dtyp dt_descr sorts i))
-  val lhs = (nth permute_fn_frees i) $ p $ list_comb (cnstr, args)
-  val rhs = list_comb (cnstr, map (perm_arg permute_fn_frees p) (dts ~~ args))
-  val eq = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs, rhs))
-in
-  (Attrib.empty_binding, eq)
-end
-
-
 (** proves the two pt-type class properties **)
 
 fun prove_permute_zero lthy induct perm_defs perm_fns =
@@ -99,49 +66,62 @@
 end
 
 
-(* user_dt_nos refers to the number of "un-unfolded" datatypes
-   given by the user
-*)
-fun define_raw_perms dt_descr sorts induct_thm user_dt_nos thy =
+fun mk_perm_eq ty_perm_assoc cnstr = 
 let
-  val all_full_tnames = map (fn (_, (n, _, _)) => n) dt_descr;
-  val user_full_tnames = List.take (all_full_tnames, user_dt_nos);
+  fun lookup_perm p (ty, arg) = 
+    case (AList.lookup (op=) ty_perm_assoc ty) of
+      SOME perm => perm $ p $ arg
+    | NONE => Const (@{const_name permute}, perm_ty ty) $ p $ arg
+
+  val p = Free ("p", @{typ perm})
+  val (arg_tys, ty) =
+    fastype_of cnstr
+    |> strip_type
+
+  val arg_names = Name.variant_list ["p"] (Datatype_Prop.make_tnames arg_tys)
+  val args = map Free (arg_names ~~ arg_tys)
 
-  val perm_fn_names = prefix_dt_names dt_descr sorts "permute_"
-  val perm_fn_types = map (fn (i, _) => perm_ty (nth_dtyp dt_descr sorts i)) dt_descr
-  val perm_fn_frees = map Free (perm_fn_names ~~ perm_fn_types)
+  val lhs = lookup_perm p (ty, list_comb (cnstr, args))
+  val rhs = list_comb (cnstr, map (lookup_perm p) (arg_tys ~~ args))
+  
+  val eq = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs, rhs))  
+in
+  (Attrib.empty_binding, eq)
+end
+
 
-  fun perm_eq (i, (_, _, constrs)) = 
-    map (perm_eq_constr dt_descr sorts perm_fn_frees i) constrs;
+fun define_raw_perms full_ty_names tys constrs induct_thm thy =
+let
+  val perm_fn_names = full_ty_names
+    |> map Long_Name.base_name
+    |> map (prefix "permute_")
 
-  val perm_eqs = maps perm_eq dt_descr;
+  val perm_fn_types = map perm_ty tys
+  val perm_fn_frees = map Free (perm_fn_names ~~ perm_fn_types)
+  val perm_fn_binds = map (fn s => (Binding.name s, NONE, NoSyn)) perm_fn_names
+
+  val perm_eqs = map (mk_perm_eq (tys ~~ perm_fn_frees)) constrs
 
   val lthy =
-    Class.instantiation (user_full_tnames, [], @{sort pt}) thy;
+    Class.instantiation (full_ty_names, [], @{sort pt}) thy
    
   val ((perm_funs, perm_eq_thms), lthy') =
-    Primrec.add_primrec
-      (map (fn s => (Binding.name s, NONE, NoSyn)) perm_fn_names) perm_eqs lthy;
+    Primrec.add_primrec perm_fn_binds perm_eqs lthy;
     
   val perm_zero_thms = prove_permute_zero lthy' induct_thm perm_eq_thms perm_funs
   val perm_plus_thms = prove_permute_plus lthy' induct_thm perm_eq_thms perm_funs
-  val perm_zero_thms' = List.take (perm_zero_thms, user_dt_nos);
-  val perm_plus_thms' = List.take (perm_plus_thms, user_dt_nos)
-  val perms_name = space_implode "_" perm_fn_names
-  val perms_zero_bind = Binding.name (perms_name ^ "_zero")
-  val perms_plus_bind = Binding.name (perms_name ^ "_plus")
   
   fun tac _ (_, _, simps) =
     Class.intro_classes_tac [] THEN ALLGOALS (resolve_tac simps)
   
   fun morphism phi (fvs, dfs, simps) =
-    (map (Morphism.term phi) fvs, map (Morphism.thm phi) dfs, map (Morphism.thm phi) simps);
+    (map (Morphism.term phi) fvs, 
+     map (Morphism.thm phi) dfs, 
+     map (Morphism.thm phi) simps);
 in
   lthy'
-  |> snd o (Local_Theory.note ((perms_zero_bind, []), perm_zero_thms'))
-  |> snd o (Local_Theory.note ((perms_plus_bind, []), perm_plus_thms'))
   |> Class.prove_instantiation_exit_result morphism tac 
-       (perm_funs, perm_eq_thms, perm_zero_thms' @ perm_plus_thms')
+       (perm_funs, perm_eq_thms, perm_zero_thms @ perm_plus_thms)
 end