|
1 (* Nominal Mutual Functions |
|
2 Author: Christian Urban |
|
3 |
|
4 heavily based on the code of Alexander Krauss |
|
5 (code forked on 14 January 2011) |
|
6 |
|
7 |
|
8 Mutual recursive nominal function definitions. |
|
9 *) |
|
10 |
|
11 signature NOMINAL_FUNCTION_MUTUAL = |
|
12 sig |
|
13 |
|
14 val prepare_nominal_function_mutual : Function_Common.function_config |
|
15 -> string (* defname *) |
|
16 -> ((string * typ) * mixfix) list |
|
17 -> term list |
|
18 -> local_theory |
|
19 -> ((thm (* goalstate *) |
|
20 * (thm -> Function_Common.function_result) (* proof continuation *) |
|
21 ) * local_theory) |
|
22 |
|
23 end |
|
24 |
|
25 |
|
26 structure Nominal_Function_Mutual: NOMINAL_FUNCTION_MUTUAL = |
|
27 struct |
|
28 |
|
29 open Function_Lib |
|
30 open Function_Common |
|
31 |
|
32 type qgar = string * (string * typ) list * term list * term list * term |
|
33 |
|
34 datatype mutual_part = MutualPart of |
|
35 {i : int, |
|
36 i' : int, |
|
37 fvar : string * typ, |
|
38 cargTs: typ list, |
|
39 f_def: term, |
|
40 |
|
41 f: term option, |
|
42 f_defthm : thm option} |
|
43 |
|
44 datatype mutual_info = Mutual of |
|
45 {n : int, |
|
46 n' : int, |
|
47 fsum_var : string * typ, |
|
48 |
|
49 ST: typ, |
|
50 RST: typ, |
|
51 |
|
52 parts: mutual_part list, |
|
53 fqgars: qgar list, |
|
54 qglrs: ((string * typ) list * term list * term * term) list, |
|
55 |
|
56 fsum : term option} |
|
57 |
|
58 fun mutual_induct_Pnames n = |
|
59 if n < 5 then fst (chop n ["P","Q","R","S"]) |
|
60 else map (fn i => "P" ^ string_of_int i) (1 upto n) |
|
61 |
|
62 fun get_part fname = |
|
63 the o find_first (fn (MutualPart {fvar=(n,_), ...}) => n = fname) |
|
64 |
|
65 (* FIXME *) |
|
66 fun mk_prod_abs e (t1, t2) = |
|
67 let |
|
68 val bTs = rev (map snd e) |
|
69 val T1 = fastype_of1 (bTs, t1) |
|
70 val T2 = fastype_of1 (bTs, t2) |
|
71 in |
|
72 HOLogic.pair_const T1 T2 $ t1 $ t2 |
|
73 end |
|
74 |
|
75 fun analyze_eqs ctxt defname fs eqs = |
|
76 let |
|
77 val num = length fs |
|
78 val fqgars = map (split_def ctxt (K true)) eqs |
|
79 val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars |
|
80 |> AList.lookup (op =) #> the |
|
81 |
|
82 fun curried_types (fname, fT) = |
|
83 let |
|
84 val (caTs, uaTs) = chop (arity_of fname) (binder_types fT) |
|
85 in |
|
86 (caTs, uaTs ---> body_type fT) |
|
87 end |
|
88 |
|
89 val (caTss, resultTs) = split_list (map curried_types fs) |
|
90 val argTs = map (foldr1 HOLogic.mk_prodT) caTss |
|
91 |
|
92 val dresultTs = distinct (op =) resultTs |
|
93 val n' = length dresultTs |
|
94 |
|
95 val RST = Balanced_Tree.make (uncurry SumTree.mk_sumT) dresultTs |
|
96 val ST = Balanced_Tree.make (uncurry SumTree.mk_sumT) argTs |
|
97 |
|
98 val fsum_type = ST --> RST |
|
99 |
|
100 val ([fsum_var_name], _) = Variable.add_fixes [ defname ^ "_sum" ] ctxt |
|
101 val fsum_var = (fsum_var_name, fsum_type) |
|
102 |
|
103 fun define (fvar as (n, _)) caTs resultT i = |
|
104 let |
|
105 val vars = map_index (fn (j,T) => Free ("x" ^ string_of_int j, T)) caTs (* FIXME: Bind xs properly *) |
|
106 val i' = find_index (fn Ta => Ta = resultT) dresultTs + 1 |
|
107 |
|
108 val f_exp = SumTree.mk_proj RST n' i' (Free fsum_var $ SumTree.mk_inj ST num i (foldr1 HOLogic.mk_prod vars)) |
|
109 val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp) |
|
110 |
|
111 val rew = (n, fold_rev lambda vars f_exp) |
|
112 in |
|
113 (MutualPart {i=i, i'=i', fvar=fvar,cargTs=caTs,f_def=def,f=NONE,f_defthm=NONE}, rew) |
|
114 end |
|
115 |
|
116 val (parts, rews) = split_list (map4 define fs caTss resultTs (1 upto num)) |
|
117 |
|
118 fun convert_eqs (f, qs, gs, args, rhs) = |
|
119 let |
|
120 val MutualPart {i, i', ...} = get_part f parts |
|
121 in |
|
122 (qs, gs, SumTree.mk_inj ST num i (foldr1 (mk_prod_abs qs) args), |
|
123 SumTree.mk_inj RST n' i' (replace_frees rews rhs) |
|
124 |> Envir.beta_norm) |
|
125 end |
|
126 |
|
127 val qglrs = map convert_eqs fqgars |
|
128 in |
|
129 Mutual {n=num, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, |
|
130 parts=parts, fqgars=fqgars, qglrs=qglrs, fsum=NONE} |
|
131 end |
|
132 |
|
133 fun define_projections fixes mutual fsum lthy = |
|
134 let |
|
135 fun def ((MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs, f_def, ...}), (_, mixfix)) lthy = |
|
136 let |
|
137 val ((f, (_, f_defthm)), lthy') = |
|
138 Local_Theory.define |
|
139 ((Binding.name fname, mixfix), |
|
140 ((Binding.conceal (Binding.name (fname ^ "_def")), []), |
|
141 Term.subst_bound (fsum, f_def))) lthy |
|
142 in |
|
143 (MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs=cargTs, f_def=f_def, |
|
144 f=SOME f, f_defthm=SOME f_defthm }, |
|
145 lthy') |
|
146 end |
|
147 |
|
148 val Mutual { n, n', fsum_var, ST, RST, parts, fqgars, qglrs, ... } = mutual |
|
149 val (parts', lthy') = fold_map def (parts ~~ fixes) lthy |
|
150 in |
|
151 (Mutual { n=n, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, parts=parts', |
|
152 fqgars=fqgars, qglrs=qglrs, fsum=SOME fsum }, |
|
153 lthy') |
|
154 end |
|
155 |
|
156 fun in_context ctxt (f, pre_qs, pre_gs, pre_args, pre_rhs) F = |
|
157 let |
|
158 val thy = ProofContext.theory_of ctxt |
|
159 |
|
160 val oqnames = map fst pre_qs |
|
161 val (qs, _) = Variable.variant_fixes oqnames ctxt |
|
162 |>> map2 (fn (_, T) => fn n => Free (n, T)) pre_qs |
|
163 |
|
164 fun inst t = subst_bounds (rev qs, t) |
|
165 val gs = map inst pre_gs |
|
166 val args = map inst pre_args |
|
167 val rhs = inst pre_rhs |
|
168 |
|
169 val cqs = map (cterm_of thy) qs |
|
170 val ags = map (Thm.assume o cterm_of thy) gs |
|
171 |
|
172 val import = fold Thm.forall_elim cqs |
|
173 #> fold Thm.elim_implies ags |
|
174 |
|
175 val export = fold_rev (Thm.implies_intr o cprop_of) ags |
|
176 #> fold_rev forall_intr_rename (oqnames ~~ cqs) |
|
177 in |
|
178 F ctxt (f, qs, gs, args, rhs) import export |
|
179 end |
|
180 |
|
181 fun recover_mutual_psimp all_orig_fdefs parts ctxt (fname, _, _, args, rhs) |
|
182 import (export : thm -> thm) sum_psimp_eq = |
|
183 let |
|
184 val (MutualPart {f=SOME f, ...}) = get_part fname parts |
|
185 |
|
186 val psimp = import sum_psimp_eq |
|
187 val (simp, restore_cond) = |
|
188 case cprems_of psimp of |
|
189 [] => (psimp, I) |
|
190 | [cond] => (Thm.implies_elim psimp (Thm.assume cond), Thm.implies_intr cond) |
|
191 | _ => raise General.Fail "Too many conditions" |
|
192 |
|
193 in |
|
194 Goal.prove ctxt [] [] |
|
195 (HOLogic.Trueprop $ HOLogic.mk_eq (list_comb (f, args), rhs)) |
|
196 (fn _ => (Local_Defs.unfold_tac ctxt all_orig_fdefs) |
|
197 THEN EqSubst.eqsubst_tac ctxt [0] [simp] 1 |
|
198 THEN (simp_tac (simpset_of ctxt)) 1) (* FIXME: global simpset?!! *) |
|
199 |> restore_cond |
|
200 |> export |
|
201 end |
|
202 |
|
203 fun mk_applied_form ctxt caTs thm = |
|
204 let |
|
205 val thy = ProofContext.theory_of ctxt |
|
206 val xs = map_index (fn (i,T) => cterm_of thy (Free ("x" ^ string_of_int i, T))) caTs (* FIXME: Bind xs properly *) |
|
207 in |
|
208 fold (fn x => fn thm => Thm.combination thm (Thm.reflexive x)) xs thm |
|
209 |> Conv.fconv_rule (Thm.beta_conversion true) |
|
210 |> fold_rev Thm.forall_intr xs |
|
211 |> Thm.forall_elim_vars 0 |
|
212 end |
|
213 |
|
214 fun mutual_induct_rules lthy induct all_f_defs (Mutual {n, ST, parts, ...}) = |
|
215 let |
|
216 val cert = cterm_of (ProofContext.theory_of lthy) |
|
217 val newPs = |
|
218 map2 (fn Pname => fn MutualPart {cargTs, ...} => |
|
219 Free (Pname, cargTs ---> HOLogic.boolT)) |
|
220 (mutual_induct_Pnames (length parts)) parts |
|
221 |
|
222 fun mk_P (MutualPart {cargTs, ...}) P = |
|
223 let |
|
224 val avars = map_index (fn (i,T) => Var (("a", i), T)) cargTs |
|
225 val atup = foldr1 HOLogic.mk_prod avars |
|
226 in |
|
227 HOLogic.tupled_lambda atup (list_comb (P, avars)) |
|
228 end |
|
229 |
|
230 val Ps = map2 mk_P parts newPs |
|
231 val case_exp = SumTree.mk_sumcases HOLogic.boolT Ps |
|
232 |
|
233 val induct_inst = |
|
234 Thm.forall_elim (cert case_exp) induct |
|
235 |> full_simplify SumTree.sumcase_split_ss |
|
236 |> full_simplify (HOL_basic_ss addsimps all_f_defs) |
|
237 |
|
238 fun project rule (MutualPart {cargTs, i, ...}) k = |
|
239 let |
|
240 val afs = map_index (fn (j,T) => Free ("a" ^ string_of_int (j + k), T)) cargTs (* FIXME! *) |
|
241 val inj = SumTree.mk_inj ST n i (foldr1 HOLogic.mk_prod afs) |
|
242 in |
|
243 (rule |
|
244 |> Thm.forall_elim (cert inj) |
|
245 |> full_simplify SumTree.sumcase_split_ss |
|
246 |> fold_rev (Thm.forall_intr o cert) (afs @ newPs), |
|
247 k + length cargTs) |
|
248 end |
|
249 in |
|
250 fst (fold_map (project induct_inst) parts 0) |
|
251 end |
|
252 |
|
253 fun mk_partial_rules_mutual lthy inner_cont (m as Mutual {parts, fqgars, ...}) proof = |
|
254 let |
|
255 val result = inner_cont proof |
|
256 val FunctionResult {G, R, cases, psimps, trsimps, simple_pinducts=[simple_pinduct], |
|
257 termination, domintros, ...} = result |
|
258 |
|
259 val (all_f_defs, fs) = |
|
260 map (fn MutualPart {f_defthm = SOME f_def, f = SOME f, cargTs, ...} => |
|
261 (mk_applied_form lthy cargTs (Thm.symmetric f_def), f)) |
|
262 parts |
|
263 |> split_list |
|
264 |
|
265 val all_orig_fdefs = |
|
266 map (fn MutualPart {f_defthm = SOME f_def, ...} => f_def) parts |
|
267 |
|
268 fun mk_mpsimp fqgar sum_psimp = |
|
269 in_context lthy fqgar (recover_mutual_psimp all_orig_fdefs parts) sum_psimp |
|
270 |
|
271 val rew_ss = HOL_basic_ss addsimps all_f_defs |
|
272 val mpsimps = map2 mk_mpsimp fqgars psimps |
|
273 val mtrsimps = Option.map (map2 mk_mpsimp fqgars) trsimps |
|
274 val minducts = mutual_induct_rules lthy simple_pinduct all_f_defs m |
|
275 val mtermination = full_simplify rew_ss termination |
|
276 val mdomintros = Option.map (map (full_simplify rew_ss)) domintros |
|
277 in |
|
278 FunctionResult { fs=fs, G=G, R=R, |
|
279 psimps=mpsimps, simple_pinducts=minducts, |
|
280 cases=cases, termination=mtermination, |
|
281 domintros=mdomintros, trsimps=mtrsimps} |
|
282 end |
|
283 |
|
284 (* nominal *) |
|
285 fun prepare_nominal_function_mutual config defname fixes eqss lthy = |
|
286 let |
|
287 val mutual as Mutual {fsum_var=(n, T), qglrs, ...} = |
|
288 analyze_eqs lthy defname (map fst fixes) (map Envir.beta_eta_contract eqss) |
|
289 |
|
290 val ((fsum, goalstate, cont), lthy') = |
|
291 Nominal_Function_Core.prepare_nominal_function config defname [((n, T), NoSyn)] qglrs lthy |
|
292 |
|
293 val (mutual', lthy'') = define_projections fixes mutual fsum lthy' |
|
294 |
|
295 val mutual_cont = mk_partial_rules_mutual lthy'' cont mutual' |
|
296 in |
|
297 ((goalstate, mutual_cont), lthy'') |
|
298 end |
|
299 |
|
300 end |