(* @chunk SIMPLE_INDUCTIVE_PACKAGE *)
signature SIMPLE_INDUCTIVE_PACKAGE =
sig
  val add_inductive_i:
    ((Binding.binding * typ) * mixfix) list ->  (*{predicates}*)
    (Binding.binding * typ) list ->  (*{parameters}*)
    ((Binding.binding * Attrib.src list) * term) list ->  (*{rules}*)
    local_theory -> local_theory
  val add_inductive:
    (Binding.binding * string option * mixfix) list ->  (*{predicates}*)
    (Binding.binding * string option * mixfix) list ->  (*{parameters}*)
    (Attrib.binding * string) list ->  (*{rules}*)
    local_theory -> local_theory
end;
(* @end *)

structure SimpleInductivePackage: SIMPLE_INDUCTIVE_PACKAGE =
struct

fun mk_all x P = HOLogic.all_const (fastype_of x) $ lambda x P 

(* @chunk definitions *) 
fun define_aux s ((binding, syn), (attr, trm)) lthy =
let 
  val ((_, (_ , thm)), lthy) = LocalTheory.define s ((binding, syn), (attr, trm)) lthy
in 
  (thm, lthy) 
end

fun DEFINITION params' rules preds preds' Tss lthy =
let
  val rules' = map (ObjectLogic.atomize_term (ProofContext.theory_of lthy)) rules
