quotient_def.ML
changeset 277 37636f2b1c19
child 279 b2fd070c8833
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/quotient_def.ML	Wed Nov 04 11:59:15 2009 +0100
@@ -0,0 +1,185 @@
+
+signature QUOTIENT_DEF =
+sig
+  datatype flag = absF | repF
+  val get_fun: flag -> (typ * typ) list -> Proof.context -> typ -> term * (typ * typ)
+  val make_def: binding -> term -> typ -> mixfix -> Attrib.binding -> (typ * typ) list ->
+    Proof.context -> (term * thm) * local_theory
+
+  val quotdef: (binding * typ * mixfix) * (Attrib.binding * term) ->
+    local_theory -> (term * thm) * local_theory
+  val quotdef_cmd: (binding * string * mixfix) * (Attrib.binding * string) ->
+    local_theory -> local_theory
+end;
+
+structure Quotient_Def: QUOTIENT_DEF =
+struct
+
+fun define name mx attr rhs lthy =
+let
+  val ((rhs, (_ , thm)), lthy') =
+     LocalTheory.define Thm.internalK ((name, mx), (attr, rhs)) lthy
+in
+  ((rhs, thm), lthy')
+end
+
+fun lookup_qenv qenv qty =
+  (case (AList.lookup (op=) qenv qty) of
+    SOME rty => SOME (qty, rty)
+  | NONE => NONE)
+
+
+(* calculates the aggregate abs and rep functions for a given type; 
+   repF is for constants' arguments; absF is for constants;
+   function types need to be treated specially, since repF and absF
+   change *)
+
+datatype flag = absF | repF
+
+fun negF absF = repF
+  | negF repF = absF
+
+fun get_fun flag qenv lthy ty =
+let
+  
+  fun get_fun_aux s fs_tys =
+  let
+    val (fs, tys) = split_list fs_tys
+    val (otys, ntys) = split_list tys
+    val oty = Type (s, otys)
+    val nty = Type (s, ntys)
+    val ftys = map (op -->) tys
+  in
+   (case (maps_lookup (ProofContext.theory_of lthy) s) of
+      SOME info => (list_comb (Const (#mapfun info, ftys ---> (oty --> nty)), fs), (oty, nty))
+    | NONE      => error ("no map association for type " ^ s))
+  end
+
+  fun get_fun_fun fs_tys =
+  let
+    val (fs, tys) = split_list fs_tys
+    val ([oty1, oty2], [nty1, nty2]) = split_list tys
+    val oty = nty1 --> oty2
+    val nty = oty1 --> nty2
+    val ftys = map (op -->) tys
+  in
+    (list_comb (Const (@{const_name "fun_map"}, ftys ---> oty --> nty), fs), (oty, nty))
+  end
+
+  fun get_const flag (qty, rty) =
+  let 
+    val thy = ProofContext.theory_of lthy
+    val qty_name = Long_Name.base_name (fst (dest_Type qty))
+  in
+    case flag of
+      absF => (Const (Sign.full_bname thy ("ABS_" ^ qty_name), rty --> qty), (rty, qty))
+    | repF => (Const (Sign.full_bname thy ("REP_" ^ qty_name), qty --> rty), (qty, rty))
+  end
+
+  fun mk_identity ty = Abs ("", ty, Bound 0)
+
+in
+  if (AList.defined (op=) qenv ty)
+  then (get_const flag (the (lookup_qenv qenv ty)))
+  else (case ty of
+          TFree _ => (mk_identity ty, (ty, ty))
+        | Type (_, []) => (mk_identity ty, (ty, ty)) 
+        | Type ("fun" , [ty1, ty2]) => 
+                 get_fun_fun [get_fun (negF flag) qenv lthy ty1, get_fun flag qenv lthy ty2]
+        | Type (s, tys) => get_fun_aux s (map (get_fun flag qenv lthy) tys)
+        | _ => raise ERROR ("no type variables"))
+end
+
+fun make_def nconst_bname rhs qty mx attr qenv lthy =
+let
+  val (arg_tys, res_ty) = strip_type qty
+
+  val rep_fns = map (fst o get_fun repF qenv lthy) arg_tys
+  val abs_fn  = (fst o get_fun absF qenv lthy) res_ty
+
+  fun mk_fun_map t s = 
+        Const (@{const_name "fun_map"}, dummyT) $ t $ s
+
+  val absrep_fn = fold_rev mk_fun_map rep_fns abs_fn
+                  |> Syntax.check_term lthy 
+in
+  define nconst_bname mx attr (absrep_fn $ rhs) lthy
+end
+
+
+(* returns all subterms where two types differ *)
+fun diff (T, S) Ds =
+  case (T, S) of
+    (TVar v, TVar u) => if v = u then Ds else (T, S)::Ds 
+  | (TFree x, TFree y) => if x = y then Ds else (T, S)::Ds
+  | (Type (a, Ts), Type (b, Us)) => 
+      if a = b then diffs (Ts, Us) Ds else (T, S)::Ds
+  | _ => (T, S)::Ds
+and diffs (T::Ts, U::Us) Ds = diffs (Ts, Us) (diff (T, U) Ds)
+  | diffs ([], []) Ds = Ds
+  | diffs _ _ = error "Unequal length of type arguments"
+
+
+fun error_msg lthy (qty, rty) =
+let 
+  val qtystr = quote (Syntax.string_of_typ lthy qty)
+  val rtystr = quote (Syntax.string_of_typ lthy rty)
+in
+  error (implode ["Quotient type ", qtystr, " does not match with ", rtystr])
+end
+
+fun sanity_chk lthy qenv =
+let
+   val qenv' = Quotient_Info.mk_qenv lthy
+   val thy = ProofContext.theory_of lthy
+
+   fun is_inst thy (qty, rty) (qty', rty') =
+   if Sign.typ_instance thy (qty, qty')
+   then let
+          val inst = Sign.typ_match thy (qty', qty) Vartab.empty
+        in
+          rty = Envir.subst_type inst rty'
+        end
+   else false
+
+   fun chk_inst (qty, rty) = 
+     if exists (is_inst thy (qty, rty)) qenv' then true
+     else error_msg lthy (qty, rty)
+in
+  forall chk_inst qenv
+end
+
+
+fun quotdef ((bind, qty, mx), (attr, prop)) lthy =
+let   
+  val (_, prop') = PrimitiveDefs.dest_def lthy (K true) (K false) (K false) prop
+  val (_, rhs) = PrimitiveDefs.abs_def prop'
+
+  val rty = fastype_of rhs
+  val qenv = distinct (op=) (diff (qty, rty) []) 
+in
+  sanity_chk lthy qenv;
+  make_def bind rhs qty mx attr qenv lthy 
+end
+
+
+val quotdef_parser =
+  (OuterParse.binding --
+    (OuterParse.$$$ "::" |-- OuterParse.!!! (OuterParse.typ -- 
+      OuterParse.opt_mixfix' --| OuterParse.where_)) >> OuterParse.triple2) -- 
+       (SpecParse.opt_thm_name ":" -- OuterParse.prop)
+
+fun quotdef_cmd ((bind, qtystr, mx), (attr, propstr)) lthy = 
+let
+  val qty  = (Syntax.check_typ lthy o Syntax.parse_typ lthy) qtystr
+  val prop = (Syntax.check_prop lthy o Syntax.parse_prop lthy) propstr
+in
+  quotdef ((bind, qty, mx), (attr, prop)) lthy |> snd
+end
+
+val _ = OuterSyntax.local_theory "quotient_def" "lifted definition of constants"
+  OuterKeyword.thy_decl (quotdef_parser >> quotdef_cmd)
+
+end; (* structure *)
+
+open Quotient_Def;
\ No newline at end of file