improved runtime slightly, by constructing an explicit size measure for the function definitions
authorChristian Urban <urbanc@in.tum.de>
Wed, 18 Aug 2010 00:19:15 +0800
changeset 2410 2bbdb9c427b5
parent 2409 83990a42a068
child 2411 dceaf2d9fedd
improved runtime slightly, by constructing an explicit size measure for the function definitions
Nominal-General/nominal_library.ML
Nominal/Ex/CoreHaskell.thy
Nominal/Ex/SingleLet.thy
Nominal/NewParser.thy
Nominal/nominal_dt_rawfuns.ML
--- a/Nominal-General/nominal_library.ML	Tue Aug 17 18:17:53 2010 +0800
+++ b/Nominal-General/nominal_library.ML	Wed Aug 18 00:19:15 2010 +0800
@@ -12,6 +12,9 @@
 
   val size_const: typ -> term 
 
+  val sum_case_const: typ -> typ -> typ -> term
+  val mk_sum_case: term -> term -> term
+ 
   val mk_minus: term -> term
   val mk_plus: term -> term -> term
 
@@ -51,7 +54,7 @@
 
   (* tactics for function package *)
   val pat_completeness_simp: thm list -> Proof.context -> tactic
-  val prove_termination: Proof.context -> Function.info * local_theory
+  val prove_termination: thm list -> Proof.context -> Function.info * local_theory
 
   (* transformations of premises in inductions *)
   val transform_prem1: Proof.context -> string list -> thm -> thm
@@ -76,6 +79,18 @@
 
 fun size_const ty = Const (@{const_name size}, ty --> @{typ nat})
 
+fun sum_case_const ty1 ty2 ty3 = 
+  Const (@{const_name sum_case}, [ty1 --> ty3, ty2 --> ty3, Type (@{type_name sum}, [ty1, ty2])] ---> ty3)
+fun mk_sum_case trm1 trm2 =
+let
+  val ([ty1], ty3) = strip_type (fastype_of trm1)
+  val ty2 = domain_type (fastype_of trm2)
+in
+  sum_case_const ty1 ty2 ty3 $ trm1 $ trm2
+end 
+
+
+
 fun mk_minus p = @{term "uminus::perm => perm"} $ p
 
 fun mk_plus p q = @{term "plus::perm => perm => perm"} $ p $ q
@@ -187,10 +202,29 @@
     THEN ALLGOALS (asm_full_simp_tac simp_set)
 end
 
-fun prove_termination lthy =
-  Function.prove_termination NONE
-    (Lexicographic_Order.lexicographic_order_tac true lthy) lthy
+fun prove_termination_tac size_simps ctxt i st  =
+let
+  fun mk_size_measure (Type (@{type_name Sum_Type.sum}, [fT, sT])) =
+      SumTree.mk_sumcase fT sT @{typ nat} (mk_size_measure fT) (mk_size_measure sT)
+    | mk_size_measure T = size_const T
 
+  val ((_ $ (_ $ rel)) :: tl) = prems_of st
+  val measure_trm = 
+    fastype_of rel 
+    |> HOLogic.dest_setT
+    |> mk_size_measure 
+    |> curry (op $) (Const (@{const_name measure}, dummyT))
+    |> Syntax.check_term ctxt
+  val ss = HOL_ss addsimps @{thms in_measure wf_measure sum.cases add_Suc_right add.right_neutral 
+    zero_less_Suc} @ size_simps addsimprocs Nat_Numeral_Simprocs.cancel_numerals
+in
+  (Function_Relation.relation_tac ctxt measure_trm
+   THEN_ALL_NEW  simp_tac ss) i st
+end
+
+fun prove_termination size_simps ctxt = 
+  Function.prove_termination NONE 
+    (HEADGOAL (prove_termination_tac size_simps ctxt)) ctxt
 
 (** transformations of premises (in inductive proofs) **)
 
