diff -r c8acaded1777 -r 4a00077c008f Nominal/nominal_mutual.ML --- a/Nominal/nominal_mutual.ML Tue Jul 19 19:09:06 2011 +0100 +++ b/Nominal/nominal_mutual.ML Fri Jul 22 11:37:16 2011 +0100 @@ -8,6 +8,7 @@ Mutual recursive nominal function definitions. *) + signature NOMINAL_FUNCTION_MUTUAL = sig @@ -193,87 +194,46 @@ in Goal.prove ctxt [] [] (HOLogic.Trueprop $ HOLogic.mk_eq (list_comb (f, args), rhs)) - (fn _ => print_tac "start" - THEN (Local_Defs.unfold_tac ctxt all_orig_fdefs) - THEN (print_tac "second") + (fn _ => (Local_Defs.unfold_tac ctxt all_orig_fdefs) THEN EqSubst.eqsubst_tac ctxt [0] [simp] 1 - THEN (print_tac "third") - THEN (simp_tac (simpset_of ctxt)) 1 - THEN (print_tac "fourth") - ) (* FIXME: global simpset?!! *) + THEN (simp_tac (simpset_of ctxt)) 1) (* FIXME: global simpset?!! *) |> restore_cond |> export end -val test1 = @{lemma "x = Inl y ==> permute p (Sum_Type.Projl x) = Sum_Type.Projl (permute p x)" by simp} -val test2 = @{lemma "x = Inr y ==> permute p (Sum_Type.Projr x) = Sum_Type.Projr (permute p x)" by simp} +val inl_perm = @{lemma "x = Inl y ==> Sum_Type.Projl (permute p x) = permute p (Sum_Type.Projl x)" by simp} +val inr_perm = @{lemma "x = Inr y ==> Sum_Type.Projr (permute p x) = permute p (Sum_Type.Projr x)" by simp} -fun recover_mutual_eqvt eqvt_thm all_orig_fdefs parts ctxt (fname, _, _, args, rhs) +fun recover_mutual_eqvt eqvt_thm all_orig_fdefs parts ctxt (fname, _, _, args, _) import (export : thm -> thm) sum_psimp_eq = let val (MutualPart {f=SOME f, ...}) = get_part fname parts - + val psimp = import sum_psimp_eq - val (simp, restore_cond) = + val (cond, simp, restore_cond) = case cprems_of psimp of - [] => (psimp, I) - | [cond] => (Thm.implies_elim psimp (Thm.assume cond), Thm.implies_intr cond) - | _ => raise General.Fail "Too many conditions" - - val eqvt_thm' = import eqvt_thm - val (simp', restore_cond') = - case cprems_of eqvt_thm' of - [] => (eqvt_thm, I) - | [cond] => (Thm.implies_elim eqvt_thm' (Thm.assume cond), Thm.implies_intr cond) + [] => ([], psimp, I) + | [cond] => ([Thm.assume cond], Thm.implies_elim psimp (Thm.assume cond), Thm.implies_intr cond) | _ => raise General.Fail "Too many conditions" - val _ = tracing ("sum_psimp:\n" ^ @{make_string} sum_psimp_eq) - val _ = tracing ("psimp:\n" ^ @{make_string} psimp) - val _ = tracing ("simp:\n" ^ @{make_string} simp) - val _ = tracing ("eqvt:\n" ^ @{make_string} eqvt_thm) - val ([p], ctxt') = Variable.variant_fixes ["p"] ctxt val p = Free (p, @{typ perm}) - val ss = HOL_basic_ss addsimps [simp RS test1, simp'] + val ss = HOL_basic_ss addsimps + @{thms permute_sum.simps[symmetric] Pair_eqvt[symmetric]} @ + @{thms Projr.simps Projl.simps} @ + [(cond MRS eqvt_thm) RS @{thm sym}] @ + [inl_perm, inr_perm, simp] + val goal_lhs = mk_perm p (list_comb (f, args)) + val goal_rhs = list_comb (f, map (mk_perm p) args) in - Goal.prove ctxt' [] [] - (HOLogic.Trueprop $ - HOLogic.mk_eq (mk_perm p (list_comb (f, args)), list_comb (f, map (mk_perm p) args))) - (fn _ => print_tac "eqvt start" - THEN (Local_Defs.unfold_tac ctxt all_orig_fdefs) - THEN (asm_full_simp_tac ss 1) - THEN all_tac) + Goal.prove ctxt' [] [] (HOLogic.Trueprop $ HOLogic.mk_eq (goal_lhs, goal_rhs)) + (fn _ => (Local_Defs.unfold_tac ctxt all_orig_fdefs) + THEN (asm_full_simp_tac ss 1)) + |> singleton (ProofContext.export ctxt' ctxt) |> restore_cond |> export end -fun mk_meqvts ctxt eqvt_thm f_defs = - let - val ctrm1 = eqvt_thm - |> cprop_of - |> snd o Thm.dest_implies - |> Thm.dest_arg - |> Thm.dest_arg1 - |> Thm.dest_arg - - fun resolve f_def = - let - val ctrm2 = f_def - |> cprop_of - |> Thm.dest_equals_lhs - val _ = tracing ("ctrm1:\n" ^ @{make_string} ctrm1) - val _ = tracing ("ctrm2:\n" ^ @{make_string} ctrm2) - in - eqvt_thm - |> Thm.instantiate (Thm.match (ctrm1, ctrm2)) - |> simplify (HOL_basic_ss addsimps (@{thm Pair_eqvt} :: @{thms permute_sum.simps})) - |> Local_Defs.unfold ctxt [f_def] - end - in - map resolve f_defs - end - - fun mk_applied_form ctxt caTs thm = let val thy = ProofContext.theory_of ctxt @@ -324,6 +284,66 @@ fst (fold_map (project induct_inst) parts 0) end + +fun forall_elim s (Const ("all", _) $ Abs (_, _, t)) = subst_bound (s, t) + | forall_elim _ t = t + +val forall_elim_list = fold forall_elim + +fun split_conj_thm th = + (split_conj_thm (th RS conjunct1)) @ (split_conj_thm (th RS conjunct2)) handle THM _ => [th]; + +fun prove_eqvt ctxt fs argTss eqvts_thms induct_thms = + let + fun aux argTs s = argTs + |> map (pair s) + |> Variable.variant_frees ctxt fs + val argss' = map2 aux argTss (Name.invent (Variable.names_of ctxt) "" (length fs)) + val argss = (map o map) Free argss' + val arg_namess = (map o map) fst argss' + val insts = (map o map) SOME arg_namess + + val ([p_name], ctxt') = Variable.variant_fixes ["p"] ctxt + val p = Free (p_name, @{typ perm}) + + val acc_prems = + map prop_of induct_thms + |> map2 forall_elim_list argss + |> map (strip_qnt_body "all") + |> map (curry Logic.nth_prem 1) + |> map HOLogic.dest_Trueprop + + fun mk_goal acc_prem (f, args) = + let + val goal_lhs = mk_perm p (list_comb (f, args)) + val goal_rhs = list_comb (f, map (mk_perm p) args) + in + HOLogic.mk_imp (acc_prem, HOLogic.mk_eq (goal_lhs, goal_rhs)) + end + + val goal = fold_conj_balanced (map2 mk_goal acc_prems (fs ~~ argss)) + |> HOLogic.mk_Trueprop + + val induct_thm = case induct_thms of + [thm] => thm + |> Drule.gen_all + |> Thm.permute_prems 0 1 + |> (fn thm => atomize_rule (length (prems_of thm) - 1) thm) + | thms => thms + |> map Drule.gen_all + |> map (Rule_Cases.add_consumes 1) + |> snd o Rule_Cases.strict_mutual_rule ctxt' + |> atomize_concl + + fun tac thm = rtac (Drule.gen_all thm) THEN_ALL_NEW atac + in + Goal.prove ctxt' (flat arg_namess) [] goal + (fn {context, ...} => HEADGOAL (DETERM o (rtac induct_thm) THEN' RANGE (map tac eqvts_thms))) + |> singleton (ProofContext.export ctxt' ctxt) + |> split_conj_thm + |> map (fn th => th RS mp) + end + fun mk_partial_rules_mutual lthy inner_cont (m as Mutual {parts, fqgars, ...}) proof = let val result = inner_cont proof @@ -332,13 +352,16 @@ val (all_f_defs, fs) = map (fn MutualPart {f_defthm = SOME f_def, f = SOME f, cargTs, ...} => - (mk_applied_form lthy cargTs (Thm.symmetric f_def), f)) + (mk_applied_form lthy cargTs (Thm.symmetric f_def), f)) parts |> split_list val all_orig_fdefs = map (fn MutualPart {f_defthm = SOME f_def, ...} => f_def) parts + val cargTss = + map (fn MutualPart {f = SOME f, cargTs, ...} => cargTs) parts + fun mk_mpsimp fqgar sum_psimp = in_context lthy fqgar (recover_mutual_psimp all_orig_fdefs parts) sum_psimp @@ -351,11 +374,12 @@ val mtermination = full_simplify rew_ss termination val mdomintros = Option.map (map (full_simplify rew_ss)) domintros val meqvts = map2 mk_meqvts fqgars psimps + val meqvt_funs = prove_eqvt lthy fs cargTss meqvts minducts in NominalFunctionResult { fs=fs, G=G, R=R, psimps=mpsimps, simple_pinducts=minducts, cases=cases, termination=mtermination, - domintros=mdomintros, eqvts=meqvts } + domintros=mdomintros, eqvts=meqvt_funs } end (* nominal *)