Nominal/nominal_function_core.ML
changeset 2821 c7d4bd9e89e0
parent 2819 4bd584ff4fab
child 2822 23befefc6e73
--- a/Nominal/nominal_function_core.ML	Mon Jun 06 13:11:04 2011 +0100
+++ b/Nominal/nominal_function_core.ML	Tue Jun 07 08:52:59 2011 +0100
@@ -7,6 +7,49 @@
 Core of the nominal function package.
 *)
 
+
+structure Nominal_Function_Common =
+struct
+
+
+(* Configuration management *)
+datatype nominal_function_opt
+  = Sequential
+  | Default of string
+  | DomIntros
+  | No_Partials
+  | Invariant of string
+
+datatype nominal_function_config = NominalFunctionConfig of
+ {sequential: bool,
+  default: string option,
+  domintros: bool,
+  partials: bool,
+  inv: string option}
+
+fun apply_opt Sequential (NominalFunctionConfig {sequential, default, domintros, partials, inv}) =
+    NominalFunctionConfig 
+      {sequential=true, default=default, domintros=domintros, partials=partials, inv=inv}
+  | apply_opt (Default d) (NominalFunctionConfig {sequential, default, domintros, partials, inv}) =
+    NominalFunctionConfig 
+      {sequential=sequential, default=SOME d, domintros=domintros, partials=partials, inv=inv}
+  | apply_opt DomIntros (NominalFunctionConfig {sequential, default, domintros, partials, inv}) =
+    NominalFunctionConfig 
+      {sequential=sequential, default=default, domintros=true, partials=partials, inv=inv}
+  | apply_opt No_Partials (NominalFunctionConfig {sequential, default, domintros, partials, inv}) =
+    NominalFunctionConfig 
+      {sequential=sequential, default=default, domintros=domintros, partials=false, inv=inv}
+  | apply_opt (Invariant s) (NominalFunctionConfig {sequential, default, domintros, partials, inv}) =
+    NominalFunctionConfig 
+      {sequential=sequential, default=default, domintros=domintros, partials=partials, inv = SOME s}
+
+val nominal_default_config =
+  NominalFunctionConfig { sequential=false, default=NONE,
+    domintros=false, partials=true, inv=NONE}
+
+end
+
+
 signature NOMINAL_FUNCTION_CORE =
 sig
   val trace: bool Unsynchronized.ref
@@ -18,7 +61,7 @@
     -> local_theory
     -> (term   (* f *)
         * thm  (* goalstate *)
-        * (thm -> Nominal_Function_Common.function_result) (* continuation *)
+        * (thm -> Function_Common.function_result) (* continuation *)
        ) * local_theory
 
 end
@@ -33,6 +76,7 @@
 val mk_eq = HOLogic.mk_eq
 
 open Function_Lib
+open Function_Common
 open Nominal_Function_Common
 
 datatype globals = Globals of
@@ -123,6 +167,20 @@
     |> HOLogic.mk_Trueprop
   end
 
+fun mk_inv inv (f_trm, arg_trm) = 
+  betapplys (inv, [arg_trm, (f_trm $ arg_trm)])
+  |> HOLogic.mk_Trueprop
+
+fun mk_invariant (Globals {x, y, ...}) G invariant = 
+  let
+    val prem = HOLogic.mk_Trueprop (G $ x $ y)
+    val concl = HOLogic.mk_Trueprop (betapplys (invariant, [x, y]))
+  in
+    Logic.mk_implies (prem, concl)
+    |> mk_forall_rename ("y", y)
+    |> mk_forall_rename ("x", x)
+  end  
+
 (** building proof obligations *)
 fun mk_eqvt_proof_obligation qs fvar (vs, assms, arg) = 
   mk_eqvt_at (fvar, arg)
@@ -131,18 +189,27 @@
   |> curry Term.list_abs_free qs
   |> strip_abs_body
 
+fun mk_inv_proof_obligation inv qs fvar (vs, assms, arg) = 
+  mk_inv inv (fvar, arg)
+  |> curry Logic.list_implies (map prop_of assms)
+  |> curry Term.list_all_free vs
+  |> curry Term.list_abs_free qs
+  |> strip_abs_body
+
 (** building proof obligations *)
