2665
|
1 |
(* Nominal Mutual Functions
|
|
2 |
Author: Christian Urban
|
|
3 |
|
|
4 |
heavily based on the code of Alexander Krauss
|
|
5 |
(code forked on 14 January 2011)
|
|
6 |
|
|
7 |
|
|
8 |
Mutual recursive nominal function definitions.
|
|
9 |
*)
|
|
10 |
|
|
11 |
signature NOMINAL_FUNCTION_MUTUAL =
|
|
12 |
sig
|
|
13 |
|
|
14 |
val prepare_nominal_function_mutual : Function_Common.function_config
|
|
15 |
-> string (* defname *)
|
|
16 |
-> ((string * typ) * mixfix) list
|
|
17 |
-> term list
|
|
18 |
-> local_theory
|
|
19 |
-> ((thm (* goalstate *)
|
|
20 |
* (thm -> Function_Common.function_result) (* proof continuation *)
|
|
21 |
) * local_theory)
|
|
22 |
|
|
23 |
end
|
|
24 |
|
|
25 |
|
|
26 |
structure Nominal_Function_Mutual: NOMINAL_FUNCTION_MUTUAL =
|
|
27 |
struct
|
|
28 |
|
|
29 |
open Function_Lib
|
|
30 |
open Function_Common
|
|
31 |
|
|
32 |
type qgar = string * (string * typ) list * term list * term list * term
|
|
33 |
|
|
34 |
datatype mutual_part = MutualPart of
|
|
35 |
{i : int,
|
|
36 |
i' : int,
|
|
37 |
fvar : string * typ,
|
|
38 |
cargTs: typ list,
|
|
39 |
f_def: term,
|
|
40 |
|
|
41 |
f: term option,
|
|
42 |
f_defthm : thm option}
|
|
43 |
|
|
44 |
datatype mutual_info = Mutual of
|
|
45 |
{n : int,
|
|
46 |
n' : int,
|
|
47 |
fsum_var : string * typ,
|
|
48 |
|
|
49 |
ST: typ,
|
|
50 |
RST: typ,
|
|
51 |
|
|
52 |
parts: mutual_part list,
|
|
53 |
fqgars: qgar list,
|
|
54 |
qglrs: ((string * typ) list * term list * term * term) list,
|
|
55 |
|
|
56 |
fsum : term option}
|
|
57 |
|
|
58 |
fun mutual_induct_Pnames n =
|
|
59 |
if n < 5 then fst (chop n ["P","Q","R","S"])
|
|
60 |
else map (fn i => "P" ^ string_of_int i) (1 upto n)
|
|
61 |
|
|
62 |
fun get_part fname =
|
|
63 |
the o find_first (fn (MutualPart {fvar=(n,_), ...}) => n = fname)
|
|
64 |
|
|
65 |
(* FIXME *)
|
|
66 |
fun mk_prod_abs e (t1, t2) =
|
|
67 |
let
|
|
68 |
val bTs = rev (map snd e)
|
|
69 |
val T1 = fastype_of1 (bTs, t1)
|
|
70 |
val T2 = fastype_of1 (bTs, t2)
|
|
71 |
in
|
|
72 |
HOLogic.pair_const T1 T2 $ t1 $ t2
|
|
73 |
end
|
|
74 |
|
|
75 |
fun analyze_eqs ctxt defname fs eqs =
|
|
76 |
let
|
|
77 |
val num = length fs
|
|
78 |
val fqgars = map (split_def ctxt (K true)) eqs
|
|
79 |
val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars
|
|
80 |
|> AList.lookup (op =) #> the
|
|
81 |
|
|
82 |
fun curried_types (fname, fT) =
|
|
83 |
let
|
|
84 |
val (caTs, uaTs) = chop (arity_of fname) (binder_types fT)
|
|
85 |
in
|
|
86 |
(caTs, uaTs ---> body_type fT)
|
|
87 |
end
|
|
88 |
|
|
89 |
val (caTss, resultTs) = split_list (map curried_types fs)
|
|
90 |
val argTs = map (foldr1 HOLogic.mk_prodT) caTss
|
|
91 |
|
|
92 |
val dresultTs = distinct (op =) resultTs
|
|
93 |
val n' = length dresultTs
|
|
94 |
|
|
95 |
val RST = Balanced_Tree.make (uncurry SumTree.mk_sumT) dresultTs
|
|
96 |
val ST = Balanced_Tree.make (uncurry SumTree.mk_sumT) argTs
|
|
97 |
|
|
98 |
val fsum_type = ST --> RST
|
|
99 |
|
|
100 |
val ([fsum_var_name], _) = Variable.add_fixes [ defname ^ "_sum" ] ctxt
|
|
101 |
val fsum_var = (fsum_var_name, fsum_type)
|
|
102 |
|
|
103 |
fun define (fvar as (n, _)) caTs resultT i =
|
|
104 |
let
|
|
105 |
val vars = map_index (fn (j,T) => Free ("x" ^ string_of_int j, T)) caTs (* FIXME: Bind xs properly *)
|
|
106 |
val i' = find_index (fn Ta => Ta = resultT) dresultTs + 1
|
|
107 |
|
|
108 |
val f_exp = SumTree.mk_proj RST n' i' (Free fsum_var $ SumTree.mk_inj ST num i (foldr1 HOLogic.mk_prod vars))
|
|
109 |
val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp)
|
|
110 |
|
|
111 |
val rew = (n, fold_rev lambda vars f_exp)
|
|
112 |
in
|
|
113 |
(MutualPart {i=i, i'=i', fvar=fvar,cargTs=caTs,f_def=def,f=NONE,f_defthm=NONE}, rew)
|
|
114 |
end
|
|
115 |
|
|
116 |
val (parts, rews) = split_list (map4 define fs caTss resultTs (1 upto num))
|
|
117 |
|
|
118 |
fun convert_eqs (f, qs, gs, args, rhs) =
|
|
119 |
let
|
|
120 |
val MutualPart {i, i', ...} = get_part f parts
|
|
121 |
in
|
|
122 |
(qs, gs, SumTree.mk_inj ST num i (foldr1 (mk_prod_abs qs) args),
|
|
123 |
SumTree.mk_inj RST n' i' (replace_frees rews rhs)
|
|
124 |
|> Envir.beta_norm)
|
|
125 |
end
|
|
126 |
|
|
127 |
val qglrs = map convert_eqs fqgars
|
|
128 |
in
|
|
129 |
Mutual {n=num, n'=n', fsum_var=fsum_var, ST=ST, RST=RST,
|
|
130 |
parts=parts, fqgars=fqgars, qglrs=qglrs, fsum=NONE}
|
|
131 |
end
|
|
132 |
|
|
133 |
fun define_projections fixes mutual fsum lthy =
|
|
134 |
let
|
|
135 |
fun def ((MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs, f_def, ...}), (_, mixfix)) lthy =
|
|
136 |
let
|
|
137 |
val ((f, (_, f_defthm)), lthy') =
|
|
138 |
Local_Theory.define
|
|
139 |
((Binding.name fname, mixfix),
|
|
140 |
((Binding.conceal (Binding.name (fname ^ "_def")), []),
|
|
141 |
Term.subst_bound (fsum, f_def))) lthy
|
|
142 |
in
|
|
143 |
(MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs=cargTs, f_def=f_def,
|
|
144 |
f=SOME f, f_defthm=SOME f_defthm },
|
|
145 |
lthy')
|
|
146 |
end
|
|
147 |
|
|
148 |
val Mutual { n, n', fsum_var, ST, RST, parts, fqgars, qglrs, ... } = mutual
|
|
149 |
val (parts', lthy') = fold_map def (parts ~~ fixes) lthy
|
|
150 |
in
|
|
151 |
(Mutual { n=n, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, parts=parts',
|
|
152 |
fqgars=fqgars, qglrs=qglrs, fsum=SOME fsum },
|
|
153 |
lthy')
|
|
154 |
end
|
|
155 |
|
|
156 |
fun in_context ctxt (f, pre_qs, pre_gs, pre_args, pre_rhs) F =
|
|
157 |
let
|
|
158 |
val thy = ProofContext.theory_of ctxt
|
|
159 |
|
|
160 |
val oqnames = map fst pre_qs
|
|
161 |
val (qs, _) = Variable.variant_fixes oqnames ctxt
|
|
162 |
|>> map2 (fn (_, T) => fn n => Free (n, T)) pre_qs
|
|
163 |
|
|
164 |
fun inst t = subst_bounds (rev qs, t)
|
|
165 |
val gs = map inst pre_gs
|
|
166 |
val args = map inst pre_args
|
|
167 |
val rhs = inst pre_rhs
|
|
168 |
|
|
169 |
val cqs = map (cterm_of thy) qs
|
|
170 |
val ags = map (Thm.assume o cterm_of thy) gs
|
|
171 |
|
|
172 |
val import = fold Thm.forall_elim cqs
|
|
173 |
#> fold Thm.elim_implies ags
|
|
174 |
|
|
175 |
val export = fold_rev (Thm.implies_intr o cprop_of) ags
|
|
176 |
#> fold_rev forall_intr_rename (oqnames ~~ cqs)
|
|
177 |
in
|
|
178 |
F ctxt (f, qs, gs, args, rhs) import export
|
|
179 |
end
|
|
180 |
|
|
181 |
fun recover_mutual_psimp all_orig_fdefs parts ctxt (fname, _, _, args, rhs)
|
|
182 |
import (export : thm -> thm) sum_psimp_eq =
|
|
183 |
let
|
|
184 |
val (MutualPart {f=SOME f, ...}) = get_part fname parts
|
|
185 |
|
|
186 |
val psimp = import sum_psimp_eq
|
|
187 |
val (simp, restore_cond) =
|
|
188 |
case cprems_of psimp of
|
|
189 |
[] => (psimp, I)
|
|
190 |
| [cond] => (Thm.implies_elim psimp (Thm.assume cond), Thm.implies_intr cond)
|
|
191 |
| _ => raise General.Fail "Too many conditions"
|
|
192 |
|
|
193 |
in
|
|
194 |
Goal.prove ctxt [] []
|
|
195 |
(HOLogic.Trueprop $ HOLogic.mk_eq (list_comb (f, args), rhs))
|
|
196 |
(fn _ => (Local_Defs.unfold_tac ctxt all_orig_fdefs)
|
|
197 |
THEN EqSubst.eqsubst_tac ctxt [0] [simp] 1
|
|
198 |
THEN (simp_tac (simpset_of ctxt)) 1) (* FIXME: global simpset?!! *)
|
|
199 |
|> restore_cond
|
|
200 |
|> export
|
|
201 |
end
|
|
202 |
|
|
203 |
fun mk_applied_form ctxt caTs thm =
|
|
204 |
let
|
|
205 |
val thy = ProofContext.theory_of ctxt
|
|
206 |
val xs = map_index (fn (i,T) => cterm_of thy (Free ("x" ^ string_of_int i, T))) caTs (* FIXME: Bind xs properly *)
|
|
207 |
in
|
|
208 |
fold (fn x => fn thm => Thm.combination thm (Thm.reflexive x)) xs thm
|
|
209 |
|> Conv.fconv_rule (Thm.beta_conversion true)
|
|
210 |
|> fold_rev Thm.forall_intr xs
|
|
211 |
|> Thm.forall_elim_vars 0
|
|
212 |
end
|
|
213 |
|
|
214 |
fun mutual_induct_rules lthy induct all_f_defs (Mutual {n, ST, parts, ...}) =
|
|
215 |
let
|
|
216 |
val cert = cterm_of (ProofContext.theory_of lthy)
|
|
217 |
val newPs =
|
|
218 |
map2 (fn Pname => fn MutualPart {cargTs, ...} =>
|
|
219 |
Free (Pname, cargTs ---> HOLogic.boolT))
|
|
220 |
(mutual_induct_Pnames (length parts)) parts
|
|
221 |
|
|
222 |
fun mk_P (MutualPart {cargTs, ...}) P =
|
|
223 |
let
|
|
224 |
val avars = map_index (fn (i,T) => Var (("a", i), T)) cargTs
|
|
225 |
val atup = foldr1 HOLogic.mk_prod avars
|
|
226 |
in
|
|
227 |
HOLogic.tupled_lambda atup (list_comb (P, avars))
|
|
228 |
end
|
|
229 |
|
|
230 |
val Ps = map2 mk_P parts newPs
|
|
231 |
val case_exp = SumTree.mk_sumcases HOLogic.boolT Ps
|
|
232 |
|
|
233 |
val induct_inst =
|
|
234 |
Thm.forall_elim (cert case_exp) induct
|
|
235 |
|> full_simplify SumTree.sumcase_split_ss
|
|
236 |
|> full_simplify (HOL_basic_ss addsimps all_f_defs)
|
|
237 |
|
|
238 |
fun project rule (MutualPart {cargTs, i, ...}) k =
|
|
239 |
let
|
|
240 |
val afs = map_index (fn (j,T) => Free ("a" ^ string_of_int (j + k), T)) cargTs (* FIXME! *)
|
|
241 |
val inj = SumTree.mk_inj ST n i (foldr1 HOLogic.mk_prod afs)
|
|
242 |
in
|
|
243 |
(rule
|
|
244 |
|> Thm.forall_elim (cert inj)
|
|
245 |
|> full_simplify SumTree.sumcase_split_ss
|
|
246 |
|> fold_rev (Thm.forall_intr o cert) (afs @ newPs),
|
|
247 |
k + length cargTs)
|
|
248 |
end
|
|
249 |
in
|
|
250 |
fst (fold_map (project induct_inst) parts 0)
|
|
251 |
end
|
|
252 |
|
|
253 |
fun mk_partial_rules_mutual lthy inner_cont (m as Mutual {parts, fqgars, ...}) proof =
|
|
254 |
let
|
|
255 |
val result = inner_cont proof
|
|
256 |
val FunctionResult {G, R, cases, psimps, trsimps, simple_pinducts=[simple_pinduct],
|
|
257 |
termination, domintros, ...} = result
|
|
258 |
|
|
259 |
val (all_f_defs, fs) =
|
|
260 |
map (fn MutualPart {f_defthm = SOME f_def, f = SOME f, cargTs, ...} =>
|
|
261 |
(mk_applied_form lthy cargTs (Thm.symmetric f_def), f))
|
|
262 |
parts
|
|
263 |
|> split_list
|
|
264 |
|
|
265 |
val all_orig_fdefs =
|
|
266 |
map (fn MutualPart {f_defthm = SOME f_def, ...} => f_def) parts
|
|
267 |
|
|
268 |
fun mk_mpsimp fqgar sum_psimp =
|
|
269 |
in_context lthy fqgar (recover_mutual_psimp all_orig_fdefs parts) sum_psimp
|
|
270 |
|
|
271 |
val rew_ss = HOL_basic_ss addsimps all_f_defs
|
|
272 |
val mpsimps = map2 mk_mpsimp fqgars psimps
|
|
273 |
val mtrsimps = Option.map (map2 mk_mpsimp fqgars) trsimps
|
|
274 |
val minducts = mutual_induct_rules lthy simple_pinduct all_f_defs m
|
|
275 |
val mtermination = full_simplify rew_ss termination
|
|
276 |
val mdomintros = Option.map (map (full_simplify rew_ss)) domintros
|
|
277 |
in
|
|
278 |
FunctionResult { fs=fs, G=G, R=R,
|
|
279 |
psimps=mpsimps, simple_pinducts=minducts,
|
|
280 |
cases=cases, termination=mtermination,
|
|
281 |
domintros=mdomintros, trsimps=mtrsimps}
|
|
282 |
end
|
|
283 |
|
|
284 |
(* nominal *)
|
|
285 |
fun prepare_nominal_function_mutual config defname fixes eqss lthy =
|
|
286 |
let
|
|
287 |
val mutual as Mutual {fsum_var=(n, T), qglrs, ...} =
|
|
288 |
analyze_eqs lthy defname (map fst fixes) (map Envir.beta_eta_contract eqss)
|
|
289 |
|
|
290 |
val ((fsum, goalstate, cont), lthy') =
|
|
291 |
Nominal_Function_Core.prepare_nominal_function config defname [((n, T), NoSyn)] qglrs lthy
|
|
292 |
|
|
293 |
val (mutual', lthy'') = define_projections fixes mutual fsum lthy'
|
|
294 |
|
|
295 |
val mutual_cont = mk_partial_rules_mutual lthy'' cont mutual'
|
|
296 |
in
|
|
297 |
((goalstate, mutual_cont), lthy'')
|
|
298 |
end
|
|
299 |
|
|
300 |
end
|