quotient_def.ML
changeset 321 f46dc0ca08c3
parent 319 0ae9d9e66cb7
child 324 bdbb52979790
--- a/quotient_def.ML	Fri Nov 20 13:03:01 2009 +0100
+++ b/quotient_def.ML	Sat Nov 21 02:49:39 2009 +0100
@@ -2,7 +2,7 @@
 signature QUOTIENT_DEF =
 sig
   datatype flag = absF | repF
-  val get_fun: flag -> (typ * typ) list -> Proof.context -> typ -> term
+  val get_fun: flag -> Proof.context -> typ * typ -> term
   val make_def: binding -> typ -> mixfix -> Attrib.binding -> term ->
     Proof.context -> (term * thm) * local_theory
 
@@ -10,7 +10,6 @@
     local_theory -> (term * thm) * local_theory
   val quotdef_cmd: (binding * string * mixfix) * (Attrib.binding * string) ->
     local_theory -> local_theory
-  val diff: (typ * typ) -> (typ * typ) list -> (typ * typ) list
 end;
 
 structure Quotient_Def: QUOTIENT_DEF =
@@ -25,7 +24,6 @@
   ((rhs, thm), lthy')
 end
 
-
 (* calculates the aggregate abs and rep functions for a given type; 
    repF is for constants' arguments; absF is for constants;
    function types need to be treated specially, since repF and absF
@@ -36,99 +34,65 @@
 fun negF absF = repF
   | negF repF = absF
 
-fun get_fun flag qenv lthy ty =
-let
-  
-  fun get_fun_aux s fs =
-   (case (maps_lookup (ProofContext.theory_of lthy) s) of
-      SOME info => list_comb (Const (#mapfun info, dummyT), fs)
-    | NONE      => error ("no map association for type " ^ s))
-
-  fun get_const flag qty =
-  let 
-    val thy = ProofContext.theory_of lthy
-    val qty_name = Long_Name.base_name (fst (dest_Type qty))
-  in
-    case flag of
-      absF => Const (Sign.full_bname thy ("ABS_" ^ qty_name), dummyT)
-    | repF => Const (Sign.full_bname thy ("REP_" ^ qty_name), dummyT)
-  end
+fun mk_identity ty = Const (@{const_name "id"}, ty --> ty)
 
-  fun mk_identity ty = Abs ("", ty, Bound 0)
-
+fun ty_lift_error lthy rty qty =
+let
+  val rty_str = quote (Syntax.string_of_typ lthy rty) 
+  val qty_str = quote (Syntax.string_of_typ lthy qty)
+  val msg = ["quotient type", qty_str, "and lifted type", rty_str, "do not match."]
 in
-  if (AList.defined (op=) qenv ty)
-  then (get_const flag ty)
-  else (case ty of
-          TFree _ => mk_identity ty
-        | Type (_, []) => mk_identity ty 
-        | Type ("fun" , [ty1, ty2]) => 
-            let
-              val fs_ty1 = get_fun (negF flag) qenv lthy ty1
-              val fs_ty2 = get_fun flag qenv lthy ty2
-            in  
-              get_fun_aux "fun" [fs_ty1, fs_ty2]
-            end 
-        | Type (s, tys) => get_fun_aux s (map (get_fun flag qenv lthy) tys)
-        | _ => error ("no type variables allowed"))
+  raise LIFT_MATCH (space_implode " " msg)
 end
 
-(* returns all subterms where two types differ *)
-fun diff (T, S) Ds =
-  case (T, S) of
-    (TVar v, TVar u) => if v = u then Ds else (T, S)::Ds 
-  | (TFree x, TFree y) => if x = y then Ds else (T, S)::Ds
-  | (Type (a, Ts), Type (b, Us)) => 
-      if a = b then diffs (Ts, Us) Ds else (T, S)::Ds
-  | _ => (T, S)::Ds
-and diffs (T::Ts, U::Us) Ds = diffs (Ts, Us) (diff (T, U) Ds)
-  | diffs ([], []) Ds = Ds
-  | diffs _ _ = error "Unequal length of type arguments"
+fun get_fun_aux lthy s fs =
+  case (maps_lookup (ProofContext.theory_of lthy) s) of
+    SOME info => list_comb (Const (#mapfun info, dummyT), fs)
+  | NONE      => raise LIFT_MATCH ("no map association for type " ^ s)
 
-
-(* sanity check that the calculated quotient environment
-   matches with the stored quotient environment. *)
-fun sanity_chk qenv lthy =
-let
-  val global_qenv = Quotient_Info.mk_qenv lthy
+fun get_const flag lthy _ qty =
+(* FIXME: check here that _ and qty are related *)
+let 
   val thy = ProofContext.theory_of lthy
+  val qty_name = Long_Name.base_name (fst (dest_Type qty))
+in
+  case flag of
+    absF => Const (Sign.full_bname thy ("ABS_" ^ qty_name), dummyT)
+  | repF => Const (Sign.full_bname thy ("REP_" ^ qty_name), dummyT)
+end
 
-  fun error_msg lthy (qty, rty) =
-  let 
-    val qtystr = quote (Syntax.string_of_typ lthy qty)
-    val rtystr = quote (Syntax.string_of_typ lthy rty)
-  in
-    error (implode ["Quotient type ", qtystr, " does not match with ", rtystr])
-  end
-
-  fun is_inst (qty, rty) (qty', rty') =
-    if Sign.typ_instance thy (qty, qty')
-    then let
-           val inst = Sign.typ_match thy (qty', qty) Vartab.empty
-         in
-           rty = Envir.subst_type inst rty'       
-         end
-    else false
-
-  fun chk_inst (qty, rty) = 
-    if exists (is_inst (qty, rty)) global_qenv 
-    then true
-    else error_msg lthy (qty, rty)
-in
-  map chk_inst qenv
-end
+fun get_fun flag lthy (rty, qty) =
+  case (rty, qty) of 
+    (Type ("fun", [ty1, ty2]), Type ("fun", [ty1', ty2'])) =>
+     let
+       val fs_ty1 = get_fun (negF flag) lthy (ty1, ty1')
+       val fs_ty2 = get_fun flag lthy (ty2, ty2')
+     in  
+       get_fun_aux lthy "fun" [fs_ty1, fs_ty2]
+     end 
+  | (Type (s, []), Type (s', [])) =>
+     if s = s'
+     then mk_identity qty 
+     else get_const flag lthy rty qty
+  | (Type (s, tys), Type (s', tys')) =>
+     if s = s'
+     then get_fun_aux lthy s' (map (get_fun flag lthy) (tys ~~ tys'))
+     else get_const flag lthy rty qty
+  | (TFree x, TFree x') =>
+     if x = x'
+     then mk_identity qty 
+     else ty_lift_error lthy rty qty
+  | (TVar _, TVar _) => raise LIFT_MATCH "no type variables allowed"
+  | _ => ty_lift_error lthy rty qty
 
 fun make_def nconst_bname qty mx attr rhs lthy =
 let
-  val (arg_tys, res_ty) = strip_type qty
-
   val rty = fastype_of rhs
-  val qenv = distinct (op=) (diff (qty, rty) [])   
-
-  val _ = sanity_chk qenv lthy
-
-  val rep_fns = map (get_fun repF qenv lthy) arg_tys
-  val abs_fn  = (get_fun absF qenv lthy) res_ty
+  val (arg_rtys, res_rty) = strip_type rty
+  val (arg_qtys, res_qty) = strip_type qty
+  
+  val rep_fns = map (get_fun repF lthy) (arg_rtys ~~ arg_qtys)
+  val abs_fn  = get_fun absF lthy (res_rty, res_qty)
 
   fun mk_fun_map t s =  
         Const (@{const_name "fun_map"}, dummyT) $ t $ s
@@ -139,10 +103,9 @@
   val ((trm, thm), lthy') = define nconst_bname mx attr absrep_trm lthy
 
   val nconst_str = Binding.name_of nconst_bname
-  val qcinfo = {qconst = trm, rconst = rhs}
+  fun qcinfo phi = qconsts_transfer phi {qconst = trm, rconst = rhs}
   val lthy'' = Local_Theory.declaration true
-                (fn phi => qconsts_update_generic nconst_str 
-                             (qconsts_transfer phi qcinfo)) lthy'
+                 (fn phi => qconsts_update_gen nconst_str (qcinfo phi)) lthy'
 in
   ((trm, thm), lthy'')
 end
@@ -186,3 +149,5 @@
 end; (* structure *)
 
 open Quotient_Def;
+
+