--- a/Nominal/Ex/CoreHaskell.thy	Tue Aug 17 18:17:53 2010 +0800
+++ b/Nominal/Ex/CoreHaskell.thy	Wed Aug 18 00:19:15 2010 +0800
@@ -86,6 +86,7 @@
 | "bv_cvs (CvsCons v k t) = (atom v) # bv_cvs t"
 
 (* can lift *)
+
 thm distinct
 thm fv_defs
 thm alpha_bn_imps alpha_equivp
@@ -95,7 +96,6 @@
 thm perm_laws
 thm tkind_raw_ckind_raw_ty_raw_ty_lst_raw_co_raw_co_lst_raw_trm_raw_assoc_lst_raw_pat_raw_vars_raw_tvars_raw_cvars_raw.size(50 - 98)
 
-(* cannot lift yet *)
 thm eq_iff
 thm eq_iff_simps
 
--- a/Nominal/Ex/SingleLet.thy	Tue Aug 17 18:17:53 2010 +0800
+++ b/Nominal/Ex/SingleLet.thy	Wed Aug 18 00:19:15 2010 +0800
@@ -6,6 +6,7 @@
 
 declare [[STEPS = 20]]
 
+
 nominal_datatype trm  =
   Var "name"
 | App "trm" "trm"
@@ -21,6 +22,8 @@
 where
   "bn (As x y t) = {atom x}"
 
+
+
 ML {* Function.prove_termination *}
 
 text {* can lift *}
@@ -29,14 +32,14 @@
 thm trm_raw_assg_raw.inducts
 thm trm_raw.exhaust
 thm assg_raw.exhaust
-thm fv_defs
+thm FV_defs
 thm perm_simps
 thm perm_laws
 thm trm_raw_assg_raw.size(9 - 16)
 thm eq_iff
 thm eq_iff_simps
 thm bn_defs
-thm fv_eqvt
+thm FV_eqvt
 thm bn_eqvt
 thm size_eqvt
 
