Nominal/nominal_mutual.ML
changeset 2982 4a00077c008f
parent 2981 c8acaded1777
child 2983 4436039cc5e1
equal deleted inserted replaced
2981:c8acaded1777 2982:4a00077c008f
     5     (code forked on 14 January 2011)
     5     (code forked on 14 January 2011)
     6 
     6 
     7 
     7 
     8 Mutual recursive nominal function definitions.
     8 Mutual recursive nominal function definitions.
     9 *)
     9 *)
       
    10 
    10 
    11 
    11 signature NOMINAL_FUNCTION_MUTUAL =
    12 signature NOMINAL_FUNCTION_MUTUAL =
    12 sig
    13 sig
    13 
    14 
    14   val prepare_nominal_function_mutual : Nominal_Function_Common.nominal_function_config
    15   val prepare_nominal_function_mutual : Nominal_Function_Common.nominal_function_config
   191       | [cond] => (Thm.implies_elim psimp (Thm.assume cond), Thm.implies_intr cond)
   192       | [cond] => (Thm.implies_elim psimp (Thm.assume cond), Thm.implies_intr cond)
   192       | _ => raise General.Fail "Too many conditions"
   193       | _ => raise General.Fail "Too many conditions"
   193   in
   194   in
   194     Goal.prove ctxt [] []
   195     Goal.prove ctxt [] []
   195       (HOLogic.Trueprop $ HOLogic.mk_eq (list_comb (f, args), rhs))
   196       (HOLogic.Trueprop $ HOLogic.mk_eq (list_comb (f, args), rhs))
   196       (fn _ => print_tac "start" 
   197       (fn _ => (Local_Defs.unfold_tac ctxt all_orig_fdefs)
   197          THEN (Local_Defs.unfold_tac ctxt all_orig_fdefs)
       
   198          THEN (print_tac "second")
       
   199          THEN EqSubst.eqsubst_tac ctxt [0] [simp] 1
   198          THEN EqSubst.eqsubst_tac ctxt [0] [simp] 1
   200          THEN (print_tac "third")
   199          THEN (simp_tac (simpset_of ctxt)) 1) (* FIXME: global simpset?!! *)
   201          THEN (simp_tac (simpset_of ctxt)) 1
       
   202          THEN (print_tac "fourth")
       
   203       ) (* FIXME: global simpset?!! *)
       
   204     |> restore_cond
   200     |> restore_cond
   205     |> export
   201     |> export
   206   end
   202   end
   207 
   203 
   208 val test1 = @{lemma "x = Inl y ==> permute p (Sum_Type.Projl x) = Sum_Type.Projl (permute p x)" by simp}
   204 val inl_perm = @{lemma "x = Inl y ==> Sum_Type.Projl (permute p x) = permute p (Sum_Type.Projl x)" by simp}
   209 val test2 = @{lemma "x = Inr y ==> permute p (Sum_Type.Projr x) = Sum_Type.Projr (permute p x)" by simp}
   205 val inr_perm = @{lemma "x = Inr y ==> Sum_Type.Projr (permute p x) = permute p (Sum_Type.Projr x)" by simp}
   210 
   206 
   211 fun recover_mutual_eqvt eqvt_thm all_orig_fdefs parts ctxt (fname, _, _, args, rhs)
   207 fun recover_mutual_eqvt eqvt_thm all_orig_fdefs parts ctxt (fname, _, _, args, _)
   212   import (export : thm -> thm) sum_psimp_eq =
   208   import (export : thm -> thm) sum_psimp_eq =
   213   let
   209   let
   214     val (MutualPart {f=SOME f, ...}) = get_part fname parts
   210     val (MutualPart {f=SOME f, ...}) = get_part fname parts
   215  
   211     
   216     val psimp = import sum_psimp_eq
   212     val psimp = import sum_psimp_eq
   217     val (simp, restore_cond) =
   213     val (cond, simp, restore_cond) =
   218       case cprems_of psimp of
   214       case cprems_of psimp of
   219         [] => (psimp, I)
   215         [] => ([], psimp, I)
   220       | [cond] => (Thm.implies_elim psimp (Thm.assume cond), Thm.implies_intr cond)
   216       | [cond] => ([Thm.assume cond], Thm.implies_elim psimp (Thm.assume cond), Thm.implies_intr cond)
   221       | _ => raise General.Fail "Too many conditions"
   217       | _ => raise General.Fail "Too many conditions"
   222 
   218 
   223     val eqvt_thm' = import eqvt_thm
       
   224     val (simp', restore_cond') =
       
   225       case cprems_of eqvt_thm' of
       
   226         [] => (eqvt_thm, I)
       
   227       | [cond] => (Thm.implies_elim eqvt_thm' (Thm.assume cond), Thm.implies_intr cond)
       
   228       | _ => raise General.Fail "Too many conditions"
       
   229 
       
   230     val _ = tracing ("sum_psimp:\n" ^ @{make_string} sum_psimp_eq)
       
   231     val _ = tracing ("psimp:\n" ^ @{make_string} psimp)
       
   232     val _ = tracing ("simp:\n" ^ @{make_string} simp)
       
   233     val _ = tracing ("eqvt:\n" ^ @{make_string} eqvt_thm)
       
   234     
       
   235     val ([p], ctxt') = Variable.variant_fixes ["p"] ctxt		   
   219     val ([p], ctxt') = Variable.variant_fixes ["p"] ctxt		   
   236     val p = Free (p, @{typ perm})
   220     val p = Free (p, @{typ perm})
   237     val ss = HOL_basic_ss addsimps [simp RS test1, simp']
   221     val ss = HOL_basic_ss addsimps 
   238   in
   222       @{thms permute_sum.simps[symmetric] Pair_eqvt[symmetric]} @
   239     Goal.prove ctxt' [] []
   223       @{thms Projr.simps Projl.simps} @
   240       (HOLogic.Trueprop $ 
   224       [(cond MRS eqvt_thm) RS @{thm sym}] @ 
   241          HOLogic.mk_eq (mk_perm p (list_comb (f, args)), list_comb (f, map (mk_perm p) args)))
   225       [inl_perm, inr_perm, simp] 
   242       (fn _ => print_tac "eqvt start" 
   226     val goal_lhs = mk_perm p (list_comb (f, args))
   243          THEN (Local_Defs.unfold_tac ctxt all_orig_fdefs)
   227     val goal_rhs = list_comb (f, map (mk_perm p) args)
   244          THEN (asm_full_simp_tac ss 1)
   228   in
   245          THEN all_tac) 
   229     Goal.prove ctxt' [] [] (HOLogic.Trueprop $ HOLogic.mk_eq (goal_lhs, goal_rhs))
       
   230       (fn _ => (Local_Defs.unfold_tac ctxt all_orig_fdefs)
       
   231          THEN (asm_full_simp_tac ss 1))
       
   232     |> singleton (ProofContext.export ctxt' ctxt)
   246     |> restore_cond
   233     |> restore_cond
   247     |> export
   234     |> export
   248   end
   235   end
   249 
       
   250 fun mk_meqvts ctxt eqvt_thm f_defs =
       
   251   let
       
   252     val ctrm1 = eqvt_thm
       
   253       |> cprop_of
       
   254       |> snd o Thm.dest_implies
       
   255       |> Thm.dest_arg
       
   256       |> Thm.dest_arg1
       
   257       |> Thm.dest_arg
       
   258 
       
   259     fun resolve f_def =
       
   260       let
       
   261         val ctrm2 = f_def
       
   262           |> cprop_of
       
   263           |> Thm.dest_equals_lhs
       
   264         val _ = tracing ("ctrm1:\n" ^ @{make_string} ctrm1)
       
   265         val _ = tracing ("ctrm2:\n" ^ @{make_string} ctrm2)
       
   266       in
       
   267         eqvt_thm
       
   268 	|> Thm.instantiate (Thm.match (ctrm1, ctrm2))
       
   269         |> simplify (HOL_basic_ss addsimps (@{thm Pair_eqvt} :: @{thms permute_sum.simps}))
       
   270         |> Local_Defs.unfold ctxt [f_def] 
       
   271       end
       
   272   in
       
   273     map resolve f_defs
       
   274   end
       
   275 
       
   276 
   236 
   277 fun mk_applied_form ctxt caTs thm =
   237 fun mk_applied_form ctxt caTs thm =
   278   let
   238   let
   279     val thy = ProofContext.theory_of ctxt
   239     val thy = ProofContext.theory_of ctxt
   280     val xs = map_index (fn (i,T) => cterm_of thy (Free ("x" ^ string_of_int i, T))) caTs (* FIXME: Bind xs properly *)
   240     val xs = map_index (fn (i,T) => cterm_of thy (Free ("x" ^ string_of_int i, T))) caTs (* FIXME: Bind xs properly *)
   322       end
   282       end
   323   in
   283   in
   324     fst (fold_map (project induct_inst) parts 0)
   284     fst (fold_map (project induct_inst) parts 0)
   325   end
   285   end
   326 
   286 
       
   287 
       
   288 fun forall_elim s (Const ("all", _) $ Abs (_, _, t)) = subst_bound (s, t)
       
   289   | forall_elim _ t = t
       
   290 
       
   291 val forall_elim_list = fold forall_elim
       
   292 
       
   293 fun split_conj_thm th =
       
   294   (split_conj_thm (th RS conjunct1)) @ (split_conj_thm (th RS conjunct2)) handle THM _ => [th];
       
   295 
       
   296 fun prove_eqvt ctxt fs argTss eqvts_thms induct_thms =
       
   297   let
       
   298     fun aux argTs s = argTs
       
   299       |> map (pair s)
       
   300       |> Variable.variant_frees ctxt fs
       
   301     val argss' = map2 aux argTss (Name.invent (Variable.names_of ctxt) "" (length fs)) 
       
   302     val argss = (map o map) Free argss'
       
   303     val arg_namess = (map o map) fst argss'
       
   304     val insts = (map o map) SOME arg_namess 
       
   305    
       
   306     val ([p_name], ctxt') = Variable.variant_fixes ["p"] ctxt
       
   307     val p = Free (p_name, @{typ perm})
       
   308 
       
   309     val acc_prems = 
       
   310      map prop_of induct_thms
       
   311      |> map2 forall_elim_list argss 
       
   312      |> map (strip_qnt_body "all")
       
   313      |> map (curry Logic.nth_prem 1)
       
   314      |> map HOLogic.dest_Trueprop
       
   315 
       
   316     fun mk_goal acc_prem (f, args) = 
       
   317       let
       
   318         val goal_lhs = mk_perm p (list_comb (f, args))
       
   319         val goal_rhs = list_comb (f, map (mk_perm p) args)
       
   320       in
       
   321         HOLogic.mk_imp (acc_prem, HOLogic.mk_eq (goal_lhs, goal_rhs))
       
   322       end
       
   323 
       
   324     val goal = fold_conj_balanced (map2 mk_goal acc_prems (fs ~~ argss))
       
   325       |> HOLogic.mk_Trueprop
       
   326 
       
   327     val induct_thm = case induct_thms of
       
   328         [thm] => thm
       
   329           |> Drule.gen_all 
       
   330           |> Thm.permute_prems 0 1
       
   331           |> (fn thm => atomize_rule (length (prems_of thm) - 1) thm)
       
   332       | thms => thms
       
   333           |> map Drule.gen_all 
       
   334           |> map (Rule_Cases.add_consumes 1)
       
   335           |> snd o Rule_Cases.strict_mutual_rule ctxt'
       
   336           |> atomize_concl
       
   337 
       
   338     fun tac thm = rtac (Drule.gen_all thm) THEN_ALL_NEW atac
       
   339   in
       
   340     Goal.prove ctxt' (flat arg_namess) [] goal
       
   341       (fn {context, ...} => HEADGOAL (DETERM o (rtac induct_thm) THEN' RANGE (map tac eqvts_thms)))
       
   342     |> singleton (ProofContext.export ctxt' ctxt)
       
   343     |> split_conj_thm
       
   344     |> map (fn th => th RS mp)
       
   345   end
       
   346 
   327 fun mk_partial_rules_mutual lthy inner_cont (m as Mutual {parts, fqgars, ...}) proof =
   347 fun mk_partial_rules_mutual lthy inner_cont (m as Mutual {parts, fqgars, ...}) proof =
   328   let
   348   let
   329     val result = inner_cont proof
   349     val result = inner_cont proof
   330     val NominalFunctionResult {G, R, cases, psimps, simple_pinducts=[simple_pinduct],
   350     val NominalFunctionResult {G, R, cases, psimps, simple_pinducts=[simple_pinduct],
   331       termination, domintros, eqvts=[eqvt],...} = result
   351       termination, domintros, eqvts=[eqvt],...} = result
   332 
   352 
   333     val (all_f_defs, fs) =
   353     val (all_f_defs, fs) =
   334       map (fn MutualPart {f_defthm = SOME f_def, f = SOME f, cargTs, ...} =>
   354       map (fn MutualPart {f_defthm = SOME f_def, f = SOME f, cargTs, ...} =>
   335         (mk_applied_form lthy cargTs (Thm.symmetric f_def), f))
   355           (mk_applied_form lthy cargTs (Thm.symmetric f_def), f))
   336       parts
   356       parts
   337       |> split_list
   357       |> split_list
   338 
   358 
   339     val all_orig_fdefs =
   359     val all_orig_fdefs =
   340       map (fn MutualPart {f_defthm = SOME f_def, ...} => f_def) parts
   360       map (fn MutualPart {f_defthm = SOME f_def, ...} => f_def) parts
       
   361 
       
   362     val cargTss =
       
   363       map (fn MutualPart {f = SOME f, cargTs, ...} => cargTs) parts
   341 
   364 
   342     fun mk_mpsimp fqgar sum_psimp =
   365     fun mk_mpsimp fqgar sum_psimp =
   343       in_context lthy fqgar (recover_mutual_psimp all_orig_fdefs parts) sum_psimp
   366       in_context lthy fqgar (recover_mutual_psimp all_orig_fdefs parts) sum_psimp
   344 
   367 
   345     fun mk_meqvts fqgar sum_psimp =
   368     fun mk_meqvts fqgar sum_psimp =
   349     val mpsimps = map2 mk_mpsimp fqgars psimps
   372     val mpsimps = map2 mk_mpsimp fqgars psimps
   350     val minducts = mutual_induct_rules lthy simple_pinduct all_f_defs m
   373     val minducts = mutual_induct_rules lthy simple_pinduct all_f_defs m
   351     val mtermination = full_simplify rew_ss termination
   374     val mtermination = full_simplify rew_ss termination
   352     val mdomintros = Option.map (map (full_simplify rew_ss)) domintros
   375     val mdomintros = Option.map (map (full_simplify rew_ss)) domintros
   353     val meqvts = map2 mk_meqvts fqgars psimps
   376     val meqvts = map2 mk_meqvts fqgars psimps
       
   377     val meqvt_funs = prove_eqvt lthy fs cargTss meqvts minducts
   354  in
   378  in
   355     NominalFunctionResult { fs=fs, G=G, R=R,
   379     NominalFunctionResult { fs=fs, G=G, R=R,
   356       psimps=mpsimps, simple_pinducts=minducts,
   380       psimps=mpsimps, simple_pinducts=minducts,
   357       cases=cases, termination=mtermination,
   381       cases=cases, termination=mtermination,
   358       domintros=mdomintros, eqvts=meqvts }
   382       domintros=mdomintros, eqvts=meqvt_funs }
   359   end
   383   end
   360 
   384 
   361 (* nominal *)
   385 (* nominal *)
   362 fun prepare_nominal_function_mutual config defname fixes eqss lthy =
   386 fun prepare_nominal_function_mutual config defname fixes eqss lthy =
   363   let
   387   let