Nominal/nominal_mutual.ML
changeset 2665 16b5a67ee279
child 2745 34df2cffe259
equal deleted inserted replaced
2664:a9a1ed3f5023 2665:16b5a67ee279
       
     1 (*  Nominal Mutual Functions
       
     2     Author:  Christian Urban
       
     3 
       
     4     heavily based on the code of Alexander Krauss
       
     5     (code forked on 14 January 2011)
       
     6 
       
     7 
       
     8 Mutual recursive nominal function definitions.
       
     9 *)
       
    10 
       
    11 signature NOMINAL_FUNCTION_MUTUAL =
       
    12 sig
       
    13 
       
    14   val prepare_nominal_function_mutual : Function_Common.function_config
       
    15     -> string (* defname *)
       
    16     -> ((string * typ) * mixfix) list
       
    17     -> term list
       
    18     -> local_theory
       
    19     -> ((thm (* goalstate *)
       
    20         * (thm -> Function_Common.function_result) (* proof continuation *)
       
    21        ) * local_theory)
       
    22 
       
    23 end
       
    24 
       
    25 
       
    26 structure Nominal_Function_Mutual: NOMINAL_FUNCTION_MUTUAL =
       
    27 struct
       
    28 
       
    29 open Function_Lib
       
    30 open Function_Common
       
    31 
       
    32 type qgar = string * (string * typ) list * term list * term list * term
       
    33 
       
    34 datatype mutual_part = MutualPart of
       
    35  {i : int,
       
    36   i' : int,
       
    37   fvar : string * typ,
       
    38   cargTs: typ list,
       
    39   f_def: term,
       
    40 
       
    41   f: term option,
       
    42   f_defthm : thm option}
       
    43 
       
    44 datatype mutual_info = Mutual of
       
    45  {n : int,
       
    46   n' : int,
       
    47   fsum_var : string * typ,
       
    48 
       
    49   ST: typ,
       
    50   RST: typ,
       
    51 
       
    52   parts: mutual_part list,
       
    53   fqgars: qgar list,
       
    54   qglrs: ((string * typ) list * term list * term * term) list,
       
    55 
       
    56   fsum : term option}
       
    57 
       
    58 fun mutual_induct_Pnames n =
       
    59   if n < 5 then fst (chop n ["P","Q","R","S"])
       
    60   else map (fn i => "P" ^ string_of_int i) (1 upto n)
       
    61 
       
    62 fun get_part fname =
       
    63   the o find_first (fn (MutualPart {fvar=(n,_), ...}) => n = fname)
       
    64 
       
    65 (* FIXME *)
       
    66 fun mk_prod_abs e (t1, t2) =
       
    67   let
       
    68     val bTs = rev (map snd e)
       
    69     val T1 = fastype_of1 (bTs, t1)
       
    70     val T2 = fastype_of1 (bTs, t2)
       
    71   in
       
    72     HOLogic.pair_const T1 T2 $ t1 $ t2
       
    73   end
       
    74 
       
    75 fun analyze_eqs ctxt defname fs eqs =
       
    76   let
       
    77     val num = length fs
       
    78     val fqgars = map (split_def ctxt (K true)) eqs
       
    79     val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars
       
    80       |> AList.lookup (op =) #> the
       
    81 
       
    82     fun curried_types (fname, fT) =
       
    83       let
       
    84         val (caTs, uaTs) = chop (arity_of fname) (binder_types fT)
       
    85       in
       
    86         (caTs, uaTs ---> body_type fT)
       
    87       end
       
    88 
       
    89     val (caTss, resultTs) = split_list (map curried_types fs)
       
    90     val argTs = map (foldr1 HOLogic.mk_prodT) caTss
       
    91 
       
    92     val dresultTs = distinct (op =) resultTs
       
    93     val n' = length dresultTs
       
    94 
       
    95     val RST = Balanced_Tree.make (uncurry SumTree.mk_sumT) dresultTs
       
    96     val ST = Balanced_Tree.make (uncurry SumTree.mk_sumT) argTs
       
    97 
       
    98     val fsum_type = ST --> RST
       
    99 
       
   100     val ([fsum_var_name], _) = Variable.add_fixes [ defname ^ "_sum" ] ctxt
       
   101     val fsum_var = (fsum_var_name, fsum_type)
       
   102 
       
   103     fun define (fvar as (n, _)) caTs resultT i =
       
   104       let
       
   105         val vars = map_index (fn (j,T) => Free ("x" ^ string_of_int j, T)) caTs (* FIXME: Bind xs properly *)
       
   106         val i' = find_index (fn Ta => Ta = resultT) dresultTs + 1
       
   107 
       
   108         val f_exp = SumTree.mk_proj RST n' i' (Free fsum_var $ SumTree.mk_inj ST num i (foldr1 HOLogic.mk_prod vars))
       
   109         val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp)
       
   110 
       
   111         val rew = (n, fold_rev lambda vars f_exp)
       
   112       in
       
   113         (MutualPart {i=i, i'=i', fvar=fvar,cargTs=caTs,f_def=def,f=NONE,f_defthm=NONE}, rew)
       
   114       end
       
   115 
       
   116     val (parts, rews) = split_list (map4 define fs caTss resultTs (1 upto num))
       
   117 
       
   118     fun convert_eqs (f, qs, gs, args, rhs) =
       
   119       let
       
   120         val MutualPart {i, i', ...} = get_part f parts
       
   121       in
       
   122         (qs, gs, SumTree.mk_inj ST num i (foldr1 (mk_prod_abs qs) args),
       
   123          SumTree.mk_inj RST n' i' (replace_frees rews rhs)
       
   124          |> Envir.beta_norm)
       
   125       end
       
   126 
       
   127     val qglrs = map convert_eqs fqgars
       
   128   in
       
   129     Mutual {n=num, n'=n', fsum_var=fsum_var, ST=ST, RST=RST,
       
   130       parts=parts, fqgars=fqgars, qglrs=qglrs, fsum=NONE}
       
   131   end
       
   132 
       
   133 fun define_projections fixes mutual fsum lthy =
       
   134   let
       
   135     fun def ((MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs, f_def, ...}), (_, mixfix)) lthy =
       
   136       let
       
   137         val ((f, (_, f_defthm)), lthy') =
       
   138           Local_Theory.define
       
   139             ((Binding.name fname, mixfix),
       
   140               ((Binding.conceal (Binding.name (fname ^ "_def")), []),
       
   141               Term.subst_bound (fsum, f_def))) lthy
       
   142       in
       
   143         (MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs=cargTs, f_def=f_def,
       
   144            f=SOME f, f_defthm=SOME f_defthm },
       
   145          lthy')
       
   146       end
       
   147 
       
   148     val Mutual { n, n', fsum_var, ST, RST, parts, fqgars, qglrs, ... } = mutual
       
   149     val (parts', lthy') = fold_map def (parts ~~ fixes) lthy
       
   150   in
       
   151     (Mutual { n=n, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, parts=parts',
       
   152        fqgars=fqgars, qglrs=qglrs, fsum=SOME fsum },
       
   153      lthy')
       
   154   end
       
   155 
       
   156 fun in_context ctxt (f, pre_qs, pre_gs, pre_args, pre_rhs) F =
       
   157   let
       
   158     val thy = ProofContext.theory_of ctxt
       
   159 
       
   160     val oqnames = map fst pre_qs
       
   161     val (qs, _) = Variable.variant_fixes oqnames ctxt
       
   162       |>> map2 (fn (_, T) => fn n => Free (n, T)) pre_qs
       
   163 
       
   164     fun inst t = subst_bounds (rev qs, t)
       
   165     val gs = map inst pre_gs
       
   166     val args = map inst pre_args
       
   167     val rhs = inst pre_rhs
       
   168 
       
   169     val cqs = map (cterm_of thy) qs
       
   170     val ags = map (Thm.assume o cterm_of thy) gs
       
   171 
       
   172     val import = fold Thm.forall_elim cqs
       
   173       #> fold Thm.elim_implies ags
       
   174 
       
   175     val export = fold_rev (Thm.implies_intr o cprop_of) ags
       
   176       #> fold_rev forall_intr_rename (oqnames ~~ cqs)
       
   177   in
       
   178     F ctxt (f, qs, gs, args, rhs) import export
       
   179   end
       
   180 
       
   181 fun recover_mutual_psimp all_orig_fdefs parts ctxt (fname, _, _, args, rhs)
       
   182   import (export : thm -> thm) sum_psimp_eq =
       
   183   let
       
   184     val (MutualPart {f=SOME f, ...}) = get_part fname parts
       
   185 
       
   186     val psimp = import sum_psimp_eq
       
   187     val (simp, restore_cond) =
       
   188       case cprems_of psimp of
       
   189         [] => (psimp, I)
       
   190       | [cond] => (Thm.implies_elim psimp (Thm.assume cond), Thm.implies_intr cond)
       
   191       | _ => raise General.Fail "Too many conditions"
       
   192 
       
   193   in
       
   194     Goal.prove ctxt [] []
       
   195       (HOLogic.Trueprop $ HOLogic.mk_eq (list_comb (f, args), rhs))
       
   196       (fn _ => (Local_Defs.unfold_tac ctxt all_orig_fdefs)
       
   197          THEN EqSubst.eqsubst_tac ctxt [0] [simp] 1
       
   198          THEN (simp_tac (simpset_of ctxt)) 1) (* FIXME: global simpset?!! *)
       
   199     |> restore_cond
       
   200     |> export
       
   201   end
       
   202 
       
   203 fun mk_applied_form ctxt caTs thm =
       
   204   let
       
   205     val thy = ProofContext.theory_of ctxt
       
   206     val xs = map_index (fn (i,T) => cterm_of thy (Free ("x" ^ string_of_int i, T))) caTs (* FIXME: Bind xs properly *)
       
   207   in
       
   208     fold (fn x => fn thm => Thm.combination thm (Thm.reflexive x)) xs thm
       
   209     |> Conv.fconv_rule (Thm.beta_conversion true)
       
   210     |> fold_rev Thm.forall_intr xs
       
   211     |> Thm.forall_elim_vars 0
       
   212   end
       
   213 
       
   214 fun mutual_induct_rules lthy induct all_f_defs (Mutual {n, ST, parts, ...}) =
       
   215   let
       
   216     val cert = cterm_of (ProofContext.theory_of lthy)
       
   217     val newPs =
       
   218       map2 (fn Pname => fn MutualPart {cargTs, ...} =>
       
   219           Free (Pname, cargTs ---> HOLogic.boolT))
       
   220         (mutual_induct_Pnames (length parts)) parts
       
   221 
       
   222     fun mk_P (MutualPart {cargTs, ...}) P =
       
   223       let
       
   224         val avars = map_index (fn (i,T) => Var (("a", i), T)) cargTs
       
   225         val atup = foldr1 HOLogic.mk_prod avars
       
   226       in
       
   227         HOLogic.tupled_lambda atup (list_comb (P, avars))
       
   228       end
       
   229 
       
   230     val Ps = map2 mk_P parts newPs
       
   231     val case_exp = SumTree.mk_sumcases HOLogic.boolT Ps
       
   232 
       
   233     val induct_inst =
       
   234       Thm.forall_elim (cert case_exp) induct
       
   235       |> full_simplify SumTree.sumcase_split_ss
       
   236       |> full_simplify (HOL_basic_ss addsimps all_f_defs)
       
   237 
       
   238     fun project rule (MutualPart {cargTs, i, ...}) k =
       
   239       let
       
   240         val afs = map_index (fn (j,T) => Free ("a" ^ string_of_int (j + k), T)) cargTs (* FIXME! *)
       
   241         val inj = SumTree.mk_inj ST n i (foldr1 HOLogic.mk_prod afs)
       
   242       in
       
   243         (rule
       
   244          |> Thm.forall_elim (cert inj)
       
   245          |> full_simplify SumTree.sumcase_split_ss
       
   246          |> fold_rev (Thm.forall_intr o cert) (afs @ newPs),
       
   247          k + length cargTs)
       
   248       end
       
   249   in
       
   250     fst (fold_map (project induct_inst) parts 0)
       
   251   end
       
   252 
       
   253 fun mk_partial_rules_mutual lthy inner_cont (m as Mutual {parts, fqgars, ...}) proof =
       
   254   let
       
   255     val result = inner_cont proof
       
   256     val FunctionResult {G, R, cases, psimps, trsimps, simple_pinducts=[simple_pinduct],
       
   257       termination, domintros, ...} = result
       
   258 
       
   259     val (all_f_defs, fs) =
       
   260       map (fn MutualPart {f_defthm = SOME f_def, f = SOME f, cargTs, ...} =>
       
   261         (mk_applied_form lthy cargTs (Thm.symmetric f_def), f))
       
   262       parts
       
   263       |> split_list
       
   264 
       
   265     val all_orig_fdefs =
       
   266       map (fn MutualPart {f_defthm = SOME f_def, ...} => f_def) parts
       
   267 
       
   268     fun mk_mpsimp fqgar sum_psimp =
       
   269       in_context lthy fqgar (recover_mutual_psimp all_orig_fdefs parts) sum_psimp
       
   270 
       
   271     val rew_ss = HOL_basic_ss addsimps all_f_defs
       
   272     val mpsimps = map2 mk_mpsimp fqgars psimps
       
   273     val mtrsimps = Option.map (map2 mk_mpsimp fqgars) trsimps
       
   274     val minducts = mutual_induct_rules lthy simple_pinduct all_f_defs m
       
   275     val mtermination = full_simplify rew_ss termination
       
   276     val mdomintros = Option.map (map (full_simplify rew_ss)) domintros
       
   277   in
       
   278     FunctionResult { fs=fs, G=G, R=R,
       
   279       psimps=mpsimps, simple_pinducts=minducts,
       
   280       cases=cases, termination=mtermination,
       
   281       domintros=mdomintros, trsimps=mtrsimps}
       
   282   end
       
   283 
       
   284 (* nominal *)
       
   285 fun prepare_nominal_function_mutual config defname fixes eqss lthy =
       
   286   let
       
   287     val mutual as Mutual {fsum_var=(n, T), qglrs, ...} =
       
   288       analyze_eqs lthy defname (map fst fixes) (map Envir.beta_eta_contract eqss)
       
   289 
       
   290     val ((fsum, goalstate, cont), lthy') =
       
   291       Nominal_Function_Core.prepare_nominal_function config defname [((n, T), NoSyn)] qglrs lthy
       
   292 
       
   293     val (mutual', lthy'') = define_projections fixes mutual fsum lthy'
       
   294 
       
   295     val mutual_cont = mk_partial_rules_mutual lthy'' cont mutual'
       
   296   in
       
   297     ((goalstate, mutual_cont), lthy'')
       
   298   end
       
   299 
       
   300 end