Nominal/nominal_mutual.ML
changeset 3229 b52e8651591f
parent 3227 35bb5b013f0e
child 3231 188826f1ccdb
equal deleted inserted replaced
3228:040519ec99e9 3229:b52e8651591f
    11 *)
    11 *)
    12 
    12 
    13 
    13 
    14 signature NOMINAL_FUNCTION_MUTUAL =
    14 signature NOMINAL_FUNCTION_MUTUAL =
    15 sig
    15 sig
    16 
       
    17   val prepare_nominal_function_mutual : Nominal_Function_Common.nominal_function_config
    16   val prepare_nominal_function_mutual : Nominal_Function_Common.nominal_function_config
    18     -> string (* defname *)
    17     -> string (* defname *)
    19     -> ((string * typ) * mixfix) list
    18     -> ((string * typ) * mixfix) list
    20     -> term list
    19     -> term list
    21     -> local_theory
    20     -> local_theory
    22     -> ((thm (* goalstate *)
    21     -> ((thm (* goalstate *)
    23         * (thm -> Nominal_Function_Common.nominal_function_result) (* proof continuation *)
    22         * (thm -> Nominal_Function_Common.nominal_function_result) (* proof continuation *)
    24        ) * local_theory)
    23        ) * local_theory)
    25 
       
    26 end
    24 end
    27 
       
    28 
    25 
    29 structure Nominal_Function_Mutual: NOMINAL_FUNCTION_MUTUAL =
    26 structure Nominal_Function_Mutual: NOMINAL_FUNCTION_MUTUAL =
    30 struct
    27 struct
    31 
    28 
    32 open Function_Lib
    29 open Function_Lib
    93     val argTs = map (foldr1 HOLogic.mk_prodT) caTss
    90     val argTs = map (foldr1 HOLogic.mk_prodT) caTss
    94 
    91 
    95     val dresultTs = distinct (op =) resultTs
    92     val dresultTs = distinct (op =) resultTs
    96     val n' = length dresultTs
    93     val n' = length dresultTs
    97 
    94 
    98     val RST = Balanced_Tree.make (uncurry SumTree.mk_sumT) dresultTs
    95     val RST = Balanced_Tree.make (uncurry Sum_Tree.mk_sumT) dresultTs
    99     val ST = Balanced_Tree.make (uncurry SumTree.mk_sumT) argTs
    96     val ST = Balanced_Tree.make (uncurry Sum_Tree.mk_sumT) argTs
   100 
    97 
   101     val fsum_type = ST --> RST
    98     val fsum_type = ST --> RST
   102 
    99 
   103     val ([fsum_var_name], _) = Variable.add_fixes [ defname ^ "_sum" ] ctxt
   100     val ([fsum_var_name], _) = Variable.add_fixes [ defname ^ "_sum" ] ctxt
   104     val fsum_var = (fsum_var_name, fsum_type)
   101     val fsum_var = (fsum_var_name, fsum_type)
   106     fun define (fvar as (n, _)) caTs resultT i =
   103     fun define (fvar as (n, _)) caTs resultT i =
   107       let
   104       let
   108         val vars = map_index (fn (j,T) => Free ("x" ^ string_of_int j, T)) caTs (* FIXME: Bind xs properly *)
   105         val vars = map_index (fn (j,T) => Free ("x" ^ string_of_int j, T)) caTs (* FIXME: Bind xs properly *)
   109         val i' = find_index (fn Ta => Ta = resultT) dresultTs + 1
   106         val i' = find_index (fn Ta => Ta = resultT) dresultTs + 1
   110 
   107 
   111         val f_exp = SumTree.mk_proj RST n' i' (Free fsum_var $ SumTree.mk_inj ST num i (foldr1 HOLogic.mk_prod vars))
   108         val f_exp = Sum_Tree.mk_proj RST n' i' (Free fsum_var $ Sum_Tree.mk_inj ST num i (foldr1 HOLogic.mk_prod vars))
   112         val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp)
   109         val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp)
   113 
   110 
   114         val rew = (n, fold_rev lambda vars f_exp)
   111         val rew = (n, fold_rev lambda vars f_exp)
   115       in
   112       in
   116         (MutualPart {i=i, i'=i', fvar=fvar,cargTs=caTs,f_def=def,f=NONE,f_defthm=NONE}, rew)
   113         (MutualPart {i=i, i'=i', fvar=fvar,cargTs=caTs,f_def=def,f=NONE,f_defthm=NONE}, rew)
   122       let
   119       let
   123         val MutualPart {i, i', ...} = get_part f parts
   120         val MutualPart {i, i', ...} = get_part f parts
   124         val rhs' = rhs
   121         val rhs' = rhs
   125              |> map_aterms (fn t as Free (n, _) => the_default t (AList.lookup (op =) rews n) | t => t)
   122              |> map_aterms (fn t as Free (n, _) => the_default t (AList.lookup (op =) rews n) | t => t)
   126       in
   123       in
   127         (qs, gs, SumTree.mk_inj ST num i (foldr1 (mk_prod_abs qs) args),
   124         (qs, gs, Sum_Tree.mk_inj ST num i (foldr1 (mk_prod_abs qs) args),
   128          Envir.beta_norm (SumTree.mk_inj RST n' i' rhs'))
   125          Envir.beta_norm (Sum_Tree.mk_inj RST n' i' rhs'))
   129       end
   126       end
   130 
   127 
   131     val qglrs = map convert_eqs fqgars
   128     val qglrs = map convert_eqs fqgars
   132   in
   129   in
   133     Mutual {n=num, n'=n', fsum_var=fsum_var, ST=ST, RST=RST,
   130     Mutual {n=num, n'=n', fsum_var=fsum_var, ST=ST, RST=RST,
   203          THEN (simp_tac ctxt') 1)
   200          THEN (simp_tac ctxt') 1)
   204     |> restore_cond
   201     |> restore_cond
   205     |> export
   202     |> export
   206   end
   203   end
   207 
   204 
   208 val inl_perm = @{lemma "x = Inl y ==> Sum_Type.Projl (permute p x) = permute p (Sum_Type.Projl x)" by simp}
   205 val inl_perm = @{lemma "x = Inl y ==> projl (permute p x) = permute p (projl x)" by simp}
   209 val inr_perm = @{lemma "x = Inr y ==> Sum_Type.Projr (permute p x) = permute p (Sum_Type.Projr x)" by simp}
   206 val inr_perm = @{lemma "x = Inr y ==> projr (permute p x) = permute p (projr x)" by simp}
   210 
   207 
   211 fun recover_mutual_eqvt eqvt_thm all_orig_fdefs parts ctxt (fname, _, _, args, _)
   208 fun recover_mutual_eqvt eqvt_thm all_orig_fdefs parts ctxt (fname, _, _, args, _)
   212   import (export : thm -> thm) sum_psimp_eq =
   209   import (export : thm -> thm) sum_psimp_eq =
   213   let
   210   let
   214     val (MutualPart {f=SOME f, ...}) = get_part fname parts
   211     val (MutualPart {f=SOME f, ...}) = get_part fname parts
   222           in (([asm], Thm.implies_elim psimp asm, Thm.implies_intr cond), ctxt') end
   219           in (([asm], Thm.implies_elim psimp asm, Thm.implies_intr cond), ctxt') end
   223       | _ => raise General.Fail "Too many conditions"
   220       | _ => raise General.Fail "Too many conditions"
   224 
   221 
   225     val ([p], ctxt'') = ctxt'
   222     val ([p], ctxt'') = ctxt'
   226       |> fold Variable.declare_term args  
   223       |> fold Variable.declare_term args  
   227       |> Variable.variant_fixes ["p"] 		   
   224       |> Variable.variant_fixes ["p"]
   228     val p = Free (p, @{typ perm})
   225     val p = Free (p, @{typ perm})
   229 
   226 
   230     val simpset =
   227     val simpset =
   231       put_simpset HOL_basic_ss ctxt'' addsimps 
   228       put_simpset HOL_basic_ss ctxt'' addsimps 
   232       @{thms permute_sum.simps[symmetric] Pair_eqvt[symmetric]} @
   229       @{thms permute_sum.simps[symmetric] Pair_eqvt[symmetric] sum.sel} @
   233       @{thms Projr.simps Projl.simps} @
       
   234       [(cond MRS eqvt_thm) RS @{thm sym}] @ 
   230       [(cond MRS eqvt_thm) RS @{thm sym}] @ 
   235       [inl_perm, inr_perm, simp] 
   231       [inl_perm, inr_perm, simp] 
   236     val goal_lhs = mk_perm p (list_comb (f, args))
   232     val goal_lhs = mk_perm p (list_comb (f, args))
   237     val goal_rhs = list_comb (f, map (mk_perm p) args)
   233     val goal_rhs = list_comb (f, map (mk_perm p) args)
   238   in
   234   in
   270       in
   266       in
   271         HOLogic.tupled_lambda atup (list_comb (P, avars))
   267         HOLogic.tupled_lambda atup (list_comb (P, avars))
   272       end
   268       end
   273 
   269 
   274     val Ps = map2 mk_P parts newPs
   270     val Ps = map2 mk_P parts newPs
   275     val case_exp = SumTree.mk_sumcases HOLogic.boolT Ps
   271     val case_exp = Sum_Tree.mk_sumcases HOLogic.boolT Ps
   276 
   272 
   277     val induct_inst =
   273     val induct_inst =
   278       Thm.forall_elim (cert case_exp) induct
   274       Thm.forall_elim (cert case_exp) induct
   279       |> full_simplify (put_simpset SumTree.sumcase_split_ss ctxt)
   275       |> full_simplify (put_simpset Sum_Tree.sumcase_split_ss ctxt)
   280       |> full_simplify (put_simpset HOL_basic_ss ctxt addsimps all_f_defs)
   276       |> full_simplify (put_simpset HOL_basic_ss ctxt addsimps all_f_defs)
   281 
   277 
   282     fun project rule (MutualPart {cargTs, i, ...}) k =
   278     fun project rule (MutualPart {cargTs, i, ...}) k =
   283       let
   279       let
   284         val afs = map_index (fn (j,T) => Free ("a" ^ string_of_int (j + k), T)) cargTs (* FIXME! *)
   280         val afs = map_index (fn (j,T) => Free ("a" ^ string_of_int (j + k), T)) cargTs (* FIXME! *)
   285         val inj = SumTree.mk_inj ST n i (foldr1 HOLogic.mk_prod afs)
   281         val inj = Sum_Tree.mk_inj ST n i (foldr1 HOLogic.mk_prod afs)
   286       in
   282       in
   287         (rule
   283         (rule
   288          |> Thm.forall_elim (cert inj)
   284          |> Thm.forall_elim (cert inj)
   289          |> full_simplify (put_simpset SumTree.sumcase_split_ss ctxt)
   285          |> full_simplify (put_simpset Sum_Tree.sumcase_split_ss ctxt)
   290          |> fold_rev (Thm.forall_intr o cert) (afs @ newPs),
   286          |> fold_rev (Thm.forall_intr o cert) (afs @ newPs),
   291          k + length cargTs)
   287          k + length cargTs)
   292       end
   288       end
   293   in
   289   in
   294     fst (fold_map (project induct_inst) parts 0)
   290     fst (fold_map (project induct_inst) parts 0)
   425     (* defining the auxiliary graph *)
   421     (* defining the auxiliary graph *)
   426     fun mk_cases (MutualPart {i', fvar as (n, T), ...}) =
   422     fun mk_cases (MutualPart {i', fvar as (n, T), ...}) =
   427       let
   423       let
   428         val (tys, ty) = strip_type T
   424         val (tys, ty) = strip_type T
   429         val fun_var = Free (n ^ "_aux", HOLogic.mk_tupleT tys --> ty)
   425         val fun_var = Free (n ^ "_aux", HOLogic.mk_tupleT tys --> ty)
   430         val inj_fun = absdummy dummyT (SumTree.mk_inj RST n' i' (Bound 0))
   426         val inj_fun = absdummy dummyT (Sum_Tree.mk_inj RST n' i' (Bound 0))
   431       in
   427       in
   432         Syntax.check_term lthy'' (mk_comp_dummy inj_fun fun_var)
   428         Syntax.check_term lthy'' (mk_comp_dummy inj_fun fun_var)
   433       end
   429       end
   434 
   430 
   435     val sum_case_exp = map mk_cases parts
   431     val case_sum_exp = map mk_cases parts
   436       |> SumTree.mk_sumcases RST 
   432       |> Sum_Tree.mk_sumcases RST 
   437    
   433    
   438     val (G_name, G_type) = dest_Free G 
   434     val (G_name, G_type) = dest_Free G 
   439     val G_name_aux = G_name ^ "_aux"
   435     val G_name_aux = G_name ^ "_aux"
   440     val subst = [(G, Free (G_name_aux, G_type))]
   436     val subst = [(G, Free (G_name_aux, G_type))]
   441     val GIntros_aux = GIntro_thms
   437     val GIntros_aux = GIntro_thms
   442       |> map prop_of
   438       |> map prop_of
   443       |> map (Term.subst_free subst)
   439       |> map (Term.subst_free subst)
   444       |> map (subst_all sum_case_exp)
   440       |> map (subst_all case_sum_exp)
   445 
   441 
   446     val ((G_aux, GIntro_aux_thms, _, G_aux_induct), lthy''') = 
   442     val ((G_aux, GIntro_aux_thms, _, G_aux_induct), lthy''') = 
   447       Nominal_Function_Core.inductive_def ((Binding.name G_name_aux, G_type), NoSyn) GIntros_aux lthy''
   443       Nominal_Function_Core.inductive_def ((Binding.name G_name_aux, G_type), NoSyn) GIntros_aux lthy''
   448 
   444 
   449     val mutual_cont = mk_partial_rules_mutual lthy''' cont mutual'
   445     val mutual_cont = mk_partial_rules_mutual lthy''' cont mutual'
   454     val G_aux_prem = HOLogic.mk_Trueprop (G_aux $ x $ y)
   450     val G_aux_prem = HOLogic.mk_Trueprop (G_aux $ x $ y)
   455     val G_prem = HOLogic.mk_Trueprop (G $ x $ y)
   451     val G_prem = HOLogic.mk_Trueprop (G $ x $ y)
   456 
   452 
   457     fun mk_inj_goal  (MutualPart {i', ...}) =
   453     fun mk_inj_goal  (MutualPart {i', ...}) =
   458       let
   454       let
   459         val injs = SumTree.mk_inj ST n' i' (Bound 0)
   455         val injs = Sum_Tree.mk_inj ST n' i' (Bound 0)
   460         val projs = y
   456         val projs = y
   461           |> SumTree.mk_proj RST n' i'
   457           |> Sum_Tree.mk_proj RST n' i'
   462           |> SumTree.mk_inj RST n' i'
   458           |> Sum_Tree.mk_inj RST n' i'
   463       in
   459       in
   464         Const (@{const_name "All"}, dummyT) $ absdummy dummyT
   460         Const (@{const_name "All"}, dummyT) $ absdummy dummyT
   465           (HOLogic.mk_imp (HOLogic.mk_eq(x, injs), HOLogic.mk_eq(projs, y)))
   461           (HOLogic.mk_imp (HOLogic.mk_eq(x, injs), HOLogic.mk_eq(projs, y)))
   466       end
   462       end
   467 
   463 
   472     val goal_iff1 = Logic.mk_implies (G_aux_prem, G_prem)
   468     val goal_iff1 = Logic.mk_implies (G_aux_prem, G_prem)
   473       |> all x |> all y
   469       |> all x |> all y
   474     val goal_iff2 = Logic.mk_implies (G_prem, G_aux_prem)
   470     val goal_iff2 = Logic.mk_implies (G_prem, G_aux_prem)
   475       |> all x |> all y
   471       |> all x |> all y
   476 
   472 
   477     val simp_thms = @{thms Projl.simps Projr.simps sum.inject sum.cases sum.distinct o_apply}
   473     val simp_thms = @{thms sum.sel sum.inject sum.case sum.distinct o_apply}
   478     val simpset0 = put_simpset HOL_basic_ss lthy''' addsimps simp_thms
   474     val simpset0 = put_simpset HOL_basic_ss lthy''' addsimps simp_thms
   479     val simpset1 = put_simpset HOL_ss lthy''' addsimps simp_thms
   475     val simpset1 = put_simpset HOL_ss lthy''' addsimps simp_thms
   480 
   476 
   481     val inj_thm = Goal.prove lthy''' [] [] goal_inj 
   477     val inj_thm = Goal.prove lthy''' [] [] goal_inj 
   482       (K (HEADGOAL (DETERM o etac G_aux_induct THEN_ALL_NEW asm_simp_tac simpset1)))
   478       (K (HEADGOAL (DETERM o etac G_aux_induct THEN_ALL_NEW asm_simp_tac simpset1)))