diff -r a9a1ed3f5023 -r 16b5a67ee279 Nominal/nominal_mutual.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/Nominal/nominal_mutual.ML Mon Jan 17 14:37:18 2011 +0100 @@ -0,0 +1,300 @@ +(* Nominal Mutual Functions + Author: Christian Urban + + heavily based on the code of Alexander Krauss + (code forked on 14 January 2011) + + +Mutual recursive nominal function definitions. +*) + +signature NOMINAL_FUNCTION_MUTUAL = +sig + + val prepare_nominal_function_mutual : Function_Common.function_config + -> string (* defname *) + -> ((string * typ) * mixfix) list + -> term list + -> local_theory + -> ((thm (* goalstate *) + * (thm -> Function_Common.function_result) (* proof continuation *) + ) * local_theory) + +end + + +structure Nominal_Function_Mutual: NOMINAL_FUNCTION_MUTUAL = +struct + +open Function_Lib +open Function_Common + +type qgar = string * (string * typ) list * term list * term list * term + +datatype mutual_part = MutualPart of + {i : int, + i' : int, + fvar : string * typ, + cargTs: typ list, + f_def: term, + + f: term option, + f_defthm : thm option} + +datatype mutual_info = Mutual of + {n : int, + n' : int, + fsum_var : string * typ, + + ST: typ, + RST: typ, + + parts: mutual_part list, + fqgars: qgar list, + qglrs: ((string * typ) list * term list * term * term) list, + + fsum : term option} + +fun mutual_induct_Pnames n = + if n < 5 then fst (chop n ["P","Q","R","S"]) + else map (fn i => "P" ^ string_of_int i) (1 upto n) + +fun get_part fname = + the o find_first (fn (MutualPart {fvar=(n,_), ...}) => n = fname) + +(* FIXME *) +fun mk_prod_abs e (t1, t2) = + let + val bTs = rev (map snd e) + val T1 = fastype_of1 (bTs, t1) + val T2 = fastype_of1 (bTs, t2) + in + HOLogic.pair_const T1 T2 $ t1 $ t2 + end + +fun analyze_eqs ctxt defname fs eqs = + let + val num = length fs + val fqgars = map (split_def ctxt (K true)) eqs + val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars + |> AList.lookup (op =) #> the + + fun curried_types (fname, fT) = + let + val (caTs, uaTs) = chop (arity_of fname) (binder_types fT) + in + (caTs, uaTs ---> body_type fT) + end + + val (caTss, resultTs) = split_list (map curried_types fs) + val argTs = map (foldr1 HOLogic.mk_prodT) caTss + + val dresultTs = distinct (op =) resultTs + val n' = length dresultTs + + val RST = Balanced_Tree.make (uncurry SumTree.mk_sumT) dresultTs + val ST = Balanced_Tree.make (uncurry SumTree.mk_sumT) argTs + + val fsum_type = ST --> RST + + val ([fsum_var_name], _) = Variable.add_fixes [ defname ^ "_sum" ] ctxt + val fsum_var = (fsum_var_name, fsum_type) + + fun define (fvar as (n, _)) caTs resultT i = + let + val vars = map_index (fn (j,T) => Free ("x" ^ string_of_int j, T)) caTs (* FIXME: Bind xs properly *) + val i' = find_index (fn Ta => Ta = resultT) dresultTs + 1 + + val f_exp = SumTree.mk_proj RST n' i' (Free fsum_var $ SumTree.mk_inj ST num i (foldr1 HOLogic.mk_prod vars)) + val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp) + + val rew = (n, fold_rev lambda vars f_exp) + in + (MutualPart {i=i, i'=i', fvar=fvar,cargTs=caTs,f_def=def,f=NONE,f_defthm=NONE}, rew) + end + + val (parts, rews) = split_list (map4 define fs caTss resultTs (1 upto num)) + + fun convert_eqs (f, qs, gs, args, rhs) = + let + val MutualPart {i, i', ...} = get_part f parts + in + (qs, gs, SumTree.mk_inj ST num i (foldr1 (mk_prod_abs qs) args), + SumTree.mk_inj RST n' i' (replace_frees rews rhs) + |> Envir.beta_norm) + end + + val qglrs = map convert_eqs fqgars + in + Mutual {n=num, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, + parts=parts, fqgars=fqgars, qglrs=qglrs, fsum=NONE} + end + +fun define_projections fixes mutual fsum lthy = + let + fun def ((MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs, f_def, ...}), (_, mixfix)) lthy = + let + val ((f, (_, f_defthm)), lthy') = + Local_Theory.define + ((Binding.name fname, mixfix), + ((Binding.conceal (Binding.name (fname ^ "_def")), []), + Term.subst_bound (fsum, f_def))) lthy + in + (MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs=cargTs, f_def=f_def, + f=SOME f, f_defthm=SOME f_defthm }, + lthy') + end + + val Mutual { n, n', fsum_var, ST, RST, parts, fqgars, qglrs, ... } = mutual + val (parts', lthy') = fold_map def (parts ~~ fixes) lthy + in + (Mutual { n=n, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, parts=parts', + fqgars=fqgars, qglrs=qglrs, fsum=SOME fsum }, + lthy') + end + +fun in_context ctxt (f, pre_qs, pre_gs, pre_args, pre_rhs) F = + let + val thy = ProofContext.theory_of ctxt + + val oqnames = map fst pre_qs + val (qs, _) = Variable.variant_fixes oqnames ctxt + |>> map2 (fn (_, T) => fn n => Free (n, T)) pre_qs + + fun inst t = subst_bounds (rev qs, t) + val gs = map inst pre_gs + val args = map inst pre_args + val rhs = inst pre_rhs + + val cqs = map (cterm_of thy) qs + val ags = map (Thm.assume o cterm_of thy) gs + + val import = fold Thm.forall_elim cqs + #> fold Thm.elim_implies ags + + val export = fold_rev (Thm.implies_intr o cprop_of) ags + #> fold_rev forall_intr_rename (oqnames ~~ cqs) + in + F ctxt (f, qs, gs, args, rhs) import export + end + +fun recover_mutual_psimp all_orig_fdefs parts ctxt (fname, _, _, args, rhs) + import (export : thm -> thm) sum_psimp_eq = + let + val (MutualPart {f=SOME f, ...}) = get_part fname parts + + val psimp = import sum_psimp_eq + val (simp, restore_cond) = + case cprems_of psimp of + [] => (psimp, I) + | [cond] => (Thm.implies_elim psimp (Thm.assume cond), Thm.implies_intr cond) + | _ => raise General.Fail "Too many conditions" + + in + Goal.prove ctxt [] [] + (HOLogic.Trueprop $ HOLogic.mk_eq (list_comb (f, args), rhs)) + (fn _ => (Local_Defs.unfold_tac ctxt all_orig_fdefs) + THEN EqSubst.eqsubst_tac ctxt [0] [simp] 1 + THEN (simp_tac (simpset_of ctxt)) 1) (* FIXME: global simpset?!! *) + |> restore_cond + |> export + end + +fun mk_applied_form ctxt caTs thm = + let + val thy = ProofContext.theory_of ctxt + val xs = map_index (fn (i,T) => cterm_of thy (Free ("x" ^ string_of_int i, T))) caTs (* FIXME: Bind xs properly *) + in + fold (fn x => fn thm => Thm.combination thm (Thm.reflexive x)) xs thm + |> Conv.fconv_rule (Thm.beta_conversion true) + |> fold_rev Thm.forall_intr xs + |> Thm.forall_elim_vars 0 + end + +fun mutual_induct_rules lthy induct all_f_defs (Mutual {n, ST, parts, ...}) = + let + val cert = cterm_of (ProofContext.theory_of lthy) + val newPs = + map2 (fn Pname => fn MutualPart {cargTs, ...} => + Free (Pname, cargTs ---> HOLogic.boolT)) + (mutual_induct_Pnames (length parts)) parts + + fun mk_P (MutualPart {cargTs, ...}) P = + let + val avars = map_index (fn (i,T) => Var (("a", i), T)) cargTs + val atup = foldr1 HOLogic.mk_prod avars + in + HOLogic.tupled_lambda atup (list_comb (P, avars)) + end + + val Ps = map2 mk_P parts newPs + val case_exp = SumTree.mk_sumcases HOLogic.boolT Ps + + val induct_inst = + Thm.forall_elim (cert case_exp) induct + |> full_simplify SumTree.sumcase_split_ss + |> full_simplify (HOL_basic_ss addsimps all_f_defs) + + fun project rule (MutualPart {cargTs, i, ...}) k = + let + val afs = map_index (fn (j,T) => Free ("a" ^ string_of_int (j + k), T)) cargTs (* FIXME! *) + val inj = SumTree.mk_inj ST n i (foldr1 HOLogic.mk_prod afs) + in + (rule + |> Thm.forall_elim (cert inj) + |> full_simplify SumTree.sumcase_split_ss + |> fold_rev (Thm.forall_intr o cert) (afs @ newPs), + k + length cargTs) + end + in + fst (fold_map (project induct_inst) parts 0) + end + +fun mk_partial_rules_mutual lthy inner_cont (m as Mutual {parts, fqgars, ...}) proof = + let + val result = inner_cont proof + val FunctionResult {G, R, cases, psimps, trsimps, simple_pinducts=[simple_pinduct], + termination, domintros, ...} = result + + val (all_f_defs, fs) = + map (fn MutualPart {f_defthm = SOME f_def, f = SOME f, cargTs, ...} => + (mk_applied_form lthy cargTs (Thm.symmetric f_def), f)) + parts + |> split_list + + val all_orig_fdefs = + map (fn MutualPart {f_defthm = SOME f_def, ...} => f_def) parts + + fun mk_mpsimp fqgar sum_psimp = + in_context lthy fqgar (recover_mutual_psimp all_orig_fdefs parts) sum_psimp + + val rew_ss = HOL_basic_ss addsimps all_f_defs + val mpsimps = map2 mk_mpsimp fqgars psimps + val mtrsimps = Option.map (map2 mk_mpsimp fqgars) trsimps + val minducts = mutual_induct_rules lthy simple_pinduct all_f_defs m + val mtermination = full_simplify rew_ss termination + val mdomintros = Option.map (map (full_simplify rew_ss)) domintros + in + FunctionResult { fs=fs, G=G, R=R, + psimps=mpsimps, simple_pinducts=minducts, + cases=cases, termination=mtermination, + domintros=mdomintros, trsimps=mtrsimps} + end + +(* nominal *) +fun prepare_nominal_function_mutual config defname fixes eqss lthy = + let + val mutual as Mutual {fsum_var=(n, T), qglrs, ...} = + analyze_eqs lthy defname (map fst fixes) (map Envir.beta_eta_contract eqss) + + val ((fsum, goalstate, cont), lthy') = + Nominal_Function_Core.prepare_nominal_function config defname [((n, T), NoSyn)] qglrs lthy + + val (mutual', lthy'') = define_projections fixes mutual fsum lthy' + + val mutual_cont = mk_partial_rules_mutual lthy'' cont mutual' + in + ((goalstate, mutual_cont), lthy'') + end + +end