Nominal/nominal_inductive.ML
changeset 2639 a8fc346deda3
child 2645 09cf78bb53d4
equal deleted inserted replaced
2638:e1e2ca92760b 2639:a8fc346deda3
       
     1 (*  Title:      nominal_inductive.ML
       
     2     Author:     Christian Urban
       
     3 
       
     4     Infrastructure for proving strong induction theorems
       
     5     for inductive predicates involving nominal datatypes.
       
     6 
       
     7     Code based on an earlier version by Stefan Berghofer.
       
     8 *)
       
     9 
       
    10 
       
    11 signature NOMINAL_INDUCTIVE =
       
    12 sig
       
    13   val prove_strong_inductive: string list -> string list -> term list list -> thm -> thm list -> 
       
    14     Proof.context -> Proof.state
       
    15 
       
    16   val prove_strong_inductive_cmd: xstring * (string * string list) list -> Proof.context -> Proof.state
       
    17 end
       
    18 
       
    19 structure Nominal_Inductive : NOMINAL_INDUCTIVE =
       
    20 struct
       
    21 
       
    22 
       
    23 fun mk_cplus p q = Thm.capply (Thm.capply @{cterm "plus :: perm => perm => perm"} p) q 
       
    24 
       
    25 fun mk_cminus p = Thm.capply @{cterm "uminus :: perm => perm"} p 
       
    26 
       
    27 
       
    28 fun minus_permute_intro_tac p = 
       
    29   rtac (Drule.instantiate' [] [SOME (mk_cminus p)] @{thm permute_boolE})
       
    30 
       
    31 fun minus_permute_elim p thm = 
       
    32   thm RS (Drule.instantiate' [] [NONE, SOME (mk_cminus p)] @{thm permute_boolI})
       
    33 
       
    34 fun real_head_of (@{term Trueprop} $ t) = real_head_of t
       
    35   | real_head_of (Const ("==>", _) $ _ $ t) = real_head_of t
       
    36   | real_head_of (Const (@{const_name all}, _) $ Abs (_, _, t)) = real_head_of t
       
    37   | real_head_of (Const (@{const_name All}, _) $ Abs (_, _, t)) = real_head_of t
       
    38   | real_head_of (Const ("HOL.induct_forall", _) $ Abs (_, _, t)) = real_head_of t
       
    39   | real_head_of t = head_of t  
       
    40 
       
    41 
       
    42 fun mk_vc_compat (avoid, avoid_trm) prems concl_args params = 
       
    43   let
       
    44     val vc_goal = concl_args
       
    45       |> HOLogic.mk_tuple
       
    46       |> mk_fresh_star avoid_trm 
       
    47       |> HOLogic.mk_Trueprop
       
    48       |> (curry Logic.list_implies) prems
       
    49       |> (curry list_all_free) params
       
    50     val finite_goal = avoid_trm
       
    51       |> mk_finite
       
    52       |> HOLogic.mk_Trueprop
       
    53       |> (curry Logic.list_implies) prems
       
    54       |> (curry list_all_free) params
       
    55   in 
       
    56     if null avoid then [] else [vc_goal, finite_goal]
       
    57   end
       
    58 
       
    59 fun map_term prop f trm =
       
    60   if prop trm 
       
    61   then f trm
       
    62   else case trm of
       
    63     (t1 $ t2) => map_term prop f t1 $ map_term prop f t2
       
    64   | Abs (x, T, t) => Abs (x, T, map_term prop f t)
       
    65   | _ => trm
       
    66 
       
    67 fun add_p_c p (c, c_ty) trm =
       
    68   let
       
    69     val (P, args) = strip_comb trm
       
    70     val (P_name, P_ty) = dest_Free P
       
    71     val (ty_args, bool) = strip_type P_ty
       
    72     val args' = map (mk_perm p) args
       
    73   in
       
    74     list_comb (Free (P_name, (c_ty :: ty_args) ---> bool),  c :: args')
       
    75     |> (fn t => HOLogic.all_const c_ty $ lambda c t )
       
    76     |> (fn t => HOLogic.all_const @{typ perm} $  lambda p t)
       
    77   end
       
    78 
       
    79 fun induct_forall_const T = Const ("HOL.induct_forall", (T --> @{typ bool}) --> @{typ bool})
       
    80 fun mk_induct_forall (a, T) t =  induct_forall_const T $ Abs (a, T, t)
       
    81 
       
    82 fun add_c_prop qnt Ps (c, c_name, c_ty) trm =
       
    83   let
       
    84     fun add t = 
       
    85       let
       
    86         val (P, args) = strip_comb t
       
    87         val (P_name, P_ty) = dest_Free P
       
    88         val (ty_args, bool) = strip_type P_ty
       
    89         val args' = args
       
    90           |> qnt ? map (incr_boundvars 1)
       
    91       in
       
    92         list_comb (Free (P_name, (c_ty :: ty_args) ---> bool), c :: args')
       
    93         |> qnt ? mk_induct_forall (c_name, c_ty)
       
    94       end
       
    95   in
       
    96     map_term (member (op =) Ps o head_of) add trm
       
    97   end
       
    98 
       
    99 fun prep_prem Ps c_name c_ty (avoid, avoid_trm) (params, prems, concl) =
       
   100   let
       
   101     val prems' = prems
       
   102       |> map (incr_boundvars 1) 
       
   103       |> map (add_c_prop true Ps (Bound 0, c_name, c_ty))
       
   104 
       
   105     val avoid_trm' = avoid_trm
       
   106       |> (curry list_abs_free) (params @ [(c_name, c_ty)])
       
   107       |> strip_abs_body
       
   108       |> (fn t => mk_fresh_star_ty c_ty t (Bound 0))
       
   109       |> HOLogic.mk_Trueprop
       
   110 
       
   111     val prems'' = 
       
   112       if null avoid 
       
   113       then prems' 
       
   114       else avoid_trm' :: prems'
       
   115 
       
   116     val concl' = concl
       
   117       |> incr_boundvars 1 
       
   118       |> add_c_prop false Ps (Bound 0, c_name, c_ty)  
       
   119   in
       
   120     mk_full_horn (params @ [(c_name, c_ty)]) prems'' concl'
       
   121   end
       
   122 
       
   123 fun same_name (Free (a1, _), Free (a2, _)) = (a1 = a2)
       
   124   | same_name (Var (a1, _), Var (a2, _)) = (a1 = a2)
       
   125   | same_name (Const (a1, _), Const (a2, _)) = (a1 = a2)
       
   126   | same_name _ = false
       
   127 
       
   128 fun map7 _ [] [] [] [] [] [] [] = []
       
   129   | map7 f (x :: xs) (y :: ys) (z :: zs) (u :: us) (v :: vs) (r :: rs) (s :: ss) = 
       
   130       f x y z u v r s :: map7 f xs ys zs us vs rs ss
       
   131 
       
   132 (* local abbreviations *)
       
   133 fun eqvt_stac ctxt = Nominal_Permeq.eqvt_strict_tac ctxt @{thms permute_minus_cancel} []  
       
   134 fun eqvt_srule ctxt = Nominal_Permeq.eqvt_strict_rule ctxt @{thms permute_minus_cancel} []  
       
   135 
       
   136 val all_elims = 
       
   137   let
       
   138      fun spec' ct = Drule.instantiate' [SOME (ctyp_of_term ct)] [NONE, SOME ct] @{thm spec}
       
   139   in
       
   140     fold (fn ct => fn th => th RS spec' ct)
       
   141   end
       
   142 
       
   143 fun helper_tac flag prm p ctxt =
       
   144   Subgoal.SUBPROOF (fn {context, prems, ...} =>
       
   145     let
       
   146       val prems' = prems
       
   147         |> map (minus_permute_elim p)
       
   148         |> map (eqvt_srule context)
       
   149 
       
   150       val prm' = (prems' MRS prm)
       
   151         |> flag ? (all_elims [p])
       
   152         |> flag ? (eqvt_srule context)
       
   153 
       
   154       val _ = tracing ("prm':" ^ @{make_string} prm')
       
   155     in
       
   156       print_tac "start helper"
       
   157       THEN asm_full_simp_tac (HOL_ss addsimps (prm' :: @{thms induct_forall_def})) 1
       
   158       THEN print_tac "final helper"
       
   159     end) ctxt
       
   160 
       
   161 fun non_binder_tac prem intr_cvars Ps ctxt = 
       
   162   Subgoal.SUBPROOF (fn {context, params, prems, ...} =>
       
   163     let
       
   164       val thy = ProofContext.theory_of context
       
   165       val (prms, p, _) = split_last2 (map snd params)
       
   166       val prm_tys = map (fastype_of o term_of) prms
       
   167       val cperms = map (cterm_of thy o perm_const) prm_tys
       
   168       val p_prms = map2 (fn ct1 => fn ct2 => Thm.mk_binop ct1 p ct2) cperms prms 
       
   169       val prem' = cterm_instantiate (intr_cvars ~~ p_prms) prem
       
   170 
       
   171       (* for inductive-premises*)
       
   172       fun tac1 prm = helper_tac true prm p context 
       
   173 
       
   174       (* for non-inductive premises *)   
       
   175       fun tac2 prm =  
       
   176         EVERY' [ minus_permute_intro_tac p, 
       
   177                  eqvt_stac context, 
       
   178                  helper_tac false prm p context ]
       
   179 
       
   180       fun select prm (t, i) =
       
   181         (if member same_name Ps (real_head_of t) then tac1 prm else tac2 prm) i
       
   182     in
       
   183       EVERY1 [eqvt_stac ctxt, rtac prem', RANGE (map (SUBGOAL o select) prems) ]
       
   184     end) ctxt
       
   185 
       
   186 fun fresh_thm ctxt user_thm p c concl_args avoid_trm =
       
   187   let
       
   188     val conj1 = 
       
   189       mk_fresh_star (mk_perm (Bound 0) (mk_perm p avoid_trm)) c
       
   190     val conj2 =
       
   191       mk_fresh_star_ty @{typ perm} (mk_supp (HOLogic.mk_tuple (map (mk_perm p) concl_args))) (Bound 0)
       
   192     val fresh_goal = mk_exists ("q", @{typ perm}) (HOLogic.mk_conj (conj1, conj2))
       
   193       |> HOLogic.mk_Trueprop
       
   194 
       
   195     val ss = @{thms finite_supp supp_Pair finite_Un permute_finite} @ 
       
   196              @{thms fresh_star_Pair fresh_star_permute_iff}
       
   197     val simp = asm_full_simp_tac (HOL_ss addsimps ss)
       
   198   in 
       
   199     Goal.prove ctxt [] [] fresh_goal
       
   200       (K (HEADGOAL (rtac @{thm at_set_avoiding2} 
       
   201           THEN_ALL_NEW EVERY' [cut_facts_tac user_thm, REPEAT o etac @{thm conjE}, simp])))
       
   202   end
       
   203 
       
   204 val supp_perm_eq' = @{lemma "fresh_star (supp (permute p x)) q ==> permute p x == permute (q + p) x" 
       
   205   by (simp add: supp_perm_eq)}
       
   206 val fresh_star_plus = @{lemma "fresh_star (permute q (permute p x)) c ==> fresh_star (permute (q + p) x) c" 
       
   207   by (simp add: permute_plus)}
       
   208 
       
   209 
       
   210 fun binder_tac prem intr_cvars param_trms Ps user_thm avoid avoid_trm concl_args ctxt = 
       
   211   Subgoal.FOCUS (fn {context = ctxt, params, prems, concl, ...} =>
       
   212     let
       
   213       val thy = ProofContext.theory_of ctxt
       
   214       val (prms, p, c) = split_last2 (map snd params)
       
   215       val prm_trms = map term_of prms
       
   216       val prm_tys = map fastype_of prm_trms
       
   217 
       
   218       val avoid_trm' = subst_free (param_trms ~~ prm_trms) avoid_trm 
       
   219       val concl_args' = map (subst_free (param_trms ~~ prm_trms)) concl_args 
       
   220       
       
   221       val user_thm' = map (cterm_instantiate (intr_cvars ~~ prms)) user_thm
       
   222         |> map (full_simplify (HOL_ss addsimps (@{thm fresh_star_Pair}::prems)))
       
   223       
       
   224       val fthm = fresh_thm ctxt user_thm' (term_of p) (term_of c) concl_args' avoid_trm'
       
   225 
       
   226       val (([(_, q)], fprop :: fresh_eqs), ctxt') = Obtain.result
       
   227               (K (EVERY1 [etac @{thm exE}, 
       
   228                           full_simp_tac (HOL_basic_ss addsimps @{thms supp_Pair fresh_star_Un}), 
       
   229                           REPEAT o etac @{thm conjE},
       
   230                           dtac fresh_star_plus,
       
   231                           REPEAT o dtac supp_perm_eq'])) [fthm] ctxt 
       
   232 
       
   233       val expand_conv = Conv.try_conv (Conv.rewrs_conv fresh_eqs)
       
   234       fun expand_conv_bot ctxt = Conv.bottom_conv (K expand_conv) ctxt
       
   235 
       
   236       val cperms = map (cterm_of thy o perm_const) prm_tys
       
   237       val qp_prms = map2 (fn ct1 => fn ct2 => Thm.mk_binop ct1 (mk_cplus q p) ct2) cperms prms 
       
   238       val prem' = cterm_instantiate (intr_cvars ~~ qp_prms) prem
       
   239 
       
   240       val fprop' = eqvt_srule ctxt' fprop 
       
   241       val tac_fresh = simp_tac (HOL_basic_ss addsimps [fprop'])
       
   242 
       
   243       (* for inductive-premises*)
       
   244       fun tac1 prm = helper_tac true prm (mk_cplus q p) ctxt' 
       
   245 
       
   246       (* for non-inductive premises *)   
       
   247       fun tac2 prm =  
       
   248         EVERY' [ minus_permute_intro_tac (mk_cplus q p), 
       
   249                  eqvt_stac ctxt, 
       
   250                  helper_tac false prm (mk_cplus q p) ctxt' ]
       
   251 
       
   252       fun select prm (t, i) =
       
   253         (if member same_name Ps (real_head_of t) then tac1 prm else tac2 prm) i
       
   254 
       
   255       val _ = tracing ("fthm:\n" ^ @{make_string} fthm)
       
   256       val _ = tracing ("fr_eqs:\n" ^ cat_lines (map @{make_string} fresh_eqs))
       
   257       val _ = tracing ("fprop:\n" ^ @{make_string} fprop)
       
   258       val _ = tracing ("fprop':\n" ^ @{make_string} fprop')
       
   259       val _ = tracing ("fperm:\n" ^ @{make_string} q)
       
   260       val _ = tracing ("prem':\n" ^ @{make_string} prem')
       
   261 
       
   262       val side_thm = Goal.prove ctxt' [] [] (term_of concl)
       
   263         (fn {context, ...} => 
       
   264            EVERY1 [ CONVERSION (expand_conv_bot context),
       
   265                     eqvt_stac context,
       
   266                     rtac prem',
       
   267                     RANGE (tac_fresh :: map (SUBGOAL o select) prems),
       
   268                     K (print_tac "GOAL") ])
       
   269         |> singleton (ProofContext.export ctxt' ctxt)        
       
   270     in
       
   271       rtac side_thm 1
       
   272     end) ctxt
       
   273 
       
   274 fun case_tac ctxt Ps avoid avoid_trm intr_cvars param_trms prem user_thm concl_args =
       
   275   let
       
   276     val tac1 = non_binder_tac prem intr_cvars Ps ctxt
       
   277     val tac2 = binder_tac prem intr_cvars param_trms Ps user_thm avoid avoid_trm concl_args ctxt
       
   278   in 
       
   279     EVERY' [ rtac @{thm allI}, rtac @{thm allI}, if null avoid then tac1 else tac2 ]
       
   280   end
       
   281 
       
   282 fun prove_sinduct_tac raw_induct user_thms Ps avoids avoid_trms intr_cvars param_trms concl_args 
       
   283   {prems, context} =
       
   284   let
       
   285     val cases_tac = 
       
   286       map7 (case_tac context Ps) avoids avoid_trms intr_cvars param_trms prems user_thms concl_args
       
   287   in 
       
   288     EVERY1 [ DETERM o rtac raw_induct, RANGE cases_tac ]
       
   289   end
       
   290 
       
   291 val normalise = @{lemma "(Q --> (!p c. P p c)) ==> (!!c. Q ==> P (0::perm) c)" by simp}
       
   292 
       
   293 fun prove_strong_inductive pred_names rule_names avoids raw_induct intrs ctxt =
       
   294   let
       
   295     val thy = ProofContext.theory_of ctxt
       
   296     val ((_, [raw_induct']), ctxt') = Variable.import true [raw_induct] ctxt
       
   297 
       
   298     val (ind_prems, ind_concl) = raw_induct'
       
   299       |> prop_of
       
   300       |> Logic.strip_horn
       
   301       |>> map strip_full_horn
       
   302     val params = map (fn (x, _, _) => x) ind_prems
       
   303     val param_trms = (map o map) Free params  
       
   304 
       
   305     val intr_vars_tys = map (fn t => rev (Term.add_vars (prop_of t) [])) intrs
       
   306     val intr_vars = (map o map) fst intr_vars_tys
       
   307     val intr_vars_substs = map2 (curry (op ~~)) intr_vars param_trms
       
   308     val intr_cvars = (map o map) (cterm_of thy o Var) intr_vars_tys      
       
   309 
       
   310     val (intr_prems, intr_concls) = intrs
       
   311       |> map prop_of
       
   312       |> map2 subst_Vars intr_vars_substs
       
   313       |> map Logic.strip_horn
       
   314       |> split_list
       
   315 
       
   316     val intr_concls_args = map (snd o strip_comb o HOLogic.dest_Trueprop) intr_concls 
       
   317       
       
   318     val avoid_trms = avoids
       
   319       |> (map o map) (setify ctxt') 
       
   320       |> map fold_union
       
   321 
       
   322     val vc_compat_goals = 
       
   323       map4 mk_vc_compat (avoids ~~ avoid_trms) intr_prems intr_concls_args params
       
   324 
       
   325     val ([c_name, a, p], ctxt'') = Variable.variant_fixes ["c", "'a", "p"] ctxt'
       
   326     val c_ty = TFree (a, @{sort fs})
       
   327     val c = Free (c_name, c_ty)
       
   328     val p = Free (p, @{typ perm})
       
   329 
       
   330     val (preconds, ind_concls) = ind_concl
       
   331       |> HOLogic.dest_Trueprop
       
   332       |> HOLogic.dest_conj 
       
   333       |> map HOLogic.dest_imp
       
   334       |> split_list
       
   335 
       
   336     val Ps = map (fst o strip_comb) ind_concls
       
   337 
       
   338     val ind_concl' = ind_concls
       
   339       |> map (add_p_c p (c, c_ty))
       
   340       |> (curry (op ~~)) preconds  
       
   341       |> map HOLogic.mk_imp
       
   342       |> fold_conj
       
   343       |> HOLogic.mk_Trueprop
       
   344 
       
   345     val ind_prems' = ind_prems
       
   346       |> map2 (prep_prem Ps c_name c_ty) (avoids ~~ avoid_trms)   
       
   347 
       
   348     fun after_qed ctxt_outside user_thms ctxt = 
       
   349       let
       
   350         val strong_ind_thms = Goal.prove ctxt [] ind_prems' ind_concl' 
       
   351         (prove_sinduct_tac raw_induct user_thms Ps avoids avoid_trms intr_cvars param_trms intr_concls_args) 
       
   352           |> singleton (ProofContext.export ctxt ctxt_outside)
       
   353           |> Datatype_Aux.split_conj_thm
       
   354           |> map (fn thm => thm RS normalise)
       
   355           |> map (asm_full_simplify (HOL_basic_ss addsimps @{thms permute_zero induct_rulify})) 
       
   356           |> map (Drule.rotate_prems (length ind_prems'))
       
   357           |> map zero_var_indexes
       
   358 
       
   359         val qualified_thm_name = pred_names
       
   360           |> map Long_Name.base_name
       
   361           |> space_implode "_"
       
   362           |> (fn s => Binding.qualify false s (Binding.name "strong_induct"))
       
   363 
       
   364         val attrs = 
       
   365           [ Attrib.internal (K (Rule_Cases.consumes 1)),
       
   366             Attrib.internal (K (Rule_Cases.case_names rule_names)) ]
       
   367         val _ = tracing ("RESULTS\n" ^ cat_lines (map (Syntax.string_of_term ctxt o prop_of) strong_ind_thms))
       
   368         val _ = tracing ("rule_names: " ^ commas rule_names)
       
   369         val _ = tracing ("pred_names: " ^ commas pred_names)
       
   370       in
       
   371         ctxt
       
   372         |> Local_Theory.note ((qualified_thm_name, attrs), strong_ind_thms)    
       
   373         |> snd   
       
   374       end
       
   375   in
       
   376     Proof.theorem NONE (after_qed ctxt) ((map o map) (rpair []) vc_compat_goals) ctxt''
       
   377   end
       
   378 
       
   379 fun prove_strong_inductive_cmd (pred_name, avoids) ctxt =
       
   380   let
       
   381     val thy = ProofContext.theory_of ctxt;
       
   382     val ({names, ...}, {raw_induct, intrs, ...}) =
       
   383       Inductive.the_inductive ctxt (Sign.intern_const thy pred_name);
       
   384 
       
   385     val rule_names = 
       
   386       hd names
       
   387       |> the o Induct.lookup_inductP ctxt
       
   388       |> fst o Rule_Cases.get
       
   389       |> map fst
       
   390 
       
   391     val _ = (case duplicates (op = o pairself fst) avoids of
       
   392         [] => ()
       
   393       | xs => error ("Duplicate case names: " ^ commas_quote (map fst xs)))
       
   394 
       
   395     val _ = (case subtract (op =) rule_names (map fst avoids) of
       
   396         [] => ()
       
   397       | xs => error ("No such case(s) in inductive definition: " ^ commas_quote xs))
       
   398 
       
   399     val avoids_ordered = order_default (op =) [] rule_names avoids
       
   400       
       
   401     fun read_avoids avoid_trms intr =
       
   402       let
       
   403         (* fixme hack *)
       
   404         val (((_, ctrms), _), ctxt') = Variable.import true [intr] ctxt
       
   405         val trms = map (term_of o snd) ctrms
       
   406         val ctxt'' = fold Variable.declare_term trms ctxt' 
       
   407       in
       
   408         map (Syntax.read_term ctxt'') avoid_trms 
       
   409       end 
       
   410 
       
   411     val avoid_trms = map2 read_avoids avoids_ordered intrs
       
   412   in
       
   413     prove_strong_inductive names rule_names avoid_trms raw_induct intrs ctxt
       
   414   end
       
   415 
       
   416 (* outer syntax *)
       
   417 local
       
   418   structure P = Parse;
       
   419   structure S = Scan
       
   420   
       
   421   val _ = Keyword.keyword "avoids"
       
   422 
       
   423   val single_avoid_parser = 
       
   424     P.name -- (P.$$$ ":" |-- P.and_list1 P.term)
       
   425 
       
   426   val avoids_parser = 
       
   427     S.optional (P.$$$ "avoids" |-- P.enum1 "|" single_avoid_parser) []
       
   428 
       
   429   val main_parser = P.xname -- avoids_parser
       
   430 in
       
   431   val _ =
       
   432   Outer_Syntax.local_theory_to_proof "nominal_inductive"
       
   433     "prove strong induction theorem for inductive predicate involving nominal datatypes"
       
   434       Keyword.thy_goal (main_parser >> prove_strong_inductive_cmd)
       
   435 end
       
   436 
       
   437 end