derived equivariance for the function graph and function relation
authorChristian Urban <urbanc@in.tum.de>
Fri, 07 Jan 2011 02:30:00 +0000
changeset 2650 e5fa8de0e4bd
parent 2649 a8ebcb368a15
child 2651 4aa72a88b2c1
derived equivariance for the function graph and function relation
Nominal/Ex/LamTest.thy
Nominal/Nominal2.thy
Nominal/nominal_eqvt.ML
Nominal/nominal_thmdecls.ML
--- a/Nominal/Ex/LamTest.thy	Thu Jan 06 23:06:45 2011 +0000
+++ b/Nominal/Ex/LamTest.thy	Fri Jan 07 02:30:00 2011 +0000
@@ -12,6 +12,10 @@
 definition
  "eqvt x \<equiv> \<forall>p. p \<bullet> x = x"
 
+lemma eqvtI:
+  "(\<And>p. p \<bullet> x \<equiv> x) \<Longrightarrow> eqvt x"
+apply(auto simp add: eqvt_def)
+done
 
 ML {*
 
@@ -96,6 +100,16 @@
   end
 *}
 
+ML {*
+fun mk_eqvt trm =
+  let
+    val ty = fastype_of trm
+  in
+    Const (@{const_name eqvt}, ty --> @{typ bool}) $ trm
+    |> HOLogic.mk_Trueprop
+  end
+*}
+
 
 ML {*
 (** building proof obligations *)
@@ -114,6 +128,7 @@
           (HOLogic.mk_Trueprop (HOLogic.eq_const domT $ shift lhs $ lhs'),
             HOLogic.mk_Trueprop (HOLogic.eq_const ranT $ shift rhs $ rhs'))
         |> fold_rev (curry Logic.mk_implies) (map shift gs @ gs')
+        (* HERE |> (curry Logic.mk_implies) (mk_eqvt fvar) *)
         |> (curry Logic.mk_implies) @{term "Trueprop True"}
         |> fold_rev (fn (n,T) => fn b => Term.all T $ Abs(n,T,b)) (qs @ qs')
         |> curry abstract_over fvar
@@ -228,7 +243,6 @@
   end
 *}
 
-
 ML {*
 (* expects i <= j *)
 fun lookup_compat_thm i j cts =
@@ -244,7 +258,6 @@
     val lhsi_eq_lhsj = cterm_of thy (HOLogic.mk_Trueprop (mk_eq (lhsi, lhsj)))
   in if j < i then
     let
-      val _ = tracing "then"
       val compat = lookup_compat_thm j i cts
     in
       compat         (* "!!qj qi. Gsj => Gsi => lhsj = lhsi ==> rhsj = rhsi" *)
@@ -256,22 +269,14 @@
     end
     else
     let
-      val _ = tracing "else"
       val compat = lookup_compat_thm i j cts
-
-      fun pp s (th:thm) = tracing (s ^ " compat thm: " ^ @{make_string} th)
     in
       compat        (* "!!qi qj. Gsi => Gsj => lhsi = lhsj ==> rhsi = rhsj" *)
-      |> (tap (fn th => pp "*0" th))
       |> fold Thm.forall_elim (cqsi @ cqsj) (* "Gsi => Gsj => lhsi = lhsj ==> rhsi = rhsj" *)
-      |> (tap (fn th => pp "*1" th))
       |> Thm.elim_implies @{thm TrueI}
       |> fold Thm.elim_implies agsi
-      |> (tap (fn th => pp "*2" th))
       |> fold Thm.elim_implies agsj
-      |> (tap (fn th => pp "*3" th))
       |> Thm.elim_implies (Thm.assume lhsi_eq_lhsj)
-      |> (tap (fn th => pp "*4" th))
       |> (fn thm => thm RS sym) (* "Gsi, Gsj, lhsi = lhsj |-- rhsj = rhsi" *)
     end
   end
@@ -320,33 +325,20 @@
     val cctxj as ClauseContext {ags = agsj', lhs = lhsj', rhs = rhsj', qs = qsj', cqs = cqsj', ...} =
       mk_clause_context x ctxti cdescj
 
-    val _ = tracing "1"
-
     val rhsj'h = Pattern.rewrite_term thy [(fvar,h)] [] rhsj'
-   
-    val _ = tracing "1A"
-
     val compat = get_compat_thm thy compat_store i j cctxi cctxj
-   
-    val _ = tracing "1B"
 
     val Ghsj' = map 
       (fn RCInfo {h_assum, ...} => Thm.assume (cterm_of thy (subst_bounds (rev qsj', h_assum)))) RCsj
 
-    val _ = tracing "2"
-
     val RLj_import = RLj
       |> fold Thm.forall_elim cqsj'
       |> fold Thm.elim_implies agsj'
       |> fold Thm.elim_implies Ghsj'
 
-    val _ = tracing "3"
-
     val y_eq_rhsj'h = Thm.assume (cterm_of thy (HOLogic.mk_Trueprop (mk_eq (y, rhsj'h))))
     val lhsi_eq_lhsj' = Thm.assume (cterm_of thy (HOLogic.mk_Trueprop (mk_eq (lhsi, lhsj'))))
        (* lhs_i = lhs_j' |-- lhs_i = lhs_j' *)
-
-    val _ = tracing "4"
   in
     (trans OF [case_hyp, lhsi_eq_lhsj']) (* lhs_i = lhs_j' |-- x = lhs_j' *)
     |> Thm.implies_elim RLj_import
@@ -375,31 +367,21 @@
 
     val ih_intro_case = full_simplify (HOL_basic_ss addsimps [case_hyp]) ih_intro
 
-    val _ = tracing "A"
-
     fun prep_RC (RCInfo {llRI, RIvs, CCas, ...}) = (llRI RS ih_intro_case)
       |> fold_rev (Thm.implies_intr o cprop_of) CCas
       |> fold_rev (Thm.forall_intr o cterm_of thy o Free) RIvs
 
     val existence = fold (curry op COMP o prep_RC) RCs lGI
 
-    val _ = tracing "B"
-
     val P = cterm_of thy (mk_eq (y, rhsC))
     val G_lhs_y = Thm.assume (cterm_of thy (HOLogic.mk_Trueprop (G $ lhs $ y)))
 
-    val _ = tracing "B2"
-
     val unique_clauses =
       map2 (mk_uniqueness_clause thy globals compat_store clausei) clauses rep_lemmas
   
-    val _ = tracing "C"
-
     fun elim_implies_eta A AB =
       Thm.compose_no_flatten true (A, 0) 1 AB |> Seq.list_of |> the_single
 
-    val _ = tracing "D"
-
     val uniqueness = G_cases
       |> Thm.forall_elim (cterm_of thy lhs)
       |> Thm.forall_elim (cterm_of thy y)
@@ -409,8 +391,6 @@
       |> Thm.implies_intr (cprop_of G_lhs_y)
       |> Thm.forall_intr (cterm_of thy y)
 
-    val _ = tracing "E"
-
     val P2 = cterm_of thy (lambda y (G $ lhs $ y)) (* P2 y := (lhs, y): G *)
 
     val exactly_one =
@@ -422,8 +402,6 @@
       |> fold_rev (Thm.implies_intr o cprop_of) ags
       |> fold_rev Thm.forall_intr cqs
 
-    val _ = tracing "F"
-
     val function_value =
       existence
       |> Thm.implies_intr ihyp
@@ -431,8 +409,6 @@
       |> Thm.forall_intr (cterm_of thy x)
       |> Thm.forall_elim (cterm_of thy lhs)
       |> curry (op RS) refl
-
-    val _ = tracing "G" 
   in
     (exactly_one, function_value)
   end
@@ -482,13 +458,20 @@
   in
     (goalstate, values)
   end
+*}
 
+ML {* 
+Inductive.add_inductive_i;
+Local_Theory.conceal
+*}
+
+ML {*
 (* wrapper -- restores quantifiers in rule specifications *)
 fun inductive_def (binding as ((R, T), _)) intrs lthy =
   let
-    val ({intrs = intrs_gen, elims = [elim_gen], preds = [ Rdef ], induct, ...}, lthy) =
+    val ({intrs = intrs_gen, elims = [elim_gen], preds = [ Rdef ], induct, raw_induct, ...}, lthy) =
       lthy
-      |> Local_Theory.conceal
+      |> Local_Theory.conceal 
       |> Inductive.add_inductive_i
           {quiet_mode = true,
             verbose = ! trace,
@@ -504,6 +487,9 @@
           [] (* no special monos *)
       ||> Local_Theory.restore_naming lthy
 
+    val ([eqvt_thm], lthy) = Nominal_Eqvt.raw_equivariance false [Rdef] raw_induct intrs_gen lthy
+    val eqvt_thm' = (Nominal_ThmDecls.eqvt_transform lthy eqvt_thm) RS @{thm eqvtI}
+
     val cert = cterm_of (ProofContext.theory_of lthy)
     fun requantify orig_intro thm =
       let
@@ -517,7 +503,7 @@
           forall_intr_rename (n, cert (Var (varmap (n, T), T)))) qs thm
       end
   in
-    ((Rdef, map2 requantify intrs intrs_gen, forall_intr_vars elim_gen, induct), lthy)
+    ((Rdef, map2 requantify intrs intrs_gen, forall_intr_vars elim_gen, induct, eqvt_thm'), lthy)
   end
 *}
 
@@ -574,10 +560,10 @@
 
     val R_intross = map2 (map o mk_RIntro) (clauses ~~ qglrs) RCss
 
-    val ((R, RIntro_thms, R_elim, _), lthy) =
+    val ((R, RIntro_thms, R_elim, _, R_eqvt), lthy) =
       inductive_def ((Binding.name n, T), NoSyn) (flat R_intross) lthy
   in
-    ((R, Library.unflat R_intross RIntro_thms, R_elim), lthy)
+    ((R, Library.unflat R_intross RIntro_thms, R_elim, R_eqvt), lthy)
   end
 
 
@@ -928,8 +914,9 @@
   in
     map2 mk_trsimp clauses psimps
   end
+*}
 
-
+ML {*
 fun prepare_function config defname [((fname, fT), mixfix)] abstract_qglrs lthy =
   let
     val FunctionConfig {domintros, tailrec, default=default_opt, ...} = config
@@ -956,11 +943,13 @@
     val trees = map build_tree clauses
     val RCss = map find_calls trees
 
-    val ((G, GIntro_thms, G_elim, G_induct), lthy) =
+    val ((G, GIntro_thms, G_elim, G_induct, G_eqvt), lthy) =
       PROFILE "def_graph" (define_graph (graph_name defname) fvar domT ranT clauses RCss) lthy
 
     val _ = tracing ("Graph - name: " ^ @{make_string} G)
     val _ = tracing ("Graph intros:\n" ^ cat_lines (map @{make_string} GIntro_thms))
+    val _ = tracing ("Graph Equivariance" ^ @{make_string} G_eqvt)
+    
 
     val ((f, (_, f_defthm)), lthy) =
       PROFILE "def_fun" (define_function (defname ^ "_sumC_def") (fname, mixfix) domT ranT G default) lthy
@@ -971,11 +960,12 @@
     val RCss = map (map (inst_RC (ProofContext.theory_of lthy) fvar f)) RCss
     val trees = map (Function_Ctx_Tree.inst_tree (ProofContext.theory_of lthy) fvar f) trees
 
-    val ((R, RIntro_thmss, R_elim), lthy) =
+    val ((R, RIntro_thmss, R_elim, R_eqvt), lthy) =
       PROFILE "def_rel" (define_recursion_relation (rel_name defname) domT abstract_qglrs clauses RCss) lthy
 
     val _ = tracing ("Relation - name: " ^ @{make_string} R) 
     val _ = tracing ("Relation intros:\n" ^ cat_lines (map @{make_string} RIntro_thmss))
+    val _ = tracing ("Relation Equivariance" ^ @{make_string} R_eqvt)
 
     val (_, lthy) =
       Local_Theory.abbrev Syntax.mode_default ((Binding.name (dom_name defname), NoSyn), mk_acc domT R) lthy
@@ -998,10 +988,12 @@
 
     val compat_store = store_compat_thms n compat
 
+    (*
     val _ = tracing ("globals:\n" ^ @{make_string} globals)
     val _ = tracing ("complete:\n" ^ @{make_string} complete)
     val _ = tracing ("compat:\n" ^ @{make_string} compat)
     val _ = tracing ("compat_store:\n" ^ @{make_string} compat_store)
+    *)
 
     val (goalstate, values) = PROFILE "prove_stuff"
       (prove_stuff lthy globals G f R xclauses complete compat
@@ -1501,6 +1493,7 @@
   "depth (Var x) = 1"
 | "depth (App t1 t2) = (max (depth t1) (depth t2)) + 1"
 | "depth (Lam x t) = (depth t) + 1"
+thm depth_graph.intros
 apply(rule_tac y="x" in lam.exhaust)
 apply(simp_all)[3]
 apply(simp_all only: lam.distinct)
--- a/Nominal/Nominal2.thy	Thu Jan 06 23:06:45 2011 +0000
+++ b/Nominal/Nominal2.thy	Fri Jan 07 02:30:00 2011 +0000
@@ -250,7 +250,7 @@
   val lthy5 = snd (Local_Theory.note ((Binding.empty, [eqvt_attr]), raw_fv_eqvt) lthy_tmp)
 
   val (alpha_eqvt, lthy6) =
-    Nominal_Eqvt.equivariance true (alpha_trms @ alpha_bn_trms) alpha_induct alpha_intros lthy5
+    Nominal_Eqvt.raw_equivariance true (alpha_trms @ alpha_bn_trms) alpha_induct alpha_intros lthy5
 
   val _ = trace_msg (K "Proving equivalence of alpha...")
   val alpha_refl_thms = 
--- a/Nominal/nominal_eqvt.ML	Thu Jan 06 23:06:45 2011 +0000
+++ b/Nominal/nominal_eqvt.ML	Fri Jan 07 02:30:00 2011 +0000
@@ -10,7 +10,8 @@
   val eqvt_rel_tac: Proof.context -> string list -> term -> thm -> thm list -> int -> tactic
   val eqvt_rel_single_case_tac: Proof.context -> string list -> term -> thm -> int -> tactic
   
-  val equivariance: bool -> term list -> thm -> thm list -> Proof.context -> thm list * local_theory
+  val raw_equivariance: bool -> term list -> thm -> thm list -> Proof.context -> thm list * local_theory
+  val equivariance: string -> Proof.context -> (thm list * local_theory)
   val equivariance_cmd: string -> Proof.context -> local_theory
 end
 
@@ -60,7 +61,7 @@
   let
     val cases = map (eqvt_rel_single_case_tac ctxt pred_names pi) intros
   in
-    EVERY' (rtac induct :: cases)
+    EVERY' ((DETERM o rtac induct) :: cases)
   end
 
 
@@ -85,7 +86,10 @@
     (thm', ctxt')
   end
 
-fun equivariance note_flag pred_trms raw_induct intrs ctxt = 
+fun get_name (Const (a, _)) = a
+  | get_name (Free  (a, _)) = a
+
+fun raw_equivariance note_flag pred_trms raw_induct intrs ctxt = 
   let
     val is_already_eqvt = 
       filter (is_eqvt ctxt) pred_trms
@@ -93,7 +97,7 @@
     val _ = if null is_already_eqvt then ()
       else error ("Already equivariant: " ^ commas is_already_eqvt)
 
-    val pred_names = map (fst o dest_Const) pred_trms
+    val pred_names = map get_name pred_trms
     val raw_induct' = atomize_induct ctxt raw_induct
     val intrs' = map atomize_intr intrs
   
@@ -121,13 +125,22 @@
     else (thms, ctxt) 
   end
 
+fun equivariance pred_name ctxt =
+  let
+    val thy = ProofContext.theory_of ctxt
+    val (_, {preds, raw_induct, intrs, ...}) =
+      Inductive.the_inductive ctxt (Sign.intern_const thy pred_name)
+  in
+    raw_equivariance false preds raw_induct intrs ctxt 
+  end
+
 fun equivariance_cmd pred_name ctxt =
   let
     val thy = ProofContext.theory_of ctxt
     val (_, {preds, raw_induct, intrs, ...}) =
       Inductive.the_inductive ctxt (Sign.intern_const thy pred_name)
   in
-    equivariance true preds raw_induct intrs ctxt |> snd
+    raw_equivariance true preds raw_induct intrs ctxt |> snd
   end
 
 local structure P = Parse and K = Keyword in
--- a/Nominal/nominal_thmdecls.ML	Thu Jan 06 23:06:45 2011 +0000
+++ b/Nominal/nominal_thmdecls.ML	Fri Jan 07 02:30:00 2011 +0000
@@ -74,7 +74,7 @@
 fun is_eqvt ctxt trm =
   case trm of 
     (c as Const _) => Termtab.defined (EqvtRawData.get (Context.Proof ctxt)) c
-  | _ => raise TERM ("Term must be a constsnt.", [trm])
+  | _ => false (* raise TERM ("Term must be a constant.", [trm]) *)
 
 
 
@@ -89,13 +89,21 @@
    | Abs (_, _, t) => get_perms t 
    | _ => []
 
-fun put_perm p trm =
-  case trm of 
-     Bound _ => trm
-   | Const _ => trm
-   | t $ u => put_perm p t $ put_perm p u
-   | Abs (x, ty, t) => Abs (x, ty, put_perm p t)
-   | _ => mk_perm p trm
+fun add_perm p trm =
+  let
+    fun aux trm = 
+      case trm of 
+        Bound _ => trm
+      | Const _ => trm
+      | t $ u => aux t $ aux u
+      | Abs (x, ty, t) => Abs (x, ty, aux t)
+      | _ => mk_perm p trm
+  in
+    strip_comb trm
+    ||> map aux
+    |> list_comb
+  end  
+
 
 (* tests whether there is a disagreement between the permutations, 
    and that there is at least one permutation *)
@@ -132,7 +140,7 @@
       then error (msg ^ "(constants do not agree).")
     else if is_bad_list (p :: ps)  
       then error (msg ^ "(permutations do not agree).") 
-    else if not (rhs aconv (put_perm p t))
+    else if not (rhs aconv (add_perm p t))
       then error (msg ^ "(arguments do not agree).")
     else if is_Const t 
       then safe_mk_equiv thm
@@ -175,7 +183,7 @@
       then error (msg ^ "(constants do not agree).")
     else if is_bad_list ps  
       then error (msg ^ "(permutations do not agree).") 
-    else if not (concl aconv (put_perm (the p) prem)) 
+    else if not (concl aconv (add_perm (the p) prem)) 
       then error (msg ^ "(arguments do not agree).")
     else 
       let