--- a/Nominal/NewParser.thy	Tue Aug 17 18:17:53 2010 +0800
+++ b/Nominal/NewParser.thy	Wed Aug 18 00:19:15 2010 +0800
@@ -226,7 +226,7 @@
 *}
 
 ML {*
-fun raw_nominal_decls dts bn_funs bn_eqs binds lthy =
+fun define_raw_dts dts bn_funs bn_eqs binds lthy =
 let
   val thy = ProofContext.theory_of lthy
   val thy_name = Context.theory_name thy
@@ -261,7 +261,8 @@
 *}
 
 ML {*
-fun raw_bn_decls dt_names dts raw_bn_funs raw_bn_eqs constr_thms lthy =
+(* should be in nominal_dt_rawfuns *)
+fun define_raw_bns dt_names dts raw_bn_funs raw_bn_eqs constr_thms size_thms lthy =
   if null raw_bn_funs 
   then ([], [], [], [], lthy)
   else 
@@ -269,8 +270,8 @@
       val (_, lthy1) = Function.add_function raw_bn_funs raw_bn_eqs
         Function_Common.default_config (pat_completeness_simp constr_thms) lthy
 
-      val (info, lthy2) = prove_termination (Local_Theory.restore lthy1)
-      val {fs, simps, inducts, ...} = info;
+      val (info, lthy2) = prove_termination size_thms (Local_Theory.restore lthy1)
+      val {fs, simps, inducts, ...} = info
 
       val raw_bn_induct = (the inducts)
       val raw_bn_eqs = the simps
@@ -301,7 +302,7 @@
   val _ = warning "Definition of raw datatypes";
   val (raw_dt_names, raw_dts, raw_bclauses, raw_bn_funs, raw_bn_eqs, lthy0) =
     if get_STEPS lthy > 0 
-    then raw_nominal_decls dts bn_funs bn_eqs bclauses lthy
+    then define_raw_dts dts bn_funs bn_eqs bclauses lthy
     else raise TEST lthy
 
   val dtinfo = Datatype.the_info (ProofContext.theory_of lthy0) (hd raw_dt_names)
@@ -339,14 +340,14 @@
   val _ = warning "Definition of raw fv-functions";
   val (raw_bns, raw_bn_defs, raw_bn_info, raw_bn_induct, lthy3a) =
     if get_STEPS lthy3 > 2 
-    then raw_bn_decls raw_full_ty_names raw_dts raw_bn_funs raw_bn_eqs 
-      (raw_inject_thms @ raw_distinct_thms) lthy3
+    then define_raw_bns raw_full_ty_names raw_dts raw_bn_funs raw_bn_eqs 
+      (raw_inject_thms @ raw_distinct_thms) raw_size_thms lthy3
     else raise TEST lthy3
 
   val (raw_fvs, raw_fv_bns, raw_fv_defs, raw_fv_bns_induct, lthy3b) = 
     if get_STEPS lthy3a > 3 
     then define_raw_fvs raw_full_ty_names raw_tys raw_cns_info raw_bn_info raw_bclauses 
-      (raw_inject_thms @ raw_distinct_thms) lthy3a
+      (raw_inject_thms @ raw_distinct_thms) raw_size_thms lthy3a
     else raise TEST lthy3a
 
   (* definition of raw alphas *)
--- a/Nominal/nominal_dt_rawfuns.ML	Tue Aug 17 18:17:53 2010 +0800
+++ b/Nominal/nominal_dt_rawfuns.ML	Wed Aug 18 00:19:15 2010 +0800
@@ -25,7 +25,7 @@
   val listify: Proof.context -> term -> term
 
   val define_raw_fvs: string list -> typ list -> cns_info list -> bn_info -> bclause list list list -> 
-    thm list -> Proof.context -> term list * term list * thm list * thm list * local_theory
+    thm list -> thm list -> Proof.context -> term list * term list * thm list * thm list * local_theory
  
   val raw_prove_eqvt: term list -> thm list -> thm list -> Proof.context -> thm list
 end
@@ -210,7 +210,7 @@
   map2 (mk_fv_bn_eq lthy bn_trm fv_map fv_bn_map) (bn_argss ~~ nth_constrs_info) nth_bclausess
 end
 
-fun define_raw_fvs raw_full_ty_names raw_tys cns_info bn_info bclausesss constr_thms lthy =
+fun define_raw_fvs raw_full_ty_names raw_tys cns_info bn_info bclausesss constr_thms size_simps lthy =
 let
   val fv_names = map (prefix "fv_" o Long_Name.base_name) raw_full_ty_names
   val fv_tys = map (fn ty => ty --> @{typ "atom set"}) raw_tys
@@ -233,19 +233,18 @@
 
   val (_, lthy') = Function.add_function all_fun_names all_fun_eqs
     Function_Common.default_config (pat_completeness_simp constr_thms) lthy
-
-  val (info, lthy'') = prove_termination (Local_Theory.restore lthy')
-
+  
+  val (info, lthy'') = prove_termination size_simps (Local_Theory.restore lthy')
+ 
   val {fs, simps, inducts, ...} = info;
 
   val morphism = ProofContext.export_morphism lthy'' lthy
   val fs_exp = map (Morphism.term morphism) fs
   val simps_exp = map (Morphism.thm morphism) (the simps)
   val inducts_exp = map (Morphism.thm morphism) (the inducts)
-  val (fv_frees_exp, fv_bns_exp) = chop (length fv_frees) fs_exp
-  
+  val (fvs_exp, fv_bns_exp) = chop (length fv_frees) fs_exp
 in
-  (fv_frees_exp, fv_bns_exp, simps_exp, inducts_exp, lthy'')
+  (fvs_exp, fv_bns_exp, simps_exp, inducts_exp, lthy'')
 end