Quot/quotient_term.ML
changeset 790 3a48ffcf0f9a
parent 786 d6407afb913c
child 791 fb4bfbb1a291
--- a/Quot/quotient_term.ML	Fri Dec 25 00:58:06 2009 +0100
+++ b/Quot/quotient_term.ML	Sat Dec 26 07:15:30 2009 +0100
@@ -34,27 +34,59 @@
 fun negF absF = repF
   | negF repF = absF
 
+val mk_id = Const (@{const_name "id"}, dummyT)
 fun mk_identity ty = Const (@{const_name "id"}, ty --> ty)
 
+fun mk_Free (TVar ((x, i), _)) = Free (unprefix "'" x ^ string_of_int i, dummyT)
+
 fun mk_compose flag (trm1, trm2) = 
   case flag of
     absF => Const (@{const_name "comp"}, dummyT) $ trm1 $ trm2
   | repF => Const (@{const_name "comp"}, dummyT) $ trm2 $ trm1
 
-fun get_map ctxt ty_str =
+fun get_mapfun ctxt s =
 let
   val thy = ProofContext.theory_of ctxt
-  val exc = LIFT_MATCH (space_implode " " ["absrep_fun: no map for type", quote ty_str, "."])
-  val mapfun = #mapfun (maps_lookup thy ty_str) handle NotFound => raise exc
+  val exc = LIFT_MATCH ("No map function for type " ^ (quote s) ^ " found.")
+  val mapfun = #mapfun (maps_lookup thy s) handle NotFound => raise exc
 in
   Const (mapfun, dummyT)
 end
 
-fun get_absrep_const flag ctxt _ qty =
-(* FIXME: check here that the type-constructors of _ and qty are related *)
+fun mk_mapfun ctxt vs ty =
+let
+  val vs' = map (mk_Free) vs
+
+  fun mk_mapfun_aux ty =
+    case ty of
+      TVar _ => mk_Free ty
+    | Type (_, []) => mk_id
+    | Type (s, tys) => list_comb (get_mapfun ctxt s, map mk_mapfun_aux tys)
+    | _ => raise LIFT_MATCH ("mk_mapfun_aux (default)")
+in
+  fold_rev Term.lambda vs' (mk_mapfun_aux ty)
+end
+
+fun get_rty_qty ctxt s =
 let
   val thy = ProofContext.theory_of ctxt
-  val qty_name = Long_Name.base_name (fst (dest_Type qty))
+  val exc = LIFT_MATCH ("No quotient type " ^ (quote s) ^ " found.")
+  val qdata = (quotdata_lookup thy s) handle NotFound => raise exc
+in
+  (#rtyp qdata, #qtyp qdata)
+end
+
+fun double_lookup rtyenv qtyenv v =
+let
+  val v' = fst (dest_TVar v)
+in
+  (snd (the (Vartab.lookup rtyenv v')), snd (the (Vartab.lookup qtyenv v')))
+end
+
+fun absrep_const flag ctxt qty_str =
+let
+  val thy = ProofContext.theory_of ctxt
+  val qty_name = Long_Name.base_name qty_str
 in
   case flag of
     absF => Const (Sign.full_bname thy ("abs_" ^ qty_name), dummyT)
@@ -62,39 +94,62 @@
 end
 
 fun absrep_fun flag ctxt (rty, qty) =
-  if rty = qty then mk_identity qty else
-  case (rty, qty) of
-    (Type ("fun", [ty1, ty2]), Type ("fun", [ty1', ty2'])) =>
-      let
-        val arg1 = absrep_fun (negF flag) ctxt (ty1, ty1')
-        val arg2 = absrep_fun flag ctxt (ty2, ty2')
-      in
-        list_comb (get_map ctxt "fun", [arg1, arg2])
-      end
-  | (Type (s, _), Type (s', [])) =>
-      if s = s'
-      then mk_identity qty
-      else get_absrep_const flag ctxt rty qty
-  | (Type (s, tys), Type (s', tys')) =>
-      let
-        val args = map (absrep_fun flag ctxt) (tys ~~ tys')
-        val result = list_comb (get_map ctxt s, args)
-      in
+  if rty = qty  
+  then mk_identity qty
+  else
+    case (rty, qty) of
+      (Type ("fun", [ty1, ty2]), Type ("fun", [ty1', ty2'])) =>
+        let
+          val arg1 = absrep_fun (negF flag) ctxt (ty1, ty1')
+          val arg2 = absrep_fun flag ctxt (ty2, ty2')
+        in
+          list_comb (get_mapfun ctxt "fun", [arg1, arg2])
+        end
+    | (Type (s, tys), Type (s', tys')) =>
         if s = s'
-        then result
-        else mk_compose flag (get_absrep_const flag ctxt rty qty, result)
-      end
-  | (TFree x, TFree x') =>
-      if x = x'
-      then mk_identity qty
-      else raise (LIFT_MATCH "absrep_fun (frees)")
-  | (TVar _, TVar _) => raise (LIFT_MATCH "absrep_fun (vars)")
-  | _ => raise (LIFT_MATCH "absrep_fun (default)")
+        then 
+           let
+             val args = map (absrep_fun flag ctxt) (tys ~~ tys')
+           in
+             list_comb (get_mapfun ctxt s, args)
+           end
+        else
+           let
+             val thy = ProofContext.theory_of ctxt
+             val (rty_pat, qty_pat as Type (_, vs)) = get_rty_qty ctxt s'
+             val rtyenv = Sign.typ_match thy (rty_pat, rty) Vartab.empty
+             val qtyenv = Sign.typ_match thy (qty_pat, qty) Vartab.empty
+             val args_aux = map (double_lookup rtyenv qtyenv) vs            
+             val args = map (absrep_fun flag ctxt) args_aux
+             val map_fun = mk_mapfun ctxt vs rty_pat       
+             val result = list_comb (map_fun, args) 
+           in
+             mk_compose flag (absrep_const flag ctxt s', result)
+           end 
+    | (TFree x, TFree x') =>
+        if x = x'
+        then mk_identity qty
+        else raise (LIFT_MATCH "absrep_fun (frees)")
+    | (TVar _, TVar _) => raise (LIFT_MATCH "absrep_fun (vars)")
+    | _ => 
+         let
+           val rty_str = Syntax.string_of_typ ctxt rty
+           val qty_str = Syntax.string_of_typ ctxt qty
+         in
+           raise (LIFT_MATCH ("absrep_fun (default) " ^ rty_str ^ " " ^ qty_str))
+         end
 
 fun absrep_fun_chk flag ctxt (rty, qty) =
+let
+  val rty_str = Syntax.string_of_typ ctxt rty
+  val qty_str = Syntax.string_of_typ ctxt qty
+  val _ = tracing "rty / qty"
+  val _ = tracing rty_str
+  val _ = tracing qty_str
+in
   absrep_fun flag ctxt (rty, qty)
   |> Syntax.check_term ctxt
-
+end
 
 (* Regularizing an rtrm means: