Nominal/nominal_induct.ML
changeset 2631 e73bd379e839
child 2632 e8732350a29f
equal deleted inserted replaced
2630:8268b277d240 2631:e73bd379e839
       
     1 (*  Author:     Christian Urban and Makarius
       
     2 
       
     3 The nominal induct proof method.
       
     4 *)
       
     5 
       
     6 structure NominalInduct:
       
     7 sig
       
     8   val nominal_induct_tac: Proof.context -> bool -> (binding option * (term * bool)) option list list ->
       
     9     (string * typ) list -> (string * typ) list list -> thm list -> thm list -> int -> Rule_Cases.cases_tactic
       
    10 
       
    11   val nominal_induct_method: (Proof.context -> Proof.method) context_parser
       
    12 end =
       
    13 
       
    14 struct
       
    15 
       
    16 (* proper tuples -- nested left *)
       
    17 
       
    18 fun tupleT Ts = HOLogic.unitT |> fold (fn T => fn U => HOLogic.mk_prodT (U, T)) Ts;
       
    19 fun tuple ts = HOLogic.unit |> fold (fn t => fn u => HOLogic.mk_prod (u, t)) ts;
       
    20 
       
    21 fun tuple_fun Ts (xi, T) =
       
    22   Library.funpow (length Ts) HOLogic.mk_split
       
    23     (Var (xi, (HOLogic.unitT :: Ts) ---> Term.range_type T));
       
    24 
       
    25 val split_all_tuples =
       
    26   Simplifier.full_simplify (HOL_basic_ss addsimps
       
    27     @{thms split_conv split_paired_all unit_all_eq1})
       
    28 (* 
       
    29      @{thm fresh_unit_elim}, @{thm fresh_prod_elim}] @
       
    30      @{thms fresh_star_unit_elim} @ @{thms fresh_star_prod_elim})
       
    31 *)
       
    32 
       
    33 
       
    34 (* prepare rule *)
       
    35 
       
    36 fun inst_mutual_rule ctxt insts avoiding rules =
       
    37   let
       
    38     val (nconcls, joined_rule) = Rule_Cases.strict_mutual_rule ctxt rules;
       
    39     val concls = Logic.dest_conjunctions (Thm.concl_of joined_rule);
       
    40     val (cases, consumes) = Rule_Cases.get joined_rule;
       
    41 
       
    42     val l = length rules;
       
    43     val _ =
       
    44       if length insts = l then ()
       
    45       else error ("Bad number of instantiations for " ^ string_of_int l ^ " rules");
       
    46 
       
    47     fun subst inst concl =
       
    48       let
       
    49         val vars = Induct.vars_of concl;
       
    50         val m = length vars and n = length inst;
       
    51         val _ = if m >= n + 2 then () else error "Too few variables in conclusion of rule";
       
    52         val P :: x :: ys = vars;
       
    53         val zs = drop (m - n - 2) ys;
       
    54       in
       
    55         (P, tuple_fun (map #2 avoiding) (Term.dest_Var P)) ::
       
    56         (x, tuple (map Free avoiding)) ::
       
    57         map_filter (fn (z, SOME t) => SOME (z, t) | _ => NONE) (zs ~~ inst)
       
    58       end;
       
    59      val substs =
       
    60        map2 subst insts concls |> flat |> distinct (op =)
       
    61        |> map (pairself (Thm.cterm_of (ProofContext.theory_of ctxt)));
       
    62   in 
       
    63     (((cases, nconcls), consumes), Drule.cterm_instantiate substs joined_rule) 
       
    64   end;
       
    65 
       
    66 fun rename_params_rule internal xs rule =
       
    67   let
       
    68     val tune =
       
    69       if internal then Name.internal
       
    70       else fn x => the_default x (try Name.dest_internal x);
       
    71     val n = length xs;
       
    72     fun rename prem =
       
    73       let
       
    74         val ps = Logic.strip_params prem;
       
    75         val p = length ps;
       
    76         val ys =
       
    77           if p < n then []
       
    78           else map (tune o #1) (take (p - n) ps) @ xs;
       
    79       in Logic.list_rename_params (ys, prem) end;
       
    80     fun rename_prems prop =
       
    81       let val (As, C) = Logic.strip_horn prop
       
    82       in Logic.list_implies (map rename As, C) end;
       
    83   in Thm.equal_elim (Thm.reflexive (Drule.cterm_fun rename_prems (Thm.cprop_of rule))) rule end;
       
    84 
       
    85 
       
    86 (* nominal_induct_tac *)
       
    87 
       
    88 fun nominal_induct_tac ctxt simp def_insts avoiding fixings rules facts =
       
    89   let
       
    90     val thy = ProofContext.theory_of ctxt;
       
    91     val cert = Thm.cterm_of thy;
       
    92 
       
    93     val ((insts, defs), defs_ctxt) = fold_map Induct.add_defs def_insts ctxt |>> split_list;
       
    94     val atomized_defs = map (map (Conv.fconv_rule Induct.atomize_cterm)) defs;
       
    95 
       
    96     val finish_rule =
       
    97       split_all_tuples
       
    98       #> rename_params_rule true
       
    99         (map (Name.clean o ProofContext.revert_skolem defs_ctxt o fst) avoiding);
       
   100 
       
   101     fun rule_cases ctxt r =
       
   102       let val r' = if simp then Induct.simplified_rule ctxt r else r
       
   103       in Rule_Cases.make_nested (Thm.prop_of r') (Induct.rulified_term r') end;
       
   104   in
       
   105     (fn i => fn st =>
       
   106       rules
       
   107       |> inst_mutual_rule ctxt insts avoiding
       
   108       |> Rule_Cases.consume (flat defs) facts
       
   109       |> Seq.maps (fn (((cases, concls), (more_consumes, more_facts)), rule) =>
       
   110         (PRECISE_CONJUNCTS (length concls) (ALLGOALS (fn j =>
       
   111           (CONJUNCTS (ALLGOALS
       
   112             let
       
   113               val adefs = nth_list atomized_defs (j - 1);
       
   114               val frees = fold (Term.add_frees o prop_of) adefs [];
       
   115               val xs = nth_list fixings (j - 1);
       
   116               val k = nth concls (j - 1) + more_consumes
       
   117             in
       
   118               Method.insert_tac (more_facts @ adefs) THEN'
       
   119                 (if simp then
       
   120                    Induct.rotate_tac k (length adefs) THEN'
       
   121                    Induct.fix_tac defs_ctxt k
       
   122                      (List.partition (member op = frees) xs |> op @)
       
   123                  else
       
   124                    Induct.fix_tac defs_ctxt k xs)
       
   125             end)
       
   126           THEN' Induct.inner_atomize_tac) j))
       
   127         THEN' Induct.atomize_tac) i st |> Seq.maps (fn st' =>
       
   128             Induct.guess_instance ctxt
       
   129               (finish_rule (Induct.internalize more_consumes rule)) i st'
       
   130             |> Seq.maps (fn rule' =>
       
   131               CASES (rule_cases ctxt rule' cases)
       
   132                 (Tactic.rtac (rename_params_rule false [] rule') i THEN
       
   133                   PRIMITIVE (singleton (ProofContext.export defs_ctxt ctxt))) st'))))
       
   134     THEN_ALL_NEW_CASES
       
   135       ((if simp then Induct.simplify_tac ctxt THEN' (TRY o Induct.trivial_tac)
       
   136         else K all_tac)
       
   137        THEN_ALL_NEW Induct.rulify_tac)
       
   138   end;
       
   139 
       
   140 
       
   141 (* concrete syntax *)
       
   142 
       
   143 local
       
   144 
       
   145 val avoidingN = "avoiding";
       
   146 val fixingN = "arbitrary";  (* to be consistent with induct; hopefully this changes again *)
       
   147 val ruleN = "rule";
       
   148 
       
   149 val inst = Scan.lift (Args.$$$ "_") >> K NONE ||
       
   150   Args.term >> (SOME o rpair false) ||
       
   151   Scan.lift (Args.$$$ "(") |-- (Args.term >> (SOME o rpair true)) --|
       
   152     Scan.lift (Args.$$$ ")");
       
   153 
       
   154 val def_inst =
       
   155   ((Scan.lift (Args.binding --| (Args.$$$ "\<equiv>" || Args.$$$ "==")) >> SOME)
       
   156       -- (Args.term >> rpair false)) >> SOME ||
       
   157     inst >> Option.map (pair NONE);
       
   158 
       
   159 val free = Args.context -- Args.term >> (fn (_, Free v) => v | (ctxt, t) =>
       
   160   error ("Bad free variable: " ^ Syntax.string_of_term ctxt t));
       
   161 
       
   162 fun unless_more_args scan = Scan.unless (Scan.lift
       
   163   ((Args.$$$ avoidingN || Args.$$$ fixingN || Args.$$$ ruleN) -- Args.colon)) scan;
       
   164 
       
   165 
       
   166 val avoiding = Scan.optional (Scan.lift (Args.$$$ avoidingN -- Args.colon) |--
       
   167   Scan.repeat (unless_more_args free)) [];
       
   168 
       
   169 val fixing = Scan.optional (Scan.lift (Args.$$$ fixingN -- Args.colon) |--
       
   170   Parse.and_list' (Scan.repeat (unless_more_args free))) [];
       
   171 
       
   172 val rule_spec = Scan.lift (Args.$$$ "rule" -- Args.colon) |-- Attrib.thms;
       
   173 
       
   174 in
       
   175 
       
   176 val nominal_induct_method =
       
   177   Args.mode Induct.no_simpN -- (Parse.and_list' (Scan.repeat (unless_more_args def_inst)) --
       
   178   avoiding -- fixing -- rule_spec) >>
       
   179   (fn (no_simp, (((x, y), z), w)) => fn ctxt =>
       
   180     RAW_METHOD_CASES (fn facts =>
       
   181       HEADGOAL (nominal_induct_tac ctxt (not no_simp) x y z w facts)));
       
   182 
       
   183 end
       
   184 
       
   185 end;