diff -r 7d3d86beacd6 -r f46dc0ca08c3 quotient_def.ML --- a/quotient_def.ML Fri Nov 20 13:03:01 2009 +0100 +++ b/quotient_def.ML Sat Nov 21 02:49:39 2009 +0100 @@ -2,7 +2,7 @@ signature QUOTIENT_DEF = sig datatype flag = absF | repF - val get_fun: flag -> (typ * typ) list -> Proof.context -> typ -> term + val get_fun: flag -> Proof.context -> typ * typ -> term val make_def: binding -> typ -> mixfix -> Attrib.binding -> term -> Proof.context -> (term * thm) * local_theory @@ -10,7 +10,6 @@ local_theory -> (term * thm) * local_theory val quotdef_cmd: (binding * string * mixfix) * (Attrib.binding * string) -> local_theory -> local_theory - val diff: (typ * typ) -> (typ * typ) list -> (typ * typ) list end; structure Quotient_Def: QUOTIENT_DEF = @@ -25,7 +24,6 @@ ((rhs, thm), lthy') end - (* calculates the aggregate abs and rep functions for a given type; repF is for constants' arguments; absF is for constants; function types need to be treated specially, since repF and absF @@ -36,99 +34,65 @@ fun negF absF = repF | negF repF = absF -fun get_fun flag qenv lthy ty = -let - - fun get_fun_aux s fs = - (case (maps_lookup (ProofContext.theory_of lthy) s) of - SOME info => list_comb (Const (#mapfun info, dummyT), fs) - | NONE => error ("no map association for type " ^ s)) - - fun get_const flag qty = - let - val thy = ProofContext.theory_of lthy - val qty_name = Long_Name.base_name (fst (dest_Type qty)) - in - case flag of - absF => Const (Sign.full_bname thy ("ABS_" ^ qty_name), dummyT) - | repF => Const (Sign.full_bname thy ("REP_" ^ qty_name), dummyT) - end +fun mk_identity ty = Const (@{const_name "id"}, ty --> ty) - fun mk_identity ty = Abs ("", ty, Bound 0) - +fun ty_lift_error lthy rty qty = +let + val rty_str = quote (Syntax.string_of_typ lthy rty) + val qty_str = quote (Syntax.string_of_typ lthy qty) + val msg = ["quotient type", qty_str, "and lifted type", rty_str, "do not match."] in - if (AList.defined (op=) qenv ty) - then (get_const flag ty) - else (case ty of - TFree _ => mk_identity ty - | Type (_, []) => mk_identity ty - | Type ("fun" , [ty1, ty2]) => - let - val fs_ty1 = get_fun (negF flag) qenv lthy ty1 - val fs_ty2 = get_fun flag qenv lthy ty2 - in - get_fun_aux "fun" [fs_ty1, fs_ty2] - end - | Type (s, tys) => get_fun_aux s (map (get_fun flag qenv lthy) tys) - | _ => error ("no type variables allowed")) + raise LIFT_MATCH (space_implode " " msg) end -(* returns all subterms where two types differ *) -fun diff (T, S) Ds = - case (T, S) of - (TVar v, TVar u) => if v = u then Ds else (T, S)::Ds - | (TFree x, TFree y) => if x = y then Ds else (T, S)::Ds - | (Type (a, Ts), Type (b, Us)) => - if a = b then diffs (Ts, Us) Ds else (T, S)::Ds - | _ => (T, S)::Ds -and diffs (T::Ts, U::Us) Ds = diffs (Ts, Us) (diff (T, U) Ds) - | diffs ([], []) Ds = Ds - | diffs _ _ = error "Unequal length of type arguments" +fun get_fun_aux lthy s fs = + case (maps_lookup (ProofContext.theory_of lthy) s) of + SOME info => list_comb (Const (#mapfun info, dummyT), fs) + | NONE => raise LIFT_MATCH ("no map association for type " ^ s) - -(* sanity check that the calculated quotient environment - matches with the stored quotient environment. *) -fun sanity_chk qenv lthy = -let - val global_qenv = Quotient_Info.mk_qenv lthy +fun get_const flag lthy _ qty = +(* FIXME: check here that _ and qty are related *) +let val thy = ProofContext.theory_of lthy + val qty_name = Long_Name.base_name (fst (dest_Type qty)) +in + case flag of + absF => Const (Sign.full_bname thy ("ABS_" ^ qty_name), dummyT) + | repF => Const (Sign.full_bname thy ("REP_" ^ qty_name), dummyT) +end - fun error_msg lthy (qty, rty) = - let - val qtystr = quote (Syntax.string_of_typ lthy qty) - val rtystr = quote (Syntax.string_of_typ lthy rty) - in - error (implode ["Quotient type ", qtystr, " does not match with ", rtystr]) - end - - fun is_inst (qty, rty) (qty', rty') = - if Sign.typ_instance thy (qty, qty') - then let - val inst = Sign.typ_match thy (qty', qty) Vartab.empty - in - rty = Envir.subst_type inst rty' - end - else false - - fun chk_inst (qty, rty) = - if exists (is_inst (qty, rty)) global_qenv - then true - else error_msg lthy (qty, rty) -in - map chk_inst qenv -end +fun get_fun flag lthy (rty, qty) = + case (rty, qty) of + (Type ("fun", [ty1, ty2]), Type ("fun", [ty1', ty2'])) => + let + val fs_ty1 = get_fun (negF flag) lthy (ty1, ty1') + val fs_ty2 = get_fun flag lthy (ty2, ty2') + in + get_fun_aux lthy "fun" [fs_ty1, fs_ty2] + end + | (Type (s, []), Type (s', [])) => + if s = s' + then mk_identity qty + else get_const flag lthy rty qty + | (Type (s, tys), Type (s', tys')) => + if s = s' + then get_fun_aux lthy s' (map (get_fun flag lthy) (tys ~~ tys')) + else get_const flag lthy rty qty + | (TFree x, TFree x') => + if x = x' + then mk_identity qty + else ty_lift_error lthy rty qty + | (TVar _, TVar _) => raise LIFT_MATCH "no type variables allowed" + | _ => ty_lift_error lthy rty qty fun make_def nconst_bname qty mx attr rhs lthy = let - val (arg_tys, res_ty) = strip_type qty - val rty = fastype_of rhs - val qenv = distinct (op=) (diff (qty, rty) []) - - val _ = sanity_chk qenv lthy - - val rep_fns = map (get_fun repF qenv lthy) arg_tys - val abs_fn = (get_fun absF qenv lthy) res_ty + val (arg_rtys, res_rty) = strip_type rty + val (arg_qtys, res_qty) = strip_type qty + + val rep_fns = map (get_fun repF lthy) (arg_rtys ~~ arg_qtys) + val abs_fn = get_fun absF lthy (res_rty, res_qty) fun mk_fun_map t s = Const (@{const_name "fun_map"}, dummyT) $ t $ s @@ -139,10 +103,9 @@ val ((trm, thm), lthy') = define nconst_bname mx attr absrep_trm lthy val nconst_str = Binding.name_of nconst_bname - val qcinfo = {qconst = trm, rconst = rhs} + fun qcinfo phi = qconsts_transfer phi {qconst = trm, rconst = rhs} val lthy'' = Local_Theory.declaration true - (fn phi => qconsts_update_generic nconst_str - (qconsts_transfer phi qcinfo)) lthy' + (fn phi => qconsts_update_gen nconst_str (qcinfo phi)) lthy' in ((trm, thm), lthy'') end @@ -186,3 +149,5 @@ end; (* structure *) open Quotient_Def; + +