Nominal/nominal_mutual.ML
changeset 2982 4a00077c008f
parent 2981 c8acaded1777
child 2983 4436039cc5e1
--- a/Nominal/nominal_mutual.ML	Tue Jul 19 19:09:06 2011 +0100
+++ b/Nominal/nominal_mutual.ML	Fri Jul 22 11:37:16 2011 +0100
@@ -8,6 +8,7 @@
 Mutual recursive nominal function definitions.
 *)
 
+
 signature NOMINAL_FUNCTION_MUTUAL =
 sig
 
@@ -193,87 +194,46 @@
   in
     Goal.prove ctxt [] []
       (HOLogic.Trueprop $ HOLogic.mk_eq (list_comb (f, args), rhs))
-      (fn _ => print_tac "start" 
-         THEN (Local_Defs.unfold_tac ctxt all_orig_fdefs)
-         THEN (print_tac "second")
+      (fn _ => (Local_Defs.unfold_tac ctxt all_orig_fdefs)
          THEN EqSubst.eqsubst_tac ctxt [0] [simp] 1
-         THEN (print_tac "third")
-         THEN (simp_tac (simpset_of ctxt)) 1
-         THEN (print_tac "fourth")
-      ) (* FIXME: global simpset?!! *)
+         THEN (simp_tac (simpset_of ctxt)) 1) (* FIXME: global simpset?!! *)
     |> restore_cond
     |> export
   end
 
-val test1 = @{lemma "x = Inl y ==> permute p (Sum_Type.Projl x) = Sum_Type.Projl (permute p x)" by simp}
-val test2 = @{lemma "x = Inr y ==> permute p (Sum_Type.Projr x) = Sum_Type.Projr (permute p x)" by simp}
+val inl_perm = @{lemma "x = Inl y ==> Sum_Type.Projl (permute p x) = permute p (Sum_Type.Projl x)" by simp}
+val inr_perm = @{lemma "x = Inr y ==> Sum_Type.Projr (permute p x) = permute p (Sum_Type.Projr x)" by simp}
 
-fun recover_mutual_eqvt eqvt_thm all_orig_fdefs parts ctxt (fname, _, _, args, rhs)
+fun recover_mutual_eqvt eqvt_thm all_orig_fdefs parts ctxt (fname, _, _, args, _)
   import (export : thm -> thm) sum_psimp_eq =
   let
     val (MutualPart {f=SOME f, ...}) = get_part fname parts
- 
+    
     val psimp = import sum_psimp_eq
-    val (simp, restore_cond) =
+    val (cond, simp, restore_cond) =
       case cprems_of psimp of
-        [] => (psimp, I)
-      | [cond] => (Thm.implies_elim psimp (Thm.assume cond), Thm.implies_intr cond)
-      | _ => raise General.Fail "Too many conditions"
-
-    val eqvt_thm' = import eqvt_thm
-    val (simp', restore_cond') =
-      case cprems_of eqvt_thm' of
-        [] => (eqvt_thm, I)
-      | [cond] => (Thm.implies_elim eqvt_thm' (Thm.assume cond), Thm.implies_intr cond)
+        [] => ([], psimp, I)
+      | [cond] => ([Thm.assume cond], Thm.implies_elim psimp (Thm.assume cond), Thm.implies_intr cond)
       | _ => raise General.Fail "Too many conditions"
 
-    val _ = tracing ("sum_psimp:\n" ^ @{make_string} sum_psimp_eq)
-    val _ = tracing ("psimp:\n" ^ @{make_string} psimp)
-    val _ = tracing ("simp:\n" ^ @{make_string} simp)
-    val _ = tracing ("eqvt:\n" ^ @{make_string} eqvt_thm)
-    
     val ([p], ctxt') = Variable.variant_fixes ["p"] ctxt		   
     val p = Free (p, @{typ perm})
-    val ss = HOL_basic_ss addsimps [simp RS test1, simp']
+    val ss = HOL_basic_ss addsimps 
+      @{thms permute_sum.simps[symmetric] Pair_eqvt[symmetric]} @
+      @{thms Projr.simps Projl.simps} @
+      [(cond MRS eqvt_thm) RS @{thm sym}] @ 
+      [inl_perm, inr_perm, simp] 
+    val goal_lhs = mk_perm p (list_comb (f, args))
+    val goal_rhs = list_comb (f, map (mk_perm p) args)
   in
-    Goal.prove ctxt' [] []
-      (HOLogic.Trueprop $ 
-         HOLogic.mk_eq (mk_perm p (list_comb (f, args)), list_comb (f, map (mk_perm p) args)))
-      (fn _ => print_tac "eqvt start" 
-         THEN (Local_Defs.unfold_tac ctxt all_orig_fdefs)
-         THEN (asm_full_simp_tac ss 1)
-         THEN all_tac) 
+    Goal.prove ctxt' [] [] (HOLogic.Trueprop $ HOLogic.mk_eq (goal_lhs, goal_rhs))
+      (fn _ => (Local_Defs.unfold_tac ctxt all_orig_fdefs)
+         THEN (asm_full_simp_tac ss 1))
+    |> singleton (ProofContext.export ctxt' ctxt)
     |> restore_cond
     |> export
   end
 
-fun mk_meqvts ctxt eqvt_thm f_defs =
-  let
-    val ctrm1 = eqvt_thm
-      |> cprop_of
-      |> snd o Thm.dest_implies
-      |> Thm.dest_arg
-      |> Thm.dest_arg1
-      |> Thm.dest_arg
-
-    fun resolve f_def =
-      let
-        val ctrm2 = f_def
-          |> cprop_of
-          |> Thm.dest_equals_lhs
-        val _ = tracing ("ctrm1:\n" ^ @{make_string} ctrm1)
-        val _ = tracing ("ctrm2:\n" ^ @{make_string} ctrm2)
-      in
-        eqvt_thm
-	|> Thm.instantiate (Thm.match (ctrm1, ctrm2))
-        |> simplify (HOL_basic_ss addsimps (@{thm Pair_eqvt} :: @{thms permute_sum.simps}))
-        |> Local_Defs.unfold ctxt [f_def] 
-      end
-  in
-    map resolve f_defs
-  end
-
-
 fun mk_applied_form ctxt caTs thm =
   let
     val thy = ProofContext.theory_of ctxt
@@ -324,6 +284,66 @@
     fst (fold_map (project induct_inst) parts 0)
   end
 
+
+fun forall_elim s (Const ("all", _) $ Abs (_, _, t)) = subst_bound (s, t)
+  | forall_elim _ t = t
+
+val forall_elim_list = fold forall_elim
+
+fun split_conj_thm th =
+  (split_conj_thm (th RS conjunct1)) @ (split_conj_thm (th RS conjunct2)) handle THM _ => [th];
+
+fun prove_eqvt ctxt fs argTss eqvts_thms induct_thms =
+  let
+    fun aux argTs s = argTs
+      |> map (pair s)
+      |> Variable.variant_frees ctxt fs
+    val argss' = map2 aux argTss (Name.invent (Variable.names_of ctxt) "" (length fs)) 
+    val argss = (map o map) Free argss'
+    val arg_namess = (map o map) fst argss'
+    val insts = (map o map) SOME arg_namess 
+   
+    val ([p_name], ctxt') = Variable.variant_fixes ["p"] ctxt
+    val p = Free (p_name, @{typ perm})
+
+    val acc_prems = 
+     map prop_of induct_thms
+     |> map2 forall_elim_list argss 
+     |> map (strip_qnt_body "all")
+     |> map (curry Logic.nth_prem 1)
+     |> map HOLogic.dest_Trueprop
+
+    fun mk_goal acc_prem (f, args) = 
+      let
+        val goal_lhs = mk_perm p (list_comb (f, args))
+        val goal_rhs = list_comb (f, map (mk_perm p) args)
+      in
+        HOLogic.mk_imp (acc_prem, HOLogic.mk_eq (goal_lhs, goal_rhs))
+      end
+
+    val goal = fold_conj_balanced (map2 mk_goal acc_prems (fs ~~ argss))
+      |> HOLogic.mk_Trueprop
+
+    val induct_thm = case induct_thms of
+        [thm] => thm
+          |> Drule.gen_all 
+          |> Thm.permute_prems 0 1
+          |> (fn thm => atomize_rule (length (prems_of thm) - 1) thm)
+      | thms => thms
+          |> map Drule.gen_all 
+          |> map (Rule_Cases.add_consumes 1)
+          |> snd o Rule_Cases.strict_mutual_rule ctxt'
+          |> atomize_concl
+
+    fun tac thm = rtac (Drule.gen_all thm) THEN_ALL_NEW atac
+  in
+    Goal.prove ctxt' (flat arg_namess) [] goal
+      (fn {context, ...} => HEADGOAL (DETERM o (rtac induct_thm) THEN' RANGE (map tac eqvts_thms)))
+    |> singleton (ProofContext.export ctxt' ctxt)
+    |> split_conj_thm
+    |> map (fn th => th RS mp)
+  end
+
 fun mk_partial_rules_mutual lthy inner_cont (m as Mutual {parts, fqgars, ...}) proof =
   let
     val result = inner_cont proof
@@ -332,13 +352,16 @@
 
     val (all_f_defs, fs) =
       map (fn MutualPart {f_defthm = SOME f_def, f = SOME f, cargTs, ...} =>
-        (mk_applied_form lthy cargTs (Thm.symmetric f_def), f))
+          (mk_applied_form lthy cargTs (Thm.symmetric f_def), f))
       parts
       |> split_list
 
     val all_orig_fdefs =
       map (fn MutualPart {f_defthm = SOME f_def, ...} => f_def) parts
 
+    val cargTss =
+      map (fn MutualPart {f = SOME f, cargTs, ...} => cargTs) parts
+
     fun mk_mpsimp fqgar sum_psimp =
       in_context lthy fqgar (recover_mutual_psimp all_orig_fdefs parts) sum_psimp
 
@@ -351,11 +374,12 @@
     val mtermination = full_simplify rew_ss termination
     val mdomintros = Option.map (map (full_simplify rew_ss)) domintros
     val meqvts = map2 mk_meqvts fqgars psimps
+    val meqvt_funs = prove_eqvt lthy fs cargTss meqvts minducts
  in
     NominalFunctionResult { fs=fs, G=G, R=R,
       psimps=mpsimps, simple_pinducts=minducts,
       cases=cases, termination=mtermination,
-      domintros=mdomintros, eqvts=meqvts }
+      domintros=mdomintros, eqvts=meqvt_funs }
   end
 
 (* nominal *)