-fun mk_compat_proof_obligations domT ranT fvar f RCss glrs =
+fun mk_compat_proof_obligations domT ranT fvar f RCss inv glrs =
   let
     fun mk_impl (((qs, gs, lhs, rhs), RCs), ((qs', gs', lhs', rhs'), _)) =
       let
         val shift = incr_boundvars (length qs')
         val eqvts_proof_obligations = map (shift o mk_eqvt_proof_obligation qs fvar) RCs
+        val invs_proof_obligations = map (shift o mk_inv_proof_obligation inv qs fvar) RCs
       in
         Logic.mk_implies
           (HOLogic.mk_Trueprop (HOLogic.eq_const domT $ shift lhs $ lhs'),
             HOLogic.mk_Trueprop (HOLogic.eq_const ranT $ shift rhs $ rhs'))
         |> fold_rev (curry Logic.mk_implies) (map shift gs @ gs')
+        |> fold_rev (curry Logic.mk_implies) invs_proof_obligations (* nominal *)
         |> fold_rev (curry Logic.mk_implies) eqvts_proof_obligations (* nominal *)
         |> fold_rev (fn (n,T) => fn b => Term.all T $ Abs(n,T,b)) (qs @ qs')
         |> curry abstract_over fvar
@@ -152,7 +219,6 @@
     map mk_impl (unordered_pairs (glrs ~~ RCss))
   end
 
-
 fun mk_completeness (Globals {x, Pbool, ...}) clauses qglrs =
   let
     fun mk_case (ClauseContext {qs, gs, lhs, ...}, (oqs, _, _, _)) =
@@ -260,7 +326,7 @@
 (* nominal *)
 (* Returns "Gsi, Gsj, lhs_i = lhs_j |-- rhs_j_f = rhs_i_f" *)
 (* if j < i, then turn around *)
-fun get_compat_thm thy cts eqvtsi eqvtsj i j ctxi ctxj =
+fun get_compat_thm thy cts eqvtsi eqvtsj invsi invsj i j ctxi ctxj =
   let
     val ClauseContext {cqs=cqsi,ags=agsi,lhs=lhsi,case_hyp=case_hypi,...} = ctxi
     val ClauseContext {cqs=cqsj,ags=agsj,lhs=lhsj,case_hyp=case_hypj,...} = ctxj
@@ -273,6 +339,7 @@
       compat         (* "!!qj qi. Gsj => Gsi => lhsj = lhsi ==> rhsj = rhsi" *)
       |> fold Thm.forall_elim (cqsj @ cqsi) (* "Gsj => Gsi => lhsj = lhsi ==> rhsj = rhsi" *)
       |> fold Thm.elim_implies eqvtsj (* nominal *)
+      |> fold Thm.elim_implies invsj (* nominal *)
       |> fold Thm.elim_implies agsj
       |> fold Thm.elim_implies agsi
       |> Thm.elim_implies ((Thm.assume lhsi_eq_lhsj) RS sym) (* "Gsj, Gsi, lhsi = lhsj |-- rhsj = rhsi" *)
@@ -284,6 +351,7 @@
       compat        (* "!!qi qj. Gsi => Gsj => lhsi = lhsj ==> rhsi = rhsj" *)
       |> fold Thm.forall_elim (cqsi @ cqsj) (* "Gsi => Gsj => lhsi = lhsj ==> rhsi = rhsj" *)
       |> fold Thm.elim_implies eqvtsi  (* nominal *)
+      |> fold Thm.elim_implies invsi  (* nominal *)
       |> fold Thm.elim_implies agsi
       |> fold Thm.elim_implies agsj
       |> Thm.elim_implies (Thm.assume lhsi_eq_lhsj)
@@ -347,7 +415,31 @@
   end
 
 (* nominal *)
-fun mk_uniqueness_clause thy globals compat_store eqvts clausei clausej RLj =
+fun mk_invariant_lemma thy ih_inv clause =
+  let
+    val ClauseInfo {cdata=ClauseContext {cqs, ags, case_hyp, ...}, RCs, ...} = clause
+     
+    local open Conv in
+      val ih_conv = arg1_conv o arg_conv o arg_conv
+    end
+
+    val ih_inv_case =
+      Conv.fconv_rule (ih_conv (K (case_hyp RS eq_reflection))) ih_inv
+
+    fun prep_inv (RCInfo {llRI, RIvs, CCas, ...}) = 
+        (llRI RS ih_inv_case)
+        |> fold_rev (Thm.implies_intr o cprop_of) CCas
+        |> fold_rev (Thm.forall_intr o cterm_of thy o Free) RIvs 
+  in
+    map prep_inv RCs
+    |> map (fold_rev (Thm.implies_intr o cprop_of) ags)
+    |> map (Thm.implies_intr (cprop_of case_hyp))
+    |> map (fold_rev Thm.forall_intr cqs)
+    |> map (Thm.close_derivation) 
+  end
+
+(* nominal *)
+fun mk_uniqueness_clause thy globals compat_store eqvts invs clausei clausej RLj =
   let
     val Globals {h, y, x, fvar, ...} = globals
     val ClauseInfo {no=i, cdata=cctxi as ClauseContext {ctxt=ctxti, lhs=lhsi, case_hyp, cqs = cqsi, 
@@ -383,7 +475,17 @@
       |> map (fold Thm.elim_implies [case_hypj'])
       |> map (fold Thm.elim_implies agsj')
 
-    val compat = get_compat_thm thy compat_store eqvtsi eqvtsj i j cctxi cctxj
+    val invsi = nth invs (i - 1)
+      |> map (fold Thm.forall_elim cqsi)
+      |> map (fold Thm.elim_implies [case_hyp])
+      |> map (fold Thm.elim_implies agsi)
+
+    val invsj = nth invs (j - 1)
+      |> map (fold Thm.forall_elim cqsj')
+      |> map (fold Thm.elim_implies [case_hypj'])
+      |> map (fold Thm.elim_implies agsj')
+
+    val compat = get_compat_thm thy compat_store eqvtsi eqvtsj invsi invsj i j cctxi cctxj
 
   in
     (trans OF [case_hyp, lhsi_eq_lhsj']) (* lhs_i = lhs_j' |-- x = lhs_j' *)
@@ -402,7 +504,8 @@
   end
 
 (* nominal *)
-fun mk_uniqueness_case thy globals G f ihyp ih_intro G_cases compat_store clauses replems eqvtlems clausei =
+fun mk_uniqueness_case thy globals G f ihyp ih_intro G_cases compat_store clauses replems eqvtlems invlems 
+  clausei =
   let
     val Globals {x, y, ranT, fvar, ...} = globals
     val ClauseInfo {cdata = ClauseContext {lhs, rhs, cqs, ags, case_hyp, ...}, lGI, RCs, ...} = clausei
@@ -421,7 +524,7 @@
     val G_lhs_y = Thm.assume (cterm_of thy (HOLogic.mk_Trueprop (G $ lhs $ y)))
 
     val unique_clauses =
-      map2 (mk_uniqueness_clause thy globals compat_store eqvtlems clausei) clauses replems
+      map2 (mk_uniqueness_clause thy globals compat_store eqvtlems invlems clausei) clauses replems
 
     fun elim_implies_eta A AB =
       Thm.compose_no_flatten true (A, 0) 1 AB |> Seq.list_of |> the_single
@@ -459,7 +562,7 @@
 
 
 (* nominal *)
-fun prove_stuff ctxt globals G f R clauses complete compat compat_store G_elim G_eqvt f_def =
+fun prove_stuff ctxt globals G f R clauses complete compat compat_store G_elim G_eqvt invariant f_def =
   let
     val Globals {h, domT, ranT, x, ...} = globals
     val thy = ProofContext.theory_of ctxt
@@ -476,17 +579,21 @@
     val ih_elim = ihyp_thm RS (f_def RS ex1_implies_un)
       |> instantiate' [] [NONE, SOME (cterm_of thy h)]
     val ih_eqvt = ihyp_thm RS (G_eqvt RS (f_def RS @{thm fundef_ex1_eqvt_at}))
- 
+    val ih_inv =  ihyp_thm RS (invariant COMP (f_def RS @{thm fundef_ex1_prop}))
+
     val _ = trace_msg (K "Proving Replacement lemmas...")
     val repLemmas = map (mk_replacement_lemma thy h ih_elim) clauses
 
     val _ = trace_msg (K "Proving Equivariance lemmas...")
     val eqvtLemmas = map (mk_eqvt_lemma thy ih_eqvt) clauses
 
+    val _ = trace_msg (K "Proving Invariance lemmas...")
+    val invLemmas = map (mk_invariant_lemma thy ih_inv) clauses
+
     val _ = trace_msg (K "Proving cases for unique existence...")
     val (ex1s, values) =
       split_list (map (mk_uniqueness_case thy globals G f 
-        ihyp ih_intro G_elim compat_store clauses repLemmas eqvtLemmas) clauses)
+        ihyp ih_intro G_elim compat_store clauses repLemmas eqvtLemmas invLemmas) clauses)
      
     val _ = trace_msg (K "Proving: Graph is a function")
     val graph_is_function = complete
@@ -499,11 +606,12 @@
       |> (fn it => fold (Thm.forall_intr o cterm_of thy o Var) (Term.add_vars (prop_of it) []) it)
 
     val goalstate =  
-         Conjunction.intr (Conjunction.intr graph_is_function complete) G_eqvt 
+         Conjunction.intr (Conjunction.intr (Conjunction.intr graph_is_function complete) invariant) G_eqvt 
       |> Thm.close_derivation
       |> Goal.protect
       |> fold_rev (Thm.implies_intr o cprop_of) compat
       |> Thm.implies_intr (cprop_of complete)
+      |> Thm.implies_intr (cprop_of invariant)
       |> Thm.implies_intr (cprop_of G_eqvt)
   in
     (goalstate, values)
@@ -905,9 +1013,10 @@
 (* nominal *)
 fun prepare_nominal_function config defname [((fname, fT), mixfix)] abstract_qglrs lthy =
   let
-    val NominalFunctionConfig {domintros, default=default_opt, ...} = config
+    val NominalFunctionConfig {domintros, default=default_opt, inv=invariant_opt,...} = config
 
     val default_str = the_default "%x. undefined" default_opt (*FIXME dynamic scoping*)
+    val invariant_str = the_default "%x y. True" invariant_opt
     val fvar = Free (fname, fT)
     val domT = domain_type fT
     val ranT = range_type fT
@@ -915,6 +1024,9 @@
     val default = Syntax.parse_term lthy default_str
       |> Type.constraint fT |> Syntax.check_term lthy
 
+    val invariant_trm = Syntax.parse_term lthy invariant_str
+      |> Type.constraint ([domT, ranT] ---> @{typ bool}) |> Syntax.check_term lthy
+    
     val (globals, ctxt') = fix_globals domT ranT fvar lthy
 
     val Globals { x, h, ... } = globals
@@ -957,26 +1069,29 @@
       mk_completeness globals clauses abstract_qglrs |> cert |> Thm.assume
 
     val compat =
-      mk_compat_proof_obligations domT ranT fvar f RCss abstract_qglrs 
+      mk_compat_proof_obligations domT ranT fvar f RCss invariant_trm abstract_qglrs 
       |> map (cert #> Thm.assume)
 
     val G_eqvt = mk_eqvt G |> cert |> Thm.assume
  
+    val invariant = mk_invariant globals G invariant_trm |> cert |> Thm.assume
+ 
     val compat_store = store_compat_thms n compat
 
     val (goalstate, values) = PROFILE "prove_stuff"
       (prove_stuff lthy globals G f R xclauses complete compat
-         compat_store G_elim G_eqvt) f_defthm
+         compat_store G_elim G_eqvt invariant) f_defthm
      
     fun mk_partial_rules provedgoal =
       let
         val newthy = theory_of_thm provedgoal (*FIXME*)
 
-        val ((graph_is_function, complete_thm), _) =
+        val (graph_is_function, complete_thm) =
           provedgoal
+          |> fst o Conjunction.elim
+          |> fst o Conjunction.elim
           |> Conjunction.elim
-          |>> Conjunction.elim
-          |>> apfst (Thm.forall_elim_vars 0)
+          |> apfst (Thm.forall_elim_vars 0)
 
         val f_iff = graph_is_function RS (f_defthm RS ex1_implies_iff)