in
  fold_map (fn ((((R, _), syn), pred), Ts) =>
    let 
      val zs = map Free (Variable.variant_frees lthy rules' (map (pair "z") Ts))
        
      val t0 = list_comb (pred, zs);
      val t1 = fold_rev (curry HOLogic.mk_imp) rules' t0;
      val t2 = fold_rev mk_all preds' t1;      
      val t3 = fold_rev lambda (params' @ zs) t2;
    in
      define_aux Thm.internalK ((R, syn), (Attrib.empty_binding, t3))
    end) (preds ~~ preds' ~~ Tss) lthy
end
(* @end *)

fun inst_spec ct = 
  Drule.instantiate' [SOME (ctyp_of_term ct)] [NONE, SOME ct] @{thm spec};

val all_elims = fold (fn ct => fn th => th RS inst_spec ct);
val imp_elims = fold (fn th => fn th' => [th', th] MRS @{thm mp});


(* @chunk induction_rules *)
fun INDUCTION rules preds' Tss defs lthy1 lthy2 =
let
    val (Pnames, lthy3) = Variable.variant_fixes (replicate (length preds') "P") lthy2;
    val Ps = map (fn (s, Ts) => Free (s, Ts ---> HOLogic.boolT)) (Pnames ~~ Tss);
    val cPs = map (cterm_of (ProofContext.theory_of lthy3)) Ps;
    val rules'' = map (subst_free (preds' ~~ Ps)) rules;

    fun prove_indrule ((R, P), Ts)  =
      let
        val (znames, lthy4) = Variable.variant_fixes (replicate (length Ts) "z") lthy3;
        val zs = map Free (znames ~~ Ts)

        val prem = HOLogic.mk_Trueprop (list_comb (R, zs))
        val goal = Logic.list_implies (rules'', HOLogic.mk_Trueprop (list_comb (P, zs)))
      in
        Goal.prove lthy4 [] [prem] goal
          (fn {prems, ...} => EVERY1
             ([ObjectLogic.full_atomize_tac,
               cut_facts_tac prems,
               K (rewrite_goals_tac defs)] @
              map (fn ct => dtac (inst_spec ct)) cPs @
              [assume_tac])) |>
           singleton (ProofContext.export lthy4 lthy1)
      end;
in
  map prove_indrule (preds' ~~ Ps ~~ Tss)
end
(* @end *)

(* @chunk intro_rules *) 
fun INTROS rules preds' defs lthy1 lthy2 = 
let
  fun prove_intro (i, r) =
      Goal.prove lthy2 [] [] r
        (fn {prems, context = ctxt} => EVERY
           [ObjectLogic.rulify_tac 1,
            rewrite_goals_tac defs,
            REPEAT (resolve_tac [@{thm allI},@{thm impI}] 1),
            SUBPROOF (fn {params, prems, context = ctxt', ...} =>
              let
                val (prems1, prems2) = chop (length prems - length rules) prems;
                val (params1, params2) = chop (length params - length preds') params;
              in
                rtac (ObjectLogic.rulify (all_elims params1 (nth prems2 i))) 1 
                THEN
                EVERY1 (map (fn prem =>
                  SUBPROOF (fn {prems = prems', concl, ...} =>
                    let
          
                      val prem' = prems' MRS prem;
                      val prem'' = case prop_of prem' of
                          _ $ (Const (@{const_name All}, _) $ _) =>
                            prem' |> all_elims params2 
                                  |> imp_elims prems2
                        | _ => prem';
                    in rtac prem'' 1 end) ctxt') prems1)
              end) ctxt 1]) |>
      singleton (ProofContext.export lthy2 lthy1)
in
  map_index prove_intro rules
end
(* @end *)

(* @chunk add_inductive_i *)
fun add_inductive_i preds params specs lthy =
  let
    val params' = map (fn (p, T) => Free (Binding.base_name p, T)) params;
    val preds' = map (fn ((R, T), _) => list_comb (Free (Binding.base_name R, T), params')) preds;
    val Tss = map (binder_types o fastype_of) preds';   
    val (ass,rules) = split_list specs;    

    val (defs, lthy1) = DEFINITION params' rules preds preds' Tss lthy
    val (_, lthy2) = Variable.add_fixes (map (Binding.base_name o fst) params) lthy1;
      
    val inducts = INDUCTION rules preds' Tss defs lthy1 lthy2

    val intros = INTROS rules preds' defs lthy1 lthy2

    val mut_name = space_implode "_" (map (Binding.base_name o fst o fst) preds);
    val case_names = map (Binding.base_name o fst o fst) specs
  in
    lthy1 
    |> LocalTheory.notes Thm.theoremK (map (fn (((a, atts), _), th) =>
        ((Binding.qualify mut_name a, atts), [([th], [])])) (specs ~~ intros)) 
    |-> (fn intross => LocalTheory.note Thm.theoremK
         ((Binding.qualify mut_name (Binding.name "intros"), []), maps snd intross)) 
    |>> snd 
    ||>> (LocalTheory.notes Thm.theoremK (map (fn (((R, _), _), th) =>
         ((Binding.qualify (Binding.base_name R) (Binding.name "induct"),
          [Attrib.internal (K (RuleCases.case_names case_names)),
           Attrib.internal (K (RuleCases.consumes 1)),
           Attrib.internal (K (Induct.induct_pred ""))]), [([th], [])]))
          (preds ~~ inducts)) #>> maps snd) 
    |> snd
  end
(* @end *)

(* @chunk add_inductive *)
fun read_specification' vars specs lthy =
let 
  val specs' = map (fn (a, s) => [(a, [s])]) specs
  val ((varst, specst), _) = Specification.read_specification vars specs' lthy
  val specst' = map (apsnd the_single) specst
in   
  (varst, specst')
end 

fun add_inductive preds params specs lthy =
let
  val (vars, specs') = read_specification' (preds @ params) specs lthy;
  val (preds', params') = chop (length preds) vars;
  val params'' = map fst params'
in
  add_inductive_i preds' params'' specs' lthy
end;
(* @end *)

(* @chunk syntax *)
val parser = 
   OuterParse.opt_target --
   OuterParse.fixes -- 
   OuterParse.for_fixes --
   Scan.optional 
       (OuterParse.$$$ "where" |--
          OuterParse.!!! 
            (OuterParse.enum1 "|" 
               (SpecParse.opt_thm_name ":" -- OuterParse.prop))) []

val ind_decl =
  parser >>
    (fn (((loc, preds), params), specs) =>
      Toplevel.local_theory loc (add_inductive preds params specs))

val _ = OuterSyntax.command "simple_inductive" "define inductive predicates"
  OuterKeyword.thy_decl ind_decl;
(* @end *)

end;
