11 *) |
11 *) |
12 |
12 |
13 |
13 |
14 signature NOMINAL_FUNCTION_MUTUAL = |
14 signature NOMINAL_FUNCTION_MUTUAL = |
15 sig |
15 sig |
16 |
|
17 val prepare_nominal_function_mutual : Nominal_Function_Common.nominal_function_config |
16 val prepare_nominal_function_mutual : Nominal_Function_Common.nominal_function_config |
18 -> string (* defname *) |
17 -> string (* defname *) |
19 -> ((string * typ) * mixfix) list |
18 -> ((string * typ) * mixfix) list |
20 -> term list |
19 -> term list |
21 -> local_theory |
20 -> local_theory |
22 -> ((thm (* goalstate *) |
21 -> ((thm (* goalstate *) |
23 * (thm -> Nominal_Function_Common.nominal_function_result) (* proof continuation *) |
22 * (thm -> Nominal_Function_Common.nominal_function_result) (* proof continuation *) |
24 ) * local_theory) |
23 ) * local_theory) |
25 |
|
26 end |
24 end |
27 |
|
28 |
25 |
29 structure Nominal_Function_Mutual: NOMINAL_FUNCTION_MUTUAL = |
26 structure Nominal_Function_Mutual: NOMINAL_FUNCTION_MUTUAL = |
30 struct |
27 struct |
31 |
28 |
32 open Function_Lib |
29 open Function_Lib |
93 val argTs = map (foldr1 HOLogic.mk_prodT) caTss |
90 val argTs = map (foldr1 HOLogic.mk_prodT) caTss |
94 |
91 |
95 val dresultTs = distinct (op =) resultTs |
92 val dresultTs = distinct (op =) resultTs |
96 val n' = length dresultTs |
93 val n' = length dresultTs |
97 |
94 |
98 val RST = Balanced_Tree.make (uncurry SumTree.mk_sumT) dresultTs |
95 val RST = Balanced_Tree.make (uncurry Sum_Tree.mk_sumT) dresultTs |
99 val ST = Balanced_Tree.make (uncurry SumTree.mk_sumT) argTs |
96 val ST = Balanced_Tree.make (uncurry Sum_Tree.mk_sumT) argTs |
100 |
97 |
101 val fsum_type = ST --> RST |
98 val fsum_type = ST --> RST |
102 |
99 |
103 val ([fsum_var_name], _) = Variable.add_fixes [ defname ^ "_sum" ] ctxt |
100 val ([fsum_var_name], _) = Variable.add_fixes [ defname ^ "_sum" ] ctxt |
104 val fsum_var = (fsum_var_name, fsum_type) |
101 val fsum_var = (fsum_var_name, fsum_type) |
106 fun define (fvar as (n, _)) caTs resultT i = |
103 fun define (fvar as (n, _)) caTs resultT i = |
107 let |
104 let |
108 val vars = map_index (fn (j,T) => Free ("x" ^ string_of_int j, T)) caTs (* FIXME: Bind xs properly *) |
105 val vars = map_index (fn (j,T) => Free ("x" ^ string_of_int j, T)) caTs (* FIXME: Bind xs properly *) |
109 val i' = find_index (fn Ta => Ta = resultT) dresultTs + 1 |
106 val i' = find_index (fn Ta => Ta = resultT) dresultTs + 1 |
110 |
107 |
111 val f_exp = SumTree.mk_proj RST n' i' (Free fsum_var $ SumTree.mk_inj ST num i (foldr1 HOLogic.mk_prod vars)) |
108 val f_exp = Sum_Tree.mk_proj RST n' i' (Free fsum_var $ Sum_Tree.mk_inj ST num i (foldr1 HOLogic.mk_prod vars)) |
112 val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp) |
109 val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp) |
113 |
110 |
114 val rew = (n, fold_rev lambda vars f_exp) |
111 val rew = (n, fold_rev lambda vars f_exp) |
115 in |
112 in |
116 (MutualPart {i=i, i'=i', fvar=fvar,cargTs=caTs,f_def=def,f=NONE,f_defthm=NONE}, rew) |
113 (MutualPart {i=i, i'=i', fvar=fvar,cargTs=caTs,f_def=def,f=NONE,f_defthm=NONE}, rew) |
122 let |
119 let |
123 val MutualPart {i, i', ...} = get_part f parts |
120 val MutualPart {i, i', ...} = get_part f parts |
124 val rhs' = rhs |
121 val rhs' = rhs |
125 |> map_aterms (fn t as Free (n, _) => the_default t (AList.lookup (op =) rews n) | t => t) |
122 |> map_aterms (fn t as Free (n, _) => the_default t (AList.lookup (op =) rews n) | t => t) |
126 in |
123 in |
127 (qs, gs, SumTree.mk_inj ST num i (foldr1 (mk_prod_abs qs) args), |
124 (qs, gs, Sum_Tree.mk_inj ST num i (foldr1 (mk_prod_abs qs) args), |
128 Envir.beta_norm (SumTree.mk_inj RST n' i' rhs')) |
125 Envir.beta_norm (Sum_Tree.mk_inj RST n' i' rhs')) |
129 end |
126 end |
130 |
127 |
131 val qglrs = map convert_eqs fqgars |
128 val qglrs = map convert_eqs fqgars |
132 in |
129 in |
133 Mutual {n=num, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, |
130 Mutual {n=num, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, |
203 THEN (simp_tac ctxt') 1) |
200 THEN (simp_tac ctxt') 1) |
204 |> restore_cond |
201 |> restore_cond |
205 |> export |
202 |> export |
206 end |
203 end |
207 |
204 |
208 val inl_perm = @{lemma "x = Inl y ==> Sum_Type.Projl (permute p x) = permute p (Sum_Type.Projl x)" by simp} |
205 val inl_perm = @{lemma "x = Inl y ==> projl (permute p x) = permute p (projl x)" by simp} |
209 val inr_perm = @{lemma "x = Inr y ==> Sum_Type.Projr (permute p x) = permute p (Sum_Type.Projr x)" by simp} |
206 val inr_perm = @{lemma "x = Inr y ==> projr (permute p x) = permute p (projr x)" by simp} |
210 |
207 |
211 fun recover_mutual_eqvt eqvt_thm all_orig_fdefs parts ctxt (fname, _, _, args, _) |
208 fun recover_mutual_eqvt eqvt_thm all_orig_fdefs parts ctxt (fname, _, _, args, _) |
212 import (export : thm -> thm) sum_psimp_eq = |
209 import (export : thm -> thm) sum_psimp_eq = |
213 let |
210 let |
214 val (MutualPart {f=SOME f, ...}) = get_part fname parts |
211 val (MutualPart {f=SOME f, ...}) = get_part fname parts |
222 in (([asm], Thm.implies_elim psimp asm, Thm.implies_intr cond), ctxt') end |
219 in (([asm], Thm.implies_elim psimp asm, Thm.implies_intr cond), ctxt') end |
223 | _ => raise General.Fail "Too many conditions" |
220 | _ => raise General.Fail "Too many conditions" |
224 |
221 |
225 val ([p], ctxt'') = ctxt' |
222 val ([p], ctxt'') = ctxt' |
226 |> fold Variable.declare_term args |
223 |> fold Variable.declare_term args |
227 |> Variable.variant_fixes ["p"] |
224 |> Variable.variant_fixes ["p"] |
228 val p = Free (p, @{typ perm}) |
225 val p = Free (p, @{typ perm}) |
229 |
226 |
230 val simpset = |
227 val simpset = |
231 put_simpset HOL_basic_ss ctxt'' addsimps |
228 put_simpset HOL_basic_ss ctxt'' addsimps |
232 @{thms permute_sum.simps[symmetric] Pair_eqvt[symmetric]} @ |
229 @{thms permute_sum.simps[symmetric] Pair_eqvt[symmetric] sum.sel} @ |
233 @{thms Projr.simps Projl.simps} @ |
|
234 [(cond MRS eqvt_thm) RS @{thm sym}] @ |
230 [(cond MRS eqvt_thm) RS @{thm sym}] @ |
235 [inl_perm, inr_perm, simp] |
231 [inl_perm, inr_perm, simp] |
236 val goal_lhs = mk_perm p (list_comb (f, args)) |
232 val goal_lhs = mk_perm p (list_comb (f, args)) |
237 val goal_rhs = list_comb (f, map (mk_perm p) args) |
233 val goal_rhs = list_comb (f, map (mk_perm p) args) |
238 in |
234 in |
270 in |
266 in |
271 HOLogic.tupled_lambda atup (list_comb (P, avars)) |
267 HOLogic.tupled_lambda atup (list_comb (P, avars)) |
272 end |
268 end |
273 |
269 |
274 val Ps = map2 mk_P parts newPs |
270 val Ps = map2 mk_P parts newPs |
275 val case_exp = SumTree.mk_sumcases HOLogic.boolT Ps |
271 val case_exp = Sum_Tree.mk_sumcases HOLogic.boolT Ps |
276 |
272 |
277 val induct_inst = |
273 val induct_inst = |
278 Thm.forall_elim (cert case_exp) induct |
274 Thm.forall_elim (cert case_exp) induct |
279 |> full_simplify (put_simpset SumTree.sumcase_split_ss ctxt) |
275 |> full_simplify (put_simpset Sum_Tree.sumcase_split_ss ctxt) |
280 |> full_simplify (put_simpset HOL_basic_ss ctxt addsimps all_f_defs) |
276 |> full_simplify (put_simpset HOL_basic_ss ctxt addsimps all_f_defs) |
281 |
277 |
282 fun project rule (MutualPart {cargTs, i, ...}) k = |
278 fun project rule (MutualPart {cargTs, i, ...}) k = |
283 let |
279 let |
284 val afs = map_index (fn (j,T) => Free ("a" ^ string_of_int (j + k), T)) cargTs (* FIXME! *) |
280 val afs = map_index (fn (j,T) => Free ("a" ^ string_of_int (j + k), T)) cargTs (* FIXME! *) |
285 val inj = SumTree.mk_inj ST n i (foldr1 HOLogic.mk_prod afs) |
281 val inj = Sum_Tree.mk_inj ST n i (foldr1 HOLogic.mk_prod afs) |
286 in |
282 in |
287 (rule |
283 (rule |
288 |> Thm.forall_elim (cert inj) |
284 |> Thm.forall_elim (cert inj) |
289 |> full_simplify (put_simpset SumTree.sumcase_split_ss ctxt) |
285 |> full_simplify (put_simpset Sum_Tree.sumcase_split_ss ctxt) |
290 |> fold_rev (Thm.forall_intr o cert) (afs @ newPs), |
286 |> fold_rev (Thm.forall_intr o cert) (afs @ newPs), |
291 k + length cargTs) |
287 k + length cargTs) |
292 end |
288 end |
293 in |
289 in |
294 fst (fold_map (project induct_inst) parts 0) |
290 fst (fold_map (project induct_inst) parts 0) |
425 (* defining the auxiliary graph *) |
421 (* defining the auxiliary graph *) |
426 fun mk_cases (MutualPart {i', fvar as (n, T), ...}) = |
422 fun mk_cases (MutualPart {i', fvar as (n, T), ...}) = |
427 let |
423 let |
428 val (tys, ty) = strip_type T |
424 val (tys, ty) = strip_type T |
429 val fun_var = Free (n ^ "_aux", HOLogic.mk_tupleT tys --> ty) |
425 val fun_var = Free (n ^ "_aux", HOLogic.mk_tupleT tys --> ty) |
430 val inj_fun = absdummy dummyT (SumTree.mk_inj RST n' i' (Bound 0)) |
426 val inj_fun = absdummy dummyT (Sum_Tree.mk_inj RST n' i' (Bound 0)) |
431 in |
427 in |
432 Syntax.check_term lthy'' (mk_comp_dummy inj_fun fun_var) |
428 Syntax.check_term lthy'' (mk_comp_dummy inj_fun fun_var) |
433 end |
429 end |
434 |
430 |
435 val sum_case_exp = map mk_cases parts |
431 val case_sum_exp = map mk_cases parts |
436 |> SumTree.mk_sumcases RST |
432 |> Sum_Tree.mk_sumcases RST |
437 |
433 |
438 val (G_name, G_type) = dest_Free G |
434 val (G_name, G_type) = dest_Free G |
439 val G_name_aux = G_name ^ "_aux" |
435 val G_name_aux = G_name ^ "_aux" |
440 val subst = [(G, Free (G_name_aux, G_type))] |
436 val subst = [(G, Free (G_name_aux, G_type))] |
441 val GIntros_aux = GIntro_thms |
437 val GIntros_aux = GIntro_thms |
442 |> map prop_of |
438 |> map prop_of |
443 |> map (Term.subst_free subst) |
439 |> map (Term.subst_free subst) |
444 |> map (subst_all sum_case_exp) |
440 |> map (subst_all case_sum_exp) |
445 |
441 |
446 val ((G_aux, GIntro_aux_thms, _, G_aux_induct), lthy''') = |
442 val ((G_aux, GIntro_aux_thms, _, G_aux_induct), lthy''') = |
447 Nominal_Function_Core.inductive_def ((Binding.name G_name_aux, G_type), NoSyn) GIntros_aux lthy'' |
443 Nominal_Function_Core.inductive_def ((Binding.name G_name_aux, G_type), NoSyn) GIntros_aux lthy'' |
448 |
444 |
449 val mutual_cont = mk_partial_rules_mutual lthy''' cont mutual' |
445 val mutual_cont = mk_partial_rules_mutual lthy''' cont mutual' |
454 val G_aux_prem = HOLogic.mk_Trueprop (G_aux $ x $ y) |
450 val G_aux_prem = HOLogic.mk_Trueprop (G_aux $ x $ y) |
455 val G_prem = HOLogic.mk_Trueprop (G $ x $ y) |
451 val G_prem = HOLogic.mk_Trueprop (G $ x $ y) |
456 |
452 |
457 fun mk_inj_goal (MutualPart {i', ...}) = |
453 fun mk_inj_goal (MutualPart {i', ...}) = |
458 let |
454 let |
459 val injs = SumTree.mk_inj ST n' i' (Bound 0) |
455 val injs = Sum_Tree.mk_inj ST n' i' (Bound 0) |
460 val projs = y |
456 val projs = y |
461 |> SumTree.mk_proj RST n' i' |
457 |> Sum_Tree.mk_proj RST n' i' |
462 |> SumTree.mk_inj RST n' i' |
458 |> Sum_Tree.mk_inj RST n' i' |
463 in |
459 in |
464 Const (@{const_name "All"}, dummyT) $ absdummy dummyT |
460 Const (@{const_name "All"}, dummyT) $ absdummy dummyT |
465 (HOLogic.mk_imp (HOLogic.mk_eq(x, injs), HOLogic.mk_eq(projs, y))) |
461 (HOLogic.mk_imp (HOLogic.mk_eq(x, injs), HOLogic.mk_eq(projs, y))) |
466 end |
462 end |
467 |
463 |
472 val goal_iff1 = Logic.mk_implies (G_aux_prem, G_prem) |
468 val goal_iff1 = Logic.mk_implies (G_aux_prem, G_prem) |
473 |> all x |> all y |
469 |> all x |> all y |
474 val goal_iff2 = Logic.mk_implies (G_prem, G_aux_prem) |
470 val goal_iff2 = Logic.mk_implies (G_prem, G_aux_prem) |
475 |> all x |> all y |
471 |> all x |> all y |
476 |
472 |
477 val simp_thms = @{thms Projl.simps Projr.simps sum.inject sum.cases sum.distinct o_apply} |
473 val simp_thms = @{thms sum.sel sum.inject sum.case sum.distinct o_apply} |
478 val simpset0 = put_simpset HOL_basic_ss lthy''' addsimps simp_thms |
474 val simpset0 = put_simpset HOL_basic_ss lthy''' addsimps simp_thms |
479 val simpset1 = put_simpset HOL_ss lthy''' addsimps simp_thms |
475 val simpset1 = put_simpset HOL_ss lthy''' addsimps simp_thms |
480 |
476 |
481 val inj_thm = Goal.prove lthy''' [] [] goal_inj |
477 val inj_thm = Goal.prove lthy''' [] [] goal_inj |
482 (K (HEADGOAL (DETERM o etac G_aux_induct THEN_ALL_NEW asm_simp_tac simpset1))) |
478 (K (HEADGOAL (DETERM o etac G_aux_induct THEN_ALL_NEW asm_simp_tac simpset1))) |