signature QUOTIENT_DEF =
sig
datatype flag = absF | repF
val get_fun: flag -> (typ * typ) list -> Proof.context -> typ -> term
val make_def: binding -> term -> typ -> mixfix -> Attrib.binding -> (typ * typ) list ->
Proof.context -> (term * thm) * local_theory
val quotdef: (binding * typ * mixfix) * (Attrib.binding * term) ->
local_theory -> (term * thm) * local_theory
val quotdef_cmd: (binding * string * mixfix) * (Attrib.binding * string) ->
local_theory -> local_theory
end;
structure Quotient_Def: QUOTIENT_DEF =
struct
(* wrapper for define *)
fun define name mx attr rhs lthy =
let
val ((rhs, (_ , thm)), lthy') =
LocalTheory.define Thm.internalK ((name, mx), (attr, rhs)) lthy
in
((rhs, thm), lthy')
end
(* calculates the aggregate abs and rep functions for a given type;
repF is for constants' arguments; absF is for constants;
function types need to be treated specially, since repF and absF
change *)
datatype flag = absF | repF
fun negF absF = repF
| negF repF = absF
fun get_fun flag qenv lthy ty =
let
fun get_fun_aux s fs =
(case (maps_lookup (ProofContext.theory_of lthy) s) of
SOME info => list_comb (Const (#mapfun info, dummyT), fs)
| NONE => error ("no map association for type " ^ s))
fun get_const flag qty =
let
val thy = ProofContext.theory_of lthy
val qty_name = Long_Name.base_name (fst (dest_Type qty))
in
case flag of
absF => Const (Sign.full_bname thy ("ABS_" ^ qty_name), dummyT)
| repF => Const (Sign.full_bname thy ("REP_" ^ qty_name), dummyT)
end
fun mk_identity ty = Abs ("", ty, Bound 0)
in
if (AList.defined (op=) qenv ty)
then (get_const flag ty)
else (case ty of
TFree _ => mk_identity ty
| Type (_, []) => mk_identity ty
| Type ("fun" , [ty1, ty2]) =>
let
val fs_ty1 = get_fun (negF flag) qenv lthy ty1
val fs_ty2 = get_fun flag qenv lthy ty2
in
get_fun_aux "fun" [fs_ty1, fs_ty2]
end
| Type (s, tys) => get_fun_aux s (map (get_fun flag qenv lthy) tys)
| _ => error ("no type variables allowed"))
end
fun make_def nconst_bname rhs qty mx attr qenv lthy =
let
val (arg_tys, res_ty) = strip_type qty
val rep_fns = map (get_fun repF qenv lthy) arg_tys
val abs_fn = (get_fun absF qenv lthy) res_ty
fun mk_fun_map t s =
Const (@{const_name "fun_map"}, dummyT) $ t $ s
val absrep_trm = (fold_rev mk_fun_map rep_fns abs_fn $ rhs)
|> Syntax.check_term lthy
in
define nconst_bname mx attr absrep_trm lthy
end
(* returns all subterms where two types differ *)
fun diff (T, S) Ds =
case (T, S) of
(TVar v, TVar u) => if v = u then Ds else (T, S)::Ds
| (TFree x, TFree y) => if x = y then Ds else (T, S)::Ds
| (Type (a, Ts), Type (b, Us)) =>
if a = b then diffs (Ts, Us) Ds else (T, S)::Ds
| _ => (T, S)::Ds
and diffs (T::Ts, U::Us) Ds = diffs (Ts, Us) (diff (T, U) Ds)
| diffs ([], []) Ds = Ds
| diffs _ _ = error "Unequal length of type arguments"
(* sanity check that the calculated quotient environment
matches with the stored quotient environment. *)
fun error_msg lthy (qty, rty) =
let
val qtystr = quote (Syntax.string_of_typ lthy qty)
val rtystr = quote (Syntax.string_of_typ lthy rty)
in
error (implode ["Quotient type ", qtystr, " does not match with ", rtystr])
end
fun sanity_chk lthy qenv =
let
val global_qenv = Quotient_Info.mk_qenv lthy
val thy = ProofContext.theory_of lthy
fun is_inst (qty, rty) (qty', rty') =
if Sign.typ_instance thy (qty, qty')
then let
val inst = Sign.typ_match thy (qty', qty) Vartab.empty
in
rty = Envir.subst_type inst rty'
end
else false
fun chk_inst (qty, rty) =
if exists (is_inst (qty, rty)) global_qenv
then true
else error_msg lthy (qty, rty)
in
map chk_inst qenv
end
fun quotdef ((bind, qty, mx), (attr, prop)) lthy =
let
val (_, prop') = PrimitiveDefs.dest_def lthy (K true) (K false) (K false) prop
val (_, rhs) = PrimitiveDefs.abs_def prop'
val rty = fastype_of rhs
val qenv = distinct (op=) (diff (qty, rty) [])
in
sanity_chk lthy qenv;
make_def bind rhs qty mx attr qenv lthy
end
fun quotdef_cmd ((bind, qtystr, mx), (attr, propstr)) lthy =
let
val qty = Syntax.read_typ lthy qtystr
val prop = Syntax.read_prop lthy propstr
in
quotdef ((bind, qty, mx), (attr, prop)) lthy |> snd
end
val quotdef_parser =
(OuterParse.binding --
(OuterParse.$$$ "::" |-- OuterParse.!!! (OuterParse.typ --
OuterParse.opt_mixfix' --| OuterParse.where_)) >> OuterParse.triple2) --
(SpecParse.opt_thm_name ":" -- OuterParse.prop)
val _ = OuterSyntax.local_theory "quotient_def" "lifted definition of constants"
OuterKeyword.thy_decl (quotdef_parser >> quotdef_cmd)
end; (* structure *)
open Quotient_Def;