signature QUOTIENT_TYPE =
sig
  exception LIFT_MATCH of string
  val quotient_type: ((binding * mixfix) * (typ * term)) list -> Proof.context -> Proof.state
  val quotient_type_cmd: (((bstring * mixfix) * string) * string) list -> Proof.context -> Proof.state
end;
structure Quotient_Type: QUOTIENT_TYPE =
struct
open Quotient_Info;
exception LIFT_MATCH of string
(* wrappers for define, note, Attrib.internal and theorem_i *)
fun define (name, mx, rhs) lthy =
let
  val ((rhs, (_ , thm)), lthy') =
     Local_Theory.define ((name, mx), (Attrib.empty_binding, rhs)) lthy
in
  ((rhs, thm), lthy')
end
fun note (name, thm, attrs) lthy =
let
  val ((_,[thm']), lthy') = Local_Theory.note ((name, attrs), [thm]) lthy
in
  (thm', lthy')
end
fun intern_attr at = Attrib.internal (K at)
fun theorem after_qed goals ctxt =
let
  val goals' = map (rpair []) goals
  fun after_qed' thms = after_qed (the_single thms)
in 
  Proof.theorem_i NONE after_qed' [goals'] ctxt
end
(* definition of quotient types *)
(********************************)
val mem_def1 = @{lemma "y : S ==> S y" by (simp add: mem_def)}
val mem_def2 = @{lemma "S y ==> y : S" by (simp add: mem_def)}
(* constructs the term lambda (c::rty => bool). EX (x::rty). c = rel x *)
fun typedef_term rel rty lthy =
let
  val [x, c] = [("x", rty), ("c", HOLogic.mk_setT rty)]
               |> Variable.variant_frees lthy [rel]
               |> map Free
in
  lambda c
    (HOLogic.exists_const rty $
       lambda x (HOLogic.mk_eq (c, (rel $ x))))
end
(* makes the new type definitions and proves non-emptyness*)
fun typedef_make (qty_name, mx, rel, rty) lthy =
let
  val typedef_tac =
     EVERY1 [rtac @{thm exI},
             rtac mem_def2, 
             rtac @{thm exI},
             rtac @{thm refl}]
  val tfrees = map fst (Term.add_tfreesT rty [])
in
  Local_Theory.theory_result
    (Typedef.add_typedef false NONE
       (qty_name, tfrees, mx)
         (typedef_term rel rty lthy)
           NONE typedef_tac) lthy
end
(* tactic to prove the Quot_Type theorem for the new type *)
fun typedef_quot_type_tac equiv_thm (typedef_info: Typedef.info) =
let
  val rep_thm = (#Rep typedef_info) RS mem_def1
  val rep_inv = #Rep_inverse typedef_info
  val abs_inv = mem_def2 RS (#Abs_inverse typedef_info)
  val rep_inj = #Rep_inject typedef_info
in
  (rtac @{thm Quot_Type.intro} THEN' 
   RANGE [rtac equiv_thm,
          rtac rep_thm,
          rtac rep_inv,
          EVERY' [rtac abs_inv, rtac @{thm exI}, rtac @{thm refl}],
          rtac rep_inj]) 1
end
(* proves the Quot_Type theorem *)
fun typedef_quot_type_thm (rel, abs, rep, equiv_thm, typedef_info) lthy =
let
  val quot_type_const = Const (@{const_name "Quot_Type"}, dummyT)
  val goal = HOLogic.mk_Trueprop (quot_type_const $ rel $ abs $ rep)
             |> Syntax.check_term lthy
in
  Goal.prove lthy [] [] goal
    (K (typedef_quot_type_tac equiv_thm typedef_info))
end
(* proves the quotient theorem *)
fun typedef_quotient_thm (rel, abs, rep, abs_def, rep_def, quot_type_thm) lthy =
let
  val quotient_const = Const (@{const_name "Quotient"}, dummyT)
  val goal = HOLogic.mk_Trueprop (quotient_const $ rel $ abs $ rep)
             |> Syntax.check_term lthy
  val typedef_quotient_thm_tac =
    EVERY1 [K (rewrite_goals_tac [abs_def, rep_def]),
            rtac @{thm Quot_Type.Quotient},
            rtac quot_type_thm]
in
  Goal.prove lthy [] [] goal
    (K typedef_quotient_thm_tac)
end
(* main function for constructing a quotient type *)
fun mk_typedef_main (((qty_name, mx), (rty, rel)), equiv_thm) lthy =
let
  (* generates the typedef *)
  val ((qty_full_name, typedef_info), lthy1) = typedef_make (qty_name, mx, rel, rty) lthy
  (* abs and rep functions from the typedef *)
  val Abs_ty = #abs_type typedef_info
  val Rep_ty = #rep_type typedef_info
  val Abs_name = #Abs_name typedef_info
  val Rep_name = #Rep_name typedef_info
  val Abs_const = Const (Abs_name, Rep_ty --> Abs_ty)
  val Rep_const = Const (Rep_name, Abs_ty --> Rep_ty)
  (* more abstract abs and rep definitions *)
  val abs_const = Const (@{const_name "Quot_Type.abs"}, dummyT )
  val rep_const = Const (@{const_name "Quot_Type.rep"}, dummyT )
  val abs_trm = Syntax.check_term lthy1 (abs_const $ rel $ Abs_const)
  val rep_trm = Syntax.check_term lthy1 (rep_const $ Rep_const)
  val abs_name = Binding.prefix_name "abs_" qty_name
  val rep_name = Binding.prefix_name "rep_" qty_name
  val ((abs, abs_def), lthy2) = define (abs_name, NoSyn, abs_trm) lthy1
  val ((rep, rep_def), lthy3) = define (rep_name, NoSyn, rep_trm) lthy2
  (* quot_type theorem - needed below *)
  val quot_thm = typedef_quot_type_thm (rel, Abs_const, Rep_const, equiv_thm, typedef_info) lthy3
  (* quotient theorem *)  
  val quotient_thm = typedef_quotient_thm (rel, abs, rep, abs_def, rep_def, quot_thm) lthy3
  val quotient_thm_name = Binding.prefix_name "Quotient_" qty_name
  (* name equivalence theorem *)
  val equiv_thm_name = Binding.suffix_name "_equivp" qty_name
  (* storing the quot-info *)
  val lthy4 = quotdata_update qty_full_name 
               (Logic.varifyT Abs_ty, Logic.varifyT rty, map_types Logic.varifyT rel, equiv_thm) lthy3  
  (* FIXME: VarifyT should not be used - at the moment it allows matching against the types. *)
  (* FIXME: The relation can be any term, that later maybe needs to be given *)
  (* FIXME: a different type (in regularize_trm); how should this be done?   *)
in
  lthy4
  |> note (quotient_thm_name, quotient_thm, [intern_attr quotient_rules_add])
  ||>> note (equiv_thm_name, equiv_thm, [intern_attr equiv_rules_add])
end
(* interface and syntax setup *)
(* the ML-interface takes a list of 4-tuples consisting of  *)
(*                                                          *)
(* - the name of the quotient type                          *)
(* - its mixfix annotation                                  *)
(* - the type to be quotient                                *)
(* - the relation according to which the type is quotient   *)
fun quotient_type quot_list lthy = 
let
  fun mk_goal (rty, rel) =
  let
    val equivp_ty = ([rty, rty] ---> @{typ bool}) --> @{typ bool}
  in 
    HOLogic.mk_Trueprop (Const (@{const_name equivp}, equivp_ty) $ rel)
  end
  val goals = map (mk_goal o snd) quot_list
              
  fun after_qed thms lthy =
    fold_map mk_typedef_main (quot_list ~~ thms) lthy |> snd
in
  theorem after_qed goals lthy
end
           
fun quotient_type_cmd spec lthy = 
let
  fun parse_spec (((qty_str, mx), rty_str), rel_str) =
  let
    val qty_name = Binding.name qty_str
    val rty = Syntax.read_typ lthy rty_str
    val rel = Syntax.read_term lthy rel_str 
  in
    ((qty_name, mx), (rty, rel))
  end
in
  quotient_type (map parse_spec spec) lthy
end
val quotspec_parser = 
    OuterParse.and_list1
     (OuterParse.short_ident -- OuterParse.opt_infix -- 
       (OuterParse.$$$ "=" |-- OuterParse.typ) -- 
         (OuterParse.$$$ "/" |-- OuterParse.term))
val _ = OuterKeyword.keyword "/"
val _ = 
    OuterSyntax.local_theory_to_proof "quotient_type" 
      "quotient type definitions (require equivalence proofs)"
         OuterKeyword.thy_goal (quotspec_parser >> quotient_type_cmd)
end; (* structure *)